use axum::{
extract::Json,
http::StatusCode,
response::{
sse::{Event, Sse},
IntoResponse, Response,
},
routing::{get, post},
Extension, Router,
};
use futures::Stream;
use serde_json::json;
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{error, info};
use tracing_subscriber;
use crate::agent::{AgentRegistry, MCPError, MCPMessage};
use crate::auth::AuthConfig;
use crate::conversation::ConversationManager;
#[derive(Clone)]
pub struct AppState {
registry: Arc<RwLock<AgentRegistry>>,
#[allow(dead_code)]
auth_config: Option<AuthConfig>,
conversation_manager: Option<Arc<ConversationManager>>,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct ErrorResponse {
error: String,
}
impl IntoResponse for MCPError {
fn into_response(self) -> Response {
let body = Json(ErrorResponse {
error: self.to_string(),
});
(StatusCode::BAD_REQUEST, body).into_response()
}
}
pub async fn run_http_server(registry: AgentRegistry, addr: SocketAddr) {
tracing_subscriber::fmt::init();
let app_state = AppState {
registry: Arc::new(RwLock::new(registry)),
auth_config: None,
conversation_manager: None,
};
let app = Router::new()
.route("/mcp", post(handle_mcp))
.route("/health", get(|| async { "OK" }))
.with_state(app_state);
info!("Servidor MCP rodando em {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
pub async fn run_http_server_with_auth(
registry: AgentRegistry,
auth_config: AuthConfig,
conversation_manager: ConversationManager,
addr: SocketAddr,
) {
tracing_subscriber::fmt::init();
let app_state = AppState {
registry: Arc::new(RwLock::new(registry)),
auth_config: Some(auth_config.clone()),
conversation_manager: Some(Arc::new(conversation_manager)),
};
let app = Router::new()
.route("/mcp", post(handle_mcp))
.route("/mcp/stream", get(handle_stream_mcp))
.route("/conversation", post(create_conversation))
.route("/conversation/:id", get(get_conversation))
.route("/health", get(|| async { "OK" }))
.with_state(app_state)
.layer(Extension(auth_config));
info!("Servidor MCP avançado rodando em {}", addr);
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await
.unwrap();
}
async fn handle_mcp(
axum::extract::State(state): axum::extract::State<AppState>,
Json(payload): Json<MCPMessage>,
) -> Result<Json<MCPMessage>, MCPError> {
if payload.magic != "MCP0" {
error!("Magic inválido: {}", payload.magic);
return Ok(Json(MCPMessage::new(
"error",
json!({"message": "Magic inválido"}),
)));
}
let response = {
let reg = state.registry.read().await;
reg.process(payload).await?
};
Ok(Json(response))
}
async fn handle_stream_mcp(
axum::extract::State(state): axum::extract::State<AppState>,
Json(payload): Json<MCPMessage>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
tokio::spawn(async move {
if payload.magic != "MCP0" {
let _ = tx
.send(Ok(Event::default().data("Error: Invalid magic")))
.await;
return;
}
let reg = state.registry.read().await;
match reg.process(payload).await {
Ok(response) => {
let _ = tx
.send(Ok(
Event::default().data(serde_json::to_string(&response).unwrap_or_default())
))
.await;
}
Err(error) => {
let _ = tx
.send(Ok(Event::default().data(format!("Error: {}", error))))
.await;
}
}
});
Sse::new(ReceiverStream::new(rx))
}
async fn create_conversation(
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl IntoResponse {
if let Some(ref conversation_manager) = state.conversation_manager {
match conversation_manager.create_conversation() {
Ok(conversation) => (
StatusCode::CREATED,
Json(json!({
"conversation_id": conversation.id,
"created_at": conversation.created_at.elapsed().unwrap_or_default().as_secs()
})),
),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": e })),
),
}
} else {
(
StatusCode::NOT_IMPLEMENTED,
Json(json!({ "error": "Gerenciamento de conversas não está habilitado" })),
)
}
}
async fn get_conversation(
axum::extract::State(state): axum::extract::State<AppState>,
axum::extract::Path(id): axum::extract::Path<String>,
) -> impl IntoResponse {
if let Some(ref conversation_manager) = state.conversation_manager {
match conversation_manager.get_conversation(&id) {
Some(conversation) => {
let messages: Vec<_> = conversation
.messages
.iter()
.map(|msg| {
json!({
"role": msg.role,
"content": msg.content,
"timestamp": msg.timestamp.elapsed().unwrap_or_default().as_secs()
})
})
.collect();
(
StatusCode::OK,
Json(json!({
"conversation_id": conversation.id,
"messages": messages,
"metadata": conversation.metadata,
"created_at": conversation.created_at.elapsed().unwrap_or_default().as_secs(),
"updated_at": conversation.updated_at.elapsed().unwrap_or_default().as_secs()
})),
)
}
None => (
StatusCode::NOT_FOUND,
Json(json!({ "error": "Conversa não encontrada" })),
),
}
} else {
(
StatusCode::NOT_IMPLEMENTED,
Json(json!({ "error": "Gerenciamento de conversas não está habilitado" })),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::DummyAgent;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use serde_json::json;
use tower::ServiceExt;
async fn build_test_app() -> Router {
let mut registry = AgentRegistry::new();
registry.register_agent(Box::new(DummyAgent {
api_key: "test_key".to_string(),
}));
let app_state = AppState {
registry: Arc::new(RwLock::new(registry)),
auth_config: None,
conversation_manager: None,
};
Router::new()
.route("/mcp", post(handle_mcp))
.with_state(app_state)
}
#[tokio::test]
async fn test_handle_mcp_valid_request() {
let app = build_test_app().await;
let message = MCPMessage::new("dummy:test", json!({"test": "value"}));
let request = Request::builder()
.uri("/mcp")
.method("POST")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&message).unwrap()))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
let response_message: MCPMessage = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_message.command, "dummy_response");
assert_eq!(response_message.payload, json!({"test": "value"}));
}
#[tokio::test]
async fn test_handle_mcp_invalid_magic() {
let app = build_test_app().await;
let mut message = MCPMessage::new("dummy:test", json!({"test": "value"}));
message.magic = "INVALID".to_string();
let request = Request::builder()
.uri("/mcp")
.method("POST")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&message).unwrap()))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body_bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
let response_message: MCPMessage = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(response_message.command, "error");
assert!(response_message.payload["message"]
.as_str()
.unwrap()
.contains("inválido"));
}
#[tokio::test]
async fn test_handle_mcp_agent_not_found() {
let app = build_test_app().await;
let message = MCPMessage::new("nonexistent:test", json!({"test": "value"}));
let request = Request::builder()
.uri("/mcp")
.method("POST")
.header("Content-Type", "application/json")
.body(Body::from(serde_json::to_string(&message).unwrap()))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body_bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
let error_response: ErrorResponse = serde_json::from_slice(&body_bytes).unwrap();
assert!(error_response.error.contains("não foi encontrado"));
}
}