1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31pub struct ExecutionResult {
35 pub state: TerminalState,
36 pub metrics: ExecutionMetrics,
37}
38
39pub enum FeedResult {
41 Accepted { remaining: usize },
43 Paused { queries: Vec<LlmQuery> },
45 Finished(ExecutionResult),
47}
48
49impl FeedResult {
50 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52 match self {
53 Self::Accepted { remaining } => json!({
54 "status": "accepted",
55 "remaining": remaining,
56 }),
57 Self::Paused { queries } => {
58 if queries.len() == 1 {
59 let q = &queries[0];
60 let mut obj = json!({
61 "status": "needs_response",
62 "session_id": session_id,
63 "prompt": q.prompt,
64 "system": q.system,
65 "max_tokens": q.max_tokens,
66 });
67 if q.grounded {
68 obj["grounded"] = json!(true);
69 }
70 if q.underspecified {
71 obj["underspecified"] = json!(true);
72 }
73 obj
74 } else {
75 let qs: Vec<_> = queries
76 .iter()
77 .map(|q| {
78 let mut obj = json!({
79 "id": q.id.as_str(),
80 "prompt": q.prompt,
81 "system": q.system,
82 "max_tokens": q.max_tokens,
83 });
84 if q.grounded {
85 obj["grounded"] = json!(true);
86 }
87 if q.underspecified {
88 obj["underspecified"] = json!(true);
89 }
90 obj
91 })
92 .collect();
93 json!({
94 "status": "needs_response",
95 "session_id": session_id,
96 "queries": qs,
97 })
98 }
99 }
100 Self::Finished(result) => match &result.state {
101 TerminalState::Completed { result: val } => json!({
102 "status": "completed",
103 "result": val,
104 "stats": result.metrics.to_json(),
105 }),
106 TerminalState::Failed { error } => json!({
107 "status": "error",
108 "error": error,
109 }),
110 TerminalState::Cancelled => json!({
111 "status": "cancelled",
112 "stats": result.metrics.to_json(),
113 }),
114 },
115 }
116 }
117}
118
119pub struct Session {
127 state: ExecutionState,
128 metrics: ExecutionMetrics,
129 observer: MetricsObserver,
130 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
131 exec_task: AsyncTask,
132 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
134 _vm_driver: AsyncIsleDriver,
137}
138
139impl Session {
140 pub fn new(
141 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
142 exec_task: AsyncTask,
143 metrics: ExecutionMetrics,
144 vm_driver: AsyncIsleDriver,
145 ) -> Self {
146 let observer = metrics.create_observer();
147 Self {
148 state: ExecutionState::Running,
149 metrics,
150 observer,
151 llm_rx,
152 exec_task,
153 resp_txs: HashMap::new(),
154 _vm_driver: vm_driver,
155 }
156 }
157
158 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
163 tokio::select! {
164 result = &mut self.exec_task => {
165 match result {
166 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
167 Ok(v) => {
168 self.state.complete(v.clone()).map_err(|e| {
169 SessionError::InvalidTransition(e.to_string())
170 })?;
171 self.observer.on_completed(&v);
172 Ok(FeedResult::Finished(ExecutionResult {
173 state: TerminalState::Completed { result: v },
174 metrics: self.take_metrics(),
175 }))
176 }
177 Err(e) => self.fail_with(format!("JSON parse: {e}")),
178 },
179 Err(e) => self.fail_with(e.to_string()),
180 }
181 }
182 Some(req) = self.llm_rx.recv() => {
183 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
184 id: qr.id.clone(),
185 prompt: qr.prompt.clone(),
186 system: qr.system.clone(),
187 max_tokens: qr.max_tokens,
188 grounded: qr.grounded,
189 underspecified: qr.underspecified,
190 }).collect();
191
192 for qr in req.queries {
193 self.resp_txs.insert(qr.id, qr.resp_tx);
194 }
195
196 self.state.pause(queries.clone()).map_err(|e| {
197 SessionError::InvalidTransition(e.to_string())
198 })?;
199 self.observer.on_paused(&queries);
200 Ok(FeedResult::Paused { queries })
201 }
202 }
203 }
204
205 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
209 self.observer.on_response_fed(query_id, &response);
211
212 if let Some(tx) = self.resp_txs.remove(query_id) {
214 let _ = tx.send(Ok(response.clone()));
215 }
216
217 let complete = self
219 .state
220 .feed(query_id, response)
221 .map_err(SessionError::Feed)?;
222
223 if complete {
224 self.state
226 .take_responses()
227 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
228 self.observer.on_resumed();
229 } else {
230 self.observer
231 .on_partial_feed(query_id, self.state.remaining());
232 }
233
234 Ok(complete)
235 }
236
237 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
238 self.state
239 .fail(msg.clone())
240 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
241 self.observer.on_failed(&msg);
242 Ok(FeedResult::Finished(ExecutionResult {
243 state: TerminalState::Failed { error: msg },
244 metrics: self.take_metrics(),
245 }))
246 }
247
248 fn take_metrics(&mut self) -> ExecutionMetrics {
249 std::mem::take(&mut self.metrics)
250 }
251}
252
253pub struct SessionRegistry {
257 sessions: Arc<Mutex<HashMap<String, Session>>>,
258}
259
260impl Default for SessionRegistry {
261 fn default() -> Self {
262 Self::new()
263 }
264}
265
266impl SessionRegistry {
267 pub fn new() -> Self {
268 Self {
269 sessions: Arc::new(Mutex::new(HashMap::new())),
270 }
271 }
272
273 pub async fn start_execution(
275 &self,
276 mut session: Session,
277 ) -> Result<(String, FeedResult), SessionError> {
278 let session_id = gen_session_id();
279 let result = session.wait_event().await?;
280
281 if matches!(result, FeedResult::Paused { .. }) {
282 self.sessions
283 .lock()
284 .await
285 .insert(session_id.clone(), session);
286 }
287
288 Ok((session_id, result))
289 }
290
291 pub async fn feed_response(
297 &self,
298 session_id: &str,
299 query_id: &QueryId,
300 response: String,
301 ) -> Result<FeedResult, SessionError> {
302 let complete = {
304 let mut map = self.sessions.lock().await;
305 let session = map
306 .get_mut(session_id)
307 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
308
309 let complete = session.feed_one(query_id, response)?;
310
311 if !complete {
312 return Ok(FeedResult::Accepted {
313 remaining: session.state.remaining(),
314 });
315 }
316
317 complete
318 };
319
320 debug_assert!(complete);
322 let mut session = {
323 let mut map = self.sessions.lock().await;
324 map.remove(session_id)
325 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
326 };
327
328 let result = session.wait_event().await?;
329
330 if matches!(result, FeedResult::Paused { .. }) {
331 self.sessions
332 .lock()
333 .await
334 .insert(session_id.into(), session);
335 }
336
337 Ok(result)
338 }
339}
340
341fn gen_session_id() -> String {
347 use std::time::{SystemTime, UNIX_EPOCH};
348 let ts = SystemTime::now()
349 .duration_since(UNIX_EPOCH)
350 .unwrap_or_default()
351 .as_nanos();
352 let random: u64 = {
354 use std::collections::hash_map::RandomState;
355 use std::hash::{BuildHasher, Hasher};
356 let s = RandomState::new();
357 let mut h = s.build_hasher();
358 h.write_u128(ts);
359 h.finish()
360 };
361 format!("s-{ts:x}-{random:016x}")
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
368 use serde_json::json;
369
370 fn make_query(index: usize) -> LlmQuery {
371 LlmQuery {
372 id: QueryId::batch(index),
373 prompt: format!("prompt-{index}"),
374 system: None,
375 max_tokens: 100,
376 grounded: false,
377 underspecified: false,
378 }
379 }
380
381 #[test]
384 fn to_json_accepted() {
385 let result = FeedResult::Accepted { remaining: 3 };
386 let json = result.to_json("s-123");
387 assert_eq!(json["status"], "accepted");
388 assert_eq!(json["remaining"], 3);
389 }
390
391 #[test]
392 fn to_json_paused_single_query() {
393 let query = LlmQuery {
394 id: QueryId::single(),
395 prompt: "What is 2+2?".into(),
396 system: Some("You are a calculator.".into()),
397 max_tokens: 50,
398 grounded: false,
399 underspecified: false,
400 };
401 let result = FeedResult::Paused {
402 queries: vec![query],
403 };
404 let json = result.to_json("s-abc");
405
406 assert_eq!(json["status"], "needs_response");
407 assert_eq!(json["session_id"], "s-abc");
408 assert_eq!(json["prompt"], "What is 2+2?");
409 assert_eq!(json["system"], "You are a calculator.");
410 assert_eq!(json["max_tokens"], 50);
411 assert!(json.get("queries").is_none());
413 assert!(
415 json.get("grounded").is_none(),
416 "grounded key must be absent when false"
417 );
418 assert!(
420 json.get("underspecified").is_none(),
421 "underspecified key must be absent when false"
422 );
423 }
424
425 #[test]
426 fn to_json_paused_single_query_grounded() {
427 let query = LlmQuery {
428 id: QueryId::single(),
429 prompt: "verify this claim".into(),
430 system: None,
431 max_tokens: 200,
432 grounded: true,
433 underspecified: false,
434 };
435 let result = FeedResult::Paused {
436 queries: vec![query],
437 };
438 let json = result.to_json("s-grounded");
439
440 assert_eq!(json["status"], "needs_response");
441 assert_eq!(
442 json["grounded"], true,
443 "grounded must appear in single-query MCP JSON"
444 );
445 }
446
447 #[test]
448 fn to_json_paused_single_query_underspecified() {
449 let query = LlmQuery {
450 id: QueryId::single(),
451 prompt: "what output format do you need?".into(),
452 system: None,
453 max_tokens: 200,
454 grounded: false,
455 underspecified: true,
456 };
457 let result = FeedResult::Paused {
458 queries: vec![query],
459 };
460 let json = result.to_json("s-underspec");
461
462 assert_eq!(json["status"], "needs_response");
463 assert_eq!(
464 json["underspecified"], true,
465 "underspecified must appear in single-query MCP JSON"
466 );
467 assert!(
468 json.get("grounded").is_none(),
469 "grounded must be absent when false"
470 );
471 }
472
473 #[test]
474 fn to_json_paused_multiple_queries_mixed_grounded() {
475 let grounded_query = LlmQuery {
476 id: QueryId::batch(0),
477 prompt: "verify".into(),
478 system: None,
479 max_tokens: 100,
480 grounded: true,
481 underspecified: false,
482 };
483 let normal_query = LlmQuery {
484 id: QueryId::batch(1),
485 prompt: "generate".into(),
486 system: None,
487 max_tokens: 100,
488 grounded: false,
489 underspecified: false,
490 };
491 let result = FeedResult::Paused {
492 queries: vec![grounded_query, normal_query],
493 };
494 let json = result.to_json("s-batch");
495
496 let qs = json["queries"].as_array().expect("queries should be array");
497 assert_eq!(
498 qs[0]["grounded"], true,
499 "grounded query must have grounded=true"
500 );
501 assert!(
502 qs[1].get("grounded").is_none(),
503 "non-grounded query must omit grounded key"
504 );
505 }
506
507 #[test]
508 fn to_json_paused_multiple_queries_mixed_underspecified() {
509 let underspec_query = LlmQuery {
510 id: QueryId::batch(0),
511 prompt: "clarify intent".into(),
512 system: None,
513 max_tokens: 100,
514 grounded: false,
515 underspecified: true,
516 };
517 let normal_query = LlmQuery {
518 id: QueryId::batch(1),
519 prompt: "generate".into(),
520 system: None,
521 max_tokens: 100,
522 grounded: false,
523 underspecified: false,
524 };
525 let result = FeedResult::Paused {
526 queries: vec![underspec_query, normal_query],
527 };
528 let json = result.to_json("s-batch-us");
529
530 let qs = json["queries"].as_array().expect("queries should be array");
531 assert_eq!(
532 qs[0]["underspecified"], true,
533 "underspecified query must have underspecified=true"
534 );
535 assert!(
536 qs[1].get("underspecified").is_none(),
537 "non-underspecified query must omit underspecified key"
538 );
539 }
540
541 #[test]
542 fn to_json_paused_single_query_no_system() {
543 let query = LlmQuery {
544 id: QueryId::single(),
545 prompt: "hello".into(),
546 system: None,
547 max_tokens: 1024,
548 grounded: false,
549 underspecified: false,
550 };
551 let result = FeedResult::Paused {
552 queries: vec![query],
553 };
554 let json = result.to_json("s-x");
555
556 assert_eq!(json["status"], "needs_response");
557 assert!(json["system"].is_null());
558 }
559
560 #[test]
561 fn to_json_paused_multiple_queries() {
562 let queries = vec![make_query(0), make_query(1), make_query(2)];
563 let result = FeedResult::Paused { queries };
564 let json = result.to_json("s-multi");
565
566 assert_eq!(json["status"], "needs_response");
567 assert_eq!(json["session_id"], "s-multi");
568
569 let qs = json["queries"].as_array().expect("queries should be array");
570 assert_eq!(qs.len(), 3);
571 assert_eq!(qs[0]["id"], "q-0");
572 assert_eq!(qs[0]["prompt"], "prompt-0");
573 assert_eq!(qs[1]["id"], "q-1");
574 assert_eq!(qs[2]["id"], "q-2");
575 }
576
577 #[test]
578 fn to_json_finished_completed() {
579 let result = FeedResult::Finished(ExecutionResult {
580 state: TerminalState::Completed {
581 result: json!({"answer": 42}),
582 },
583 metrics: ExecutionMetrics::new(),
584 });
585 let json = result.to_json("s-done");
586
587 assert_eq!(json["status"], "completed");
588 assert_eq!(json["result"]["answer"], 42);
589 assert!(json.get("stats").is_some());
590 }
591
592 #[test]
593 fn to_json_finished_failed() {
594 let result = FeedResult::Finished(ExecutionResult {
595 state: TerminalState::Failed {
596 error: "lua error: bad argument".into(),
597 },
598 metrics: ExecutionMetrics::new(),
599 });
600 let json = result.to_json("s-err");
601
602 assert_eq!(json["status"], "error");
603 assert_eq!(json["error"], "lua error: bad argument");
604 }
605
606 #[test]
607 fn to_json_finished_cancelled() {
608 let result = FeedResult::Finished(ExecutionResult {
609 state: TerminalState::Cancelled,
610 metrics: ExecutionMetrics::new(),
611 });
612 let json = result.to_json("s-cancel");
613
614 assert_eq!(json["status"], "cancelled");
615 assert!(json.get("stats").is_some());
616 }
617
618 #[test]
621 fn session_id_starts_with_prefix() {
622 let id = gen_session_id();
623 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
624 }
625
626 #[test]
627 fn session_id_uniqueness() {
628 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
629 let set: std::collections::HashSet<&String> = ids.iter().collect();
630 assert_eq!(set.len(), 10, "10 IDs should all be unique");
631 }
632}