1use axum::{
2 extract::DefaultBodyLimit,
3 routing::{delete, get, post},
4 Router,
5};
6use std::net::SocketAddr;
7use tokio::sync::{broadcast, watch};
8use tower_http::cors::CorsLayer;
9use tower_http::trace::TraceLayer;
10
11use crate::handlers::{
12 add_provider_model_handler, chat_handler, create_cron_job_handler, create_mcp_handler,
13 create_provider_handler, delete_cron_job_handler, delete_mcp_handler, delete_provider_handler,
14 delete_provider_model_handler, delete_session_handler, delete_skill_handler, events_handler,
15 get_channels_handler, get_config_handler, get_cron_job_handler, get_mcps_handler,
16 get_provider_handler, get_provider_models_handler, get_providers_handler,
17 get_session_history_handler, get_sessions_handler, get_skills_handler, get_tools_handler,
18 heartbeat_handler, list_cron_jobs_handler, refresh_mcp_status_handler, reset_session_handler,
19 resolve_provider_handler, run_cron_job_handler, set_cron_job_enabled_handler,
20 set_mcp_enabled_handler, stop_chat_handler, stop_cron_job_handler, update_channel_handler,
21 update_config_handler, update_cron_job_handler, update_mcp_handler, update_provider_handler,
22 update_tools_handler, upload_file_handler, upload_skill_handler,
23};
24use crate::state::AppState;
25
26pub async fn run_server(
27 state: AppState,
28 port: u16,
29 mut shutdown_rx: broadcast::Receiver<()>,
30) -> anyhow::Result<()> {
31 let addr = SocketAddr::from(([127, 0, 0, 1], port));
32 tracing::info!("Listening on {}", addr);
33
34 let listener = tokio::net::TcpListener::bind(addr).await?;
35 run_server_with_listener_and_signal(state, listener, async move {
36 let _ = shutdown_rx.recv().await;
37 tracing::info!("Server shutting down signal received");
38 })
39 .await
40}
41
42pub async fn run_server_with_listener(
43 state: AppState,
44 listener: tokio::net::TcpListener,
45 mut shutdown_rx: watch::Receiver<bool>,
46) -> anyhow::Result<()> {
47 let listen_addr = listener.local_addr()?;
48 tracing::info!("Listening on {}", listen_addr);
49
50 run_server_with_listener_and_signal(state, listener, async move {
51 let _ = shutdown_rx.wait_for(|value| *value).await;
52 tracing::info!("Server shutting down signal received");
53 })
54 .await
55}
56
57async fn run_server_with_listener_and_signal<F>(
58 state: AppState,
59 listener: tokio::net::TcpListener,
60 shutdown_signal: F,
61) -> anyhow::Result<()>
62where
63 F: std::future::Future<Output = ()> + Send + 'static,
64{
65 let app = build_router(state);
66 axum::serve(listener, app)
67 .with_graceful_shutdown(async move {
68 shutdown_signal.await;
69 })
70 .await?;
71 Ok(())
72}
73
74pub fn build_router(state: AppState) -> Router {
76 Router::new()
77 .merge(runtime_routes())
78 .merge(provider_routes())
79 .merge(misc_routes())
80 .layer(CorsLayer::permissive())
81 .layer(TraceLayer::new_for_http())
82 .with_state(state)
83}
84
85fn runtime_routes() -> Router<AppState> {
86 Router::new()
87 .route("/api/chat", post(chat_handler))
88 .route("/api/chat/stop", post(stop_chat_handler))
89 .route("/api/events", get(events_handler))
90 .route("/api/sessions", get(get_sessions_handler))
91 .route(
92 "/api/sessions/:id",
93 get(get_session_history_handler)
94 .delete(delete_session_handler)
95 .post(delete_session_handler),
96 )
97 .route("/api/sessions/reset", post(reset_session_handler))
98 .route(
99 "/api/config",
100 get(get_config_handler).post(update_config_handler),
101 )
102 .route(
103 "/api/channels",
104 get(get_channels_handler).post(update_channel_handler),
105 )
106 .route(
107 "/api/tools",
108 get(get_tools_handler).post(update_tools_handler),
109 )
110 .route(
111 "/api/skills",
112 get(get_skills_handler).post(upload_skill_handler),
113 )
114 .route("/api/skills/:name", delete(delete_skill_handler))
115 .route(
116 "/api/files/upload",
117 post(upload_file_handler).layer(DefaultBodyLimit::max(50 * 1024 * 1024)),
118 ) .route("/api/mcps", get(get_mcps_handler).post(create_mcp_handler))
120 .route(
121 "/api/mcps/:name",
122 axum::routing::put(update_mcp_handler).delete(delete_mcp_handler),
123 )
124 .route("/api/mcps/:name/enable", post(set_mcp_enabled_handler))
125 .route("/api/mcps/:name/refresh", post(refresh_mcp_status_handler))
126 .route(
127 "/api/cron/jobs",
128 get(list_cron_jobs_handler).post(create_cron_job_handler),
129 )
130 .route(
131 "/api/cron/jobs/:id",
132 get(get_cron_job_handler)
133 .put(update_cron_job_handler)
134 .delete(delete_cron_job_handler),
135 )
136 .route(
137 "/api/cron/jobs/:id/enable",
138 post(set_cron_job_enabled_handler),
139 )
140 .route("/api/cron/jobs/:id/run", post(run_cron_job_handler))
141 .route("/api/cron/jobs/:id/stop", post(stop_cron_job_handler))
142}
143
144fn provider_routes() -> Router<AppState> {
145 Router::new()
146 .route(
147 "/api/providers",
148 get(get_providers_handler).post(create_provider_handler),
149 )
150 .route("/api/providers/resolve", post(resolve_provider_handler))
151 .route(
152 "/api/providers/:name",
153 get(get_provider_handler)
154 .put(update_provider_handler)
155 .delete(delete_provider_handler),
156 )
157 .route(
158 "/api/providers/:name/models",
159 get(get_provider_models_handler).post(add_provider_model_handler),
160 )
161 .route(
162 "/api/providers/:name/models/:model_id",
163 delete(delete_provider_model_handler),
164 )
165}
166
167fn misc_routes() -> Router<AppState> {
168 Router::new().route("/api/health", get(heartbeat_handler))
169}
170
171#[cfg(test)]
172mod tests {
173 use super::build_router;
174 use axum::body::Body;
175 use axum::http::{Request, StatusCode};
176 use tower::util::ServiceExt;
177
178 use crate::state::{AppState, ManagerCommand};
179
180 #[tokio::test]
181 async fn build_router_keeps_health_and_skills_routes_without_overlap() {
182 let (api_tx, mut api_rx) = tokio::sync::mpsc::channel(1);
183 tokio::spawn(async move {
186 while let Some(cmd) = api_rx.recv().await {
187 if let ManagerCommand::GetSkills(tx) = cmd {
188 let _ = tx.send(Ok(vec![]));
189 }
190 }
191 });
192 let state = AppState {
193 api_tx,
194 bus: agent_diva_core::bus::MessageBus::new(),
195 };
196
197 let app = build_router(state.clone());
198
199 let health_response = app
200 .clone()
201 .oneshot(
202 Request::builder()
203 .uri("/api/health")
204 .body(Body::empty())
205 .unwrap(),
206 )
207 .await
208 .unwrap();
209 assert_eq!(health_response.status(), StatusCode::OK);
210
211 let skills_response = app
212 .oneshot(
213 Request::builder()
214 .uri("/api/skills")
215 .body(Body::empty())
216 .unwrap(),
217 )
218 .await
219 .unwrap();
220 assert_eq!(skills_response.status(), StatusCode::OK);
221 }
222}