Skip to main content

rho_server/
lib.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use axum::{
6    extract::{Path, State},
7    http::{HeaderMap, StatusCode},
8    response::{
9        sse::{Event, KeepAlive, Sse},
10        IntoResponse, Json,
11    },
12    routing::{delete, get, post},
13    Router,
14};
15use serde::{Deserialize, Serialize};
16use tokio::sync::RwLock;
17use tokio_util::sync::CancellationToken;
18
19use rho_core::agent_loop::{agent_loop, AgentLoopConfig};
20use rho_core::compaction;
21use rho_core::models::{ModelConfig, ModelRegistry};
22use rho_core::tool::AgentTool;
23use rho_core::types::*;
24
25// === Config & State ===
26
27#[derive(Clone)]
28pub struct ServerConfig {
29    pub model_config: ModelConfig,
30    pub api_key: String,
31    pub system_prompt: String,
32    pub tools_factory: Arc<dyn Fn() -> Vec<Arc<dyn AgentTool>> + Send + Sync>,
33    pub thinking: ThinkingLevel,
34    pub bearer_token: Option<String>,
35    pub compact_threshold: Option<f64>,
36    pub cwd: PathBuf,
37}
38
39struct SessionHandle {
40    messages: Vec<Message>,
41    model_id: String,
42    cancel: CancellationToken,
43    created_at: u64,
44}
45
46struct AppState {
47    sessions: RwLock<HashMap<String, SessionHandle>>,
48    config: ServerConfig,
49}
50
51fn check_bearer_token(state: &AppState, headers: &HeaderMap) -> Result<(), StatusCode> {
52    if let Some(ref expected) = state.config.bearer_token {
53        let auth = headers
54            .get("authorization")
55            .and_then(|v| v.to_str().ok())
56            .and_then(|v| v.strip_prefix("Bearer "));
57        match auth {
58            Some(token) if token == expected => Ok(()),
59            _ => Err(StatusCode::UNAUTHORIZED),
60        }
61    } else {
62        Ok(())
63    }
64}
65
66/// Helper macro to check auth at the start of each handler.
67/// Health endpoint is unauthenticated.
68macro_rules! require_auth {
69    ($state:expr, $headers:expr) => {
70        if let Err(status) = check_bearer_token(&$state, &$headers) {
71            return Err(status);
72        }
73    };
74}
75
76// === Route handlers ===
77
78async fn health() -> &'static str {
79    "ok"
80}
81
82#[derive(Deserialize)]
83struct CreateSessionRequest {
84    model: Option<String>,
85}
86
87#[derive(Serialize)]
88struct SessionInfo {
89    id: String,
90    model: String,
91    message_count: usize,
92    created_at: u64,
93}
94
95async fn create_session(
96    State(state): State<Arc<AppState>>,
97    headers: HeaderMap,
98    Json(body): Json<CreateSessionRequest>,
99) -> Result<Json<SessionInfo>, StatusCode> {
100    require_auth!(state, headers);
101
102    let id = uuid::Uuid::new_v4().to_string();
103    let model_id = body
104        .model
105        .unwrap_or_else(|| state.config.model_config.id.clone());
106    let now = std::time::SystemTime::now()
107        .duration_since(std::time::UNIX_EPOCH)
108        .unwrap()
109        .as_millis() as u64;
110
111    let handle = SessionHandle {
112        messages: Vec::new(),
113        model_id: model_id.clone(),
114        cancel: CancellationToken::new(),
115        created_at: now,
116    };
117
118    let info = SessionInfo {
119        id: id.clone(),
120        model: model_id,
121        message_count: 0,
122        created_at: now,
123    };
124
125    state.sessions.write().await.insert(id, handle);
126    Ok(Json(info))
127}
128
129async fn list_sessions(
130    State(state): State<Arc<AppState>>,
131    headers: HeaderMap,
132) -> Result<Json<Vec<SessionInfo>>, StatusCode> {
133    require_auth!(state, headers);
134
135    let sessions = state.sessions.read().await;
136    let list: Vec<SessionInfo> = sessions
137        .iter()
138        .map(|(id, h)| SessionInfo {
139            id: id.clone(),
140            model: h.model_id.clone(),
141            message_count: h.messages.len(),
142            created_at: h.created_at,
143        })
144        .collect();
145    Ok(Json(list))
146}
147
148async fn get_session(
149    State(state): State<Arc<AppState>>,
150    headers: HeaderMap,
151    Path(id): Path<String>,
152) -> Result<Json<SessionInfo>, StatusCode> {
153    require_auth!(state, headers);
154
155    let sessions = state.sessions.read().await;
156    let handle = sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?;
157    Ok(Json(SessionInfo {
158        id: id.clone(),
159        model: handle.model_id.clone(),
160        message_count: handle.messages.len(),
161        created_at: handle.created_at,
162    }))
163}
164
165async fn delete_session_handler(
166    State(state): State<Arc<AppState>>,
167    headers: HeaderMap,
168    Path(id): Path<String>,
169) -> Result<StatusCode, StatusCode> {
170    require_auth!(state, headers);
171
172    let mut sessions = state.sessions.write().await;
173    if let Some(handle) = sessions.remove(&id) {
174        handle.cancel.cancel();
175        Ok(StatusCode::NO_CONTENT)
176    } else {
177        Err(StatusCode::NOT_FOUND)
178    }
179}
180
181#[derive(Deserialize)]
182struct SendMessageRequest {
183    message: String,
184}
185
186async fn send_message(
187    State(state): State<Arc<AppState>>,
188    headers: HeaderMap,
189    Path(id): Path<String>,
190    Json(body): Json<SendMessageRequest>,
191) -> Result<impl IntoResponse, StatusCode> {
192    require_auth!(state, headers);
193
194    let (messages, cancel) = {
195        let mut sessions = state.sessions.write().await;
196        let handle = sessions.get_mut(&id).ok_or(StatusCode::NOT_FOUND)?;
197
198        let now = std::time::SystemTime::now()
199            .duration_since(std::time::UNIX_EPOCH)
200            .unwrap()
201            .as_millis() as u64;
202
203        handle.messages.push(Message::User {
204            content: UserContent::Text(body.message),
205            timestamp: now,
206        });
207
208        let cancel = CancellationToken::new();
209        handle.cancel = cancel.clone();
210        (handle.messages.clone(), cancel)
211    };
212
213    let config = &state.config;
214    let model = ModelRegistry::to_model(&config.model_config);
215    let tools = (config.tools_factory)();
216    let transform_messages = config
217        .compact_threshold
218        .map(compaction::make_compaction_transform);
219
220    let loop_config = AgentLoopConfig {
221        model,
222        api_key: config.api_key.clone(),
223        system_prompt: config.system_prompt.clone(),
224        tools,
225        thinking: config.thinking,
226        max_tokens: None,
227        stream_fn: rho_provider::stream_fn_for_model(&config.model_config),
228        get_steering_messages: None,
229        get_follow_up_messages: None,
230        transform_messages,
231        post_tools_hooks: Vec::new(),
232    };
233
234    let mut consumer = agent_loop(messages, loop_config, cancel);
235
236    let state_clone = state.clone();
237    let session_id = id.clone();
238
239    let stream = async_stream::stream! {
240        while let Some(event) = consumer.next().await {
241            let sse = match &event {
242                AgentEvent::MessageUpdate { event: stream_event, .. } => match stream_event {
243                    AssistantStreamEvent::TextDelta { delta, .. } => {
244                        Some(Event::default()
245                            .event("text_delta")
246                            .json_data(serde_json::json!({"text": delta})))
247                    }
248                    _ => None,
249                },
250                AgentEvent::ToolExecutionStart { tool_call_id, tool_name, args } => {
251                    Some(Event::default()
252                        .event("tool_start")
253                        .json_data(serde_json::json!({
254                            "tool_id": tool_call_id,
255                            "tool_name": tool_name,
256                            "args": args
257                        })))
258                }
259                AgentEvent::ToolExecutionEnd { tool_call_id, tool_name, is_error, .. } => {
260                    Some(Event::default()
261                        .event("tool_end")
262                        .json_data(serde_json::json!({
263                            "tool_id": tool_call_id,
264                            "tool_name": tool_name,
265                            "is_error": is_error
266                        })))
267                }
268                AgentEvent::TurnStart => {
269                    Some(Event::default()
270                        .event("turn_start")
271                        .json_data(serde_json::json!({})))
272                }
273                AgentEvent::TurnEnd { .. } => {
274                    Some(Event::default()
275                        .event("turn_end")
276                        .json_data(serde_json::json!({})))
277                }
278                AgentEvent::ContextCompacted { original_estimate, compacted_estimate, .. } => {
279                    Some(Event::default()
280                        .event("context_compacted")
281                        .json_data(serde_json::json!({
282                            "original_estimate": original_estimate,
283                            "compacted_estimate": compacted_estimate
284                        })))
285                }
286                AgentEvent::AgentEnd { messages } => {
287                    // Store final messages back into session
288                    let mut sessions = state_clone.sessions.write().await;
289                    if let Some(handle) = sessions.get_mut(&session_id) {
290                        handle.messages = messages.clone();
291                    }
292                    Some(Event::default()
293                        .event("done")
294                        .json_data(serde_json::json!({"message_count": messages.len()})))
295                }
296                _ => None,
297            };
298
299            if let Some(Ok(e)) = sse {
300                yield Ok::<_, std::convert::Infallible>(e);
301            }
302        }
303    };
304
305    Ok(Sse::new(stream).keep_alive(KeepAlive::default()))
306}
307
308async fn get_events(
309    State(state): State<Arc<AppState>>,
310    headers: HeaderMap,
311    Path(id): Path<String>,
312) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
313    require_auth!(state, headers);
314
315    let sessions = state.sessions.read().await;
316    let handle = sessions.get(&id).ok_or(StatusCode::NOT_FOUND)?;
317
318    let events: Vec<serde_json::Value> = handle
319        .messages
320        .iter()
321        .map(|msg| serde_json::to_value(msg).unwrap_or_default())
322        .collect();
323
324    Ok(Json(events))
325}
326
327// === Server entry point ===
328
329pub async fn start_server_with_addr(
330    config: ServerConfig,
331    host: &str,
332    port: u16,
333) -> anyhow::Result<()> {
334    let state = Arc::new(AppState {
335        sessions: RwLock::new(HashMap::new()),
336        config,
337    });
338
339    let app = Router::new()
340        .route("/health", get(health))
341        .route("/v1/sessions", post(create_session))
342        .route("/v1/sessions", get(list_sessions))
343        .route("/v1/sessions/{id}", get(get_session))
344        .route("/v1/sessions/{id}", delete(delete_session_handler))
345        .route("/v1/sessions/{id}/send", post(send_message))
346        .route("/v1/sessions/{id}/events", get(get_events))
347        .layer(tower_http::cors::CorsLayer::permissive())
348        .with_state(state);
349
350    let addr = format!("{host}:{port}");
351    tracing::info!("Starting rho server on {}", addr);
352    let listener = tokio::net::TcpListener::bind(&addr).await?;
353    axum::serve(listener, app).await?;
354    Ok(())
355}