1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31#[derive(serde::Serialize)]
35pub struct ExecutionResult {
36 pub state: TerminalState,
37 pub metrics: ExecutionMetrics,
38}
39
40#[derive(serde::Serialize)]
42pub enum FeedResult {
43 Accepted { remaining: usize },
45 Paused { queries: Vec<LlmQuery> },
47 Finished(ExecutionResult),
49}
50
51impl FeedResult {
52 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
54 match self {
55 Self::Accepted { remaining } => json!({
56 "status": "accepted",
57 "remaining": remaining,
58 }),
59 Self::Paused { queries } => {
60 if queries.len() == 1 {
61 let q = &queries[0];
62 let mut obj = json!({
63 "status": "needs_response",
64 "session_id": session_id,
65 "query_id": q.id.as_str(),
66 "prompt": q.prompt,
67 "system": q.system,
68 "max_tokens": q.max_tokens,
69 });
70 if q.grounded {
71 obj["grounded"] = json!(true);
72 }
73 if q.underspecified {
74 obj["underspecified"] = json!(true);
75 }
76 obj
77 } else {
78 let qs: Vec<_> = queries
79 .iter()
80 .map(|q| {
81 let mut obj = json!({
82 "id": q.id.as_str(),
83 "prompt": q.prompt,
84 "system": q.system,
85 "max_tokens": q.max_tokens,
86 });
87 if q.grounded {
88 obj["grounded"] = json!(true);
89 }
90 if q.underspecified {
91 obj["underspecified"] = json!(true);
92 }
93 obj
94 })
95 .collect();
96 json!({
97 "status": "needs_response",
98 "session_id": session_id,
99 "queries": qs,
100 })
101 }
102 }
103 Self::Finished(result) => match &result.state {
104 TerminalState::Completed { result: val } => json!({
105 "status": "completed",
106 "result": val,
107 "stats": result.metrics.to_json(),
108 }),
109 TerminalState::Failed { error } => json!({
110 "status": "error",
111 "error": error,
112 }),
113 TerminalState::Cancelled => json!({
114 "status": "cancelled",
115 "stats": result.metrics.to_json(),
116 }),
117 },
118 }
119 }
120}
121
122pub struct Session {
130 state: ExecutionState,
131 metrics: ExecutionMetrics,
132 observer: MetricsObserver,
133 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
134 exec_task: AsyncTask,
135 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
137 _vm_driver: AsyncIsleDriver,
140}
141
142impl Session {
143 pub fn new(
144 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
145 exec_task: AsyncTask,
146 metrics: ExecutionMetrics,
147 vm_driver: AsyncIsleDriver,
148 ) -> Self {
149 let observer = metrics.create_observer();
150 Self {
151 state: ExecutionState::Running,
152 metrics,
153 observer,
154 llm_rx,
155 exec_task,
156 resp_txs: HashMap::new(),
157 _vm_driver: vm_driver,
158 }
159 }
160
161 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
166 tokio::select! {
167 result = &mut self.exec_task => {
168 match result {
169 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
170 Ok(v) => {
171 self.state.complete(v.clone()).map_err(|e| {
172 SessionError::InvalidTransition(e.to_string())
173 })?;
174 self.observer.on_completed(&v);
175 Ok(FeedResult::Finished(ExecutionResult {
176 state: TerminalState::Completed { result: v },
177 metrics: self.take_metrics(),
178 }))
179 }
180 Err(e) => self.fail_with(format!("JSON parse: {e}")),
181 },
182 Err(e) => self.fail_with(e.to_string()),
183 }
184 }
185 Some(req) = self.llm_rx.recv() => {
186 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
187 id: qr.id.clone(),
188 prompt: qr.prompt.clone(),
189 system: qr.system.clone(),
190 max_tokens: qr.max_tokens,
191 grounded: qr.grounded,
192 underspecified: qr.underspecified,
193 }).collect();
194
195 for qr in req.queries {
196 self.resp_txs.insert(qr.id, qr.resp_tx);
197 }
198
199 self.state.pause(queries.clone()).map_err(|e| {
200 SessionError::InvalidTransition(e.to_string())
201 })?;
202 self.observer.on_paused(&queries);
203 Ok(FeedResult::Paused { queries })
204 }
205 }
206 }
207
208 fn feed_one(
212 &mut self,
213 query_id: &QueryId,
214 response: String,
215 usage: Option<&algocline_core::TokenUsage>,
216 ) -> Result<bool, SessionError> {
217 self.observer.on_response_fed(query_id, &response, usage);
219
220 if let Some(tx) = self.resp_txs.remove(query_id) {
222 let _ = tx.send(Ok(response.clone()));
223 }
224
225 let complete = self
227 .state
228 .feed(query_id, response)
229 .map_err(SessionError::Feed)?;
230
231 if complete {
232 self.state
234 .take_responses()
235 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
236 self.observer.on_resumed();
237 } else {
238 self.observer
239 .on_partial_feed(query_id, self.state.remaining());
240 }
241
242 Ok(complete)
243 }
244
245 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
246 self.state
247 .fail(msg.clone())
248 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
249 self.observer.on_failed(&msg);
250 Ok(FeedResult::Finished(ExecutionResult {
251 state: TerminalState::Failed { error: msg },
252 metrics: self.take_metrics(),
253 }))
254 }
255
256 fn take_metrics(&mut self) -> ExecutionMetrics {
257 std::mem::take(&mut self.metrics)
258 }
259
260 pub fn snapshot(&self) -> serde_json::Value {
265 let state_label = match &self.state {
266 ExecutionState::Running => "running",
267 ExecutionState::Paused(_) => "paused",
268 _ => "terminal",
269 };
270
271 let mut json = serde_json::json!({
272 "state": state_label,
273 });
274
275 let metrics = self.metrics.snapshot();
276 if !metrics.is_null() {
277 json["metrics"] = metrics;
278 }
279
280 if let ExecutionState::Paused(_) = &self.state {
282 json["pending_queries"] = self.state.remaining().into();
283 }
284
285 json
286 }
287}
288
289pub struct SessionRegistry {
324 sessions: Arc<Mutex<HashMap<String, Session>>>,
325}
326
327impl Default for SessionRegistry {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl SessionRegistry {
334 pub fn new() -> Self {
335 Self {
336 sessions: Arc::new(Mutex::new(HashMap::new())),
337 }
338 }
339
340 pub async fn start_execution(
342 &self,
343 mut session: Session,
344 ) -> Result<(String, FeedResult), SessionError> {
345 let session_id = gen_session_id();
346 let result = session.wait_event().await?;
347
348 if matches!(result, FeedResult::Paused { .. }) {
349 self.sessions
350 .lock()
351 .await
352 .insert(session_id.clone(), session);
353 }
354
355 Ok((session_id, result))
356 }
357
358 pub async fn feed_response(
364 &self,
365 session_id: &str,
366 query_id: &QueryId,
367 response: String,
368 usage: Option<&algocline_core::TokenUsage>,
369 ) -> Result<FeedResult, SessionError> {
370 let complete = {
372 let mut map = self.sessions.lock().await;
373 let session = map
374 .get_mut(session_id)
375 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
376
377 let complete = session.feed_one(query_id, response, usage)?;
378
379 if !complete {
380 return Ok(FeedResult::Accepted {
381 remaining: session.state.remaining(),
382 });
383 }
384
385 complete
386 };
387
388 debug_assert!(complete);
390 let mut session = {
391 let mut map = self.sessions.lock().await;
392 map.remove(session_id)
393 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
394 };
395
396 let result = session.wait_event().await?;
397
398 if matches!(result, FeedResult::Paused { .. }) {
399 self.sessions
400 .lock()
401 .await
402 .insert(session_id.into(), session);
403 }
404
405 Ok(result)
406 }
407
408 pub async fn resolve_sole_pending_id(&self, session_id: &str) -> Result<QueryId, SessionError> {
414 let map = self.sessions.lock().await;
415 let session = map
416 .get(session_id)
417 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
418 let keys: Vec<QueryId> = session.resp_txs.keys().cloned().collect();
419 match keys.len() {
420 0 => Err(SessionError::InvalidTransition("no pending queries".into())),
421 1 => keys
422 .into_iter()
423 .next()
424 .ok_or_else(|| SessionError::InvalidTransition("unexpected empty keys".into())),
425 n => Err(SessionError::InvalidTransition(format!(
426 "{n} queries pending; specify query_id explicitly"
427 ))),
428 }
429 }
430
431 pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
437 let map = self.sessions.lock().await;
438 map.iter()
439 .map(|(id, session)| (id.clone(), session.snapshot()))
440 .collect()
441 }
442}
443
444fn gen_session_id() -> String {
463 use std::time::{SystemTime, UNIX_EPOCH};
464 let ts = SystemTime::now()
465 .duration_since(UNIX_EPOCH)
466 .unwrap_or_default()
467 .as_nanos();
468 let random: u64 = {
470 use std::collections::hash_map::RandomState;
471 use std::hash::{BuildHasher, Hasher};
472 let s = RandomState::new();
473 let mut h = s.build_hasher();
474 h.write_u128(ts);
475 h.finish()
476 };
477 format!("s-{ts:x}-{random:016x}")
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
484 use serde_json::json;
485
486 fn make_query(index: usize) -> LlmQuery {
487 LlmQuery {
488 id: QueryId::batch(index),
489 prompt: format!("prompt-{index}"),
490 system: None,
491 max_tokens: 100,
492 grounded: false,
493 underspecified: false,
494 }
495 }
496
497 #[test]
500 fn to_json_accepted() {
501 let result = FeedResult::Accepted { remaining: 3 };
502 let json = result.to_json("s-123");
503 assert_eq!(json["status"], "accepted");
504 assert_eq!(json["remaining"], 3);
505 }
506
507 #[test]
508 fn to_json_paused_single_query() {
509 let query = LlmQuery {
510 id: QueryId::single(),
511 prompt: "What is 2+2?".into(),
512 system: Some("You are a calculator.".into()),
513 max_tokens: 50,
514 grounded: false,
515 underspecified: false,
516 };
517 let result = FeedResult::Paused {
518 queries: vec![query],
519 };
520 let json = result.to_json("s-abc");
521
522 assert_eq!(json["status"], "needs_response");
523 assert_eq!(json["session_id"], "s-abc");
524 assert_eq!(json["prompt"], "What is 2+2?");
525 assert_eq!(json["system"], "You are a calculator.");
526 assert_eq!(json["max_tokens"], 50);
527 assert!(json.get("queries").is_none());
529 assert!(
531 json.get("grounded").is_none(),
532 "grounded key must be absent when false"
533 );
534 assert!(
536 json.get("underspecified").is_none(),
537 "underspecified key must be absent when false"
538 );
539 }
540
541 #[test]
542 fn to_json_paused_single_query_grounded() {
543 let query = LlmQuery {
544 id: QueryId::single(),
545 prompt: "verify this claim".into(),
546 system: None,
547 max_tokens: 200,
548 grounded: true,
549 underspecified: false,
550 };
551 let result = FeedResult::Paused {
552 queries: vec![query],
553 };
554 let json = result.to_json("s-grounded");
555
556 assert_eq!(json["status"], "needs_response");
557 assert_eq!(
558 json["grounded"], true,
559 "grounded must appear in single-query MCP JSON"
560 );
561 }
562
563 #[test]
564 fn to_json_paused_single_query_underspecified() {
565 let query = LlmQuery {
566 id: QueryId::single(),
567 prompt: "what output format do you need?".into(),
568 system: None,
569 max_tokens: 200,
570 grounded: false,
571 underspecified: true,
572 };
573 let result = FeedResult::Paused {
574 queries: vec![query],
575 };
576 let json = result.to_json("s-underspec");
577
578 assert_eq!(json["status"], "needs_response");
579 assert_eq!(
580 json["underspecified"], true,
581 "underspecified must appear in single-query MCP JSON"
582 );
583 assert!(
584 json.get("grounded").is_none(),
585 "grounded must be absent when false"
586 );
587 }
588
589 #[test]
590 fn to_json_paused_multiple_queries_mixed_grounded() {
591 let grounded_query = LlmQuery {
592 id: QueryId::batch(0),
593 prompt: "verify".into(),
594 system: None,
595 max_tokens: 100,
596 grounded: true,
597 underspecified: false,
598 };
599 let normal_query = LlmQuery {
600 id: QueryId::batch(1),
601 prompt: "generate".into(),
602 system: None,
603 max_tokens: 100,
604 grounded: false,
605 underspecified: false,
606 };
607 let result = FeedResult::Paused {
608 queries: vec![grounded_query, normal_query],
609 };
610 let json = result.to_json("s-batch");
611
612 let qs = json["queries"].as_array().expect("queries should be array");
613 assert_eq!(
614 qs[0]["grounded"], true,
615 "grounded query must have grounded=true"
616 );
617 assert!(
618 qs[1].get("grounded").is_none(),
619 "non-grounded query must omit grounded key"
620 );
621 }
622
623 #[test]
624 fn to_json_paused_multiple_queries_mixed_underspecified() {
625 let underspec_query = LlmQuery {
626 id: QueryId::batch(0),
627 prompt: "clarify intent".into(),
628 system: None,
629 max_tokens: 100,
630 grounded: false,
631 underspecified: true,
632 };
633 let normal_query = LlmQuery {
634 id: QueryId::batch(1),
635 prompt: "generate".into(),
636 system: None,
637 max_tokens: 100,
638 grounded: false,
639 underspecified: false,
640 };
641 let result = FeedResult::Paused {
642 queries: vec![underspec_query, normal_query],
643 };
644 let json = result.to_json("s-batch-us");
645
646 let qs = json["queries"].as_array().expect("queries should be array");
647 assert_eq!(
648 qs[0]["underspecified"], true,
649 "underspecified query must have underspecified=true"
650 );
651 assert!(
652 qs[1].get("underspecified").is_none(),
653 "non-underspecified query must omit underspecified key"
654 );
655 }
656
657 #[test]
658 fn to_json_paused_single_query_no_system() {
659 let query = LlmQuery {
660 id: QueryId::single(),
661 prompt: "hello".into(),
662 system: None,
663 max_tokens: 1024,
664 grounded: false,
665 underspecified: false,
666 };
667 let result = FeedResult::Paused {
668 queries: vec![query],
669 };
670 let json = result.to_json("s-x");
671
672 assert_eq!(json["status"], "needs_response");
673 assert!(json["system"].is_null());
674 }
675
676 #[test]
677 fn to_json_paused_multiple_queries() {
678 let queries = vec![make_query(0), make_query(1), make_query(2)];
679 let result = FeedResult::Paused { queries };
680 let json = result.to_json("s-multi");
681
682 assert_eq!(json["status"], "needs_response");
683 assert_eq!(json["session_id"], "s-multi");
684
685 let qs = json["queries"].as_array().expect("queries should be array");
686 assert_eq!(qs.len(), 3);
687 assert_eq!(qs[0]["id"], "q-0");
688 assert_eq!(qs[0]["prompt"], "prompt-0");
689 assert_eq!(qs[1]["id"], "q-1");
690 assert_eq!(qs[2]["id"], "q-2");
691 }
692
693 #[test]
694 fn to_json_finished_completed() {
695 let result = FeedResult::Finished(ExecutionResult {
696 state: TerminalState::Completed {
697 result: json!({"answer": 42}),
698 },
699 metrics: ExecutionMetrics::new(),
700 });
701 let json = result.to_json("s-done");
702
703 assert_eq!(json["status"], "completed");
704 assert_eq!(json["result"]["answer"], 42);
705 assert!(json.get("stats").is_some());
706 }
707
708 #[test]
709 fn to_json_finished_failed() {
710 let result = FeedResult::Finished(ExecutionResult {
711 state: TerminalState::Failed {
712 error: "lua error: bad argument".into(),
713 },
714 metrics: ExecutionMetrics::new(),
715 });
716 let json = result.to_json("s-err");
717
718 assert_eq!(json["status"], "error");
719 assert_eq!(json["error"], "lua error: bad argument");
720 }
721
722 #[test]
723 fn to_json_finished_cancelled() {
724 let result = FeedResult::Finished(ExecutionResult {
725 state: TerminalState::Cancelled,
726 metrics: ExecutionMetrics::new(),
727 });
728 let json = result.to_json("s-cancel");
729
730 assert_eq!(json["status"], "cancelled");
731 assert!(json.get("stats").is_some());
732 }
733
734 #[test]
737 fn session_id_starts_with_prefix() {
738 let id = gen_session_id();
739 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
740 }
741
742 #[test]
743 fn session_id_uniqueness() {
744 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
745 let set: std::collections::HashSet<&String> = ids.iter().collect();
746 assert_eq!(set.len(), 10, "10 IDs should all be unique");
747 }
748}