1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use algocline_core::{
11 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
12 TerminalState,
13};
14use mlua_isle::{AsyncIsleDriver, AsyncTask};
15use serde_json::json;
16use tokio::sync::Mutex;
17
18use crate::llm_bridge::LlmRequest;
19
20#[derive(Debug, thiserror::Error)]
23pub enum SessionError {
24 #[error("session '{0}' not found")]
25 NotFound(String),
26 #[error(transparent)]
27 Feed(#[from] algocline_core::FeedError),
28 #[error("invalid transition: {0}")]
29 InvalidTransition(String),
30}
31
32#[derive(serde::Serialize)]
36pub struct ExecutionResult {
37 pub state: TerminalState,
38 pub metrics: ExecutionMetrics,
39}
40
41#[derive(serde::Serialize)]
43pub enum FeedResult {
44 Accepted { remaining: usize },
46 Paused { queries: Vec<LlmQuery> },
48 Finished(ExecutionResult),
50}
51
52impl FeedResult {
53 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
55 match self {
56 Self::Accepted { remaining } => json!({
57 "status": "accepted",
58 "remaining": remaining,
59 }),
60 Self::Paused { queries } => {
61 if queries.len() == 1 {
62 let q = &queries[0];
63 let mut obj = json!({
64 "status": "needs_response",
65 "session_id": session_id,
66 "query_id": q.id.as_str(),
67 "prompt": q.prompt,
68 "system": q.system,
69 "max_tokens": q.max_tokens,
70 });
71 if q.grounded {
72 obj["grounded"] = json!(true);
73 }
74 if q.underspecified {
75 obj["underspecified"] = json!(true);
76 }
77 obj
78 } else {
79 let qs: Vec<_> = queries
80 .iter()
81 .map(|q| {
82 let mut obj = json!({
83 "id": q.id.as_str(),
84 "prompt": q.prompt,
85 "system": q.system,
86 "max_tokens": q.max_tokens,
87 });
88 if q.grounded {
89 obj["grounded"] = json!(true);
90 }
91 if q.underspecified {
92 obj["underspecified"] = json!(true);
93 }
94 obj
95 })
96 .collect();
97 json!({
98 "status": "needs_response",
99 "session_id": session_id,
100 "queries": qs,
101 })
102 }
103 }
104 Self::Finished(result) => match &result.state {
105 TerminalState::Completed { result: val } => json!({
106 "status": "completed",
107 "result": val,
108 "stats": result.metrics.to_json(),
109 }),
110 TerminalState::Failed { error } => json!({
111 "status": "error",
112 "error": error,
113 }),
114 TerminalState::Cancelled => json!({
115 "status": "cancelled",
116 "stats": result.metrics.to_json(),
117 }),
118 },
119 }
120 }
121}
122
123pub struct Session {
131 state: ExecutionState,
132 metrics: ExecutionMetrics,
133 observer: MetricsObserver,
134 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
135 exec_task: AsyncTask,
136 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
138 _vm_driver: AsyncIsleDriver,
141 last_active: std::time::Instant,
144}
145
146impl Session {
147 pub fn new(
148 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
149 exec_task: AsyncTask,
150 metrics: ExecutionMetrics,
151 vm_driver: AsyncIsleDriver,
152 ) -> Self {
153 let observer = metrics.create_observer();
154 Self {
155 state: ExecutionState::Running,
156 metrics,
157 observer,
158 llm_rx,
159 exec_task,
160 resp_txs: HashMap::new(),
161 _vm_driver: vm_driver,
162 last_active: std::time::Instant::now(),
163 }
164 }
165
166 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
171 tokio::select! {
172 result = &mut self.exec_task => {
173 match result {
174 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
175 Ok(v) => {
176 self.state.complete(v.clone()).map_err(|e| {
177 SessionError::InvalidTransition(e.to_string())
178 })?;
179 self.observer.on_completed(&v);
180 Ok(FeedResult::Finished(ExecutionResult {
181 state: TerminalState::Completed { result: v },
182 metrics: self.take_metrics(),
183 }))
184 }
185 Err(e) => self.fail_with(format!("JSON parse: {e}")),
186 },
187 Err(e) => self.fail_with(e.to_string()),
188 }
189 }
190 Some(req) = self.llm_rx.recv() => {
191 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
192 id: qr.id.clone(),
193 prompt: qr.prompt.clone(),
194 system: qr.system.clone(),
195 max_tokens: qr.max_tokens,
196 grounded: qr.grounded,
197 underspecified: qr.underspecified,
198 }).collect();
199
200 for qr in req.queries {
201 self.resp_txs.insert(qr.id, qr.resp_tx);
202 }
203
204 self.state.pause(queries.clone()).map_err(|e| {
205 SessionError::InvalidTransition(e.to_string())
206 })?;
207 self.observer.on_paused(&queries);
208 Ok(FeedResult::Paused { queries })
209 }
210 }
211 }
212
213 fn feed_one(
217 &mut self,
218 query_id: &QueryId,
219 response: String,
220 usage: Option<&algocline_core::TokenUsage>,
221 ) -> Result<bool, SessionError> {
222 self.last_active = std::time::Instant::now();
224
225 self.observer.on_response_fed(query_id, &response, usage);
227
228 if let Some(tx) = self.resp_txs.remove(query_id) {
230 let _ = tx.send(Ok(response.clone()));
231 }
232
233 let complete = self
235 .state
236 .feed(query_id, response)
237 .map_err(SessionError::Feed)?;
238
239 if complete {
240 self.state
242 .take_responses()
243 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
244 self.observer.on_resumed();
245 } else {
246 self.observer
247 .on_partial_feed(query_id, self.state.remaining());
248 }
249
250 Ok(complete)
251 }
252
253 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
254 self.state
255 .fail(msg.clone())
256 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
257 self.observer.on_failed(&msg);
258 Ok(FeedResult::Finished(ExecutionResult {
259 state: TerminalState::Failed { error: msg },
260 metrics: self.take_metrics(),
261 }))
262 }
263
264 fn take_metrics(&mut self) -> ExecutionMetrics {
265 std::mem::take(&mut self.metrics)
266 }
267
268 pub fn snapshot(&self) -> serde_json::Value {
273 let state_label = match &self.state {
274 ExecutionState::Running => "running",
275 ExecutionState::Paused(_) => "paused",
276 _ => "terminal",
277 };
278
279 let mut json = serde_json::json!({
280 "state": state_label,
281 });
282
283 let metrics = self.metrics.snapshot();
284 if !metrics.is_null() {
285 json["metrics"] = metrics;
286 }
287
288 if let ExecutionState::Paused(_) = &self.state {
290 json["pending_queries"] = self.state.remaining().into();
291 }
292
293 json
294 }
295
296 pub fn is_expired(&self, ttl: Duration) -> bool {
301 is_expired_impl(self.last_active, ttl)
302 }
303}
304
305fn is_expired_impl(last_active: std::time::Instant, ttl: Duration) -> bool {
307 std::time::Instant::now().saturating_duration_since(last_active) >= ttl
308}
309
310pub struct SessionRegistry {
345 sessions: Arc<Mutex<HashMap<String, Session>>>,
346}
347
348impl Default for SessionRegistry {
349 fn default() -> Self {
350 Self::new()
351 }
352}
353
354impl SessionRegistry {
355 pub fn new() -> Self {
356 Self {
357 sessions: Arc::new(Mutex::new(HashMap::new())),
358 }
359 }
360
361 pub async fn start_execution(
363 &self,
364 mut session: Session,
365 ) -> Result<(String, FeedResult), SessionError> {
366 let session_id = gen_session_id();
367 let result = session.wait_event().await?;
368
369 if matches!(result, FeedResult::Paused { .. }) {
370 self.sessions
371 .lock()
372 .await
373 .insert(session_id.clone(), session);
374 }
375
376 Ok((session_id, result))
377 }
378
379 pub async fn feed_response(
385 &self,
386 session_id: &str,
387 query_id: &QueryId,
388 response: String,
389 usage: Option<&algocline_core::TokenUsage>,
390 ) -> Result<FeedResult, SessionError> {
391 let complete = {
393 let mut map = self.sessions.lock().await;
394 let session = map
395 .get_mut(session_id)
396 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
397
398 let complete = session.feed_one(query_id, response, usage)?;
399
400 if !complete {
401 return Ok(FeedResult::Accepted {
402 remaining: session.state.remaining(),
403 });
404 }
405
406 complete
407 };
408
409 debug_assert!(complete);
411 let mut session = {
412 let mut map = self.sessions.lock().await;
413 map.remove(session_id)
414 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
415 };
416
417 let result = session.wait_event().await?;
418
419 if matches!(result, FeedResult::Paused { .. }) {
420 self.sessions
421 .lock()
422 .await
423 .insert(session_id.into(), session);
424 }
425
426 Ok(result)
427 }
428
429 pub async fn resolve_sole_pending_id(&self, session_id: &str) -> Result<QueryId, SessionError> {
435 let map = self.sessions.lock().await;
436 let session = map
437 .get(session_id)
438 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
439 let keys: Vec<QueryId> = session.resp_txs.keys().cloned().collect();
440 match keys.len() {
441 0 => Err(SessionError::InvalidTransition("no pending queries".into())),
442 1 => keys
443 .into_iter()
444 .next()
445 .ok_or_else(|| SessionError::InvalidTransition("unexpected empty keys".into())),
446 n => Err(SessionError::InvalidTransition(format!(
447 "{n} queries pending; specify query_id explicitly"
448 ))),
449 }
450 }
451
452 pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
458 let map = self.sessions.lock().await;
459 map.iter()
460 .map(|(id, session)| (id.clone(), session.snapshot()))
461 .collect()
462 }
463
464 pub fn spawn_gc_task(&self, ttl: Duration) {
470 let sessions = Arc::clone(&self.sessions);
471 tokio::spawn(async move {
472 let mut interval = tokio::time::interval(Duration::from_secs(60));
473 loop {
474 interval.tick().await;
475 let mut map = sessions.lock().await;
476 let expired: Vec<String> = map
477 .iter()
478 .filter(|(_, s)| s.is_expired(ttl))
479 .map(|(id, _)| id.clone())
480 .collect();
481 for id in &expired {
482 tracing::info!(session_id = %id, "GC: reaping expired session");
483 map.remove(id);
484 }
485 }
486 });
487 }
488}
489
490fn gen_session_id() -> String {
509 use std::time::{SystemTime, UNIX_EPOCH};
510 let ts = SystemTime::now()
511 .duration_since(UNIX_EPOCH)
512 .unwrap_or_default()
513 .as_nanos();
514 let random: u64 = {
516 use std::collections::hash_map::RandomState;
517 use std::hash::{BuildHasher, Hasher};
518 let s = RandomState::new();
519 let mut h = s.build_hasher();
520 h.write_u128(ts);
521 h.finish()
522 };
523 format!("s-{ts:x}-{random:016x}")
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
530 use serde_json::json;
531
532 fn make_query(index: usize) -> LlmQuery {
533 LlmQuery {
534 id: QueryId::batch(index),
535 prompt: format!("prompt-{index}"),
536 system: None,
537 max_tokens: 100,
538 grounded: false,
539 underspecified: false,
540 }
541 }
542
543 #[test]
546 fn to_json_accepted() {
547 let result = FeedResult::Accepted { remaining: 3 };
548 let json = result.to_json("s-123");
549 assert_eq!(json["status"], "accepted");
550 assert_eq!(json["remaining"], 3);
551 }
552
553 #[test]
554 fn to_json_paused_single_query() {
555 let query = LlmQuery {
556 id: QueryId::single(),
557 prompt: "What is 2+2?".into(),
558 system: Some("You are a calculator.".into()),
559 max_tokens: 50,
560 grounded: false,
561 underspecified: false,
562 };
563 let result = FeedResult::Paused {
564 queries: vec![query],
565 };
566 let json = result.to_json("s-abc");
567
568 assert_eq!(json["status"], "needs_response");
569 assert_eq!(json["session_id"], "s-abc");
570 assert_eq!(json["prompt"], "What is 2+2?");
571 assert_eq!(json["system"], "You are a calculator.");
572 assert_eq!(json["max_tokens"], 50);
573 assert!(json.get("queries").is_none());
575 assert!(
577 json.get("grounded").is_none(),
578 "grounded key must be absent when false"
579 );
580 assert!(
582 json.get("underspecified").is_none(),
583 "underspecified key must be absent when false"
584 );
585 }
586
587 #[test]
588 fn to_json_paused_single_query_grounded() {
589 let query = LlmQuery {
590 id: QueryId::single(),
591 prompt: "verify this claim".into(),
592 system: None,
593 max_tokens: 200,
594 grounded: true,
595 underspecified: false,
596 };
597 let result = FeedResult::Paused {
598 queries: vec![query],
599 };
600 let json = result.to_json("s-grounded");
601
602 assert_eq!(json["status"], "needs_response");
603 assert_eq!(
604 json["grounded"], true,
605 "grounded must appear in single-query MCP JSON"
606 );
607 }
608
609 #[test]
610 fn to_json_paused_single_query_underspecified() {
611 let query = LlmQuery {
612 id: QueryId::single(),
613 prompt: "what output format do you need?".into(),
614 system: None,
615 max_tokens: 200,
616 grounded: false,
617 underspecified: true,
618 };
619 let result = FeedResult::Paused {
620 queries: vec![query],
621 };
622 let json = result.to_json("s-underspec");
623
624 assert_eq!(json["status"], "needs_response");
625 assert_eq!(
626 json["underspecified"], true,
627 "underspecified must appear in single-query MCP JSON"
628 );
629 assert!(
630 json.get("grounded").is_none(),
631 "grounded must be absent when false"
632 );
633 }
634
635 #[test]
636 fn to_json_paused_multiple_queries_mixed_grounded() {
637 let grounded_query = LlmQuery {
638 id: QueryId::batch(0),
639 prompt: "verify".into(),
640 system: None,
641 max_tokens: 100,
642 grounded: true,
643 underspecified: false,
644 };
645 let normal_query = LlmQuery {
646 id: QueryId::batch(1),
647 prompt: "generate".into(),
648 system: None,
649 max_tokens: 100,
650 grounded: false,
651 underspecified: false,
652 };
653 let result = FeedResult::Paused {
654 queries: vec![grounded_query, normal_query],
655 };
656 let json = result.to_json("s-batch");
657
658 let qs = json["queries"].as_array().expect("queries should be array");
659 assert_eq!(
660 qs[0]["grounded"], true,
661 "grounded query must have grounded=true"
662 );
663 assert!(
664 qs[1].get("grounded").is_none(),
665 "non-grounded query must omit grounded key"
666 );
667 }
668
669 #[test]
670 fn to_json_paused_multiple_queries_mixed_underspecified() {
671 let underspec_query = LlmQuery {
672 id: QueryId::batch(0),
673 prompt: "clarify intent".into(),
674 system: None,
675 max_tokens: 100,
676 grounded: false,
677 underspecified: true,
678 };
679 let normal_query = LlmQuery {
680 id: QueryId::batch(1),
681 prompt: "generate".into(),
682 system: None,
683 max_tokens: 100,
684 grounded: false,
685 underspecified: false,
686 };
687 let result = FeedResult::Paused {
688 queries: vec![underspec_query, normal_query],
689 };
690 let json = result.to_json("s-batch-us");
691
692 let qs = json["queries"].as_array().expect("queries should be array");
693 assert_eq!(
694 qs[0]["underspecified"], true,
695 "underspecified query must have underspecified=true"
696 );
697 assert!(
698 qs[1].get("underspecified").is_none(),
699 "non-underspecified query must omit underspecified key"
700 );
701 }
702
703 #[test]
704 fn to_json_paused_single_query_no_system() {
705 let query = LlmQuery {
706 id: QueryId::single(),
707 prompt: "hello".into(),
708 system: None,
709 max_tokens: 1024,
710 grounded: false,
711 underspecified: false,
712 };
713 let result = FeedResult::Paused {
714 queries: vec![query],
715 };
716 let json = result.to_json("s-x");
717
718 assert_eq!(json["status"], "needs_response");
719 assert!(json["system"].is_null());
720 }
721
722 #[test]
723 fn to_json_paused_multiple_queries() {
724 let queries = vec![make_query(0), make_query(1), make_query(2)];
725 let result = FeedResult::Paused { queries };
726 let json = result.to_json("s-multi");
727
728 assert_eq!(json["status"], "needs_response");
729 assert_eq!(json["session_id"], "s-multi");
730
731 let qs = json["queries"].as_array().expect("queries should be array");
732 assert_eq!(qs.len(), 3);
733 assert_eq!(qs[0]["id"], "q-0");
734 assert_eq!(qs[0]["prompt"], "prompt-0");
735 assert_eq!(qs[1]["id"], "q-1");
736 assert_eq!(qs[2]["id"], "q-2");
737 }
738
739 #[test]
740 fn to_json_finished_completed() {
741 let result = FeedResult::Finished(ExecutionResult {
742 state: TerminalState::Completed {
743 result: json!({"answer": 42}),
744 },
745 metrics: ExecutionMetrics::new(),
746 });
747 let json = result.to_json("s-done");
748
749 assert_eq!(json["status"], "completed");
750 assert_eq!(json["result"]["answer"], 42);
751 assert!(json.get("stats").is_some());
752 }
753
754 #[test]
755 fn to_json_finished_failed() {
756 let result = FeedResult::Finished(ExecutionResult {
757 state: TerminalState::Failed {
758 error: "lua error: bad argument".into(),
759 },
760 metrics: ExecutionMetrics::new(),
761 });
762 let json = result.to_json("s-err");
763
764 assert_eq!(json["status"], "error");
765 assert_eq!(json["error"], "lua error: bad argument");
766 }
767
768 #[test]
769 fn to_json_finished_cancelled() {
770 let result = FeedResult::Finished(ExecutionResult {
771 state: TerminalState::Cancelled,
772 metrics: ExecutionMetrics::new(),
773 });
774 let json = result.to_json("s-cancel");
775
776 assert_eq!(json["status"], "cancelled");
777 assert!(json.get("stats").is_some());
778 }
779
780 #[test]
783 fn session_id_starts_with_prefix() {
784 let id = gen_session_id();
785 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
786 }
787
788 #[test]
789 fn session_id_uniqueness() {
790 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
791 let set: std::collections::HashSet<&String> = ids.iter().collect();
792 assert_eq!(set.len(), 10, "10 IDs should all be unique");
793 }
794
795 #[test]
802 fn is_expired_impl_fresh_instant_not_expired() {
803 let now = std::time::Instant::now();
805 assert!(!is_expired_impl(now, Duration::from_secs(1)));
806 }
807
808 #[test]
809 fn is_expired_impl_old_instant_expired() {
810 let two_hours_ago = std::time::Instant::now()
812 .checked_sub(Duration::from_secs(7200))
813 .expect("checked_sub should succeed with sane duration");
814 assert!(is_expired_impl(two_hours_ago, Duration::from_secs(3600)));
816 }
817
818 #[test]
819 fn is_expired_impl_not_yet_expired() {
820 let one_hour_ago = std::time::Instant::now()
822 .checked_sub(Duration::from_secs(3600))
823 .expect("checked_sub should succeed with sane duration");
824 assert!(!is_expired_impl(one_hour_ago, Duration::from_secs(10800)));
826 }
827
828 #[test]
829 fn is_expired_impl_zero_ttl_always_expired() {
830 let now = std::time::Instant::now();
832 assert!(is_expired_impl(now, Duration::ZERO));
833 }
834}