Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34pub struct ExecutionResult {
35    pub state: TerminalState,
36    pub metrics: ExecutionMetrics,
37}
38
39/// Result of a session interaction (start or feed).
40pub enum FeedResult {
41    /// Partial feed accepted, still waiting for more responses.
42    Accepted { remaining: usize },
43    /// All queries answered, Lua re-paused with new queries.
44    Paused { queries: Vec<LlmQuery> },
45    /// Execution completed (success, failure, or cancellation).
46    Finished(ExecutionResult),
47}
48
49impl FeedResult {
50    /// Convert to JSON for MCP tool response.
51    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52        match self {
53            Self::Accepted { remaining } => json!({
54                "status": "accepted",
55                "remaining": remaining,
56            }),
57            Self::Paused { queries } => {
58                if queries.len() == 1 {
59                    let q = &queries[0];
60                    let mut obj = json!({
61                        "status": "needs_response",
62                        "session_id": session_id,
63                        "prompt": q.prompt,
64                        "system": q.system,
65                        "max_tokens": q.max_tokens,
66                    });
67                    if q.grounded {
68                        obj["grounded"] = json!(true);
69                    }
70                    if q.underspecified {
71                        obj["underspecified"] = json!(true);
72                    }
73                    obj
74                } else {
75                    let qs: Vec<_> = queries
76                        .iter()
77                        .map(|q| {
78                            let mut obj = json!({
79                                "id": q.id.as_str(),
80                                "prompt": q.prompt,
81                                "system": q.system,
82                                "max_tokens": q.max_tokens,
83                            });
84                            if q.grounded {
85                                obj["grounded"] = json!(true);
86                            }
87                            if q.underspecified {
88                                obj["underspecified"] = json!(true);
89                            }
90                            obj
91                        })
92                        .collect();
93                    json!({
94                        "status": "needs_response",
95                        "session_id": session_id,
96                        "queries": qs,
97                    })
98                }
99            }
100            Self::Finished(result) => match &result.state {
101                TerminalState::Completed { result: val } => json!({
102                    "status": "completed",
103                    "result": val,
104                    "stats": result.metrics.to_json(),
105                }),
106                TerminalState::Failed { error } => json!({
107                    "status": "error",
108                    "error": error,
109                }),
110                TerminalState::Cancelled => json!({
111                    "status": "cancelled",
112                    "stats": result.metrics.to_json(),
113                }),
114            },
115        }
116    }
117}
118
119// ─── Session ─────────────────────────────────────────────────
120
121/// A Lua execution session with domain state tracking.
122pub struct Session {
123    state: ExecutionState,
124    metrics: ExecutionMetrics,
125    observer: MetricsObserver,
126    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
127    exec_task: AsyncTask,
128    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
129    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
130}
131
132impl Session {
133    pub fn new(
134        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
135        exec_task: AsyncTask,
136        metrics: ExecutionMetrics,
137    ) -> Self {
138        let observer = metrics.create_observer();
139        Self {
140            state: ExecutionState::Running,
141            metrics,
142            observer,
143            llm_rx,
144            exec_task,
145            resp_txs: HashMap::new(),
146        }
147    }
148
149    /// Wait for the next event from Lua execution.
150    ///
151    /// Called after initial start or after feeding all responses.
152    /// State must be Running when called.
153    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
154        tokio::select! {
155            result = &mut self.exec_task => {
156                match result {
157                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
158                        Ok(v) => {
159                            self.state.complete(v.clone()).map_err(|e| {
160                                SessionError::InvalidTransition(e.to_string())
161                            })?;
162                            self.observer.on_completed(&v);
163                            Ok(FeedResult::Finished(ExecutionResult {
164                                state: TerminalState::Completed { result: v },
165                                metrics: self.take_metrics(),
166                            }))
167                        }
168                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
169                    },
170                    Err(e) => self.fail_with(e.to_string()),
171                }
172            }
173            Some(req) = self.llm_rx.recv() => {
174                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
175                    id: qr.id.clone(),
176                    prompt: qr.prompt.clone(),
177                    system: qr.system.clone(),
178                    max_tokens: qr.max_tokens,
179                    grounded: qr.grounded,
180                    underspecified: qr.underspecified,
181                }).collect();
182
183                for qr in req.queries {
184                    self.resp_txs.insert(qr.id, qr.resp_tx);
185                }
186
187                self.state.pause(queries.clone()).map_err(|e| {
188                    SessionError::InvalidTransition(e.to_string())
189                })?;
190                self.observer.on_paused(&queries);
191                Ok(FeedResult::Paused { queries })
192            }
193        }
194    }
195
196    /// Feed one response by query_id.
197    ///
198    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
199    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
200        // Track response before ownership transfer.
201        self.observer.on_response_fed(query_id, &response);
202
203        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
204        if let Some(tx) = self.resp_txs.remove(query_id) {
205            let _ = tx.send(Ok(response.clone()));
206        }
207
208        // Domain: record in state machine
209        let complete = self
210            .state
211            .feed(query_id, response)
212            .map_err(SessionError::Feed)?;
213
214        if complete {
215            // Domain: transition Paused(complete) → Running
216            self.state
217                .take_responses()
218                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
219            self.observer.on_resumed();
220        } else {
221            self.observer
222                .on_partial_feed(query_id, self.state.remaining());
223        }
224
225        Ok(complete)
226    }
227
228    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
229        self.state
230            .fail(msg.clone())
231            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
232        self.observer.on_failed(&msg);
233        Ok(FeedResult::Finished(ExecutionResult {
234            state: TerminalState::Failed { error: msg },
235            metrics: self.take_metrics(),
236        }))
237    }
238
239    fn take_metrics(&mut self) -> ExecutionMetrics {
240        std::mem::take(&mut self.metrics)
241    }
242}
243
244// ─── Registry ────────────────────────────────────────────────
245
246/// Manages active sessions. Replaces the raw SessionMap.
247pub struct SessionRegistry {
248    sessions: Arc<Mutex<HashMap<String, Session>>>,
249}
250
251impl Default for SessionRegistry {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257impl SessionRegistry {
258    pub fn new() -> Self {
259        Self {
260            sessions: Arc::new(Mutex::new(HashMap::new())),
261        }
262    }
263
264    /// Start execution and wait for first event (pause or completion).
265    pub async fn start_execution(
266        &self,
267        mut session: Session,
268    ) -> Result<(String, FeedResult), SessionError> {
269        let session_id = gen_session_id();
270        let result = session.wait_event().await?;
271
272        if matches!(result, FeedResult::Paused { .. }) {
273            self.sessions
274                .lock()
275                .await
276                .insert(session_id.clone(), session);
277        }
278
279        Ok((session_id, result))
280    }
281
282    /// Feed one response to a paused session by query_id.
283    ///
284    /// If this completes all pending queries, the session resumes and
285    /// returns the next event (Paused or Finished).
286    /// If queries remain, returns Accepted { remaining }.
287    pub async fn feed_response(
288        &self,
289        session_id: &str,
290        query_id: &QueryId,
291        response: String,
292    ) -> Result<FeedResult, SessionError> {
293        // 1. Feed under lock
294        let complete = {
295            let mut map = self.sessions.lock().await;
296            let session = map
297                .get_mut(session_id)
298                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
299
300            let complete = session.feed_one(query_id, response)?;
301
302            if !complete {
303                return Ok(FeedResult::Accepted {
304                    remaining: session.state.remaining(),
305                });
306            }
307
308            complete
309        };
310
311        // 2. All complete → take session out for async resume
312        debug_assert!(complete);
313        let mut session = {
314            let mut map = self.sessions.lock().await;
315            map.remove(session_id)
316                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
317        };
318
319        let result = session.wait_event().await?;
320
321        if matches!(result, FeedResult::Paused { .. }) {
322            self.sessions
323                .lock()
324                .await
325                .insert(session_id.into(), session);
326        }
327
328        Ok(result)
329    }
330}
331
332/// Generate a non-deterministic session ID.
333///
334/// MCP spec requires "secure, non-deterministic session IDs" to prevent
335/// session hijacking. Uses timestamp + random bytes for uniqueness and
336/// unpredictability.
337fn gen_session_id() -> String {
338    use std::time::{SystemTime, UNIX_EPOCH};
339    let ts = SystemTime::now()
340        .duration_since(UNIX_EPOCH)
341        .unwrap_or_default()
342        .as_nanos();
343    // 8 random bytes → 16 hex chars of entropy
344    let random: u64 = {
345        use std::collections::hash_map::RandomState;
346        use std::hash::{BuildHasher, Hasher};
347        let s = RandomState::new();
348        let mut h = s.build_hasher();
349        h.write_u128(ts);
350        h.finish()
351    };
352    format!("s-{ts:x}-{random:016x}")
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
359    use serde_json::json;
360
361    fn make_query(index: usize) -> LlmQuery {
362        LlmQuery {
363            id: QueryId::batch(index),
364            prompt: format!("prompt-{index}"),
365            system: None,
366            max_tokens: 100,
367            grounded: false,
368            underspecified: false,
369        }
370    }
371
372    // ─── FeedResult::to_json tests ───
373
374    #[test]
375    fn to_json_accepted() {
376        let result = FeedResult::Accepted { remaining: 3 };
377        let json = result.to_json("s-123");
378        assert_eq!(json["status"], "accepted");
379        assert_eq!(json["remaining"], 3);
380    }
381
382    #[test]
383    fn to_json_paused_single_query() {
384        let query = LlmQuery {
385            id: QueryId::single(),
386            prompt: "What is 2+2?".into(),
387            system: Some("You are a calculator.".into()),
388            max_tokens: 50,
389            grounded: false,
390            underspecified: false,
391        };
392        let result = FeedResult::Paused {
393            queries: vec![query],
394        };
395        let json = result.to_json("s-abc");
396
397        assert_eq!(json["status"], "needs_response");
398        assert_eq!(json["session_id"], "s-abc");
399        assert_eq!(json["prompt"], "What is 2+2?");
400        assert_eq!(json["system"], "You are a calculator.");
401        assert_eq!(json["max_tokens"], 50);
402        // single query mode: no "queries" array
403        assert!(json.get("queries").is_none());
404        // grounded=false must be absent
405        assert!(
406            json.get("grounded").is_none(),
407            "grounded key must be absent when false"
408        );
409        // underspecified=false must be absent
410        assert!(
411            json.get("underspecified").is_none(),
412            "underspecified key must be absent when false"
413        );
414    }
415
416    #[test]
417    fn to_json_paused_single_query_grounded() {
418        let query = LlmQuery {
419            id: QueryId::single(),
420            prompt: "verify this claim".into(),
421            system: None,
422            max_tokens: 200,
423            grounded: true,
424            underspecified: false,
425        };
426        let result = FeedResult::Paused {
427            queries: vec![query],
428        };
429        let json = result.to_json("s-grounded");
430
431        assert_eq!(json["status"], "needs_response");
432        assert_eq!(
433            json["grounded"], true,
434            "grounded must appear in single-query MCP JSON"
435        );
436    }
437
438    #[test]
439    fn to_json_paused_single_query_underspecified() {
440        let query = LlmQuery {
441            id: QueryId::single(),
442            prompt: "what output format do you need?".into(),
443            system: None,
444            max_tokens: 200,
445            grounded: false,
446            underspecified: true,
447        };
448        let result = FeedResult::Paused {
449            queries: vec![query],
450        };
451        let json = result.to_json("s-underspec");
452
453        assert_eq!(json["status"], "needs_response");
454        assert_eq!(
455            json["underspecified"], true,
456            "underspecified must appear in single-query MCP JSON"
457        );
458        assert!(
459            json.get("grounded").is_none(),
460            "grounded must be absent when false"
461        );
462    }
463
464    #[test]
465    fn to_json_paused_multiple_queries_mixed_grounded() {
466        let grounded_query = LlmQuery {
467            id: QueryId::batch(0),
468            prompt: "verify".into(),
469            system: None,
470            max_tokens: 100,
471            grounded: true,
472            underspecified: false,
473        };
474        let normal_query = LlmQuery {
475            id: QueryId::batch(1),
476            prompt: "generate".into(),
477            system: None,
478            max_tokens: 100,
479            grounded: false,
480            underspecified: false,
481        };
482        let result = FeedResult::Paused {
483            queries: vec![grounded_query, normal_query],
484        };
485        let json = result.to_json("s-batch");
486
487        let qs = json["queries"].as_array().expect("queries should be array");
488        assert_eq!(
489            qs[0]["grounded"], true,
490            "grounded query must have grounded=true"
491        );
492        assert!(
493            qs[1].get("grounded").is_none(),
494            "non-grounded query must omit grounded key"
495        );
496    }
497
498    #[test]
499    fn to_json_paused_multiple_queries_mixed_underspecified() {
500        let underspec_query = LlmQuery {
501            id: QueryId::batch(0),
502            prompt: "clarify intent".into(),
503            system: None,
504            max_tokens: 100,
505            grounded: false,
506            underspecified: true,
507        };
508        let normal_query = LlmQuery {
509            id: QueryId::batch(1),
510            prompt: "generate".into(),
511            system: None,
512            max_tokens: 100,
513            grounded: false,
514            underspecified: false,
515        };
516        let result = FeedResult::Paused {
517            queries: vec![underspec_query, normal_query],
518        };
519        let json = result.to_json("s-batch-us");
520
521        let qs = json["queries"].as_array().expect("queries should be array");
522        assert_eq!(
523            qs[0]["underspecified"], true,
524            "underspecified query must have underspecified=true"
525        );
526        assert!(
527            qs[1].get("underspecified").is_none(),
528            "non-underspecified query must omit underspecified key"
529        );
530    }
531
532    #[test]
533    fn to_json_paused_single_query_no_system() {
534        let query = LlmQuery {
535            id: QueryId::single(),
536            prompt: "hello".into(),
537            system: None,
538            max_tokens: 1024,
539            grounded: false,
540            underspecified: false,
541        };
542        let result = FeedResult::Paused {
543            queries: vec![query],
544        };
545        let json = result.to_json("s-x");
546
547        assert_eq!(json["status"], "needs_response");
548        assert!(json["system"].is_null());
549    }
550
551    #[test]
552    fn to_json_paused_multiple_queries() {
553        let queries = vec![make_query(0), make_query(1), make_query(2)];
554        let result = FeedResult::Paused { queries };
555        let json = result.to_json("s-multi");
556
557        assert_eq!(json["status"], "needs_response");
558        assert_eq!(json["session_id"], "s-multi");
559
560        let qs = json["queries"].as_array().expect("queries should be array");
561        assert_eq!(qs.len(), 3);
562        assert_eq!(qs[0]["id"], "q-0");
563        assert_eq!(qs[0]["prompt"], "prompt-0");
564        assert_eq!(qs[1]["id"], "q-1");
565        assert_eq!(qs[2]["id"], "q-2");
566    }
567
568    #[test]
569    fn to_json_finished_completed() {
570        let result = FeedResult::Finished(ExecutionResult {
571            state: TerminalState::Completed {
572                result: json!({"answer": 42}),
573            },
574            metrics: ExecutionMetrics::new(),
575        });
576        let json = result.to_json("s-done");
577
578        assert_eq!(json["status"], "completed");
579        assert_eq!(json["result"]["answer"], 42);
580        assert!(json.get("stats").is_some());
581    }
582
583    #[test]
584    fn to_json_finished_failed() {
585        let result = FeedResult::Finished(ExecutionResult {
586            state: TerminalState::Failed {
587                error: "lua error: bad argument".into(),
588            },
589            metrics: ExecutionMetrics::new(),
590        });
591        let json = result.to_json("s-err");
592
593        assert_eq!(json["status"], "error");
594        assert_eq!(json["error"], "lua error: bad argument");
595    }
596
597    #[test]
598    fn to_json_finished_cancelled() {
599        let result = FeedResult::Finished(ExecutionResult {
600            state: TerminalState::Cancelled,
601            metrics: ExecutionMetrics::new(),
602        });
603        let json = result.to_json("s-cancel");
604
605        assert_eq!(json["status"], "cancelled");
606        assert!(json.get("stats").is_some());
607    }
608
609    // ─── gen_session_id tests ───
610
611    #[test]
612    fn session_id_starts_with_prefix() {
613        let id = gen_session_id();
614        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
615    }
616
617    #[test]
618    fn session_id_uniqueness() {
619        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
620        let set: std::collections::HashSet<&String> = ids.iter().collect();
621        assert_eq!(set.len(), 10, "10 IDs should all be unique");
622    }
623}