1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31pub struct ExecutionResult {
35 pub state: TerminalState,
36 pub metrics: ExecutionMetrics,
37}
38
39pub enum FeedResult {
41 Accepted { remaining: usize },
43 Paused { queries: Vec<LlmQuery> },
45 Finished(ExecutionResult),
47}
48
49impl FeedResult {
50 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52 match self {
53 Self::Accepted { remaining } => json!({
54 "status": "accepted",
55 "remaining": remaining,
56 }),
57 Self::Paused { queries } => {
58 if queries.len() == 1 {
59 let q = &queries[0];
60 json!({
61 "status": "needs_response",
62 "session_id": session_id,
63 "prompt": q.prompt,
64 "system": q.system,
65 "max_tokens": q.max_tokens,
66 })
67 } else {
68 let qs: Vec<_> = queries
69 .iter()
70 .map(|q| {
71 json!({
72 "id": q.id.as_str(),
73 "prompt": q.prompt,
74 "system": q.system,
75 "max_tokens": q.max_tokens,
76 })
77 })
78 .collect();
79 json!({
80 "status": "needs_response",
81 "session_id": session_id,
82 "queries": qs,
83 })
84 }
85 }
86 Self::Finished(result) => match &result.state {
87 TerminalState::Completed { result: val } => json!({
88 "status": "completed",
89 "result": val,
90 "stats": result.metrics.to_json(),
91 }),
92 TerminalState::Failed { error } => json!({
93 "status": "error",
94 "error": error,
95 }),
96 TerminalState::Cancelled => json!({
97 "status": "cancelled",
98 "stats": result.metrics.to_json(),
99 }),
100 },
101 }
102 }
103}
104
105pub struct Session {
109 state: ExecutionState,
110 metrics: ExecutionMetrics,
111 observer: MetricsObserver,
112 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
113 exec_task: AsyncTask,
114 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
116}
117
118impl Session {
119 pub fn new(
120 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
121 exec_task: AsyncTask,
122 metrics: ExecutionMetrics,
123 ) -> Self {
124 let observer = metrics.create_observer();
125 Self {
126 state: ExecutionState::Running,
127 metrics,
128 observer,
129 llm_rx,
130 exec_task,
131 resp_txs: HashMap::new(),
132 }
133 }
134
135 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
140 tokio::select! {
141 result = &mut self.exec_task => {
142 match result {
143 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
144 Ok(v) => {
145 self.state.complete(v.clone()).map_err(|e| {
146 SessionError::InvalidTransition(e.to_string())
147 })?;
148 self.observer.on_completed(&v);
149 Ok(FeedResult::Finished(ExecutionResult {
150 state: TerminalState::Completed { result: v },
151 metrics: self.take_metrics(),
152 }))
153 }
154 Err(e) => self.fail_with(format!("JSON parse: {e}")),
155 },
156 Err(e) => self.fail_with(e.to_string()),
157 }
158 }
159 Some(req) = self.llm_rx.recv() => {
160 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
161 id: qr.id.clone(),
162 prompt: qr.prompt.clone(),
163 system: qr.system.clone(),
164 max_tokens: qr.max_tokens,
165 }).collect();
166
167 for qr in req.queries {
168 self.resp_txs.insert(qr.id, qr.resp_tx);
169 }
170
171 self.state.pause(queries.clone()).map_err(|e| {
172 SessionError::InvalidTransition(e.to_string())
173 })?;
174 self.observer.on_paused(&queries);
175 Ok(FeedResult::Paused { queries })
176 }
177 }
178 }
179
180 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
184 if let Some(tx) = self.resp_txs.remove(query_id) {
186 let _ = tx.send(Ok(response.clone()));
187 }
188
189 let complete = self
191 .state
192 .feed(query_id, response)
193 .map_err(SessionError::Feed)?;
194
195 if complete {
196 self.state
198 .take_responses()
199 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
200 self.observer.on_resumed();
201 } else {
202 self.observer
203 .on_partial_feed(query_id, self.state.remaining());
204 }
205
206 Ok(complete)
207 }
208
209 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
210 self.state
211 .fail(msg.clone())
212 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
213 self.observer.on_failed(&msg);
214 Ok(FeedResult::Finished(ExecutionResult {
215 state: TerminalState::Failed { error: msg },
216 metrics: self.take_metrics(),
217 }))
218 }
219
220 fn take_metrics(&mut self) -> ExecutionMetrics {
221 std::mem::take(&mut self.metrics)
222 }
223}
224
225pub struct SessionRegistry {
229 sessions: Arc<Mutex<HashMap<String, Session>>>,
230}
231
232impl Default for SessionRegistry {
233 fn default() -> Self {
234 Self::new()
235 }
236}
237
238impl SessionRegistry {
239 pub fn new() -> Self {
240 Self {
241 sessions: Arc::new(Mutex::new(HashMap::new())),
242 }
243 }
244
245 pub async fn start_execution(
247 &self,
248 mut session: Session,
249 ) -> Result<(String, FeedResult), SessionError> {
250 let session_id = gen_session_id();
251 let result = session.wait_event().await?;
252
253 if matches!(result, FeedResult::Paused { .. }) {
254 self.sessions
255 .lock()
256 .await
257 .insert(session_id.clone(), session);
258 }
259
260 Ok((session_id, result))
261 }
262
263 pub async fn feed_response(
269 &self,
270 session_id: &str,
271 query_id: &QueryId,
272 response: String,
273 ) -> Result<FeedResult, SessionError> {
274 let complete = {
276 let mut map = self.sessions.lock().await;
277 let session = map
278 .get_mut(session_id)
279 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
280
281 let complete = session.feed_one(query_id, response)?;
282
283 if !complete {
284 return Ok(FeedResult::Accepted {
285 remaining: session.state.remaining(),
286 });
287 }
288
289 complete
290 };
291
292 debug_assert!(complete);
294 let mut session = {
295 let mut map = self.sessions.lock().await;
296 map.remove(session_id)
297 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
298 };
299
300 let result = session.wait_event().await?;
301
302 if matches!(result, FeedResult::Paused { .. }) {
303 self.sessions
304 .lock()
305 .await
306 .insert(session_id.into(), session);
307 }
308
309 Ok(result)
310 }
311}
312
313fn gen_session_id() -> String {
319 use std::time::{SystemTime, UNIX_EPOCH};
320 let ts = SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap_or_default()
323 .as_nanos();
324 let random: u64 = {
326 use std::collections::hash_map::RandomState;
327 use std::hash::{BuildHasher, Hasher};
328 let s = RandomState::new();
329 let mut h = s.build_hasher();
330 h.write_u128(ts);
331 h.finish()
332 };
333 format!("s-{ts:x}-{random:016x}")
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
340 use serde_json::json;
341
342 fn make_query(index: usize) -> LlmQuery {
343 LlmQuery {
344 id: QueryId::batch(index),
345 prompt: format!("prompt-{index}"),
346 system: None,
347 max_tokens: 100,
348 }
349 }
350
351 #[test]
354 fn to_json_accepted() {
355 let result = FeedResult::Accepted { remaining: 3 };
356 let json = result.to_json("s-123");
357 assert_eq!(json["status"], "accepted");
358 assert_eq!(json["remaining"], 3);
359 }
360
361 #[test]
362 fn to_json_paused_single_query() {
363 let query = LlmQuery {
364 id: QueryId::single(),
365 prompt: "What is 2+2?".into(),
366 system: Some("You are a calculator.".into()),
367 max_tokens: 50,
368 };
369 let result = FeedResult::Paused {
370 queries: vec![query],
371 };
372 let json = result.to_json("s-abc");
373
374 assert_eq!(json["status"], "needs_response");
375 assert_eq!(json["session_id"], "s-abc");
376 assert_eq!(json["prompt"], "What is 2+2?");
377 assert_eq!(json["system"], "You are a calculator.");
378 assert_eq!(json["max_tokens"], 50);
379 assert!(json.get("queries").is_none());
381 }
382
383 #[test]
384 fn to_json_paused_single_query_no_system() {
385 let query = LlmQuery {
386 id: QueryId::single(),
387 prompt: "hello".into(),
388 system: None,
389 max_tokens: 1024,
390 };
391 let result = FeedResult::Paused {
392 queries: vec![query],
393 };
394 let json = result.to_json("s-x");
395
396 assert_eq!(json["status"], "needs_response");
397 assert!(json["system"].is_null());
398 }
399
400 #[test]
401 fn to_json_paused_multiple_queries() {
402 let queries = vec![make_query(0), make_query(1), make_query(2)];
403 let result = FeedResult::Paused { queries };
404 let json = result.to_json("s-multi");
405
406 assert_eq!(json["status"], "needs_response");
407 assert_eq!(json["session_id"], "s-multi");
408
409 let qs = json["queries"].as_array().expect("queries should be array");
410 assert_eq!(qs.len(), 3);
411 assert_eq!(qs[0]["id"], "q-0");
412 assert_eq!(qs[0]["prompt"], "prompt-0");
413 assert_eq!(qs[1]["id"], "q-1");
414 assert_eq!(qs[2]["id"], "q-2");
415 }
416
417 #[test]
418 fn to_json_finished_completed() {
419 let result = FeedResult::Finished(ExecutionResult {
420 state: TerminalState::Completed {
421 result: json!({"answer": 42}),
422 },
423 metrics: ExecutionMetrics::new(),
424 });
425 let json = result.to_json("s-done");
426
427 assert_eq!(json["status"], "completed");
428 assert_eq!(json["result"]["answer"], 42);
429 assert!(json.get("stats").is_some());
430 }
431
432 #[test]
433 fn to_json_finished_failed() {
434 let result = FeedResult::Finished(ExecutionResult {
435 state: TerminalState::Failed {
436 error: "lua error: bad argument".into(),
437 },
438 metrics: ExecutionMetrics::new(),
439 });
440 let json = result.to_json("s-err");
441
442 assert_eq!(json["status"], "error");
443 assert_eq!(json["error"], "lua error: bad argument");
444 }
445
446 #[test]
447 fn to_json_finished_cancelled() {
448 let result = FeedResult::Finished(ExecutionResult {
449 state: TerminalState::Cancelled,
450 metrics: ExecutionMetrics::new(),
451 });
452 let json = result.to_json("s-cancel");
453
454 assert_eq!(json["status"], "cancelled");
455 assert!(json.get("stats").is_some());
456 }
457
458 #[test]
461 fn session_id_starts_with_prefix() {
462 let id = gen_session_id();
463 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
464 }
465
466 #[test]
467 fn session_id_uniqueness() {
468 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
469 let set: std::collections::HashSet<&String> = ids.iter().collect();
470 assert_eq!(set.len(), 10, "10 IDs should all be unique");
471 }
472}