Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34#[derive(serde::Serialize)]
35pub struct ExecutionResult {
36    pub state: TerminalState,
37    pub metrics: ExecutionMetrics,
38}
39
40/// Result of a session interaction (start or feed).
41#[derive(serde::Serialize)]
42pub enum FeedResult {
43    /// Partial feed accepted, still waiting for more responses.
44    Accepted { remaining: usize },
45    /// All queries answered, Lua re-paused with new queries.
46    Paused { queries: Vec<LlmQuery> },
47    /// Execution completed (success, failure, or cancellation).
48    Finished(ExecutionResult),
49}
50
51impl FeedResult {
52    /// Convert to JSON for MCP tool response.
53    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
54        match self {
55            Self::Accepted { remaining } => json!({
56                "status": "accepted",
57                "remaining": remaining,
58            }),
59            Self::Paused { queries } => {
60                if queries.len() == 1 {
61                    let q = &queries[0];
62                    let mut obj = json!({
63                        "status": "needs_response",
64                        "session_id": session_id,
65                        "query_id": q.id.as_str(),
66                        "prompt": q.prompt,
67                        "system": q.system,
68                        "max_tokens": q.max_tokens,
69                    });
70                    if q.grounded {
71                        obj["grounded"] = json!(true);
72                    }
73                    if q.underspecified {
74                        obj["underspecified"] = json!(true);
75                    }
76                    obj
77                } else {
78                    let qs: Vec<_> = queries
79                        .iter()
80                        .map(|q| {
81                            let mut obj = json!({
82                                "id": q.id.as_str(),
83                                "prompt": q.prompt,
84                                "system": q.system,
85                                "max_tokens": q.max_tokens,
86                            });
87                            if q.grounded {
88                                obj["grounded"] = json!(true);
89                            }
90                            if q.underspecified {
91                                obj["underspecified"] = json!(true);
92                            }
93                            obj
94                        })
95                        .collect();
96                    json!({
97                        "status": "needs_response",
98                        "session_id": session_id,
99                        "queries": qs,
100                    })
101                }
102            }
103            Self::Finished(result) => match &result.state {
104                TerminalState::Completed { result: val } => json!({
105                    "status": "completed",
106                    "result": val,
107                    "stats": result.metrics.to_json(),
108                }),
109                TerminalState::Failed { error } => json!({
110                    "status": "error",
111                    "error": error,
112                }),
113                TerminalState::Cancelled => json!({
114                    "status": "cancelled",
115                    "stats": result.metrics.to_json(),
116                }),
117            },
118        }
119    }
120}
121
122// ─── Session ─────────────────────────────────────────────────
123
124/// A Lua execution session with domain state tracking.
125///
126/// Each session owns a dedicated Lua VM via `_vm_driver`. The VM's OS thread
127/// stays alive as long as the driver is held, and exits cleanly when the
128/// session is dropped (channel closes → Lua thread drains and exits).
129pub struct Session {
130    state: ExecutionState,
131    metrics: ExecutionMetrics,
132    observer: MetricsObserver,
133    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
134    exec_task: AsyncTask,
135    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
136    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
137    /// Per-session VM lifecycle driver. Keeps the Lua thread alive.
138    /// Dropped when the session completes or is abandoned.
139    _vm_driver: AsyncIsleDriver,
140}
141
142impl Session {
143    pub fn new(
144        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
145        exec_task: AsyncTask,
146        metrics: ExecutionMetrics,
147        vm_driver: AsyncIsleDriver,
148    ) -> Self {
149        let observer = metrics.create_observer();
150        Self {
151            state: ExecutionState::Running,
152            metrics,
153            observer,
154            llm_rx,
155            exec_task,
156            resp_txs: HashMap::new(),
157            _vm_driver: vm_driver,
158        }
159    }
160
161    /// Wait for the next event from Lua execution.
162    ///
163    /// Called after initial start or after feeding all responses.
164    /// State must be Running when called.
165    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
166        tokio::select! {
167            result = &mut self.exec_task => {
168                match result {
169                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
170                        Ok(v) => {
171                            self.state.complete(v.clone()).map_err(|e| {
172                                SessionError::InvalidTransition(e.to_string())
173                            })?;
174                            self.observer.on_completed(&v);
175                            Ok(FeedResult::Finished(ExecutionResult {
176                                state: TerminalState::Completed { result: v },
177                                metrics: self.take_metrics(),
178                            }))
179                        }
180                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
181                    },
182                    Err(e) => self.fail_with(e.to_string()),
183                }
184            }
185            Some(req) = self.llm_rx.recv() => {
186                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
187                    id: qr.id.clone(),
188                    prompt: qr.prompt.clone(),
189                    system: qr.system.clone(),
190                    max_tokens: qr.max_tokens,
191                    grounded: qr.grounded,
192                    underspecified: qr.underspecified,
193                }).collect();
194
195                for qr in req.queries {
196                    self.resp_txs.insert(qr.id, qr.resp_tx);
197                }
198
199                self.state.pause(queries.clone()).map_err(|e| {
200                    SessionError::InvalidTransition(e.to_string())
201                })?;
202                self.observer.on_paused(&queries);
203                Ok(FeedResult::Paused { queries })
204            }
205        }
206    }
207
208    /// Feed one response by query_id.
209    ///
210    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
211    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
212        // Track response before ownership transfer.
213        self.observer.on_response_fed(query_id, &response);
214
215        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
216        if let Some(tx) = self.resp_txs.remove(query_id) {
217            let _ = tx.send(Ok(response.clone()));
218        }
219
220        // Domain: record in state machine
221        let complete = self
222            .state
223            .feed(query_id, response)
224            .map_err(SessionError::Feed)?;
225
226        if complete {
227            // Domain: transition Paused(complete) → Running
228            self.state
229                .take_responses()
230                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
231            self.observer.on_resumed();
232        } else {
233            self.observer
234                .on_partial_feed(query_id, self.state.remaining());
235        }
236
237        Ok(complete)
238    }
239
240    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
241        self.state
242            .fail(msg.clone())
243            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
244        self.observer.on_failed(&msg);
245        Ok(FeedResult::Finished(ExecutionResult {
246            state: TerminalState::Failed { error: msg },
247            metrics: self.take_metrics(),
248        }))
249    }
250
251    fn take_metrics(&mut self) -> ExecutionMetrics {
252        std::mem::take(&mut self.metrics)
253    }
254
255    /// Lightweight snapshot for external observation (alc_status).
256    ///
257    /// Returns session state label and running metrics without consuming
258    /// or modifying the session.
259    pub fn snapshot(&self) -> serde_json::Value {
260        let state_label = match &self.state {
261            ExecutionState::Running => "running",
262            ExecutionState::Paused(_) => "paused",
263            _ => "terminal",
264        };
265
266        let mut json = serde_json::json!({
267            "state": state_label,
268        });
269
270        let metrics = self.metrics.snapshot();
271        if !metrics.is_null() {
272            json["metrics"] = metrics;
273        }
274
275        // Include pending query count when paused
276        if let ExecutionState::Paused(_) = &self.state {
277            json["pending_queries"] = self.state.remaining().into();
278        }
279
280        json
281    }
282}
283
284// ─── Registry ────────────────────────────────────────────────
285
286/// Manages active sessions.
287///
288/// # Locking design (lock **C**)
289///
290/// Uses `tokio::sync::Mutex` because `feed_response` holds the lock
291/// while calling `Session::feed_one()` (which itself acquires the
292/// per-session `std::sync::Mutex<SessionStatus>`, lock **A**). The lock
293/// ordering invariant is always **C → A** — no code path acquires A
294/// then C, so deadlock is structurally impossible.
295///
296/// `tokio::sync::Mutex` is chosen here (rather than `std::sync::Mutex`)
297/// because `feed_response` must take the session out of the map for
298/// the async `wait_event()` call. The two-phase pattern (lock → remove
299/// → unlock → await → lock → reinsert) requires an async-aware mutex
300/// to avoid holding the lock across the `wait_event().await`.
301///
302/// ## Contention
303///
304/// `list_snapshots()` (from `alc_status`) holds lock C while iterating
305/// all sessions. During this time, `feed_response` for any session is
306/// blocked. Given that snapshot iteration is O(n) with n = active
307/// sessions (typically 1–3) and each snapshot takes microseconds, this
308/// is acceptable. If session count grows significantly, consider
309/// switching to a concurrent map or per-session locks.
310///
311/// ## Interaction with lock A
312///
313/// `Session::snapshot()` (called under lock C in `list_snapshots`)
314/// acquires lock A via `ExecutionMetrics::snapshot()`. This is safe:
315/// - Lock order: C → A (consistent with `feed_response`)
316/// - Lock A hold time: microseconds (JSON field reads)
317/// - Lock A is per-session (no cross-session contention)
318pub struct SessionRegistry {
319    sessions: Arc<Mutex<HashMap<String, Session>>>,
320}
321
322impl Default for SessionRegistry {
323    fn default() -> Self {
324        Self::new()
325    }
326}
327
328impl SessionRegistry {
329    pub fn new() -> Self {
330        Self {
331            sessions: Arc::new(Mutex::new(HashMap::new())),
332        }
333    }
334
335    /// Start execution and wait for first event (pause or completion).
336    pub async fn start_execution(
337        &self,
338        mut session: Session,
339    ) -> Result<(String, FeedResult), SessionError> {
340        let session_id = gen_session_id();
341        let result = session.wait_event().await?;
342
343        if matches!(result, FeedResult::Paused { .. }) {
344            self.sessions
345                .lock()
346                .await
347                .insert(session_id.clone(), session);
348        }
349
350        Ok((session_id, result))
351    }
352
353    /// Feed one response to a paused session by query_id.
354    ///
355    /// If this completes all pending queries, the session resumes and
356    /// returns the next event (Paused or Finished).
357    /// If queries remain, returns Accepted { remaining }.
358    pub async fn feed_response(
359        &self,
360        session_id: &str,
361        query_id: &QueryId,
362        response: String,
363    ) -> Result<FeedResult, SessionError> {
364        // 1. Feed under lock
365        let complete = {
366            let mut map = self.sessions.lock().await;
367            let session = map
368                .get_mut(session_id)
369                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
370
371            let complete = session.feed_one(query_id, response)?;
372
373            if !complete {
374                return Ok(FeedResult::Accepted {
375                    remaining: session.state.remaining(),
376                });
377            }
378
379            complete
380        };
381
382        // 2. All complete → take session out for async resume
383        debug_assert!(complete);
384        let mut session = {
385            let mut map = self.sessions.lock().await;
386            map.remove(session_id)
387                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
388        };
389
390        let result = session.wait_event().await?;
391
392        if matches!(result, FeedResult::Paused { .. }) {
393            self.sessions
394                .lock()
395                .await
396                .insert(session_id.into(), session);
397        }
398
399        Ok(result)
400    }
401
402    /// Resolve the sole pending query ID for a session.
403    ///
404    /// When `alc_continue` is called without an explicit `query_id`, this
405    /// method checks if exactly one query is pending and returns its ID.
406    /// Returns an error if zero or multiple queries are pending.
407    pub async fn resolve_sole_pending_id(&self, session_id: &str) -> Result<QueryId, SessionError> {
408        let map = self.sessions.lock().await;
409        let session = map
410            .get(session_id)
411            .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
412        let keys: Vec<QueryId> = session.resp_txs.keys().cloned().collect();
413        match keys.len() {
414            0 => Err(SessionError::InvalidTransition("no pending queries".into())),
415            1 => keys
416                .into_iter()
417                .next()
418                .ok_or_else(|| SessionError::InvalidTransition("unexpected empty keys".into())),
419            n => Err(SessionError::InvalidTransition(format!(
420                "{n} queries pending; specify query_id explicitly"
421            ))),
422        }
423    }
424
425    /// Snapshot all active sessions for external observation (alc_status).
426    ///
427    /// Returns a map of session_id → snapshot JSON. Only includes sessions
428    /// currently held in the registry (i.e. paused, awaiting responses).
429    /// Sessions that have completed are already removed from the registry.
430    pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
431        let map = self.sessions.lock().await;
432        map.iter()
433            .map(|(id, session)| (id.clone(), session.snapshot()))
434            .collect()
435    }
436}
437
438/// Generate a non-deterministic session ID.
439///
440/// MCP spec requires "secure, non-deterministic session IDs" to prevent
441/// session hijacking. Uses timestamp + random bytes for uniqueness and
442/// unpredictability.
443///
444/// # `unwrap_or_default` on `duration_since(UNIX_EPOCH)`
445///
446/// `SystemTime::now().duration_since(UNIX_EPOCH)` can fail if the system
447/// clock is set before 1970-01-01 (e.g. NTP drift, misconfigured VM).
448/// The Rust std docs recommend `expect()` or `match` for explicit handling,
449/// but `expect` would panic in library code (prohibited by project policy).
450///
451/// `unwrap_or_default` returns `Duration::ZERO` on failure, yielding
452/// timestamp `0`. This is acceptable here because the 8-byte random
453/// suffix (16 hex chars of entropy) independently guarantees uniqueness
454/// and unpredictability — the timestamp is a convenience prefix, not
455/// a security-critical component.
456fn gen_session_id() -> String {
457    use std::time::{SystemTime, UNIX_EPOCH};
458    let ts = SystemTime::now()
459        .duration_since(UNIX_EPOCH)
460        .unwrap_or_default()
461        .as_nanos();
462    // 8 random bytes → 16 hex chars of entropy
463    let random: u64 = {
464        use std::collections::hash_map::RandomState;
465        use std::hash::{BuildHasher, Hasher};
466        let s = RandomState::new();
467        let mut h = s.build_hasher();
468        h.write_u128(ts);
469        h.finish()
470    };
471    format!("s-{ts:x}-{random:016x}")
472}
473
474#[cfg(test)]
475mod tests {
476    use super::*;
477    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
478    use serde_json::json;
479
480    fn make_query(index: usize) -> LlmQuery {
481        LlmQuery {
482            id: QueryId::batch(index),
483            prompt: format!("prompt-{index}"),
484            system: None,
485            max_tokens: 100,
486            grounded: false,
487            underspecified: false,
488        }
489    }
490
491    // ─── FeedResult::to_json tests ───
492
493    #[test]
494    fn to_json_accepted() {
495        let result = FeedResult::Accepted { remaining: 3 };
496        let json = result.to_json("s-123");
497        assert_eq!(json["status"], "accepted");
498        assert_eq!(json["remaining"], 3);
499    }
500
501    #[test]
502    fn to_json_paused_single_query() {
503        let query = LlmQuery {
504            id: QueryId::single(),
505            prompt: "What is 2+2?".into(),
506            system: Some("You are a calculator.".into()),
507            max_tokens: 50,
508            grounded: false,
509            underspecified: false,
510        };
511        let result = FeedResult::Paused {
512            queries: vec![query],
513        };
514        let json = result.to_json("s-abc");
515
516        assert_eq!(json["status"], "needs_response");
517        assert_eq!(json["session_id"], "s-abc");
518        assert_eq!(json["prompt"], "What is 2+2?");
519        assert_eq!(json["system"], "You are a calculator.");
520        assert_eq!(json["max_tokens"], 50);
521        // single query mode: no "queries" array
522        assert!(json.get("queries").is_none());
523        // grounded=false must be absent
524        assert!(
525            json.get("grounded").is_none(),
526            "grounded key must be absent when false"
527        );
528        // underspecified=false must be absent
529        assert!(
530            json.get("underspecified").is_none(),
531            "underspecified key must be absent when false"
532        );
533    }
534
535    #[test]
536    fn to_json_paused_single_query_grounded() {
537        let query = LlmQuery {
538            id: QueryId::single(),
539            prompt: "verify this claim".into(),
540            system: None,
541            max_tokens: 200,
542            grounded: true,
543            underspecified: false,
544        };
545        let result = FeedResult::Paused {
546            queries: vec![query],
547        };
548        let json = result.to_json("s-grounded");
549
550        assert_eq!(json["status"], "needs_response");
551        assert_eq!(
552            json["grounded"], true,
553            "grounded must appear in single-query MCP JSON"
554        );
555    }
556
557    #[test]
558    fn to_json_paused_single_query_underspecified() {
559        let query = LlmQuery {
560            id: QueryId::single(),
561            prompt: "what output format do you need?".into(),
562            system: None,
563            max_tokens: 200,
564            grounded: false,
565            underspecified: true,
566        };
567        let result = FeedResult::Paused {
568            queries: vec![query],
569        };
570        let json = result.to_json("s-underspec");
571
572        assert_eq!(json["status"], "needs_response");
573        assert_eq!(
574            json["underspecified"], true,
575            "underspecified must appear in single-query MCP JSON"
576        );
577        assert!(
578            json.get("grounded").is_none(),
579            "grounded must be absent when false"
580        );
581    }
582
583    #[test]
584    fn to_json_paused_multiple_queries_mixed_grounded() {
585        let grounded_query = LlmQuery {
586            id: QueryId::batch(0),
587            prompt: "verify".into(),
588            system: None,
589            max_tokens: 100,
590            grounded: true,
591            underspecified: false,
592        };
593        let normal_query = LlmQuery {
594            id: QueryId::batch(1),
595            prompt: "generate".into(),
596            system: None,
597            max_tokens: 100,
598            grounded: false,
599            underspecified: false,
600        };
601        let result = FeedResult::Paused {
602            queries: vec![grounded_query, normal_query],
603        };
604        let json = result.to_json("s-batch");
605
606        let qs = json["queries"].as_array().expect("queries should be array");
607        assert_eq!(
608            qs[0]["grounded"], true,
609            "grounded query must have grounded=true"
610        );
611        assert!(
612            qs[1].get("grounded").is_none(),
613            "non-grounded query must omit grounded key"
614        );
615    }
616
617    #[test]
618    fn to_json_paused_multiple_queries_mixed_underspecified() {
619        let underspec_query = LlmQuery {
620            id: QueryId::batch(0),
621            prompt: "clarify intent".into(),
622            system: None,
623            max_tokens: 100,
624            grounded: false,
625            underspecified: true,
626        };
627        let normal_query = LlmQuery {
628            id: QueryId::batch(1),
629            prompt: "generate".into(),
630            system: None,
631            max_tokens: 100,
632            grounded: false,
633            underspecified: false,
634        };
635        let result = FeedResult::Paused {
636            queries: vec![underspec_query, normal_query],
637        };
638        let json = result.to_json("s-batch-us");
639
640        let qs = json["queries"].as_array().expect("queries should be array");
641        assert_eq!(
642            qs[0]["underspecified"], true,
643            "underspecified query must have underspecified=true"
644        );
645        assert!(
646            qs[1].get("underspecified").is_none(),
647            "non-underspecified query must omit underspecified key"
648        );
649    }
650
651    #[test]
652    fn to_json_paused_single_query_no_system() {
653        let query = LlmQuery {
654            id: QueryId::single(),
655            prompt: "hello".into(),
656            system: None,
657            max_tokens: 1024,
658            grounded: false,
659            underspecified: false,
660        };
661        let result = FeedResult::Paused {
662            queries: vec![query],
663        };
664        let json = result.to_json("s-x");
665
666        assert_eq!(json["status"], "needs_response");
667        assert!(json["system"].is_null());
668    }
669
670    #[test]
671    fn to_json_paused_multiple_queries() {
672        let queries = vec![make_query(0), make_query(1), make_query(2)];
673        let result = FeedResult::Paused { queries };
674        let json = result.to_json("s-multi");
675
676        assert_eq!(json["status"], "needs_response");
677        assert_eq!(json["session_id"], "s-multi");
678
679        let qs = json["queries"].as_array().expect("queries should be array");
680        assert_eq!(qs.len(), 3);
681        assert_eq!(qs[0]["id"], "q-0");
682        assert_eq!(qs[0]["prompt"], "prompt-0");
683        assert_eq!(qs[1]["id"], "q-1");
684        assert_eq!(qs[2]["id"], "q-2");
685    }
686
687    #[test]
688    fn to_json_finished_completed() {
689        let result = FeedResult::Finished(ExecutionResult {
690            state: TerminalState::Completed {
691                result: json!({"answer": 42}),
692            },
693            metrics: ExecutionMetrics::new(),
694        });
695        let json = result.to_json("s-done");
696
697        assert_eq!(json["status"], "completed");
698        assert_eq!(json["result"]["answer"], 42);
699        assert!(json.get("stats").is_some());
700    }
701
702    #[test]
703    fn to_json_finished_failed() {
704        let result = FeedResult::Finished(ExecutionResult {
705            state: TerminalState::Failed {
706                error: "lua error: bad argument".into(),
707            },
708            metrics: ExecutionMetrics::new(),
709        });
710        let json = result.to_json("s-err");
711
712        assert_eq!(json["status"], "error");
713        assert_eq!(json["error"], "lua error: bad argument");
714    }
715
716    #[test]
717    fn to_json_finished_cancelled() {
718        let result = FeedResult::Finished(ExecutionResult {
719            state: TerminalState::Cancelled,
720            metrics: ExecutionMetrics::new(),
721        });
722        let json = result.to_json("s-cancel");
723
724        assert_eq!(json["status"], "cancelled");
725        assert!(json.get("stats").is_some());
726    }
727
728    // ─── gen_session_id tests ───
729
730    #[test]
731    fn session_id_starts_with_prefix() {
732        let id = gen_session_id();
733        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
734    }
735
736    #[test]
737    fn session_id_uniqueness() {
738        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
739        let set: std::collections::HashSet<&String> = ids.iter().collect();
740        assert_eq!(set.len(), 10, "10 IDs should all be unique");
741    }
742}