1use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10 ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11 TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23 #[error("session '{0}' not found")]
24 NotFound(String),
25 #[error(transparent)]
26 Feed(#[from] algocline_core::FeedError),
27 #[error("invalid transition: {0}")]
28 InvalidTransition(String),
29}
30
31pub struct ExecutionResult {
35 pub state: TerminalState,
36 pub metrics: ExecutionMetrics,
37}
38
39pub enum FeedResult {
41 Accepted { remaining: usize },
43 Paused { queries: Vec<LlmQuery> },
45 Finished(ExecutionResult),
47}
48
49impl FeedResult {
50 pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52 match self {
53 Self::Accepted { remaining } => json!({
54 "status": "accepted",
55 "remaining": remaining,
56 }),
57 Self::Paused { queries } => {
58 if queries.len() == 1 {
59 let q = &queries[0];
60 let mut obj = json!({
61 "status": "needs_response",
62 "session_id": session_id,
63 "prompt": q.prompt,
64 "system": q.system,
65 "max_tokens": q.max_tokens,
66 });
67 if q.grounded {
68 obj["grounded"] = json!(true);
69 }
70 if q.underspecified {
71 obj["underspecified"] = json!(true);
72 }
73 obj
74 } else {
75 let qs: Vec<_> = queries
76 .iter()
77 .map(|q| {
78 let mut obj = json!({
79 "id": q.id.as_str(),
80 "prompt": q.prompt,
81 "system": q.system,
82 "max_tokens": q.max_tokens,
83 });
84 if q.grounded {
85 obj["grounded"] = json!(true);
86 }
87 if q.underspecified {
88 obj["underspecified"] = json!(true);
89 }
90 obj
91 })
92 .collect();
93 json!({
94 "status": "needs_response",
95 "session_id": session_id,
96 "queries": qs,
97 })
98 }
99 }
100 Self::Finished(result) => match &result.state {
101 TerminalState::Completed { result: val } => json!({
102 "status": "completed",
103 "result": val,
104 "stats": result.metrics.to_json(),
105 }),
106 TerminalState::Failed { error } => json!({
107 "status": "error",
108 "error": error,
109 }),
110 TerminalState::Cancelled => json!({
111 "status": "cancelled",
112 "stats": result.metrics.to_json(),
113 }),
114 },
115 }
116 }
117}
118
119pub struct Session {
123 state: ExecutionState,
124 metrics: ExecutionMetrics,
125 observer: MetricsObserver,
126 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
127 exec_task: AsyncTask,
128 resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
130}
131
132impl Session {
133 pub fn new(
134 llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
135 exec_task: AsyncTask,
136 metrics: ExecutionMetrics,
137 ) -> Self {
138 let observer = metrics.create_observer();
139 Self {
140 state: ExecutionState::Running,
141 metrics,
142 observer,
143 llm_rx,
144 exec_task,
145 resp_txs: HashMap::new(),
146 }
147 }
148
149 async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
154 tokio::select! {
155 result = &mut self.exec_task => {
156 match result {
157 Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
158 Ok(v) => {
159 self.state.complete(v.clone()).map_err(|e| {
160 SessionError::InvalidTransition(e.to_string())
161 })?;
162 self.observer.on_completed(&v);
163 Ok(FeedResult::Finished(ExecutionResult {
164 state: TerminalState::Completed { result: v },
165 metrics: self.take_metrics(),
166 }))
167 }
168 Err(e) => self.fail_with(format!("JSON parse: {e}")),
169 },
170 Err(e) => self.fail_with(e.to_string()),
171 }
172 }
173 Some(req) = self.llm_rx.recv() => {
174 let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
175 id: qr.id.clone(),
176 prompt: qr.prompt.clone(),
177 system: qr.system.clone(),
178 max_tokens: qr.max_tokens,
179 grounded: qr.grounded,
180 underspecified: qr.underspecified,
181 }).collect();
182
183 for qr in req.queries {
184 self.resp_txs.insert(qr.id, qr.resp_tx);
185 }
186
187 self.state.pause(queries.clone()).map_err(|e| {
188 SessionError::InvalidTransition(e.to_string())
189 })?;
190 self.observer.on_paused(&queries);
191 Ok(FeedResult::Paused { queries })
192 }
193 }
194 }
195
196 fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
200 self.observer.on_response_fed(query_id, &response);
202
203 if let Some(tx) = self.resp_txs.remove(query_id) {
205 let _ = tx.send(Ok(response.clone()));
206 }
207
208 let complete = self
210 .state
211 .feed(query_id, response)
212 .map_err(SessionError::Feed)?;
213
214 if complete {
215 self.state
217 .take_responses()
218 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
219 self.observer.on_resumed();
220 } else {
221 self.observer
222 .on_partial_feed(query_id, self.state.remaining());
223 }
224
225 Ok(complete)
226 }
227
228 fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
229 self.state
230 .fail(msg.clone())
231 .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
232 self.observer.on_failed(&msg);
233 Ok(FeedResult::Finished(ExecutionResult {
234 state: TerminalState::Failed { error: msg },
235 metrics: self.take_metrics(),
236 }))
237 }
238
239 fn take_metrics(&mut self) -> ExecutionMetrics {
240 std::mem::take(&mut self.metrics)
241 }
242}
243
244pub struct SessionRegistry {
248 sessions: Arc<Mutex<HashMap<String, Session>>>,
249}
250
251impl Default for SessionRegistry {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257impl SessionRegistry {
258 pub fn new() -> Self {
259 Self {
260 sessions: Arc::new(Mutex::new(HashMap::new())),
261 }
262 }
263
264 pub async fn start_execution(
266 &self,
267 mut session: Session,
268 ) -> Result<(String, FeedResult), SessionError> {
269 let session_id = gen_session_id();
270 let result = session.wait_event().await?;
271
272 if matches!(result, FeedResult::Paused { .. }) {
273 self.sessions
274 .lock()
275 .await
276 .insert(session_id.clone(), session);
277 }
278
279 Ok((session_id, result))
280 }
281
282 pub async fn feed_response(
288 &self,
289 session_id: &str,
290 query_id: &QueryId,
291 response: String,
292 ) -> Result<FeedResult, SessionError> {
293 let complete = {
295 let mut map = self.sessions.lock().await;
296 let session = map
297 .get_mut(session_id)
298 .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
299
300 let complete = session.feed_one(query_id, response)?;
301
302 if !complete {
303 return Ok(FeedResult::Accepted {
304 remaining: session.state.remaining(),
305 });
306 }
307
308 complete
309 };
310
311 debug_assert!(complete);
313 let mut session = {
314 let mut map = self.sessions.lock().await;
315 map.remove(session_id)
316 .ok_or_else(|| SessionError::NotFound(session_id.into()))?
317 };
318
319 let result = session.wait_event().await?;
320
321 if matches!(result, FeedResult::Paused { .. }) {
322 self.sessions
323 .lock()
324 .await
325 .insert(session_id.into(), session);
326 }
327
328 Ok(result)
329 }
330}
331
332fn gen_session_id() -> String {
338 use std::time::{SystemTime, UNIX_EPOCH};
339 let ts = SystemTime::now()
340 .duration_since(UNIX_EPOCH)
341 .unwrap_or_default()
342 .as_nanos();
343 let random: u64 = {
345 use std::collections::hash_map::RandomState;
346 use std::hash::{BuildHasher, Hasher};
347 let s = RandomState::new();
348 let mut h = s.build_hasher();
349 h.write_u128(ts);
350 h.finish()
351 };
352 format!("s-{ts:x}-{random:016x}")
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
359 use serde_json::json;
360
361 fn make_query(index: usize) -> LlmQuery {
362 LlmQuery {
363 id: QueryId::batch(index),
364 prompt: format!("prompt-{index}"),
365 system: None,
366 max_tokens: 100,
367 grounded: false,
368 underspecified: false,
369 }
370 }
371
372 #[test]
375 fn to_json_accepted() {
376 let result = FeedResult::Accepted { remaining: 3 };
377 let json = result.to_json("s-123");
378 assert_eq!(json["status"], "accepted");
379 assert_eq!(json["remaining"], 3);
380 }
381
382 #[test]
383 fn to_json_paused_single_query() {
384 let query = LlmQuery {
385 id: QueryId::single(),
386 prompt: "What is 2+2?".into(),
387 system: Some("You are a calculator.".into()),
388 max_tokens: 50,
389 grounded: false,
390 underspecified: false,
391 };
392 let result = FeedResult::Paused {
393 queries: vec![query],
394 };
395 let json = result.to_json("s-abc");
396
397 assert_eq!(json["status"], "needs_response");
398 assert_eq!(json["session_id"], "s-abc");
399 assert_eq!(json["prompt"], "What is 2+2?");
400 assert_eq!(json["system"], "You are a calculator.");
401 assert_eq!(json["max_tokens"], 50);
402 assert!(json.get("queries").is_none());
404 assert!(
406 json.get("grounded").is_none(),
407 "grounded key must be absent when false"
408 );
409 assert!(
411 json.get("underspecified").is_none(),
412 "underspecified key must be absent when false"
413 );
414 }
415
416 #[test]
417 fn to_json_paused_single_query_grounded() {
418 let query = LlmQuery {
419 id: QueryId::single(),
420 prompt: "verify this claim".into(),
421 system: None,
422 max_tokens: 200,
423 grounded: true,
424 underspecified: false,
425 };
426 let result = FeedResult::Paused {
427 queries: vec![query],
428 };
429 let json = result.to_json("s-grounded");
430
431 assert_eq!(json["status"], "needs_response");
432 assert_eq!(
433 json["grounded"], true,
434 "grounded must appear in single-query MCP JSON"
435 );
436 }
437
438 #[test]
439 fn to_json_paused_single_query_underspecified() {
440 let query = LlmQuery {
441 id: QueryId::single(),
442 prompt: "what output format do you need?".into(),
443 system: None,
444 max_tokens: 200,
445 grounded: false,
446 underspecified: true,
447 };
448 let result = FeedResult::Paused {
449 queries: vec![query],
450 };
451 let json = result.to_json("s-underspec");
452
453 assert_eq!(json["status"], "needs_response");
454 assert_eq!(
455 json["underspecified"], true,
456 "underspecified must appear in single-query MCP JSON"
457 );
458 assert!(
459 json.get("grounded").is_none(),
460 "grounded must be absent when false"
461 );
462 }
463
464 #[test]
465 fn to_json_paused_multiple_queries_mixed_grounded() {
466 let grounded_query = LlmQuery {
467 id: QueryId::batch(0),
468 prompt: "verify".into(),
469 system: None,
470 max_tokens: 100,
471 grounded: true,
472 underspecified: false,
473 };
474 let normal_query = LlmQuery {
475 id: QueryId::batch(1),
476 prompt: "generate".into(),
477 system: None,
478 max_tokens: 100,
479 grounded: false,
480 underspecified: false,
481 };
482 let result = FeedResult::Paused {
483 queries: vec![grounded_query, normal_query],
484 };
485 let json = result.to_json("s-batch");
486
487 let qs = json["queries"].as_array().expect("queries should be array");
488 assert_eq!(
489 qs[0]["grounded"], true,
490 "grounded query must have grounded=true"
491 );
492 assert!(
493 qs[1].get("grounded").is_none(),
494 "non-grounded query must omit grounded key"
495 );
496 }
497
498 #[test]
499 fn to_json_paused_multiple_queries_mixed_underspecified() {
500 let underspec_query = LlmQuery {
501 id: QueryId::batch(0),
502 prompt: "clarify intent".into(),
503 system: None,
504 max_tokens: 100,
505 grounded: false,
506 underspecified: true,
507 };
508 let normal_query = LlmQuery {
509 id: QueryId::batch(1),
510 prompt: "generate".into(),
511 system: None,
512 max_tokens: 100,
513 grounded: false,
514 underspecified: false,
515 };
516 let result = FeedResult::Paused {
517 queries: vec![underspec_query, normal_query],
518 };
519 let json = result.to_json("s-batch-us");
520
521 let qs = json["queries"].as_array().expect("queries should be array");
522 assert_eq!(
523 qs[0]["underspecified"], true,
524 "underspecified query must have underspecified=true"
525 );
526 assert!(
527 qs[1].get("underspecified").is_none(),
528 "non-underspecified query must omit underspecified key"
529 );
530 }
531
532 #[test]
533 fn to_json_paused_single_query_no_system() {
534 let query = LlmQuery {
535 id: QueryId::single(),
536 prompt: "hello".into(),
537 system: None,
538 max_tokens: 1024,
539 grounded: false,
540 underspecified: false,
541 };
542 let result = FeedResult::Paused {
543 queries: vec![query],
544 };
545 let json = result.to_json("s-x");
546
547 assert_eq!(json["status"], "needs_response");
548 assert!(json["system"].is_null());
549 }
550
551 #[test]
552 fn to_json_paused_multiple_queries() {
553 let queries = vec![make_query(0), make_query(1), make_query(2)];
554 let result = FeedResult::Paused { queries };
555 let json = result.to_json("s-multi");
556
557 assert_eq!(json["status"], "needs_response");
558 assert_eq!(json["session_id"], "s-multi");
559
560 let qs = json["queries"].as_array().expect("queries should be array");
561 assert_eq!(qs.len(), 3);
562 assert_eq!(qs[0]["id"], "q-0");
563 assert_eq!(qs[0]["prompt"], "prompt-0");
564 assert_eq!(qs[1]["id"], "q-1");
565 assert_eq!(qs[2]["id"], "q-2");
566 }
567
568 #[test]
569 fn to_json_finished_completed() {
570 let result = FeedResult::Finished(ExecutionResult {
571 state: TerminalState::Completed {
572 result: json!({"answer": 42}),
573 },
574 metrics: ExecutionMetrics::new(),
575 });
576 let json = result.to_json("s-done");
577
578 assert_eq!(json["status"], "completed");
579 assert_eq!(json["result"]["answer"], 42);
580 assert!(json.get("stats").is_some());
581 }
582
583 #[test]
584 fn to_json_finished_failed() {
585 let result = FeedResult::Finished(ExecutionResult {
586 state: TerminalState::Failed {
587 error: "lua error: bad argument".into(),
588 },
589 metrics: ExecutionMetrics::new(),
590 });
591 let json = result.to_json("s-err");
592
593 assert_eq!(json["status"], "error");
594 assert_eq!(json["error"], "lua error: bad argument");
595 }
596
597 #[test]
598 fn to_json_finished_cancelled() {
599 let result = FeedResult::Finished(ExecutionResult {
600 state: TerminalState::Cancelled,
601 metrics: ExecutionMetrics::new(),
602 });
603 let json = result.to_json("s-cancel");
604
605 assert_eq!(json["status"], "cancelled");
606 assert!(json.get("stats").is_some());
607 }
608
609 #[test]
612 fn session_id_starts_with_prefix() {
613 let id = gen_session_id();
614 assert!(id.starts_with("s-"), "id should start with 's-': {id}");
615 }
616
617 #[test]
618 fn session_id_uniqueness() {
619 let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
620 let set: std::collections::HashSet<&String> = ids.iter().collect();
621 assert_eq!(set.len(), 10, "10 IDs should all be unique");
622 }
623}