Skip to main content

agent_diva_manager/
server.rs

1use axum::{
2    extract::DefaultBodyLimit,
3    routing::{delete, get, post},
4    Router,
5};
6use std::net::SocketAddr;
7use tokio::sync::{broadcast, watch};
8use tower_http::cors::CorsLayer;
9use tower_http::trace::TraceLayer;
10
11use crate::handlers::{
12    add_provider_model_handler, chat_handler, create_cron_job_handler, create_mcp_handler,
13    create_provider_handler, delete_cron_job_handler, delete_mcp_handler, delete_provider_handler,
14    delete_provider_model_handler, delete_session_handler, delete_skill_handler, events_handler,
15    get_channels_handler, get_config_handler, get_cron_job_handler, get_mcps_handler,
16    get_provider_handler, get_provider_models_handler, get_providers_handler,
17    get_session_history_handler, get_sessions_handler, get_skills_handler, get_tools_handler,
18    heartbeat_handler, list_cron_jobs_handler, refresh_mcp_status_handler, reset_session_handler,
19    resolve_provider_handler, run_cron_job_handler, set_cron_job_enabled_handler,
20    set_mcp_enabled_handler, stop_chat_handler, stop_cron_job_handler, update_channel_handler,
21    update_config_handler, update_cron_job_handler, update_mcp_handler, update_provider_handler,
22    update_tools_handler, upload_file_handler, upload_skill_handler,
23};
24use crate::state::AppState;
25
26pub async fn run_server(
27    state: AppState,
28    port: u16,
29    mut shutdown_rx: broadcast::Receiver<()>,
30) -> anyhow::Result<()> {
31    let addr = SocketAddr::from(([127, 0, 0, 1], port));
32    tracing::info!("Listening on {}", addr);
33
34    let listener = tokio::net::TcpListener::bind(addr).await?;
35    run_server_with_listener_and_signal(state, listener, async move {
36        let _ = shutdown_rx.recv().await;
37        tracing::info!("Server shutting down signal received");
38    })
39    .await
40}
41
42pub async fn run_server_with_listener(
43    state: AppState,
44    listener: tokio::net::TcpListener,
45    mut shutdown_rx: watch::Receiver<bool>,
46) -> anyhow::Result<()> {
47    let listen_addr = listener.local_addr()?;
48    tracing::info!("Listening on {}", listen_addr);
49
50    run_server_with_listener_and_signal(state, listener, async move {
51        let _ = shutdown_rx.wait_for(|value| *value).await;
52        tracing::info!("Server shutting down signal received");
53    })
54    .await
55}
56
57async fn run_server_with_listener_and_signal<F>(
58    state: AppState,
59    listener: tokio::net::TcpListener,
60    shutdown_signal: F,
61) -> anyhow::Result<()>
62where
63    F: std::future::Future<Output = ()> + Send + 'static,
64{
65    let app = build_router(state);
66    axum::serve(listener, app)
67        .with_graceful_shutdown(async move {
68            shutdown_signal.await;
69        })
70        .await?;
71    Ok(())
72}
73
74/// Build the axum router with all manager HTTP routes.
75pub fn build_router(state: AppState) -> Router {
76    Router::new()
77        .merge(runtime_routes())
78        .merge(provider_routes())
79        .merge(misc_routes())
80        .layer(CorsLayer::permissive())
81        .layer(TraceLayer::new_for_http())
82        .with_state(state)
83}
84
85fn runtime_routes() -> Router<AppState> {
86    Router::new()
87        .route("/api/chat", post(chat_handler))
88        .route("/api/chat/stop", post(stop_chat_handler))
89        .route("/api/events", get(events_handler))
90        .route("/api/sessions", get(get_sessions_handler))
91        .route(
92            "/api/sessions/:id",
93            get(get_session_history_handler)
94                .delete(delete_session_handler)
95                .post(delete_session_handler),
96        )
97        .route("/api/sessions/reset", post(reset_session_handler))
98        .route(
99            "/api/config",
100            get(get_config_handler).post(update_config_handler),
101        )
102        .route(
103            "/api/channels",
104            get(get_channels_handler).post(update_channel_handler),
105        )
106        .route(
107            "/api/tools",
108            get(get_tools_handler).post(update_tools_handler),
109        )
110        .route(
111            "/api/skills",
112            get(get_skills_handler).post(upload_skill_handler),
113        )
114        .route("/api/skills/:name", delete(delete_skill_handler))
115        .route(
116            "/api/files/upload",
117            post(upload_file_handler).layer(DefaultBodyLimit::max(50 * 1024 * 1024)),
118        ) // 50MB limit
119        .route("/api/mcps", get(get_mcps_handler).post(create_mcp_handler))
120        .route(
121            "/api/mcps/:name",
122            axum::routing::put(update_mcp_handler).delete(delete_mcp_handler),
123        )
124        .route("/api/mcps/:name/enable", post(set_mcp_enabled_handler))
125        .route("/api/mcps/:name/refresh", post(refresh_mcp_status_handler))
126        .route(
127            "/api/cron/jobs",
128            get(list_cron_jobs_handler).post(create_cron_job_handler),
129        )
130        .route(
131            "/api/cron/jobs/:id",
132            get(get_cron_job_handler)
133                .put(update_cron_job_handler)
134                .delete(delete_cron_job_handler),
135        )
136        .route(
137            "/api/cron/jobs/:id/enable",
138            post(set_cron_job_enabled_handler),
139        )
140        .route("/api/cron/jobs/:id/run", post(run_cron_job_handler))
141        .route("/api/cron/jobs/:id/stop", post(stop_cron_job_handler))
142}
143
144fn provider_routes() -> Router<AppState> {
145    Router::new()
146        .route(
147            "/api/providers",
148            get(get_providers_handler).post(create_provider_handler),
149        )
150        .route("/api/providers/resolve", post(resolve_provider_handler))
151        .route(
152            "/api/providers/:name",
153            get(get_provider_handler)
154                .put(update_provider_handler)
155                .delete(delete_provider_handler),
156        )
157        .route(
158            "/api/providers/:name/models",
159            get(get_provider_models_handler).post(add_provider_model_handler),
160        )
161        .route(
162            "/api/providers/:name/models/:model_id",
163            delete(delete_provider_model_handler),
164        )
165}
166
167fn misc_routes() -> Router<AppState> {
168    Router::new().route("/api/health", get(heartbeat_handler))
169}
170
171#[cfg(test)]
172mod tests {
173    use super::build_router;
174    use axum::body::Body;
175    use axum::http::{Request, StatusCode};
176    use tower::util::ServiceExt;
177
178    use crate::state::{AppState, ManagerCommand};
179
180    #[tokio::test]
181    async fn build_router_keeps_health_and_skills_routes_without_overlap() {
182        let (api_tx, mut api_rx) = tokio::sync::mpsc::channel(1);
183        // get_skills_handler sends GetSkills on api_tx and awaits a oneshot reply; without a
184        // consumer the request would hang forever.
185        tokio::spawn(async move {
186            while let Some(cmd) = api_rx.recv().await {
187                if let ManagerCommand::GetSkills(tx) = cmd {
188                    let _ = tx.send(Ok(vec![]));
189                }
190            }
191        });
192        let state = AppState {
193            api_tx,
194            bus: agent_diva_core::bus::MessageBus::new(),
195        };
196
197        let app = build_router(state.clone());
198
199        let health_response = app
200            .clone()
201            .oneshot(
202                Request::builder()
203                    .uri("/api/health")
204                    .body(Body::empty())
205                    .unwrap(),
206            )
207            .await
208            .unwrap();
209        assert_eq!(health_response.status(), StatusCode::OK);
210
211        let skills_response = app
212            .oneshot(
213                Request::builder()
214                    .uri("/api/skills")
215                    .body(Body::empty())
216                    .unwrap(),
217            )
218            .await
219            .unwrap();
220        assert_eq!(skills_response.status(), StatusCode::OK);
221    }
222}