1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31pub struct ExecutionResult {
35 pub state: TerminalState,
36 pub metrics: ExecutionMetrics,
37}
38
39pub enum FeedResult {
41 Accepted { remaining: usize },
43 Paused { queries: Vec<LlmQuery> },
45 Finished(ExecutionResult),
47}
48
49impl FeedResult {
50 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52 match self {
53 Self::Accepted { remaining } => json!({
54 "status": "accepted",
55 "remaining": remaining,
56 }),
57 Self::Paused { queries } => {
58 if queries.len() == 1 {
59 let q = &queries[0];
60 json!({
61 "status": "needs_response",
62 "session_id": session_id,
63 "prompt": q.prompt,
64 "system": q.system,
65 "max_tokens": q.max_tokens,
66 })
67 } else {
68 let qs: Vec<_> = queries
69 .iter()
70 .map(|q| {
71 json!({
72 "id": q.id.as_str(),
73 "prompt": q.prompt,
74 "system": q.system,
75 "max_tokens": q.max_tokens,
76 })
77 })
78 .collect();
79 json!({
80 "status": "needs_response",
81 "session_id": session_id,
82 "queries": qs,
83 })
84 }
85 }
86 Self::Finished(result) => match &result.state {
87 TerminalState::Completed { result: val } => json!({
88 "status": "completed",
89 "result": val,
90 "stats": result.metrics.to_json(),
91 }),
92 TerminalState::Failed { error } => json!({
93 "status": "error",
94 "error": error,
95 }),
96 TerminalState::Cancelled => json!({
97 "status": "cancelled",
98 "stats": result.metrics.to_json(),
99 }),
100 },
101 }
102 }
103}
104
105pub struct Session {
109 state: ExecutionState,
110 metrics: ExecutionMetrics,
111 observer: MetricsObserver,
112 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
113 exec_task: AsyncTask,
114 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
116}
117
118impl Session {
119 pub fn new(
120 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
121 exec_task: AsyncTask,
122 metrics: ExecutionMetrics,
123 ) -> Self {
124 let observer = metrics.create_observer();
125 Self {
126 state: ExecutionState::Running,
127 metrics,
128 observer,
129 llm_rx,
130 exec_task,
131 resp_txs: HashMap::new(),
132 }
133 }
134
135 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
140 tokio::select! {
141 result = &mut self.exec_task => {
142 match result {
143 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
144 Ok(v) => {
145 self.state.complete(v.clone()).map_err(|e| {
146 SessionError::InvalidTransition(e.to_string())
147 })?;
148 self.observer.on_completed(&v);
149 Ok(FeedResult::Finished(ExecutionResult {
150 state: TerminalState::Completed { result: v },
151 metrics: self.take_metrics(),
152 }))
153 }
154 Err(e) => self.fail_with(format!("JSON parse: {e}")),
155 },
156 Err(e) => self.fail_with(e.to_string()),
157 }
158 }
159 Some(req) = self.llm_rx.recv() => {
160 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
161 id: qr.id.clone(),
162 prompt: qr.prompt.clone(),
163 system: qr.system.clone(),
164 max_tokens: qr.max_tokens,
165 }).collect();
166
167 for qr in req.queries {
168 self.resp_txs.insert(qr.id, qr.resp_tx);
169 }
170
171 self.state.pause(queries.clone()).map_err(|e| {
172 SessionError::InvalidTransition(e.to_string())
173 })?;
174 self.observer.on_paused(&queries);
175 Ok(FeedResult::Paused { queries })
176 }
177 }
178 }
179
180 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
184 self.observer.on_response_fed(query_id, &response);
186
187 if let Some(tx) = self.resp_txs.remove(query_id) {
189 let _ = tx.send(Ok(response.clone()));
190 }
191
192 let complete = self
194 .state
195 .feed(query_id, response)
196 .map_err(SessionError::Feed)?;
197
198 if complete {
199 self.state
201 .take_responses()
202 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
203 self.observer.on_resumed();
204 } else {
205 self.observer
206 .on_partial_feed(query_id, self.state.remaining());
207 }
208
209 Ok(complete)
210 }
211
212 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
213 self.state
214 .fail(msg.clone())
215 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
216 self.observer.on_failed(&msg);
217 Ok(FeedResult::Finished(ExecutionResult {
218 state: TerminalState::Failed { error: msg },
219 metrics: self.take_metrics(),
220 }))
221 }
222
223 fn take_metrics(&mut self) -> ExecutionMetrics {
224 std::mem::take(&mut self.metrics)
225 }
226}
227
228pub struct SessionRegistry {
232 sessions: Arc<Mutex<HashMap<String, Session>>>,
233}
234
235impl Default for SessionRegistry {
236 fn default() -> Self {
237 Self::new()
238 }
239}
240
241impl SessionRegistry {
242 pub fn new() -> Self {
243 Self {
244 sessions: Arc::new(Mutex::new(HashMap::new())),
245 }
246 }
247
248 pub async fn start_execution(
250 &self,
251 mut session: Session,
252 ) -> Result<(String, FeedResult), SessionError> {
253 let session_id = gen_session_id();
254 let result = session.wait_event().await?;
255
256 if matches!(result, FeedResult::Paused { .. }) {
257 self.sessions
258 .lock()
259 .await
260 .insert(session_id.clone(), session);
261 }
262
263 Ok((session_id, result))
264 }
265
266 pub async fn feed_response(
272 &self,
273 session_id: &str,
274 query_id: &QueryId,
275 response: String,
276 ) -> Result<FeedResult, SessionError> {
277 let complete = {
279 let mut map = self.sessions.lock().await;
280 let session = map
281 .get_mut(session_id)
282 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
283
284 let complete = session.feed_one(query_id, response)?;
285
286 if !complete {
287 return Ok(FeedResult::Accepted {
288 remaining: session.state.remaining(),
289 });
290 }
291
292 complete
293 };
294
295 debug_assert!(complete);
297 let mut session = {
298 let mut map = self.sessions.lock().await;
299 map.remove(session_id)
300 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
301 };
302
303 let result = session.wait_event().await?;
304
305 if matches!(result, FeedResult::Paused { .. }) {
306 self.sessions
307 .lock()
308 .await
309 .insert(session_id.into(), session);
310 }
311
312 Ok(result)
313 }
314}
315
316fn gen_session_id() -> String {
322 use std::time::{SystemTime, UNIX_EPOCH};
323 let ts = SystemTime::now()
324 .duration_since(UNIX_EPOCH)
325 .unwrap_or_default()
326 .as_nanos();
327 let random: u64 = {
329 use std::collections::hash_map::RandomState;
330 use std::hash::{BuildHasher, Hasher};
331 let s = RandomState::new();
332 let mut h = s.build_hasher();
333 h.write_u128(ts);
334 h.finish()
335 };
336 format!("s-{ts:x}-{random:016x}")
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
343 use serde_json::json;
344
345 fn make_query(index: usize) -> LlmQuery {
346 LlmQuery {
347 id: QueryId::batch(index),
348 prompt: format!("prompt-{index}"),
349 system: None,
350 max_tokens: 100,
351 }
352 }
353
354 #[test]
357 fn to_json_accepted() {
358 let result = FeedResult::Accepted { remaining: 3 };
359 let json = result.to_json("s-123");
360 assert_eq!(json["status"], "accepted");
361 assert_eq!(json["remaining"], 3);
362 }
363
364 #[test]
365 fn to_json_paused_single_query() {
366 let query = LlmQuery {
367 id: QueryId::single(),
368 prompt: "What is 2+2?".into(),
369 system: Some("You are a calculator.".into()),
370 max_tokens: 50,
371 };
372 let result = FeedResult::Paused {
373 queries: vec![query],
374 };
375 let json = result.to_json("s-abc");
376
377 assert_eq!(json["status"], "needs_response");
378 assert_eq!(json["session_id"], "s-abc");
379 assert_eq!(json["prompt"], "What is 2+2?");
380 assert_eq!(json["system"], "You are a calculator.");
381 assert_eq!(json["max_tokens"], 50);
382 assert!(json.get("queries").is_none());
384 }
385
386 #[test]
387 fn to_json_paused_single_query_no_system() {
388 let query = LlmQuery {
389 id: QueryId::single(),
390 prompt: "hello".into(),
391 system: None,
392 max_tokens: 1024,
393 };
394 let result = FeedResult::Paused {
395 queries: vec![query],
396 };
397 let json = result.to_json("s-x");
398
399 assert_eq!(json["status"], "needs_response");
400 assert!(json["system"].is_null());
401 }
402
403 #[test]
404 fn to_json_paused_multiple_queries() {
405 let queries = vec![make_query(0), make_query(1), make_query(2)];
406 let result = FeedResult::Paused { queries };
407 let json = result.to_json("s-multi");
408
409 assert_eq!(json["status"], "needs_response");
410 assert_eq!(json["session_id"], "s-multi");
411
412 let qs = json["queries"].as_array().expect("queries should be array");
413 assert_eq!(qs.len(), 3);
414 assert_eq!(qs[0]["id"], "q-0");
415 assert_eq!(qs[0]["prompt"], "prompt-0");
416 assert_eq!(qs[1]["id"], "q-1");
417 assert_eq!(qs[2]["id"], "q-2");
418 }
419
420 #[test]
421 fn to_json_finished_completed() {
422 let result = FeedResult::Finished(ExecutionResult {
423 state: TerminalState::Completed {
424 result: json!({"answer": 42}),
425 },
426 metrics: ExecutionMetrics::new(),
427 });
428 let json = result.to_json("s-done");
429
430 assert_eq!(json["status"], "completed");
431 assert_eq!(json["result"]["answer"], 42);
432 assert!(json.get("stats").is_some());
433 }
434
435 #[test]
436 fn to_json_finished_failed() {
437 let result = FeedResult::Finished(ExecutionResult {
438 state: TerminalState::Failed {
439 error: "lua error: bad argument".into(),
440 },
441 metrics: ExecutionMetrics::new(),
442 });
443 let json = result.to_json("s-err");
444
445 assert_eq!(json["status"], "error");
446 assert_eq!(json["error"], "lua error: bad argument");
447 }
448
449 #[test]
450 fn to_json_finished_cancelled() {
451 let result = FeedResult::Finished(ExecutionResult {
452 state: TerminalState::Cancelled,
453 metrics: ExecutionMetrics::new(),
454 });
455 let json = result.to_json("s-cancel");
456
457 assert_eq!(json["status"], "cancelled");
458 assert!(json.get("stats").is_some());
459 }
460
461 #[test]
464 fn session_id_starts_with_prefix() {
465 let id = gen_session_id();
466 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
467 }
468
469 #[test]
470 fn session_id_uniqueness() {
471 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
472 let set: std::collections::HashSet<&String> = ids.iter().collect();
473 assert_eq!(set.len(), 10, "10 IDs should all be unique");
474 }
475}