1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31#[derive(serde::Serialize)]
35pub struct ExecutionResult {
36 pub state: TerminalState,
37 pub metrics: ExecutionMetrics,
38}
39
40#[derive(serde::Serialize)]
42pub enum FeedResult {
43 Accepted { remaining: usize },
45 Paused { queries: Vec<LlmQuery> },
47 Finished(ExecutionResult),
49}
50
51impl FeedResult {
52 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
54 match self {
55 Self::Accepted { remaining } => json!({
56 "status": "accepted",
57 "remaining": remaining,
58 }),
59 Self::Paused { queries } => {
60 if queries.len() == 1 {
61 let q = &queries[0];
62 let mut obj = json!({
63 "status": "needs_response",
64 "session_id": session_id,
65 "query_id": q.id.as_str(),
66 "prompt": q.prompt,
67 "system": q.system,
68 "max_tokens": q.max_tokens,
69 });
70 if q.grounded {
71 obj["grounded"] = json!(true);
72 }
73 if q.underspecified {
74 obj["underspecified"] = json!(true);
75 }
76 obj
77 } else {
78 let qs: Vec<_> = queries
79 .iter()
80 .map(|q| {
81 let mut obj = json!({
82 "id": q.id.as_str(),
83 "prompt": q.prompt,
84 "system": q.system,
85 "max_tokens": q.max_tokens,
86 });
87 if q.grounded {
88 obj["grounded"] = json!(true);
89 }
90 if q.underspecified {
91 obj["underspecified"] = json!(true);
92 }
93 obj
94 })
95 .collect();
96 json!({
97 "status": "needs_response",
98 "session_id": session_id,
99 "queries": qs,
100 })
101 }
102 }
103 Self::Finished(result) => match &result.state {
104 TerminalState::Completed { result: val } => json!({
105 "status": "completed",
106 "result": val,
107 "stats": result.metrics.to_json(),
108 }),
109 TerminalState::Failed { error } => json!({
110 "status": "error",
111 "error": error,
112 }),
113 TerminalState::Cancelled => json!({
114 "status": "cancelled",
115 "stats": result.metrics.to_json(),
116 }),
117 },
118 }
119 }
120}
121
122pub struct Session {
130 state: ExecutionState,
131 metrics: ExecutionMetrics,
132 observer: MetricsObserver,
133 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
134 exec_task: AsyncTask,
135 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
137 _vm_driver: AsyncIsleDriver,
140}
141
142impl Session {
143 pub fn new(
144 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
145 exec_task: AsyncTask,
146 metrics: ExecutionMetrics,
147 vm_driver: AsyncIsleDriver,
148 ) -> Self {
149 let observer = metrics.create_observer();
150 Self {
151 state: ExecutionState::Running,
152 metrics,
153 observer,
154 llm_rx,
155 exec_task,
156 resp_txs: HashMap::new(),
157 _vm_driver: vm_driver,
158 }
159 }
160
161 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
166 tokio::select! {
167 result = &mut self.exec_task => {
168 match result {
169 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
170 Ok(v) => {
171 self.state.complete(v.clone()).map_err(|e| {
172 SessionError::InvalidTransition(e.to_string())
173 })?;
174 self.observer.on_completed(&v);
175 Ok(FeedResult::Finished(ExecutionResult {
176 state: TerminalState::Completed { result: v },
177 metrics: self.take_metrics(),
178 }))
179 }
180 Err(e) => self.fail_with(format!("JSON parse: {e}")),
181 },
182 Err(e) => self.fail_with(e.to_string()),
183 }
184 }
185 Some(req) = self.llm_rx.recv() => {
186 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
187 id: qr.id.clone(),
188 prompt: qr.prompt.clone(),
189 system: qr.system.clone(),
190 max_tokens: qr.max_tokens,
191 grounded: qr.grounded,
192 underspecified: qr.underspecified,
193 }).collect();
194
195 for qr in req.queries {
196 self.resp_txs.insert(qr.id, qr.resp_tx);
197 }
198
199 self.state.pause(queries.clone()).map_err(|e| {
200 SessionError::InvalidTransition(e.to_string())
201 })?;
202 self.observer.on_paused(&queries);
203 Ok(FeedResult::Paused { queries })
204 }
205 }
206 }
207
208 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
212 self.observer.on_response_fed(query_id, &response);
214
215 if let Some(tx) = self.resp_txs.remove(query_id) {
217 let _ = tx.send(Ok(response.clone()));
218 }
219
220 let complete = self
222 .state
223 .feed(query_id, response)
224 .map_err(SessionError::Feed)?;
225
226 if complete {
227 self.state
229 .take_responses()
230 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
231 self.observer.on_resumed();
232 } else {
233 self.observer
234 .on_partial_feed(query_id, self.state.remaining());
235 }
236
237 Ok(complete)
238 }
239
240 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
241 self.state
242 .fail(msg.clone())
243 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
244 self.observer.on_failed(&msg);
245 Ok(FeedResult::Finished(ExecutionResult {
246 state: TerminalState::Failed { error: msg },
247 metrics: self.take_metrics(),
248 }))
249 }
250
251 fn take_metrics(&mut self) -> ExecutionMetrics {
252 std::mem::take(&mut self.metrics)
253 }
254
255 pub fn snapshot(&self) -> serde_json::Value {
260 let state_label = match &self.state {
261 ExecutionState::Running => "running",
262 ExecutionState::Paused(_) => "paused",
263 _ => "terminal",
264 };
265
266 let mut json = serde_json::json!({
267 "state": state_label,
268 });
269
270 let metrics = self.metrics.snapshot();
271 if !metrics.is_null() {
272 json["metrics"] = metrics;
273 }
274
275 if let ExecutionState::Paused(_) = &self.state {
277 json["pending_queries"] = self.state.remaining().into();
278 }
279
280 json
281 }
282}
283
284pub struct SessionRegistry {
319 sessions: Arc<Mutex<HashMap<String, Session>>>,
320}
321
322impl Default for SessionRegistry {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328impl SessionRegistry {
329 pub fn new() -> Self {
330 Self {
331 sessions: Arc::new(Mutex::new(HashMap::new())),
332 }
333 }
334
335 pub async fn start_execution(
337 &self,
338 mut session: Session,
339 ) -> Result<(String, FeedResult), SessionError> {
340 let session_id = gen_session_id();
341 let result = session.wait_event().await?;
342
343 if matches!(result, FeedResult::Paused { .. }) {
344 self.sessions
345 .lock()
346 .await
347 .insert(session_id.clone(), session);
348 }
349
350 Ok((session_id, result))
351 }
352
353 pub async fn feed_response(
359 &self,
360 session_id: &str,
361 query_id: &QueryId,
362 response: String,
363 ) -> Result<FeedResult, SessionError> {
364 let complete = {
366 let mut map = self.sessions.lock().await;
367 let session = map
368 .get_mut(session_id)
369 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
370
371 let complete = session.feed_one(query_id, response)?;
372
373 if !complete {
374 return Ok(FeedResult::Accepted {
375 remaining: session.state.remaining(),
376 });
377 }
378
379 complete
380 };
381
382 debug_assert!(complete);
384 let mut session = {
385 let mut map = self.sessions.lock().await;
386 map.remove(session_id)
387 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
388 };
389
390 let result = session.wait_event().await?;
391
392 if matches!(result, FeedResult::Paused { .. }) {
393 self.sessions
394 .lock()
395 .await
396 .insert(session_id.into(), session);
397 }
398
399 Ok(result)
400 }
401
402 pub async fn resolve_sole_pending_id(&self, session_id: &str) -> Result<QueryId, SessionError> {
408 let map = self.sessions.lock().await;
409 let session = map
410 .get(session_id)
411 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
412 let keys: Vec<QueryId> = session.resp_txs.keys().cloned().collect();
413 match keys.len() {
414 0 => Err(SessionError::InvalidTransition("no pending queries".into())),
415 1 => keys
416 .into_iter()
417 .next()
418 .ok_or_else(|| SessionError::InvalidTransition("unexpected empty keys".into())),
419 n => Err(SessionError::InvalidTransition(format!(
420 "{n} queries pending; specify query_id explicitly"
421 ))),
422 }
423 }
424
425 pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
431 let map = self.sessions.lock().await;
432 map.iter()
433 .map(|(id, session)| (id.clone(), session.snapshot()))
434 .collect()
435 }
436}
437
438fn gen_session_id() -> String {
457 use std::time::{SystemTime, UNIX_EPOCH};
458 let ts = SystemTime::now()
459 .duration_since(UNIX_EPOCH)
460 .unwrap_or_default()
461 .as_nanos();
462 let random: u64 = {
464 use std::collections::hash_map::RandomState;
465 use std::hash::{BuildHasher, Hasher};
466 let s = RandomState::new();
467 let mut h = s.build_hasher();
468 h.write_u128(ts);
469 h.finish()
470 };
471 format!("s-{ts:x}-{random:016x}")
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
478 use serde_json::json;
479
480 fn make_query(index: usize) -> LlmQuery {
481 LlmQuery {
482 id: QueryId::batch(index),
483 prompt: format!("prompt-{index}"),
484 system: None,
485 max_tokens: 100,
486 grounded: false,
487 underspecified: false,
488 }
489 }
490
491 #[test]
494 fn to_json_accepted() {
495 let result = FeedResult::Accepted { remaining: 3 };
496 let json = result.to_json("s-123");
497 assert_eq!(json["status"], "accepted");
498 assert_eq!(json["remaining"], 3);
499 }
500
501 #[test]
502 fn to_json_paused_single_query() {
503 let query = LlmQuery {
504 id: QueryId::single(),
505 prompt: "What is 2+2?".into(),
506 system: Some("You are a calculator.".into()),
507 max_tokens: 50,
508 grounded: false,
509 underspecified: false,
510 };
511 let result = FeedResult::Paused {
512 queries: vec![query],
513 };
514 let json = result.to_json("s-abc");
515
516 assert_eq!(json["status"], "needs_response");
517 assert_eq!(json["session_id"], "s-abc");
518 assert_eq!(json["prompt"], "What is 2+2?");
519 assert_eq!(json["system"], "You are a calculator.");
520 assert_eq!(json["max_tokens"], 50);
521 assert!(json.get("queries").is_none());
523 assert!(
525 json.get("grounded").is_none(),
526 "grounded key must be absent when false"
527 );
528 assert!(
530 json.get("underspecified").is_none(),
531 "underspecified key must be absent when false"
532 );
533 }
534
535 #[test]
536 fn to_json_paused_single_query_grounded() {
537 let query = LlmQuery {
538 id: QueryId::single(),
539 prompt: "verify this claim".into(),
540 system: None,
541 max_tokens: 200,
542 grounded: true,
543 underspecified: false,
544 };
545 let result = FeedResult::Paused {
546 queries: vec![query],
547 };
548 let json = result.to_json("s-grounded");
549
550 assert_eq!(json["status"], "needs_response");
551 assert_eq!(
552 json["grounded"], true,
553 "grounded must appear in single-query MCP JSON"
554 );
555 }
556
557 #[test]
558 fn to_json_paused_single_query_underspecified() {
559 let query = LlmQuery {
560 id: QueryId::single(),
561 prompt: "what output format do you need?".into(),
562 system: None,
563 max_tokens: 200,
564 grounded: false,
565 underspecified: true,
566 };
567 let result = FeedResult::Paused {
568 queries: vec![query],
569 };
570 let json = result.to_json("s-underspec");
571
572 assert_eq!(json["status"], "needs_response");
573 assert_eq!(
574 json["underspecified"], true,
575 "underspecified must appear in single-query MCP JSON"
576 );
577 assert!(
578 json.get("grounded").is_none(),
579 "grounded must be absent when false"
580 );
581 }
582
583 #[test]
584 fn to_json_paused_multiple_queries_mixed_grounded() {
585 let grounded_query = LlmQuery {
586 id: QueryId::batch(0),
587 prompt: "verify".into(),
588 system: None,
589 max_tokens: 100,
590 grounded: true,
591 underspecified: false,
592 };
593 let normal_query = LlmQuery {
594 id: QueryId::batch(1),
595 prompt: "generate".into(),
596 system: None,
597 max_tokens: 100,
598 grounded: false,
599 underspecified: false,
600 };
601 let result = FeedResult::Paused {
602 queries: vec![grounded_query, normal_query],
603 };
604 let json = result.to_json("s-batch");
605
606 let qs = json["queries"].as_array().expect("queries should be array");
607 assert_eq!(
608 qs[0]["grounded"], true,
609 "grounded query must have grounded=true"
610 );
611 assert!(
612 qs[1].get("grounded").is_none(),
613 "non-grounded query must omit grounded key"
614 );
615 }
616
617 #[test]
618 fn to_json_paused_multiple_queries_mixed_underspecified() {
619 let underspec_query = LlmQuery {
620 id: QueryId::batch(0),
621 prompt: "clarify intent".into(),
622 system: None,
623 max_tokens: 100,
624 grounded: false,
625 underspecified: true,
626 };
627 let normal_query = LlmQuery {
628 id: QueryId::batch(1),
629 prompt: "generate".into(),
630 system: None,
631 max_tokens: 100,
632 grounded: false,
633 underspecified: false,
634 };
635 let result = FeedResult::Paused {
636 queries: vec![underspec_query, normal_query],
637 };
638 let json = result.to_json("s-batch-us");
639
640 let qs = json["queries"].as_array().expect("queries should be array");
641 assert_eq!(
642 qs[0]["underspecified"], true,
643 "underspecified query must have underspecified=true"
644 );
645 assert!(
646 qs[1].get("underspecified").is_none(),
647 "non-underspecified query must omit underspecified key"
648 );
649 }
650
651 #[test]
652 fn to_json_paused_single_query_no_system() {
653 let query = LlmQuery {
654 id: QueryId::single(),
655 prompt: "hello".into(),
656 system: None,
657 max_tokens: 1024,
658 grounded: false,
659 underspecified: false,
660 };
661 let result = FeedResult::Paused {
662 queries: vec![query],
663 };
664 let json = result.to_json("s-x");
665
666 assert_eq!(json["status"], "needs_response");
667 assert!(json["system"].is_null());
668 }
669
670 #[test]
671 fn to_json_paused_multiple_queries() {
672 let queries = vec![make_query(0), make_query(1), make_query(2)];
673 let result = FeedResult::Paused { queries };
674 let json = result.to_json("s-multi");
675
676 assert_eq!(json["status"], "needs_response");
677 assert_eq!(json["session_id"], "s-multi");
678
679 let qs = json["queries"].as_array().expect("queries should be array");
680 assert_eq!(qs.len(), 3);
681 assert_eq!(qs[0]["id"], "q-0");
682 assert_eq!(qs[0]["prompt"], "prompt-0");
683 assert_eq!(qs[1]["id"], "q-1");
684 assert_eq!(qs[2]["id"], "q-2");
685 }
686
687 #[test]
688 fn to_json_finished_completed() {
689 let result = FeedResult::Finished(ExecutionResult {
690 state: TerminalState::Completed {
691 result: json!({"answer": 42}),
692 },
693 metrics: ExecutionMetrics::new(),
694 });
695 let json = result.to_json("s-done");
696
697 assert_eq!(json["status"], "completed");
698 assert_eq!(json["result"]["answer"], 42);
699 assert!(json.get("stats").is_some());
700 }
701
702 #[test]
703 fn to_json_finished_failed() {
704 let result = FeedResult::Finished(ExecutionResult {
705 state: TerminalState::Failed {
706 error: "lua error: bad argument".into(),
707 },
708 metrics: ExecutionMetrics::new(),
709 });
710 let json = result.to_json("s-err");
711
712 assert_eq!(json["status"], "error");
713 assert_eq!(json["error"], "lua error: bad argument");
714 }
715
716 #[test]
717 fn to_json_finished_cancelled() {
718 let result = FeedResult::Finished(ExecutionResult {
719 state: TerminalState::Cancelled,
720 metrics: ExecutionMetrics::new(),
721 });
722 let json = result.to_json("s-cancel");
723
724 assert_eq!(json["status"], "cancelled");
725 assert!(json.get("stats").is_some());
726 }
727
728 #[test]
731 fn session_id_starts_with_prefix() {
732 let id = gen_session_id();
733 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
734 }
735
736 #[test]
737 fn session_id_uniqueness() {
738 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
739 let set: std::collections::HashSet<&String> = ids.iter().collect();
740 assert_eq!(set.len(), 10, "10 IDs should all be unique");
741 }
742}