Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34pub struct ExecutionResult {
35    pub state: TerminalState,
36    pub metrics: ExecutionMetrics,
37}
38
39/// Result of a session interaction (start or feed).
40pub enum FeedResult {
41    /// Partial feed accepted, still waiting for more responses.
42    Accepted { remaining: usize },
43    /// All queries answered, Lua re-paused with new queries.
44    Paused { queries: Vec<LlmQuery> },
45    /// Execution completed (success, failure, or cancellation).
46    Finished(ExecutionResult),
47}
48
49impl FeedResult {
50    /// Convert to JSON for MCP tool response.
51    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52        match self {
53            Self::Accepted { remaining } => json!({
54                "status": "accepted",
55                "remaining": remaining,
56            }),
57            Self::Paused { queries } => {
58                if queries.len() == 1 {
59                    let q = &queries[0];
60                    let mut obj = json!({
61                        "status": "needs_response",
62                        "session_id": session_id,
63                        "prompt": q.prompt,
64                        "system": q.system,
65                        "max_tokens": q.max_tokens,
66                    });
67                    if q.grounded {
68                        obj["grounded"] = json!(true);
69                    }
70                    if q.underspecified {
71                        obj["underspecified"] = json!(true);
72                    }
73                    obj
74                } else {
75                    let qs: Vec<_> = queries
76                        .iter()
77                        .map(|q| {
78                            let mut obj = json!({
79                                "id": q.id.as_str(),
80                                "prompt": q.prompt,
81                                "system": q.system,
82                                "max_tokens": q.max_tokens,
83                            });
84                            if q.grounded {
85                                obj["grounded"] = json!(true);
86                            }
87                            if q.underspecified {
88                                obj["underspecified"] = json!(true);
89                            }
90                            obj
91                        })
92                        .collect();
93                    json!({
94                        "status": "needs_response",
95                        "session_id": session_id,
96                        "queries": qs,
97                    })
98                }
99            }
100            Self::Finished(result) => match &result.state {
101                TerminalState::Completed { result: val } => json!({
102                    "status": "completed",
103                    "result": val,
104                    "stats": result.metrics.to_json(),
105                }),
106                TerminalState::Failed { error } => json!({
107                    "status": "error",
108                    "error": error,
109                }),
110                TerminalState::Cancelled => json!({
111                    "status": "cancelled",
112                    "stats": result.metrics.to_json(),
113                }),
114            },
115        }
116    }
117}
118
119// ─── Session ─────────────────────────────────────────────────
120
121/// A Lua execution session with domain state tracking.
122///
123/// Each session owns a dedicated Lua VM via `_vm_driver`. The VM's OS thread
124/// stays alive as long as the driver is held, and exits cleanly when the
125/// session is dropped (channel closes → Lua thread drains and exits).
126pub struct Session {
127    state: ExecutionState,
128    metrics: ExecutionMetrics,
129    observer: MetricsObserver,
130    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
131    exec_task: AsyncTask,
132    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
133    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
134    /// Per-session VM lifecycle driver. Keeps the Lua thread alive.
135    /// Dropped when the session completes or is abandoned.
136    _vm_driver: AsyncIsleDriver,
137}
138
139impl Session {
140    pub fn new(
141        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
142        exec_task: AsyncTask,
143        metrics: ExecutionMetrics,
144        vm_driver: AsyncIsleDriver,
145    ) -> Self {
146        let observer = metrics.create_observer();
147        Self {
148            state: ExecutionState::Running,
149            metrics,
150            observer,
151            llm_rx,
152            exec_task,
153            resp_txs: HashMap::new(),
154            _vm_driver: vm_driver,
155        }
156    }
157
158    /// Wait for the next event from Lua execution.
159    ///
160    /// Called after initial start or after feeding all responses.
161    /// State must be Running when called.
162    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
163        tokio::select! {
164            result = &mut self.exec_task => {
165                match result {
166                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
167                        Ok(v) => {
168                            self.state.complete(v.clone()).map_err(|e| {
169                                SessionError::InvalidTransition(e.to_string())
170                            })?;
171                            self.observer.on_completed(&v);
172                            Ok(FeedResult::Finished(ExecutionResult {
173                                state: TerminalState::Completed { result: v },
174                                metrics: self.take_metrics(),
175                            }))
176                        }
177                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
178                    },
179                    Err(e) => self.fail_with(e.to_string()),
180                }
181            }
182            Some(req) = self.llm_rx.recv() => {
183                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
184                    id: qr.id.clone(),
185                    prompt: qr.prompt.clone(),
186                    system: qr.system.clone(),
187                    max_tokens: qr.max_tokens,
188                    grounded: qr.grounded,
189                    underspecified: qr.underspecified,
190                }).collect();
191
192                for qr in req.queries {
193                    self.resp_txs.insert(qr.id, qr.resp_tx);
194                }
195
196                self.state.pause(queries.clone()).map_err(|e| {
197                    SessionError::InvalidTransition(e.to_string())
198                })?;
199                self.observer.on_paused(&queries);
200                Ok(FeedResult::Paused { queries })
201            }
202        }
203    }
204
205    /// Feed one response by query_id.
206    ///
207    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
208    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
209        // Track response before ownership transfer.
210        self.observer.on_response_fed(query_id, &response);
211
212        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
213        if let Some(tx) = self.resp_txs.remove(query_id) {
214            let _ = tx.send(Ok(response.clone()));
215        }
216
217        // Domain: record in state machine
218        let complete = self
219            .state
220            .feed(query_id, response)
221            .map_err(SessionError::Feed)?;
222
223        if complete {
224            // Domain: transition Paused(complete) → Running
225            self.state
226                .take_responses()
227                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
228            self.observer.on_resumed();
229        } else {
230            self.observer
231                .on_partial_feed(query_id, self.state.remaining());
232        }
233
234        Ok(complete)
235    }
236
237    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
238        self.state
239            .fail(msg.clone())
240            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
241        self.observer.on_failed(&msg);
242        Ok(FeedResult::Finished(ExecutionResult {
243            state: TerminalState::Failed { error: msg },
244            metrics: self.take_metrics(),
245        }))
246    }
247
248    fn take_metrics(&mut self) -> ExecutionMetrics {
249        std::mem::take(&mut self.metrics)
250    }
251
252    /// Lightweight snapshot for external observation (alc_status).
253    ///
254    /// Returns session state label and running metrics without consuming
255    /// or modifying the session.
256    pub fn snapshot(&self) -> serde_json::Value {
257        let state_label = match &self.state {
258            ExecutionState::Running => "running",
259            ExecutionState::Paused(_) => "paused",
260            _ => "terminal",
261        };
262
263        let mut json = serde_json::json!({
264            "state": state_label,
265        });
266
267        let metrics = self.metrics.snapshot();
268        if !metrics.is_null() {
269            json["metrics"] = metrics;
270        }
271
272        // Include pending query count when paused
273        if let ExecutionState::Paused(_) = &self.state {
274            json["pending_queries"] = self.state.remaining().into();
275        }
276
277        json
278    }
279}
280
281// ─── Registry ────────────────────────────────────────────────
282
283/// Manages active sessions.
284///
285/// # Locking design (lock **C**)
286///
287/// Uses `tokio::sync::Mutex` because `feed_response` holds the lock
288/// while calling `Session::feed_one()` (which itself acquires the
289/// per-session `std::sync::Mutex<SessionStatus>`, lock **A**). The lock
290/// ordering invariant is always **C → A** — no code path acquires A
291/// then C, so deadlock is structurally impossible.
292///
293/// `tokio::sync::Mutex` is chosen here (rather than `std::sync::Mutex`)
294/// because `feed_response` must take the session out of the map for
295/// the async `wait_event()` call. The two-phase pattern (lock → remove
296/// → unlock → await → lock → reinsert) requires an async-aware mutex
297/// to avoid holding the lock across the `wait_event().await`.
298///
299/// ## Contention
300///
301/// `list_snapshots()` (from `alc_status`) holds lock C while iterating
302/// all sessions. During this time, `feed_response` for any session is
303/// blocked. Given that snapshot iteration is O(n) with n = active
304/// sessions (typically 1–3) and each snapshot takes microseconds, this
305/// is acceptable. If session count grows significantly, consider
306/// switching to a concurrent map or per-session locks.
307///
308/// ## Interaction with lock A
309///
310/// `Session::snapshot()` (called under lock C in `list_snapshots`)
311/// acquires lock A via `ExecutionMetrics::snapshot()`. This is safe:
312/// - Lock order: C → A (consistent with `feed_response`)
313/// - Lock A hold time: microseconds (JSON field reads)
314/// - Lock A is per-session (no cross-session contention)
315pub struct SessionRegistry {
316    sessions: Arc<Mutex<HashMap<String, Session>>>,
317}
318
319impl Default for SessionRegistry {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325impl SessionRegistry {
326    pub fn new() -> Self {
327        Self {
328            sessions: Arc::new(Mutex::new(HashMap::new())),
329        }
330    }
331
332    /// Start execution and wait for first event (pause or completion).
333    pub async fn start_execution(
334        &self,
335        mut session: Session,
336    ) -> Result<(String, FeedResult), SessionError> {
337        let session_id = gen_session_id();
338        let result = session.wait_event().await?;
339
340        if matches!(result, FeedResult::Paused { .. }) {
341            self.sessions
342                .lock()
343                .await
344                .insert(session_id.clone(), session);
345        }
346
347        Ok((session_id, result))
348    }
349
350    /// Feed one response to a paused session by query_id.
351    ///
352    /// If this completes all pending queries, the session resumes and
353    /// returns the next event (Paused or Finished).
354    /// If queries remain, returns Accepted { remaining }.
355    pub async fn feed_response(
356        &self,
357        session_id: &str,
358        query_id: &QueryId,
359        response: String,
360    ) -> Result<FeedResult, SessionError> {
361        // 1. Feed under lock
362        let complete = {
363            let mut map = self.sessions.lock().await;
364            let session = map
365                .get_mut(session_id)
366                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
367
368            let complete = session.feed_one(query_id, response)?;
369
370            if !complete {
371                return Ok(FeedResult::Accepted {
372                    remaining: session.state.remaining(),
373                });
374            }
375
376            complete
377        };
378
379        // 2. All complete → take session out for async resume
380        debug_assert!(complete);
381        let mut session = {
382            let mut map = self.sessions.lock().await;
383            map.remove(session_id)
384                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
385        };
386
387        let result = session.wait_event().await?;
388
389        if matches!(result, FeedResult::Paused { .. }) {
390            self.sessions
391                .lock()
392                .await
393                .insert(session_id.into(), session);
394        }
395
396        Ok(result)
397    }
398
399    /// Snapshot all active sessions for external observation (alc_status).
400    ///
401    /// Returns a map of session_id → snapshot JSON. Only includes sessions
402    /// currently held in the registry (i.e. paused, awaiting responses).
403    /// Sessions that have completed are already removed from the registry.
404    pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
405        let map = self.sessions.lock().await;
406        map.iter()
407            .map(|(id, session)| (id.clone(), session.snapshot()))
408            .collect()
409    }
410}
411
412/// Generate a non-deterministic session ID.
413///
414/// MCP spec requires "secure, non-deterministic session IDs" to prevent
415/// session hijacking. Uses timestamp + random bytes for uniqueness and
416/// unpredictability.
417fn gen_session_id() -> String {
418    use std::time::{SystemTime, UNIX_EPOCH};
419    let ts = SystemTime::now()
420        .duration_since(UNIX_EPOCH)
421        .unwrap_or_default()
422        .as_nanos();
423    // 8 random bytes → 16 hex chars of entropy
424    let random: u64 = {
425        use std::collections::hash_map::RandomState;
426        use std::hash::{BuildHasher, Hasher};
427        let s = RandomState::new();
428        let mut h = s.build_hasher();
429        h.write_u128(ts);
430        h.finish()
431    };
432    format!("s-{ts:x}-{random:016x}")
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
439    use serde_json::json;
440
441    fn make_query(index: usize) -> LlmQuery {
442        LlmQuery {
443            id: QueryId::batch(index),
444            prompt: format!("prompt-{index}"),
445            system: None,
446            max_tokens: 100,
447            grounded: false,
448            underspecified: false,
449        }
450    }
451
452    // ─── FeedResult::to_json tests ───
453
454    #[test]
455    fn to_json_accepted() {
456        let result = FeedResult::Accepted { remaining: 3 };
457        let json = result.to_json("s-123");
458        assert_eq!(json["status"], "accepted");
459        assert_eq!(json["remaining"], 3);
460    }
461
462    #[test]
463    fn to_json_paused_single_query() {
464        let query = LlmQuery {
465            id: QueryId::single(),
466            prompt: "What is 2+2?".into(),
467            system: Some("You are a calculator.".into()),
468            max_tokens: 50,
469            grounded: false,
470            underspecified: false,
471        };
472        let result = FeedResult::Paused {
473            queries: vec![query],
474        };
475        let json = result.to_json("s-abc");
476
477        assert_eq!(json["status"], "needs_response");
478        assert_eq!(json["session_id"], "s-abc");
479        assert_eq!(json["prompt"], "What is 2+2?");
480        assert_eq!(json["system"], "You are a calculator.");
481        assert_eq!(json["max_tokens"], 50);
482        // single query mode: no "queries" array
483        assert!(json.get("queries").is_none());
484        // grounded=false must be absent
485        assert!(
486            json.get("grounded").is_none(),
487            "grounded key must be absent when false"
488        );
489        // underspecified=false must be absent
490        assert!(
491            json.get("underspecified").is_none(),
492            "underspecified key must be absent when false"
493        );
494    }
495
496    #[test]
497    fn to_json_paused_single_query_grounded() {
498        let query = LlmQuery {
499            id: QueryId::single(),
500            prompt: "verify this claim".into(),
501            system: None,
502            max_tokens: 200,
503            grounded: true,
504            underspecified: false,
505        };
506        let result = FeedResult::Paused {
507            queries: vec![query],
508        };
509        let json = result.to_json("s-grounded");
510
511        assert_eq!(json["status"], "needs_response");
512        assert_eq!(
513            json["grounded"], true,
514            "grounded must appear in single-query MCP JSON"
515        );
516    }
517
518    #[test]
519    fn to_json_paused_single_query_underspecified() {
520        let query = LlmQuery {
521            id: QueryId::single(),
522            prompt: "what output format do you need?".into(),
523            system: None,
524            max_tokens: 200,
525            grounded: false,
526            underspecified: true,
527        };
528        let result = FeedResult::Paused {
529            queries: vec![query],
530        };
531        let json = result.to_json("s-underspec");
532
533        assert_eq!(json["status"], "needs_response");
534        assert_eq!(
535            json["underspecified"], true,
536            "underspecified must appear in single-query MCP JSON"
537        );
538        assert!(
539            json.get("grounded").is_none(),
540            "grounded must be absent when false"
541        );
542    }
543
544    #[test]
545    fn to_json_paused_multiple_queries_mixed_grounded() {
546        let grounded_query = LlmQuery {
547            id: QueryId::batch(0),
548            prompt: "verify".into(),
549            system: None,
550            max_tokens: 100,
551            grounded: true,
552            underspecified: false,
553        };
554        let normal_query = LlmQuery {
555            id: QueryId::batch(1),
556            prompt: "generate".into(),
557            system: None,
558            max_tokens: 100,
559            grounded: false,
560            underspecified: false,
561        };
562        let result = FeedResult::Paused {
563            queries: vec![grounded_query, normal_query],
564        };
565        let json = result.to_json("s-batch");
566
567        let qs = json["queries"].as_array().expect("queries should be array");
568        assert_eq!(
569            qs[0]["grounded"], true,
570            "grounded query must have grounded=true"
571        );
572        assert!(
573            qs[1].get("grounded").is_none(),
574            "non-grounded query must omit grounded key"
575        );
576    }
577
578    #[test]
579    fn to_json_paused_multiple_queries_mixed_underspecified() {
580        let underspec_query = LlmQuery {
581            id: QueryId::batch(0),
582            prompt: "clarify intent".into(),
583            system: None,
584            max_tokens: 100,
585            grounded: false,
586            underspecified: true,
587        };
588        let normal_query = LlmQuery {
589            id: QueryId::batch(1),
590            prompt: "generate".into(),
591            system: None,
592            max_tokens: 100,
593            grounded: false,
594            underspecified: false,
595        };
596        let result = FeedResult::Paused {
597            queries: vec![underspec_query, normal_query],
598        };
599        let json = result.to_json("s-batch-us");
600
601        let qs = json["queries"].as_array().expect("queries should be array");
602        assert_eq!(
603            qs[0]["underspecified"], true,
604            "underspecified query must have underspecified=true"
605        );
606        assert!(
607            qs[1].get("underspecified").is_none(),
608            "non-underspecified query must omit underspecified key"
609        );
610    }
611
612    #[test]
613    fn to_json_paused_single_query_no_system() {
614        let query = LlmQuery {
615            id: QueryId::single(),
616            prompt: "hello".into(),
617            system: None,
618            max_tokens: 1024,
619            grounded: false,
620            underspecified: false,
621        };
622        let result = FeedResult::Paused {
623            queries: vec![query],
624        };
625        let json = result.to_json("s-x");
626
627        assert_eq!(json["status"], "needs_response");
628        assert!(json["system"].is_null());
629    }
630
631    #[test]
632    fn to_json_paused_multiple_queries() {
633        let queries = vec![make_query(0), make_query(1), make_query(2)];
634        let result = FeedResult::Paused { queries };
635        let json = result.to_json("s-multi");
636
637        assert_eq!(json["status"], "needs_response");
638        assert_eq!(json["session_id"], "s-multi");
639
640        let qs = json["queries"].as_array().expect("queries should be array");
641        assert_eq!(qs.len(), 3);
642        assert_eq!(qs[0]["id"], "q-0");
643        assert_eq!(qs[0]["prompt"], "prompt-0");
644        assert_eq!(qs[1]["id"], "q-1");
645        assert_eq!(qs[2]["id"], "q-2");
646    }
647
648    #[test]
649    fn to_json_finished_completed() {
650        let result = FeedResult::Finished(ExecutionResult {
651            state: TerminalState::Completed {
652                result: json!({"answer": 42}),
653            },
654            metrics: ExecutionMetrics::new(),
655        });
656        let json = result.to_json("s-done");
657
658        assert_eq!(json["status"], "completed");
659        assert_eq!(json["result"]["answer"], 42);
660        assert!(json.get("stats").is_some());
661    }
662
663    #[test]
664    fn to_json_finished_failed() {
665        let result = FeedResult::Finished(ExecutionResult {
666            state: TerminalState::Failed {
667                error: "lua error: bad argument".into(),
668            },
669            metrics: ExecutionMetrics::new(),
670        });
671        let json = result.to_json("s-err");
672
673        assert_eq!(json["status"], "error");
674        assert_eq!(json["error"], "lua error: bad argument");
675    }
676
677    #[test]
678    fn to_json_finished_cancelled() {
679        let result = FeedResult::Finished(ExecutionResult {
680            state: TerminalState::Cancelled,
681            metrics: ExecutionMetrics::new(),
682        });
683        let json = result.to_json("s-cancel");
684
685        assert_eq!(json["status"], "cancelled");
686        assert!(json.get("stats").is_some());
687    }
688
689    // ─── gen_session_id tests ───
690
691    #[test]
692    fn session_id_starts_with_prefix() {
693        let id = gen_session_id();
694        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
695    }
696
697    #[test]
698    fn session_id_uniqueness() {
699        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
700        let set: std::collections::HashSet<&String> = ids.iter().collect();
701        assert_eq!(set.len(), 10, "10 IDs should all be unique");
702    }
703}