Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34pub struct ExecutionResult {
35    pub state: TerminalState,
36    pub metrics: ExecutionMetrics,
37}
38
39/// Result of a session interaction (start or feed).
40pub enum FeedResult {
41    /// Partial feed accepted, still waiting for more responses.
42    Accepted { remaining: usize },
43    /// All queries answered, Lua re-paused with new queries.
44    Paused { queries: Vec<LlmQuery> },
45    /// Execution completed (success, failure, or cancellation).
46    Finished(ExecutionResult),
47}
48
49impl FeedResult {
50    /// Convert to JSON for MCP tool response.
51    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52        match self {
53            Self::Accepted { remaining } => json!({
54                "status": "accepted",
55                "remaining": remaining,
56            }),
57            Self::Paused { queries } => {
58                if queries.len() == 1 {
59                    let q = &queries[0];
60                    let mut obj = json!({
61                        "status": "needs_response",
62                        "session_id": session_id,
63                        "prompt": q.prompt,
64                        "system": q.system,
65                        "max_tokens": q.max_tokens,
66                    });
67                    if q.grounded {
68                        obj["grounded"] = json!(true);
69                    }
70                    if q.underspecified {
71                        obj["underspecified"] = json!(true);
72                    }
73                    obj
74                } else {
75                    let qs: Vec<_> = queries
76                        .iter()
77                        .map(|q| {
78                            let mut obj = json!({
79                                "id": q.id.as_str(),
80                                "prompt": q.prompt,
81                                "system": q.system,
82                                "max_tokens": q.max_tokens,
83                            });
84                            if q.grounded {
85                                obj["grounded"] = json!(true);
86                            }
87                            if q.underspecified {
88                                obj["underspecified"] = json!(true);
89                            }
90                            obj
91                        })
92                        .collect();
93                    json!({
94                        "status": "needs_response",
95                        "session_id": session_id,
96                        "queries": qs,
97                    })
98                }
99            }
100            Self::Finished(result) => match &result.state {
101                TerminalState::Completed { result: val } => json!({
102                    "status": "completed",
103                    "result": val,
104                    "stats": result.metrics.to_json(),
105                }),
106                TerminalState::Failed { error } => json!({
107                    "status": "error",
108                    "error": error,
109                }),
110                TerminalState::Cancelled => json!({
111                    "status": "cancelled",
112                    "stats": result.metrics.to_json(),
113                }),
114            },
115        }
116    }
117}
118
119// ─── Session ─────────────────────────────────────────────────
120
121/// A Lua execution session with domain state tracking.
122///
123/// Each session owns a dedicated Lua VM via `_vm_driver`. The VM's OS thread
124/// stays alive as long as the driver is held, and exits cleanly when the
125/// session is dropped (channel closes → Lua thread drains and exits).
126pub struct Session {
127    state: ExecutionState,
128    metrics: ExecutionMetrics,
129    observer: MetricsObserver,
130    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
131    exec_task: AsyncTask,
132    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
133    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
134    /// Per-session VM lifecycle driver. Keeps the Lua thread alive.
135    /// Dropped when the session completes or is abandoned.
136    _vm_driver: AsyncIsleDriver,
137}
138
139impl Session {
140    pub fn new(
141        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
142        exec_task: AsyncTask,
143        metrics: ExecutionMetrics,
144        vm_driver: AsyncIsleDriver,
145    ) -> Self {
146        let observer = metrics.create_observer();
147        Self {
148            state: ExecutionState::Running,
149            metrics,
150            observer,
151            llm_rx,
152            exec_task,
153            resp_txs: HashMap::new(),
154            _vm_driver: vm_driver,
155        }
156    }
157
158    /// Wait for the next event from Lua execution.
159    ///
160    /// Called after initial start or after feeding all responses.
161    /// State must be Running when called.
162    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
163        tokio::select! {
164            result = &mut self.exec_task => {
165                match result {
166                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
167                        Ok(v) => {
168                            self.state.complete(v.clone()).map_err(|e| {
169                                SessionError::InvalidTransition(e.to_string())
170                            })?;
171                            self.observer.on_completed(&v);
172                            Ok(FeedResult::Finished(ExecutionResult {
173                                state: TerminalState::Completed { result: v },
174                                metrics: self.take_metrics(),
175                            }))
176                        }
177                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
178                    },
179                    Err(e) => self.fail_with(e.to_string()),
180                }
181            }
182            Some(req) = self.llm_rx.recv() => {
183                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
184                    id: qr.id.clone(),
185                    prompt: qr.prompt.clone(),
186                    system: qr.system.clone(),
187                    max_tokens: qr.max_tokens,
188                    grounded: qr.grounded,
189                    underspecified: qr.underspecified,
190                }).collect();
191
192                for qr in req.queries {
193                    self.resp_txs.insert(qr.id, qr.resp_tx);
194                }
195
196                self.state.pause(queries.clone()).map_err(|e| {
197                    SessionError::InvalidTransition(e.to_string())
198                })?;
199                self.observer.on_paused(&queries);
200                Ok(FeedResult::Paused { queries })
201            }
202        }
203    }
204
205    /// Feed one response by query_id.
206    ///
207    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
208    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
209        // Track response before ownership transfer.
210        self.observer.on_response_fed(query_id, &response);
211
212        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
213        if let Some(tx) = self.resp_txs.remove(query_id) {
214            let _ = tx.send(Ok(response.clone()));
215        }
216
217        // Domain: record in state machine
218        let complete = self
219            .state
220            .feed(query_id, response)
221            .map_err(SessionError::Feed)?;
222
223        if complete {
224            // Domain: transition Paused(complete) → Running
225            self.state
226                .take_responses()
227                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
228            self.observer.on_resumed();
229        } else {
230            self.observer
231                .on_partial_feed(query_id, self.state.remaining());
232        }
233
234        Ok(complete)
235    }
236
237    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
238        self.state
239            .fail(msg.clone())
240            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
241        self.observer.on_failed(&msg);
242        Ok(FeedResult::Finished(ExecutionResult {
243            state: TerminalState::Failed { error: msg },
244            metrics: self.take_metrics(),
245        }))
246    }
247
248    fn take_metrics(&mut self) -> ExecutionMetrics {
249        std::mem::take(&mut self.metrics)
250    }
251}
252
253// ─── Registry ────────────────────────────────────────────────
254
255/// Manages active sessions. Replaces the raw SessionMap.
256pub struct SessionRegistry {
257    sessions: Arc<Mutex<HashMap<String, Session>>>,
258}
259
260impl Default for SessionRegistry {
261    fn default() -> Self {
262        Self::new()
263    }
264}
265
266impl SessionRegistry {
267    pub fn new() -> Self {
268        Self {
269            sessions: Arc::new(Mutex::new(HashMap::new())),
270        }
271    }
272
273    /// Start execution and wait for first event (pause or completion).
274    pub async fn start_execution(
275        &self,
276        mut session: Session,
277    ) -> Result<(String, FeedResult), SessionError> {
278        let session_id = gen_session_id();
279        let result = session.wait_event().await?;
280
281        if matches!(result, FeedResult::Paused { .. }) {
282            self.sessions
283                .lock()
284                .await
285                .insert(session_id.clone(), session);
286        }
287
288        Ok((session_id, result))
289    }
290
291    /// Feed one response to a paused session by query_id.
292    ///
293    /// If this completes all pending queries, the session resumes and
294    /// returns the next event (Paused or Finished).
295    /// If queries remain, returns Accepted { remaining }.
296    pub async fn feed_response(
297        &self,
298        session_id: &str,
299        query_id: &QueryId,
300        response: String,
301    ) -> Result<FeedResult, SessionError> {
302        // 1. Feed under lock
303        let complete = {
304            let mut map = self.sessions.lock().await;
305            let session = map
306                .get_mut(session_id)
307                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
308
309            let complete = session.feed_one(query_id, response)?;
310
311            if !complete {
312                return Ok(FeedResult::Accepted {
313                    remaining: session.state.remaining(),
314                });
315            }
316
317            complete
318        };
319
320        // 2. All complete → take session out for async resume
321        debug_assert!(complete);
322        let mut session = {
323            let mut map = self.sessions.lock().await;
324            map.remove(session_id)
325                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
326        };
327
328        let result = session.wait_event().await?;
329
330        if matches!(result, FeedResult::Paused { .. }) {
331            self.sessions
332                .lock()
333                .await
334                .insert(session_id.into(), session);
335        }
336
337        Ok(result)
338    }
339}
340
341/// Generate a non-deterministic session ID.
342///
343/// MCP spec requires "secure, non-deterministic session IDs" to prevent
344/// session hijacking. Uses timestamp + random bytes for uniqueness and
345/// unpredictability.
346fn gen_session_id() -> String {
347    use std::time::{SystemTime, UNIX_EPOCH};
348    let ts = SystemTime::now()
349        .duration_since(UNIX_EPOCH)
350        .unwrap_or_default()
351        .as_nanos();
352    // 8 random bytes → 16 hex chars of entropy
353    let random: u64 = {
354        use std::collections::hash_map::RandomState;
355        use std::hash::{BuildHasher, Hasher};
356        let s = RandomState::new();
357        let mut h = s.build_hasher();
358        h.write_u128(ts);
359        h.finish()
360    };
361    format!("s-{ts:x}-{random:016x}")
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
368    use serde_json::json;
369
370    fn make_query(index: usize) -> LlmQuery {
371        LlmQuery {
372            id: QueryId::batch(index),
373            prompt: format!("prompt-{index}"),
374            system: None,
375            max_tokens: 100,
376            grounded: false,
377            underspecified: false,
378        }
379    }
380
381    // ─── FeedResult::to_json tests ───
382
383    #[test]
384    fn to_json_accepted() {
385        let result = FeedResult::Accepted { remaining: 3 };
386        let json = result.to_json("s-123");
387        assert_eq!(json["status"], "accepted");
388        assert_eq!(json["remaining"], 3);
389    }
390
391    #[test]
392    fn to_json_paused_single_query() {
393        let query = LlmQuery {
394            id: QueryId::single(),
395            prompt: "What is 2+2?".into(),
396            system: Some("You are a calculator.".into()),
397            max_tokens: 50,
398            grounded: false,
399            underspecified: false,
400        };
401        let result = FeedResult::Paused {
402            queries: vec![query],
403        };
404        let json = result.to_json("s-abc");
405
406        assert_eq!(json["status"], "needs_response");
407        assert_eq!(json["session_id"], "s-abc");
408        assert_eq!(json["prompt"], "What is 2+2?");
409        assert_eq!(json["system"], "You are a calculator.");
410        assert_eq!(json["max_tokens"], 50);
411        // single query mode: no "queries" array
412        assert!(json.get("queries").is_none());
413        // grounded=false must be absent
414        assert!(
415            json.get("grounded").is_none(),
416            "grounded key must be absent when false"
417        );
418        // underspecified=false must be absent
419        assert!(
420            json.get("underspecified").is_none(),
421            "underspecified key must be absent when false"
422        );
423    }
424
425    #[test]
426    fn to_json_paused_single_query_grounded() {
427        let query = LlmQuery {
428            id: QueryId::single(),
429            prompt: "verify this claim".into(),
430            system: None,
431            max_tokens: 200,
432            grounded: true,
433            underspecified: false,
434        };
435        let result = FeedResult::Paused {
436            queries: vec![query],
437        };
438        let json = result.to_json("s-grounded");
439
440        assert_eq!(json["status"], "needs_response");
441        assert_eq!(
442            json["grounded"], true,
443            "grounded must appear in single-query MCP JSON"
444        );
445    }
446
447    #[test]
448    fn to_json_paused_single_query_underspecified() {
449        let query = LlmQuery {
450            id: QueryId::single(),
451            prompt: "what output format do you need?".into(),
452            system: None,
453            max_tokens: 200,
454            grounded: false,
455            underspecified: true,
456        };
457        let result = FeedResult::Paused {
458            queries: vec![query],
459        };
460        let json = result.to_json("s-underspec");
461
462        assert_eq!(json["status"], "needs_response");
463        assert_eq!(
464            json["underspecified"], true,
465            "underspecified must appear in single-query MCP JSON"
466        );
467        assert!(
468            json.get("grounded").is_none(),
469            "grounded must be absent when false"
470        );
471    }
472
473    #[test]
474    fn to_json_paused_multiple_queries_mixed_grounded() {
475        let grounded_query = LlmQuery {
476            id: QueryId::batch(0),
477            prompt: "verify".into(),
478            system: None,
479            max_tokens: 100,
480            grounded: true,
481            underspecified: false,
482        };
483        let normal_query = LlmQuery {
484            id: QueryId::batch(1),
485            prompt: "generate".into(),
486            system: None,
487            max_tokens: 100,
488            grounded: false,
489            underspecified: false,
490        };
491        let result = FeedResult::Paused {
492            queries: vec![grounded_query, normal_query],
493        };
494        let json = result.to_json("s-batch");
495
496        let qs = json["queries"].as_array().expect("queries should be array");
497        assert_eq!(
498            qs[0]["grounded"], true,
499            "grounded query must have grounded=true"
500        );
501        assert!(
502            qs[1].get("grounded").is_none(),
503            "non-grounded query must omit grounded key"
504        );
505    }
506
507    #[test]
508    fn to_json_paused_multiple_queries_mixed_underspecified() {
509        let underspec_query = LlmQuery {
510            id: QueryId::batch(0),
511            prompt: "clarify intent".into(),
512            system: None,
513            max_tokens: 100,
514            grounded: false,
515            underspecified: true,
516        };
517        let normal_query = LlmQuery {
518            id: QueryId::batch(1),
519            prompt: "generate".into(),
520            system: None,
521            max_tokens: 100,
522            grounded: false,
523            underspecified: false,
524        };
525        let result = FeedResult::Paused {
526            queries: vec![underspec_query, normal_query],
527        };
528        let json = result.to_json("s-batch-us");
529
530        let qs = json["queries"].as_array().expect("queries should be array");
531        assert_eq!(
532            qs[0]["underspecified"], true,
533            "underspecified query must have underspecified=true"
534        );
535        assert!(
536            qs[1].get("underspecified").is_none(),
537            "non-underspecified query must omit underspecified key"
538        );
539    }
540
541    #[test]
542    fn to_json_paused_single_query_no_system() {
543        let query = LlmQuery {
544            id: QueryId::single(),
545            prompt: "hello".into(),
546            system: None,
547            max_tokens: 1024,
548            grounded: false,
549            underspecified: false,
550        };
551        let result = FeedResult::Paused {
552            queries: vec![query],
553        };
554        let json = result.to_json("s-x");
555
556        assert_eq!(json["status"], "needs_response");
557        assert!(json["system"].is_null());
558    }
559
560    #[test]
561    fn to_json_paused_multiple_queries() {
562        let queries = vec![make_query(0), make_query(1), make_query(2)];
563        let result = FeedResult::Paused { queries };
564        let json = result.to_json("s-multi");
565
566        assert_eq!(json["status"], "needs_response");
567        assert_eq!(json["session_id"], "s-multi");
568
569        let qs = json["queries"].as_array().expect("queries should be array");
570        assert_eq!(qs.len(), 3);
571        assert_eq!(qs[0]["id"], "q-0");
572        assert_eq!(qs[0]["prompt"], "prompt-0");
573        assert_eq!(qs[1]["id"], "q-1");
574        assert_eq!(qs[2]["id"], "q-2");
575    }
576
577    #[test]
578    fn to_json_finished_completed() {
579        let result = FeedResult::Finished(ExecutionResult {
580            state: TerminalState::Completed {
581                result: json!({"answer": 42}),
582            },
583            metrics: ExecutionMetrics::new(),
584        });
585        let json = result.to_json("s-done");
586
587        assert_eq!(json["status"], "completed");
588        assert_eq!(json["result"]["answer"], 42);
589        assert!(json.get("stats").is_some());
590    }
591
592    #[test]
593    fn to_json_finished_failed() {
594        let result = FeedResult::Finished(ExecutionResult {
595            state: TerminalState::Failed {
596                error: "lua error: bad argument".into(),
597            },
598            metrics: ExecutionMetrics::new(),
599        });
600        let json = result.to_json("s-err");
601
602        assert_eq!(json["status"], "error");
603        assert_eq!(json["error"], "lua error: bad argument");
604    }
605
606    #[test]
607    fn to_json_finished_cancelled() {
608        let result = FeedResult::Finished(ExecutionResult {
609            state: TerminalState::Cancelled,
610            metrics: ExecutionMetrics::new(),
611        });
612        let json = result.to_json("s-cancel");
613
614        assert_eq!(json["status"], "cancelled");
615        assert!(json.get("stats").is_some());
616    }
617
618    // ─── gen_session_id tests ───
619
620    #[test]
621    fn session_id_starts_with_prefix() {
622        let id = gen_session_id();
623        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
624    }
625
626    #[test]
627    fn session_id_uniqueness() {
628        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
629        let set: std::collections::HashSet<&String> = ids.iter().collect();
630        assert_eq!(set.len(), 10, "10 IDs should all be unique");
631    }
632}