1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31pub struct ExecutionResult {
35 pub state: TerminalState,
36 pub metrics: ExecutionMetrics,
37}
38
39pub enum FeedResult {
41 Accepted { remaining: usize },
43 Paused { queries: Vec<LlmQuery> },
45 Finished(ExecutionResult),
47}
48
49impl FeedResult {
50 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52 match self {
53 Self::Accepted { remaining } => json!({
54 "status": "accepted",
55 "remaining": remaining,
56 }),
57 Self::Paused { queries } => {
58 if queries.len() == 1 {
59 let q = &queries[0];
60 let mut obj = json!({
61 "status": "needs_response",
62 "session_id": session_id,
63 "prompt": q.prompt,
64 "system": q.system,
65 "max_tokens": q.max_tokens,
66 });
67 if q.grounded {
68 obj["grounded"] = json!(true);
69 }
70 if q.underspecified {
71 obj["underspecified"] = json!(true);
72 }
73 obj
74 } else {
75 let qs: Vec<_> = queries
76 .iter()
77 .map(|q| {
78 let mut obj = json!({
79 "id": q.id.as_str(),
80 "prompt": q.prompt,
81 "system": q.system,
82 "max_tokens": q.max_tokens,
83 });
84 if q.grounded {
85 obj["grounded"] = json!(true);
86 }
87 if q.underspecified {
88 obj["underspecified"] = json!(true);
89 }
90 obj
91 })
92 .collect();
93 json!({
94 "status": "needs_response",
95 "session_id": session_id,
96 "queries": qs,
97 })
98 }
99 }
100 Self::Finished(result) => match &result.state {
101 TerminalState::Completed { result: val } => json!({
102 "status": "completed",
103 "result": val,
104 "stats": result.metrics.to_json(),
105 }),
106 TerminalState::Failed { error } => json!({
107 "status": "error",
108 "error": error,
109 }),
110 TerminalState::Cancelled => json!({
111 "status": "cancelled",
112 "stats": result.metrics.to_json(),
113 }),
114 },
115 }
116 }
117}
118
119pub struct Session {
127 state: ExecutionState,
128 metrics: ExecutionMetrics,
129 observer: MetricsObserver,
130 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
131 exec_task: AsyncTask,
132 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
134 _vm_driver: AsyncIsleDriver,
137}
138
139impl Session {
140 pub fn new(
141 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
142 exec_task: AsyncTask,
143 metrics: ExecutionMetrics,
144 vm_driver: AsyncIsleDriver,
145 ) -> Self {
146 let observer = metrics.create_observer();
147 Self {
148 state: ExecutionState::Running,
149 metrics,
150 observer,
151 llm_rx,
152 exec_task,
153 resp_txs: HashMap::new(),
154 _vm_driver: vm_driver,
155 }
156 }
157
158 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
163 tokio::select! {
164 result = &mut self.exec_task => {
165 match result {
166 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
167 Ok(v) => {
168 self.state.complete(v.clone()).map_err(|e| {
169 SessionError::InvalidTransition(e.to_string())
170 })?;
171 self.observer.on_completed(&v);
172 Ok(FeedResult::Finished(ExecutionResult {
173 state: TerminalState::Completed { result: v },
174 metrics: self.take_metrics(),
175 }))
176 }
177 Err(e) => self.fail_with(format!("JSON parse: {e}")),
178 },
179 Err(e) => self.fail_with(e.to_string()),
180 }
181 }
182 Some(req) = self.llm_rx.recv() => {
183 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
184 id: qr.id.clone(),
185 prompt: qr.prompt.clone(),
186 system: qr.system.clone(),
187 max_tokens: qr.max_tokens,
188 grounded: qr.grounded,
189 underspecified: qr.underspecified,
190 }).collect();
191
192 for qr in req.queries {
193 self.resp_txs.insert(qr.id, qr.resp_tx);
194 }
195
196 self.state.pause(queries.clone()).map_err(|e| {
197 SessionError::InvalidTransition(e.to_string())
198 })?;
199 self.observer.on_paused(&queries);
200 Ok(FeedResult::Paused { queries })
201 }
202 }
203 }
204
205 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
209 self.observer.on_response_fed(query_id, &response);
211
212 if let Some(tx) = self.resp_txs.remove(query_id) {
214 let _ = tx.send(Ok(response.clone()));
215 }
216
217 let complete = self
219 .state
220 .feed(query_id, response)
221 .map_err(SessionError::Feed)?;
222
223 if complete {
224 self.state
226 .take_responses()
227 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
228 self.observer.on_resumed();
229 } else {
230 self.observer
231 .on_partial_feed(query_id, self.state.remaining());
232 }
233
234 Ok(complete)
235 }
236
237 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
238 self.state
239 .fail(msg.clone())
240 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
241 self.observer.on_failed(&msg);
242 Ok(FeedResult::Finished(ExecutionResult {
243 state: TerminalState::Failed { error: msg },
244 metrics: self.take_metrics(),
245 }))
246 }
247
248 fn take_metrics(&mut self) -> ExecutionMetrics {
249 std::mem::take(&mut self.metrics)
250 }
251
252 pub fn snapshot(&self) -> serde_json::Value {
257 let state_label = match &self.state {
258 ExecutionState::Running => "running",
259 ExecutionState::Paused(_) => "paused",
260 _ => "terminal",
261 };
262
263 let mut json = serde_json::json!({
264 "state": state_label,
265 });
266
267 let metrics = self.metrics.snapshot();
268 if !metrics.is_null() {
269 json["metrics"] = metrics;
270 }
271
272 if let ExecutionState::Paused(_) = &self.state {
274 json["pending_queries"] = self.state.remaining().into();
275 }
276
277 json
278 }
279}
280
281pub struct SessionRegistry {
316 sessions: Arc<Mutex<HashMap<String, Session>>>,
317}
318
319impl Default for SessionRegistry {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl SessionRegistry {
326 pub fn new() -> Self {
327 Self {
328 sessions: Arc::new(Mutex::new(HashMap::new())),
329 }
330 }
331
332 pub async fn start_execution(
334 &self,
335 mut session: Session,
336 ) -> Result<(String, FeedResult), SessionError> {
337 let session_id = gen_session_id();
338 let result = session.wait_event().await?;
339
340 if matches!(result, FeedResult::Paused { .. }) {
341 self.sessions
342 .lock()
343 .await
344 .insert(session_id.clone(), session);
345 }
346
347 Ok((session_id, result))
348 }
349
350 pub async fn feed_response(
356 &self,
357 session_id: &str,
358 query_id: &QueryId,
359 response: String,
360 ) -> Result<FeedResult, SessionError> {
361 let complete = {
363 let mut map = self.sessions.lock().await;
364 let session = map
365 .get_mut(session_id)
366 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
367
368 let complete = session.feed_one(query_id, response)?;
369
370 if !complete {
371 return Ok(FeedResult::Accepted {
372 remaining: session.state.remaining(),
373 });
374 }
375
376 complete
377 };
378
379 debug_assert!(complete);
381 let mut session = {
382 let mut map = self.sessions.lock().await;
383 map.remove(session_id)
384 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
385 };
386
387 let result = session.wait_event().await?;
388
389 if matches!(result, FeedResult::Paused { .. }) {
390 self.sessions
391 .lock()
392 .await
393 .insert(session_id.into(), session);
394 }
395
396 Ok(result)
397 }
398
399 pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
405 let map = self.sessions.lock().await;
406 map.iter()
407 .map(|(id, session)| (id.clone(), session.snapshot()))
408 .collect()
409 }
410}
411
412fn gen_session_id() -> String {
418 use std::time::{SystemTime, UNIX_EPOCH};
419 let ts = SystemTime::now()
420 .duration_since(UNIX_EPOCH)
421 .unwrap_or_default()
422 .as_nanos();
423 let random: u64 = {
425 use std::collections::hash_map::RandomState;
426 use std::hash::{BuildHasher, Hasher};
427 let s = RandomState::new();
428 let mut h = s.build_hasher();
429 h.write_u128(ts);
430 h.finish()
431 };
432 format!("s-{ts:x}-{random:016x}")
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
439 use serde_json::json;
440
441 fn make_query(index: usize) -> LlmQuery {
442 LlmQuery {
443 id: QueryId::batch(index),
444 prompt: format!("prompt-{index}"),
445 system: None,
446 max_tokens: 100,
447 grounded: false,
448 underspecified: false,
449 }
450 }
451
452 #[test]
455 fn to_json_accepted() {
456 let result = FeedResult::Accepted { remaining: 3 };
457 let json = result.to_json("s-123");
458 assert_eq!(json["status"], "accepted");
459 assert_eq!(json["remaining"], 3);
460 }
461
462 #[test]
463 fn to_json_paused_single_query() {
464 let query = LlmQuery {
465 id: QueryId::single(),
466 prompt: "What is 2+2?".into(),
467 system: Some("You are a calculator.".into()),
468 max_tokens: 50,
469 grounded: false,
470 underspecified: false,
471 };
472 let result = FeedResult::Paused {
473 queries: vec![query],
474 };
475 let json = result.to_json("s-abc");
476
477 assert_eq!(json["status"], "needs_response");
478 assert_eq!(json["session_id"], "s-abc");
479 assert_eq!(json["prompt"], "What is 2+2?");
480 assert_eq!(json["system"], "You are a calculator.");
481 assert_eq!(json["max_tokens"], 50);
482 assert!(json.get("queries").is_none());
484 assert!(
486 json.get("grounded").is_none(),
487 "grounded key must be absent when false"
488 );
489 assert!(
491 json.get("underspecified").is_none(),
492 "underspecified key must be absent when false"
493 );
494 }
495
496 #[test]
497 fn to_json_paused_single_query_grounded() {
498 let query = LlmQuery {
499 id: QueryId::single(),
500 prompt: "verify this claim".into(),
501 system: None,
502 max_tokens: 200,
503 grounded: true,
504 underspecified: false,
505 };
506 let result = FeedResult::Paused {
507 queries: vec![query],
508 };
509 let json = result.to_json("s-grounded");
510
511 assert_eq!(json["status"], "needs_response");
512 assert_eq!(
513 json["grounded"], true,
514 "grounded must appear in single-query MCP JSON"
515 );
516 }
517
518 #[test]
519 fn to_json_paused_single_query_underspecified() {
520 let query = LlmQuery {
521 id: QueryId::single(),
522 prompt: "what output format do you need?".into(),
523 system: None,
524 max_tokens: 200,
525 grounded: false,
526 underspecified: true,
527 };
528 let result = FeedResult::Paused {
529 queries: vec![query],
530 };
531 let json = result.to_json("s-underspec");
532
533 assert_eq!(json["status"], "needs_response");
534 assert_eq!(
535 json["underspecified"], true,
536 "underspecified must appear in single-query MCP JSON"
537 );
538 assert!(
539 json.get("grounded").is_none(),
540 "grounded must be absent when false"
541 );
542 }
543
544 #[test]
545 fn to_json_paused_multiple_queries_mixed_grounded() {
546 let grounded_query = LlmQuery {
547 id: QueryId::batch(0),
548 prompt: "verify".into(),
549 system: None,
550 max_tokens: 100,
551 grounded: true,
552 underspecified: false,
553 };
554 let normal_query = LlmQuery {
555 id: QueryId::batch(1),
556 prompt: "generate".into(),
557 system: None,
558 max_tokens: 100,
559 grounded: false,
560 underspecified: false,
561 };
562 let result = FeedResult::Paused {
563 queries: vec![grounded_query, normal_query],
564 };
565 let json = result.to_json("s-batch");
566
567 let qs = json["queries"].as_array().expect("queries should be array");
568 assert_eq!(
569 qs[0]["grounded"], true,
570 "grounded query must have grounded=true"
571 );
572 assert!(
573 qs[1].get("grounded").is_none(),
574 "non-grounded query must omit grounded key"
575 );
576 }
577
578 #[test]
579 fn to_json_paused_multiple_queries_mixed_underspecified() {
580 let underspec_query = LlmQuery {
581 id: QueryId::batch(0),
582 prompt: "clarify intent".into(),
583 system: None,
584 max_tokens: 100,
585 grounded: false,
586 underspecified: true,
587 };
588 let normal_query = LlmQuery {
589 id: QueryId::batch(1),
590 prompt: "generate".into(),
591 system: None,
592 max_tokens: 100,
593 grounded: false,
594 underspecified: false,
595 };
596 let result = FeedResult::Paused {
597 queries: vec![underspec_query, normal_query],
598 };
599 let json = result.to_json("s-batch-us");
600
601 let qs = json["queries"].as_array().expect("queries should be array");
602 assert_eq!(
603 qs[0]["underspecified"], true,
604 "underspecified query must have underspecified=true"
605 );
606 assert!(
607 qs[1].get("underspecified").is_none(),
608 "non-underspecified query must omit underspecified key"
609 );
610 }
611
612 #[test]
613 fn to_json_paused_single_query_no_system() {
614 let query = LlmQuery {
615 id: QueryId::single(),
616 prompt: "hello".into(),
617 system: None,
618 max_tokens: 1024,
619 grounded: false,
620 underspecified: false,
621 };
622 let result = FeedResult::Paused {
623 queries: vec![query],
624 };
625 let json = result.to_json("s-x");
626
627 assert_eq!(json["status"], "needs_response");
628 assert!(json["system"].is_null());
629 }
630
631 #[test]
632 fn to_json_paused_multiple_queries() {
633 let queries = vec![make_query(0), make_query(1), make_query(2)];
634 let result = FeedResult::Paused { queries };
635 let json = result.to_json("s-multi");
636
637 assert_eq!(json["status"], "needs_response");
638 assert_eq!(json["session_id"], "s-multi");
639
640 let qs = json["queries"].as_array().expect("queries should be array");
641 assert_eq!(qs.len(), 3);
642 assert_eq!(qs[0]["id"], "q-0");
643 assert_eq!(qs[0]["prompt"], "prompt-0");
644 assert_eq!(qs[1]["id"], "q-1");
645 assert_eq!(qs[2]["id"], "q-2");
646 }
647
648 #[test]
649 fn to_json_finished_completed() {
650 let result = FeedResult::Finished(ExecutionResult {
651 state: TerminalState::Completed {
652 result: json!({"answer": 42}),
653 },
654 metrics: ExecutionMetrics::new(),
655 });
656 let json = result.to_json("s-done");
657
658 assert_eq!(json["status"], "completed");
659 assert_eq!(json["result"]["answer"], 42);
660 assert!(json.get("stats").is_some());
661 }
662
663 #[test]
664 fn to_json_finished_failed() {
665 let result = FeedResult::Finished(ExecutionResult {
666 state: TerminalState::Failed {
667 error: "lua error: bad argument".into(),
668 },
669 metrics: ExecutionMetrics::new(),
670 });
671 let json = result.to_json("s-err");
672
673 assert_eq!(json["status"], "error");
674 assert_eq!(json["error"], "lua error: bad argument");
675 }
676
677 #[test]
678 fn to_json_finished_cancelled() {
679 let result = FeedResult::Finished(ExecutionResult {
680 state: TerminalState::Cancelled,
681 metrics: ExecutionMetrics::new(),
682 });
683 let json = result.to_json("s-cancel");
684
685 assert_eq!(json["status"], "cancelled");
686 assert!(json.get("stats").is_some());
687 }
688
689 #[test]
692 fn session_id_starts_with_prefix() {
693 let id = gen_session_id();
694 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
695 }
696
697 #[test]
698 fn session_id_uniqueness() {
699 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
700 let set: std::collections::HashSet<&String> = ids.iter().collect();
701 assert_eq!(set.len(), 10, "10 IDs should all be unique");
702 }
703}