use std::sync::Arc;
use axum::{
routing::{post, get},
Router,
Json,
extract::{State, ws::{WebSocketUpgrade, WebSocket, Message}},
};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use tower_http::cors::CorsLayer;
use crate::traits::Result;
use super::orchestrator::{SwarmOrchestrator, SwarmPattern};
use super::patterns::SwarmResult;
#[derive(Clone)]
pub struct GatewayState {
pub orchestrator: Arc<SwarmOrchestrator>,
}
#[derive(Deserialize)]
pub struct ExecuteRequest {
pub pattern: SwarmPattern,
pub input: String,
}
#[derive(Serialize)]
pub struct ExecuteResponse {
pub status: String,
pub result: Option<SwarmResult>,
pub error: Option<String>,
}
pub fn create_router(orchestrator: Arc<SwarmOrchestrator>) -> Router {
let state = GatewayState { orchestrator };
Router::new()
.route("/api/swarm/execute", post(handle_execute))
.route("/api/swarm/stream", get(handle_stream))
.layer(CorsLayer::permissive())
.with_state(state)
}
async fn handle_execute(
State(state): State<GatewayState>,
Json(req): Json<ExecuteRequest>,
) -> Json<ExecuteResponse> {
match state.orchestrator.execute(req.pattern, &req.input).await {
Ok(res) => Json(ExecuteResponse {
status: "success".into(),
result: Some(res),
error: None,
}),
Err(e) => Json(ExecuteResponse {
status: "error".into(),
result: None,
error: Some(e.to_string()),
}),
}
}
async fn handle_stream(ws: WebSocketUpgrade) -> axum::response::Response {
ws.on_upgrade(|socket| async move {
handle_websocket(socket).await;
})
}
async fn handle_websocket(mut socket: WebSocket) {
if let Err(e) = socket.send(Message::Text("Connected to Cerebro Swarm Stream".into())).await {
eprintln!("Websocket error: {}", e);
return;
}
while let Some(msg) = socket.recv().await {
if let Ok(_) = msg {
let _ = socket.send(Message::Text("Streaming not fully implemented. Check trace logs.".into())).await;
} else {
break;
}
}
}