Skip to main content

wraith_server/
lib.rs

1use std::collections::HashMap;
2use std::convert::Infallible;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
7use async_stream::stream;
8use axum::extract::{Path, State};
9use axum::http::StatusCode;
10use axum::response::sse::{Event, KeepAlive, Sse};
11use axum::response::IntoResponse;
12use axum::routing::{get, post};
13use axum::{Json, Router};
14use runtime::{ConversationMessage, Session as RuntimeSession};
15use serde::{Deserialize, Serialize};
16use tokio::sync::{broadcast, RwLock};
17
18pub type SessionId = String;
19pub type SessionStore = Arc<RwLock<HashMap<SessionId, Session>>>;
20
21const BROADCAST_CAPACITY: usize = 64;
22
23#[derive(Clone)]
24pub struct AppState {
25    sessions: SessionStore,
26    next_session_id: Arc<AtomicU64>,
27}
28
29impl AppState {
30    #[must_use]
31    pub fn new() -> Self {
32        Self {
33            sessions: Arc::new(RwLock::new(HashMap::new())),
34            next_session_id: Arc::new(AtomicU64::new(1)),
35        }
36    }
37
38    fn allocate_session_id(&self) -> SessionId {
39        let id = self.next_session_id.fetch_add(1, Ordering::Relaxed);
40        format!("session-{id}")
41    }
42}
43
44impl Default for AppState {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50#[derive(Clone)]
51pub struct Session {
52    pub id: SessionId,
53    pub created_at: u64,
54    pub conversation: RuntimeSession,
55    events: broadcast::Sender<SessionEvent>,
56}
57
58impl Session {
59    fn new(id: SessionId) -> Self {
60        let (events, _) = broadcast::channel(BROADCAST_CAPACITY);
61        Self {
62            id,
63            created_at: unix_timestamp_millis(),
64            conversation: RuntimeSession::new(),
65            events,
66        }
67    }
68
69    fn subscribe(&self) -> broadcast::Receiver<SessionEvent> {
70        self.events.subscribe()
71    }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
75#[serde(tag = "type", rename_all = "snake_case")]
76enum SessionEvent {
77    Snapshot {
78        session_id: SessionId,
79        session: RuntimeSession,
80    },
81    Message {
82        session_id: SessionId,
83        message: ConversationMessage,
84    },
85}
86
87impl SessionEvent {
88    fn event_name(&self) -> &'static str {
89        match self {
90            Self::Snapshot { .. } => "snapshot",
91            Self::Message { .. } => "message",
92        }
93    }
94
95    fn to_sse_event(&self) -> Result<Event, serde_json::Error> {
96        Ok(Event::default()
97            .event(self.event_name())
98            .data(serde_json::to_string(self)?))
99    }
100}
101
102#[derive(Debug, Serialize)]
103struct ErrorResponse {
104    error: String,
105}
106
107type ApiError = (StatusCode, Json<ErrorResponse>);
108type ApiResult<T> = Result<T, ApiError>;
109
110#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
111pub struct CreateSessionResponse {
112    pub session_id: SessionId,
113}
114
115#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
116pub struct SessionSummary {
117    pub id: SessionId,
118    pub created_at: u64,
119    pub message_count: usize,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
123pub struct ListSessionsResponse {
124    pub sessions: Vec<SessionSummary>,
125}
126
127#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
128pub struct SessionDetailsResponse {
129    pub id: SessionId,
130    pub created_at: u64,
131    pub session: RuntimeSession,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
135pub struct SendMessageRequest {
136    pub message: String,
137}
138
139pub fn app(state: AppState) -> Router {
140    Router::new()
141        .route("/sessions", post(create_session).get(list_sessions))
142        .route("/sessions/{id}", get(get_session))
143        .route("/sessions/{id}/events", get(stream_session_events))
144        .route("/sessions/{id}/message", post(send_message))
145        .with_state(state)
146}
147
148async fn create_session(
149    State(state): State<AppState>,
150) -> (StatusCode, Json<CreateSessionResponse>) {
151    let session_id = state.allocate_session_id();
152    let session = Session::new(session_id.clone());
153
154    state
155        .sessions
156        .write()
157        .await
158        .insert(session_id.clone(), session);
159
160    (
161        StatusCode::CREATED,
162        Json(CreateSessionResponse { session_id }),
163    )
164}
165
166async fn list_sessions(State(state): State<AppState>) -> Json<ListSessionsResponse> {
167    let sessions = state.sessions.read().await;
168    let mut summaries = sessions
169        .values()
170        .map(|session| SessionSummary {
171            id: session.id.clone(),
172            created_at: session.created_at,
173            message_count: session.conversation.messages.len(),
174        })
175        .collect::<Vec<_>>();
176    summaries.sort_by(|left, right| left.id.cmp(&right.id));
177
178    Json(ListSessionsResponse {
179        sessions: summaries,
180    })
181}
182
183async fn get_session(
184    State(state): State<AppState>,
185    Path(id): Path<SessionId>,
186) -> ApiResult<Json<SessionDetailsResponse>> {
187    let sessions = state.sessions.read().await;
188    let session = sessions
189        .get(&id)
190        .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
191
192    Ok(Json(SessionDetailsResponse {
193        id: session.id.clone(),
194        created_at: session.created_at,
195        session: session.conversation.clone(),
196    }))
197}
198
199async fn send_message(
200    State(state): State<AppState>,
201    Path(id): Path<SessionId>,
202    Json(payload): Json<SendMessageRequest>,
203) -> ApiResult<StatusCode> {
204    let message = ConversationMessage::user_text(payload.message);
205    let broadcaster = {
206        let mut sessions = state.sessions.write().await;
207        let session = sessions
208            .get_mut(&id)
209            .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
210        session.conversation.messages.push(message.clone());
211        session.events.clone()
212    };
213
214    let _ = broadcaster.send(SessionEvent::Message {
215        session_id: id,
216        message,
217    });
218
219    Ok(StatusCode::NO_CONTENT)
220}
221
222async fn stream_session_events(
223    State(state): State<AppState>,
224    Path(id): Path<SessionId>,
225) -> ApiResult<impl IntoResponse> {
226    let (snapshot, mut receiver) = {
227        let sessions = state.sessions.read().await;
228        let session = sessions
229            .get(&id)
230            .ok_or_else(|| not_found(format!("session `{id}` not found")))?;
231        (
232            SessionEvent::Snapshot {
233                session_id: session.id.clone(),
234                session: session.conversation.clone(),
235            },
236            session.subscribe(),
237        )
238    };
239
240    let stream = stream! {
241        if let Ok(event) = snapshot.to_sse_event() {
242            yield Ok::<Event, Infallible>(event);
243        }
244
245        loop {
246            match receiver.recv().await {
247                Ok(event) => {
248                    if let Ok(sse_event) = event.to_sse_event() {
249                        yield Ok::<Event, Infallible>(sse_event);
250                    }
251                }
252                Err(broadcast::error::RecvError::Lagged(_)) => {},
253                Err(broadcast::error::RecvError::Closed) => break,
254            }
255        }
256    };
257
258    Ok(Sse::new(stream).keep_alive(KeepAlive::new().interval(Duration::from_secs(15))))
259}
260
261fn unix_timestamp_millis() -> u64 {
262    u64::try_from(
263        SystemTime::now()
264            .duration_since(UNIX_EPOCH)
265            .expect("system time should be after epoch")
266            .as_millis(),
267    )
268    .unwrap_or(u64::MAX)
269}
270
271fn not_found(message: String) -> ApiError {
272    (
273        StatusCode::NOT_FOUND,
274        Json(ErrorResponse { error: message }),
275    )
276}
277
278#[cfg(test)]
279mod tests {
280    use super::{
281        app, AppState, CreateSessionResponse, ListSessionsResponse, SessionDetailsResponse,
282    };
283    use reqwest::Client;
284    use std::net::SocketAddr;
285    use std::time::Duration;
286    use tokio::net::TcpListener;
287    use tokio::task::JoinHandle;
288    use tokio::time::timeout;
289
290    struct TestServer {
291        address: SocketAddr,
292        handle: JoinHandle<()>,
293    }
294
295    impl TestServer {
296        async fn spawn() -> Self {
297            let listener = TcpListener::bind("127.0.0.1:0")
298                .await
299                .expect("test listener should bind");
300            let address = listener
301                .local_addr()
302                .expect("listener should report local address");
303            let handle = tokio::spawn(async move {
304                axum::serve(listener, app(AppState::default()))
305                    .await
306                    .expect("server should run");
307            });
308
309            Self { address, handle }
310        }
311
312        fn url(&self, path: &str) -> String {
313            format!("http://{}{}", self.address, path)
314        }
315    }
316
317    impl Drop for TestServer {
318        fn drop(&mut self) {
319            self.handle.abort();
320        }
321    }
322
323    async fn create_session(client: &Client, server: &TestServer) -> CreateSessionResponse {
324        client
325            .post(server.url("/sessions"))
326            .send()
327            .await
328            .expect("create request should succeed")
329            .error_for_status()
330            .expect("create request should return success")
331            .json::<CreateSessionResponse>()
332            .await
333            .expect("create response should parse")
334    }
335
336    async fn next_sse_frame(response: &mut reqwest::Response, buffer: &mut String) -> String {
337        loop {
338            if let Some(index) = buffer.find("\n\n") {
339                let frame = buffer[..index].to_string();
340                let remainder = buffer[index + 2..].to_string();
341                *buffer = remainder;
342                return frame;
343            }
344
345            let next_chunk = timeout(Duration::from_secs(5), response.chunk())
346                .await
347                .expect("SSE stream should yield within timeout")
348                .expect("SSE stream should remain readable")
349                .expect("SSE stream should stay open");
350            buffer.push_str(&String::from_utf8_lossy(&next_chunk));
351        }
352    }
353
354    #[tokio::test]
355    async fn creates_and_lists_sessions() {
356        let server = TestServer::spawn().await;
357        let client = Client::new();
358
359        // given
360        let created = create_session(&client, &server).await;
361
362        // when
363        let sessions = client
364            .get(server.url("/sessions"))
365            .send()
366            .await
367            .expect("list request should succeed")
368            .error_for_status()
369            .expect("list request should return success")
370            .json::<ListSessionsResponse>()
371            .await
372            .expect("list response should parse");
373        let details = client
374            .get(server.url(&format!("/sessions/{}", created.session_id)))
375            .send()
376            .await
377            .expect("details request should succeed")
378            .error_for_status()
379            .expect("details request should return success")
380            .json::<SessionDetailsResponse>()
381            .await
382            .expect("details response should parse");
383
384        // then
385        assert_eq!(created.session_id, "session-1");
386        assert_eq!(sessions.sessions.len(), 1);
387        assert_eq!(sessions.sessions[0].id, created.session_id);
388        assert_eq!(sessions.sessions[0].message_count, 0);
389        assert_eq!(details.id, "session-1");
390        assert!(details.session.messages.is_empty());
391    }
392
393    #[tokio::test]
394    async fn streams_message_events_and_persists_message_flow() {
395        let server = TestServer::spawn().await;
396        let client = Client::new();
397
398        // given
399        let created = create_session(&client, &server).await;
400        let mut response = client
401            .get(server.url(&format!("/sessions/{}/events", created.session_id)))
402            .send()
403            .await
404            .expect("events request should succeed")
405            .error_for_status()
406            .expect("events request should return success");
407        let mut buffer = String::new();
408        let snapshot_frame = next_sse_frame(&mut response, &mut buffer).await;
409
410        // when
411        let send_status = client
412            .post(server.url(&format!("/sessions/{}/message", created.session_id)))
413            .json(&super::SendMessageRequest {
414                message: "hello from test".to_string(),
415            })
416            .send()
417            .await
418            .expect("message request should succeed")
419            .status();
420        let message_frame = next_sse_frame(&mut response, &mut buffer).await;
421        let details = client
422            .get(server.url(&format!("/sessions/{}", created.session_id)))
423            .send()
424            .await
425            .expect("details request should succeed")
426            .error_for_status()
427            .expect("details request should return success")
428            .json::<SessionDetailsResponse>()
429            .await
430            .expect("details response should parse");
431
432        // then
433        assert_eq!(send_status, reqwest::StatusCode::NO_CONTENT);
434        assert!(snapshot_frame.contains("event: snapshot"));
435        assert!(snapshot_frame.contains("\"session_id\":\"session-1\""));
436        assert!(message_frame.contains("event: message"));
437        assert!(message_frame.contains("hello from test"));
438        assert_eq!(details.session.messages.len(), 1);
439        assert_eq!(
440            details.session.messages[0],
441            runtime::ConversationMessage::user_text("hello from test")
442        );
443    }
444}