1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use axum::{
6 extract::{Path, State},
7 http::{HeaderMap, StatusCode},
8 response::{
9 sse::{Event, KeepAlive, Sse},
10 IntoResponse, Json,
11 },
12 routing::{delete, get, post},
13 Router,
14};
15use serde::{Deserialize, Serialize};
16use tokio::sync::RwLock;
17use tokio_util::sync::CancellationToken;
18
19use rho_core::agent_loop::{agent_loop, AgentLoopConfig};
20use rho_core::compaction;
21use rho_core::models::{ModelConfig, ModelRegistry};
22use rho_core::tool::AgentTool;
23use rho_core::types::*;
24
25#[derive(Clone)]
28pub struct ServerConfig {
29 pub model_config: ModelConfig,
30 pub api_key: String,
31 pub system_prompt: String,
32 pub tools_factory: Arc<dyn Fn() -> Vec<Arc<dyn AgentTool>> + Send + Sync>,
33 pub thinking: ThinkingLevel,
34 pub bearer_token: Option<String>,
35 pub compact_threshold: Option<f64>,
36 pub cwd: PathBuf,
37}
38
39struct SessionHandle {
40 messages: Vec<Message>,
41 model_id: String,
42 cancel: CancellationToken,
43 created_at: u64,
44}
45
46struct AppState {
47 sessions: RwLock<HashMap<String, SessionHandle>>,
48 config: ServerConfig,
49}
50
51fn check_bearer_token(state: &AppState, headers: &HeaderMap) -> Result<(), StatusCode> {
52 if let Some(ref expected) = state.config.bearer_token {
53 let auth = headers
54 .get("authorization")
55 .and_then(|v| v.to_str().ok())
56 .and_then(|v| v.strip_prefix("Bearer "));
57 match auth {
58 Some(token) if token == expected => Ok(()),
59 _ => Err(StatusCode::UNAUTHORIZED),
60 }
61 } else {
62 Ok(())
63 }
64}
65
66macro_rules! require_auth {
69 ($state:expr, $headers:expr) => {
70 if let Err(status) = check_bearer_token(&$state, &$headers) {
71 return Err(status);
72 }
73 };
74}
75
76async fn health() -> &'static str {
79 "ok"
80}
81
82#[derive(Deserialize)]
83struct CreateSessionRequest {
84 model: Option<String>,
85}
86
87#[derive(Serialize)]
88struct SessionInfo {
89 id: String,
90 model: String,
91 message_count: usize,
92 created_at: u64,
93}
94
95async fn create_session(
96 State(state): State<Arc<AppState>>,
97 headers: HeaderMap,
98 Json(body): Json<CreateSessionRequest>,
99) -> Result<Json<SessionInfo>, StatusCode> {
100 require_auth!(state, headers);
101
102 let id = uuid::Uuid::new_v4().to_string();
103 let model_id = body
104 .model
105 .unwrap_or_else(|| state.config.model_config.id.clone());
106 let now = std::time::SystemTime::now()
107 .duration_since(std::time::UNIX_EPOCH)
108 .unwrap()
109 .as_millis() as u64;
110
111 let handle = SessionHandle {
112 messages: Vec::new(),
113 model_id: model_id.clone(),
114 cancel: CancellationToken::new(),
115 created_at: now,
116 };
117
118 let info = SessionInfo {
119 id: id.clone(),
120 model: model_id,
121 message_count: 0,
122 created_at: now,
123 };
124
125 state.sessions.write().await.insert(id, handle);
126 Ok(Json(info))
127}
128
129async fn list_sessions(
130 State(state): State<Arc<AppState>>,
131 headers: HeaderMap,
132) -> Result<Json<Vec<SessionInfo>>, StatusCode> {
133 require_auth!(state, headers);
134
135 let sessions = state.sessions.read().await;
136 let list: Vec<SessionInfo> = sessions
137 .iter()
138 .map(|(id, h)| SessionInfo {
139 id: id.clone(),
140 model: h.model_id.clone(),
141 message_count: h.messages.len(),
142 created_at: h.created_at,
143 })
144 .collect();
145 Ok(Json(list))
146}
147
148async fn get_session(
149 State(state): State<Arc<AppState>>,
150 headers: HeaderMap,
151 Path(id): Path<String>,
152) -> Result<Json<SessionInfo>, StatusCode> {
153 require_auth!(state, headers);
154
155 let sessions = state.sessions.read().await;
156 let handle = sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?;
157 Ok(Json(SessionInfo {
158 id: id.clone(),
159 model: handle.model_id.clone(),
160 message_count: handle.messages.len(),
161 created_at: handle.created_at,
162 }))
163}
164
165async fn delete_session_handler(
166 State(state): State<Arc<AppState>>,
167 headers: HeaderMap,
168 Path(id): Path<String>,
169) -> Result<StatusCode, StatusCode> {
170 require_auth!(state, headers);
171
172 let mut sessions = state.sessions.write().await;
173 if let Some(handle) = sessions.remove(&id) {
174 handle.cancel.cancel();
175 Ok(StatusCode::NO_CONTENT)
176 } else {
177 Err(StatusCode::NOT_FOUND)
178 }
179}
180
181#[derive(Deserialize)]
182struct SendMessageRequest {
183 message: String,
184}
185
186async fn send_message(
187 State(state): State<Arc<AppState>>,
188 headers: HeaderMap,
189 Path(id): Path<String>,
190 Json(body): Json<SendMessageRequest>,
191) -> Result<impl IntoResponse, StatusCode> {
192 require_auth!(state, headers);
193
194 let (messages, cancel) = {
195 let mut sessions = state.sessions.write().await;
196 let handle = sessions.get_mut(&id).ok_or(StatusCode::NOT_FOUND)?;
197
198 let now = std::time::SystemTime::now()
199 .duration_since(std::time::UNIX_EPOCH)
200 .unwrap()
201 .as_millis() as u64;
202
203 handle.messages.push(Message::User {
204 content: UserContent::Text(body.message),
205 timestamp: now,
206 });
207
208 let cancel = CancellationToken::new();
209 handle.cancel = cancel.clone();
210 (handle.messages.clone(), cancel)
211 };
212
213 let config = &state.config;
214 let model = ModelRegistry::to_model(&config.model_config);
215 let tools = (config.tools_factory)();
216 let transform_messages = config
217 .compact_threshold
218 .map(compaction::make_compaction_transform);
219
220 let loop_config = AgentLoopConfig {
221 model,
222 api_key: config.api_key.clone(),
223 system_prompt: config.system_prompt.clone(),
224 tools,
225 thinking: config.thinking,
226 max_tokens: None,
227 stream_fn: rho_provider::stream_fn_for_model(&config.model_config),
228 get_steering_messages: None,
229 get_follow_up_messages: None,
230 transform_messages,
231 post_tools_hooks: Vec::new(),
232 };
233
234 let mut consumer = agent_loop(messages, loop_config, cancel);
235
236 let state_clone = state.clone();
237 let session_id = id.clone();
238
239 let stream = async_stream::stream! {
240 while let Some(event) = consumer.next().await {
241 let sse = match &event {
242 AgentEvent::MessageUpdate { event: stream_event, .. } => match stream_event {
243 AssistantStreamEvent::TextDelta { delta, .. } => {
244 Some(Event::default()
245 .event("text_delta")
246 .json_data(serde_json::json!({"text": delta})))
247 }
248 _ => None,
249 },
250 AgentEvent::ToolExecutionStart { tool_call_id, tool_name, args } => {
251 Some(Event::default()
252 .event("tool_start")
253 .json_data(serde_json::json!({
254 "tool_id": tool_call_id,
255 "tool_name": tool_name,
256 "args": args
257 })))
258 }
259 AgentEvent::ToolExecutionEnd { tool_call_id, tool_name, is_error, .. } => {
260 Some(Event::default()
261 .event("tool_end")
262 .json_data(serde_json::json!({
263 "tool_id": tool_call_id,
264 "tool_name": tool_name,
265 "is_error": is_error
266 })))
267 }
268 AgentEvent::TurnStart => {
269 Some(Event::default()
270 .event("turn_start")
271 .json_data(serde_json::json!({})))
272 }
273 AgentEvent::TurnEnd { .. } => {
274 Some(Event::default()
275 .event("turn_end")
276 .json_data(serde_json::json!({})))
277 }
278 AgentEvent::ContextCompacted { original_estimate, compacted_estimate, .. } => {
279 Some(Event::default()
280 .event("context_compacted")
281 .json_data(serde_json::json!({
282 "original_estimate": original_estimate,
283 "compacted_estimate": compacted_estimate
284 })))
285 }
286 AgentEvent::AgentEnd { messages } => {
287 let mut sessions = state_clone.sessions.write().await;
289 if let Some(handle) = sessions.get_mut(&session_id) {
290 handle.messages = messages.clone();
291 }
292 Some(Event::default()
293 .event("done")
294 .json_data(serde_json::json!({"message_count": messages.len()})))
295 }
296 _ => None,
297 };
298
299 if let Some(Ok(e)) = sse {
300 yield Ok::<_, std::convert::Infallible>(e);
301 }
302 }
303 };
304
305 Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
306}
307
308async fn get_events(
309 State(state): State<Arc<AppState>>,
310 headers: HeaderMap,
311 Path(id): Path<String>,
312) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
313 require_auth!(state, headers);
314
315 let sessions = state.sessions.read().await;
316 let handle = sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?;
317
318 let events: Vec<serde_json::Value> = handle
319 .messages
320 .iter()
321 .map(|msg| serde_json::to_value(msg).unwrap_or_default())
322 .collect();
323
324 Ok(Json(events))
325}
326
327pub async fn start_server_with_addr(
330 config: ServerConfig,
331 host: &str,
332 port: u16,
333) -> anyhow::Result<()> {
334 let state = Arc::new(AppState {
335 sessions: RwLock::new(HashMap::new()),
336 config,
337 });
338
339 let app = Router::new()
340 .route("/health", get(health))
341 .route("/v1/sessions", post(create_session))
342 .route("/v1/sessions", get(list_sessions))
343 .route("/v1/sessions/{id}", get(get_session))
344 .route("/v1/sessions/{id}", delete(delete_session_handler))
345 .route("/v1/sessions/{id}/send", post(send_message))
346 .route("/v1/sessions/{id}/events", get(get_events))
347 .layer(tower_http::cors::CorsLayer::permissive())
348 .with_state(state);
349
350 let addr = format!("{host}:{port}");
351 tracing::info!("Starting rho server on {}", addr);
352 let listener = tokio::net::TcpListener::bind(&addr).await?;
353 axum::serve(listener, app).await?;
354 Ok(())
355}