Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::{AsyncIsleDriver, AsyncTask};
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34#[derive(serde::Serialize)]
35pub struct ExecutionResult {
36    pub state: TerminalState,
37    pub metrics: ExecutionMetrics,
38}
39
40/// Result of a session interaction (start or feed).
41#[derive(serde::Serialize)]
42pub enum FeedResult {
43    /// Partial feed accepted, still waiting for more responses.
44    Accepted { remaining: usize },
45    /// All queries answered, Lua re-paused with new queries.
46    Paused { queries: Vec<LlmQuery> },
47    /// Execution completed (success, failure, or cancellation).
48    Finished(ExecutionResult),
49}
50
51impl FeedResult {
52    /// Convert to JSON for MCP tool response.
53    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
54        match self {
55            Self::Accepted { remaining } => json!({
56                "status": "accepted",
57                "remaining": remaining,
58            }),
59            Self::Paused { queries } => {
60                if queries.len() == 1 {
61                    let q = &queries[0];
62                    let mut obj = json!({
63                        "status": "needs_response",
64                        "session_id": session_id,
65                        "query_id": q.id.as_str(),
66                        "prompt": q.prompt,
67                        "system": q.system,
68                        "max_tokens": q.max_tokens,
69                    });
70                    if q.grounded {
71                        obj["grounded"] = json!(true);
72                    }
73                    if q.underspecified {
74                        obj["underspecified"] = json!(true);
75                    }
76                    obj
77                } else {
78                    let qs: Vec<_> = queries
79                        .iter()
80                        .map(|q| {
81                            let mut obj = json!({
82                                "id": q.id.as_str(),
83                                "prompt": q.prompt,
84                                "system": q.system,
85                                "max_tokens": q.max_tokens,
86                            });
87                            if q.grounded {
88                                obj["grounded"] = json!(true);
89                            }
90                            if q.underspecified {
91                                obj["underspecified"] = json!(true);
92                            }
93                            obj
94                        })
95                        .collect();
96                    json!({
97                        "status": "needs_response",
98                        "session_id": session_id,
99                        "queries": qs,
100                    })
101                }
102            }
103            Self::Finished(result) => match &result.state {
104                TerminalState::Completed { result: val } => json!({
105                    "status": "completed",
106                    "result": val,
107                    "stats": result.metrics.to_json(),
108                }),
109                TerminalState::Failed { error } => json!({
110                    "status": "error",
111                    "error": error,
112                }),
113                TerminalState::Cancelled => json!({
114                    "status": "cancelled",
115                    "stats": result.metrics.to_json(),
116                }),
117            },
118        }
119    }
120}
121
122// ─── Session ─────────────────────────────────────────────────
123
124/// A Lua execution session with domain state tracking.
125///
126/// Each session owns a dedicated Lua VM via `_vm_driver`. The VM's OS thread
127/// stays alive as long as the driver is held, and exits cleanly when the
128/// session is dropped (channel closes → Lua thread drains and exits).
129pub struct Session {
130    state: ExecutionState,
131    metrics: ExecutionMetrics,
132    observer: MetricsObserver,
133    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
134    exec_task: AsyncTask,
135    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
136    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
137    /// Per-session VM lifecycle driver. Keeps the Lua thread alive.
138    /// Dropped when the session completes or is abandoned.
139    _vm_driver: AsyncIsleDriver,
140}
141
142impl Session {
143    pub fn new(
144        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
145        exec_task: AsyncTask,
146        metrics: ExecutionMetrics,
147        vm_driver: AsyncIsleDriver,
148    ) -> Self {
149        let observer = metrics.create_observer();
150        Self {
151            state: ExecutionState::Running,
152            metrics,
153            observer,
154            llm_rx,
155            exec_task,
156            resp_txs: HashMap::new(),
157            _vm_driver: vm_driver,
158        }
159    }
160
161    /// Wait for the next event from Lua execution.
162    ///
163    /// Called after initial start or after feeding all responses.
164    /// State must be Running when called.
165    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
166        tokio::select! {
167            result = &mut self.exec_task => {
168                match result {
169                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
170                        Ok(v) => {
171                            self.state.complete(v.clone()).map_err(|e| {
172                                SessionError::InvalidTransition(e.to_string())
173                            })?;
174                            self.observer.on_completed(&v);
175                            Ok(FeedResult::Finished(ExecutionResult {
176                                state: TerminalState::Completed { result: v },
177                                metrics: self.take_metrics(),
178                            }))
179                        }
180                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
181                    },
182                    Err(e) => self.fail_with(e.to_string()),
183                }
184            }
185            Some(req) = self.llm_rx.recv() => {
186                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
187                    id: qr.id.clone(),
188                    prompt: qr.prompt.clone(),
189                    system: qr.system.clone(),
190                    max_tokens: qr.max_tokens,
191                    grounded: qr.grounded,
192                    underspecified: qr.underspecified,
193                }).collect();
194
195                for qr in req.queries {
196                    self.resp_txs.insert(qr.id, qr.resp_tx);
197                }
198
199                self.state.pause(queries.clone()).map_err(|e| {
200                    SessionError::InvalidTransition(e.to_string())
201                })?;
202                self.observer.on_paused(&queries);
203                Ok(FeedResult::Paused { queries })
204            }
205        }
206    }
207
208    /// Feed one response by query_id.
209    ///
210    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
211    fn feed_one(
212        &mut self,
213        query_id: &QueryId,
214        response: String,
215        usage: Option<&algocline_core::TokenUsage>,
216    ) -> Result<bool, SessionError> {
217        // Track response before ownership transfer.
218        self.observer.on_response_fed(query_id, &response, usage);
219
220        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
221        if let Some(tx) = self.resp_txs.remove(query_id) {
222            let _ = tx.send(Ok(response.clone()));
223        }
224
225        // Domain: record in state machine
226        let complete = self
227            .state
228            .feed(query_id, response)
229            .map_err(SessionError::Feed)?;
230
231        if complete {
232            // Domain: transition Paused(complete) → Running
233            self.state
234                .take_responses()
235                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
236            self.observer.on_resumed();
237        } else {
238            self.observer
239                .on_partial_feed(query_id, self.state.remaining());
240        }
241
242        Ok(complete)
243    }
244
245    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
246        self.state
247            .fail(msg.clone())
248            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
249        self.observer.on_failed(&msg);
250        Ok(FeedResult::Finished(ExecutionResult {
251            state: TerminalState::Failed { error: msg },
252            metrics: self.take_metrics(),
253        }))
254    }
255
256    fn take_metrics(&mut self) -> ExecutionMetrics {
257        std::mem::take(&mut self.metrics)
258    }
259
260    /// Lightweight snapshot for external observation (alc_status).
261    ///
262    /// Returns session state label and running metrics without consuming
263    /// or modifying the session.
264    pub fn snapshot(&self) -> serde_json::Value {
265        let state_label = match &self.state {
266            ExecutionState::Running => "running",
267            ExecutionState::Paused(_) => "paused",
268            _ => "terminal",
269        };
270
271        let mut json = serde_json::json!({
272            "state": state_label,
273        });
274
275        let metrics = self.metrics.snapshot();
276        if !metrics.is_null() {
277            json["metrics"] = metrics;
278        }
279
280        // Include pending query count when paused
281        if let ExecutionState::Paused(_) = &self.state {
282            json["pending_queries"] = self.state.remaining().into();
283        }
284
285        json
286    }
287}
288
289// ─── Registry ────────────────────────────────────────────────
290
291/// Manages active sessions.
292///
293/// # Locking design (lock **C**)
294///
295/// Uses `tokio::sync::Mutex` because `feed_response` holds the lock
296/// while calling `Session::feed_one()` (which itself acquires the
297/// per-session `std::sync::Mutex<SessionStatus>`, lock **A**). The lock
298/// ordering invariant is always **C → A** — no code path acquires A
299/// then C, so deadlock is structurally impossible.
300///
301/// `tokio::sync::Mutex` is chosen here (rather than `std::sync::Mutex`)
302/// because `feed_response` must take the session out of the map for
303/// the async `wait_event()` call. The two-phase pattern (lock → remove
304/// → unlock → await → lock → reinsert) requires an async-aware mutex
305/// to avoid holding the lock across the `wait_event().await`.
306///
307/// ## Contention
308///
309/// `list_snapshots()` (from `alc_status`) holds lock C while iterating
310/// all sessions. During this time, `feed_response` for any session is
311/// blocked. Given that snapshot iteration is O(n) with n = active
312/// sessions (typically 1–3) and each snapshot takes microseconds, this
313/// is acceptable. If session count grows significantly, consider
314/// switching to a concurrent map or per-session locks.
315///
316/// ## Interaction with lock A
317///
318/// `Session::snapshot()` (called under lock C in `list_snapshots`)
319/// acquires lock A via `ExecutionMetrics::snapshot()`. This is safe:
320/// - Lock order: C → A (consistent with `feed_response`)
321/// - Lock A hold time: microseconds (JSON field reads)
322/// - Lock A is per-session (no cross-session contention)
323pub struct SessionRegistry {
324    sessions: Arc<Mutex<HashMap<String, Session>>>,
325}
326
327impl Default for SessionRegistry {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333impl SessionRegistry {
334    pub fn new() -> Self {
335        Self {
336            sessions: Arc::new(Mutex::new(HashMap::new())),
337        }
338    }
339
340    /// Start execution and wait for first event (pause or completion).
341    pub async fn start_execution(
342        &self,
343        mut session: Session,
344    ) -> Result<(String, FeedResult), SessionError> {
345        let session_id = gen_session_id();
346        let result = session.wait_event().await?;
347
348        if matches!(result, FeedResult::Paused { .. }) {
349            self.sessions
350                .lock()
351                .await
352                .insert(session_id.clone(), session);
353        }
354
355        Ok((session_id, result))
356    }
357
358    /// Feed one response to a paused session by query_id.
359    ///
360    /// If this completes all pending queries, the session resumes and
361    /// returns the next event (Paused or Finished).
362    /// If queries remain, returns Accepted { remaining }.
363    pub async fn feed_response(
364        &self,
365        session_id: &str,
366        query_id: &QueryId,
367        response: String,
368        usage: Option<&algocline_core::TokenUsage>,
369    ) -> Result<FeedResult, SessionError> {
370        // 1. Feed under lock
371        let complete = {
372            let mut map = self.sessions.lock().await;
373            let session = map
374                .get_mut(session_id)
375                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
376
377            let complete = session.feed_one(query_id, response, usage)?;
378
379            if !complete {
380                return Ok(FeedResult::Accepted {
381                    remaining: session.state.remaining(),
382                });
383            }
384
385            complete
386        };
387
388        // 2. All complete → take session out for async resume
389        debug_assert!(complete);
390        let mut session = {
391            let mut map = self.sessions.lock().await;
392            map.remove(session_id)
393                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
394        };
395
396        let result = session.wait_event().await?;
397
398        if matches!(result, FeedResult::Paused { .. }) {
399            self.sessions
400                .lock()
401                .await
402                .insert(session_id.into(), session);
403        }
404
405        Ok(result)
406    }
407
408    /// Resolve the sole pending query ID for a session.
409    ///
410    /// When `alc_continue` is called without an explicit `query_id`, this
411    /// method checks if exactly one query is pending and returns its ID.
412    /// Returns an error if zero or multiple queries are pending.
413    pub async fn resolve_sole_pending_id(&self, session_id: &str) -> Result<QueryId, SessionError> {
414        let map = self.sessions.lock().await;
415        let session = map
416            .get(session_id)
417            .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
418        let keys: Vec<QueryId> = session.resp_txs.keys().cloned().collect();
419        match keys.len() {
420            0 => Err(SessionError::InvalidTransition("no pending queries".into())),
421            1 => keys
422                .into_iter()
423                .next()
424                .ok_or_else(|| SessionError::InvalidTransition("unexpected empty keys".into())),
425            n => Err(SessionError::InvalidTransition(format!(
426                "{n} queries pending; specify query_id explicitly"
427            ))),
428        }
429    }
430
431    /// Snapshot all active sessions for external observation (alc_status).
432    ///
433    /// Returns a map of session_id → snapshot JSON. Only includes sessions
434    /// currently held in the registry (i.e. paused, awaiting responses).
435    /// Sessions that have completed are already removed from the registry.
436    pub async fn list_snapshots(&self) -> HashMap<String, serde_json::Value> {
437        let map = self.sessions.lock().await;
438        map.iter()
439            .map(|(id, session)| (id.clone(), session.snapshot()))
440            .collect()
441    }
442}
443
444/// Generate a non-deterministic session ID.
445///
446/// MCP spec requires "secure, non-deterministic session IDs" to prevent
447/// session hijacking. Uses timestamp + random bytes for uniqueness and
448/// unpredictability.
449///
450/// # `unwrap_or_default` on `duration_since(UNIX_EPOCH)`
451///
452/// `SystemTime::now().duration_since(UNIX_EPOCH)` can fail if the system
453/// clock is set before 1970-01-01 (e.g. NTP drift, misconfigured VM).
454/// The Rust std docs recommend `expect()` or `match` for explicit handling,
455/// but `expect` would panic in library code (prohibited by project policy).
456///
457/// `unwrap_or_default` returns `Duration::ZERO` on failure, yielding
458/// timestamp `0`. This is acceptable here because the 8-byte random
459/// suffix (16 hex chars of entropy) independently guarantees uniqueness
460/// and unpredictability — the timestamp is a convenience prefix, not
461/// a security-critical component.
462fn gen_session_id() -> String {
463    use std::time::{SystemTime, UNIX_EPOCH};
464    let ts = SystemTime::now()
465        .duration_since(UNIX_EPOCH)
466        .unwrap_or_default()
467        .as_nanos();
468    // 8 random bytes → 16 hex chars of entropy
469    let random: u64 = {
470        use std::collections::hash_map::RandomState;
471        use std::hash::{BuildHasher, Hasher};
472        let s = RandomState::new();
473        let mut h = s.build_hasher();
474        h.write_u128(ts);
475        h.finish()
476    };
477    format!("s-{ts:x}-{random:016x}")
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
484    use serde_json::json;
485
486    fn make_query(index: usize) -> LlmQuery {
487        LlmQuery {
488            id: QueryId::batch(index),
489            prompt: format!("prompt-{index}"),
490            system: None,
491            max_tokens: 100,
492            grounded: false,
493            underspecified: false,
494        }
495    }
496
497    // ─── FeedResult::to_json tests ───
498
499    #[test]
500    fn to_json_accepted() {
501        let result = FeedResult::Accepted { remaining: 3 };
502        let json = result.to_json("s-123");
503        assert_eq!(json["status"], "accepted");
504        assert_eq!(json["remaining"], 3);
505    }
506
507    #[test]
508    fn to_json_paused_single_query() {
509        let query = LlmQuery {
510            id: QueryId::single(),
511            prompt: "What is 2+2?".into(),
512            system: Some("You are a calculator.".into()),
513            max_tokens: 50,
514            grounded: false,
515            underspecified: false,
516        };
517        let result = FeedResult::Paused {
518            queries: vec![query],
519        };
520        let json = result.to_json("s-abc");
521
522        assert_eq!(json["status"], "needs_response");
523        assert_eq!(json["session_id"], "s-abc");
524        assert_eq!(json["prompt"], "What is 2+2?");
525        assert_eq!(json["system"], "You are a calculator.");
526        assert_eq!(json["max_tokens"], 50);
527        // single query mode: no "queries" array
528        assert!(json.get("queries").is_none());
529        // grounded=false must be absent
530        assert!(
531            json.get("grounded").is_none(),
532            "grounded key must be absent when false"
533        );
534        // underspecified=false must be absent
535        assert!(
536            json.get("underspecified").is_none(),
537            "underspecified key must be absent when false"
538        );
539    }
540
541    #[test]
542    fn to_json_paused_single_query_grounded() {
543        let query = LlmQuery {
544            id: QueryId::single(),
545            prompt: "verify this claim".into(),
546            system: None,
547            max_tokens: 200,
548            grounded: true,
549            underspecified: false,
550        };
551        let result = FeedResult::Paused {
552            queries: vec![query],
553        };
554        let json = result.to_json("s-grounded");
555
556        assert_eq!(json["status"], "needs_response");
557        assert_eq!(
558            json["grounded"], true,
559            "grounded must appear in single-query MCP JSON"
560        );
561    }
562
563    #[test]
564    fn to_json_paused_single_query_underspecified() {
565        let query = LlmQuery {
566            id: QueryId::single(),
567            prompt: "what output format do you need?".into(),
568            system: None,
569            max_tokens: 200,
570            grounded: false,
571            underspecified: true,
572        };
573        let result = FeedResult::Paused {
574            queries: vec![query],
575        };
576        let json = result.to_json("s-underspec");
577
578        assert_eq!(json["status"], "needs_response");
579        assert_eq!(
580            json["underspecified"], true,
581            "underspecified must appear in single-query MCP JSON"
582        );
583        assert!(
584            json.get("grounded").is_none(),
585            "grounded must be absent when false"
586        );
587    }
588
589    #[test]
590    fn to_json_paused_multiple_queries_mixed_grounded() {
591        let grounded_query = LlmQuery {
592            id: QueryId::batch(0),
593            prompt: "verify".into(),
594            system: None,
595            max_tokens: 100,
596            grounded: true,
597            underspecified: false,
598        };
599        let normal_query = LlmQuery {
600            id: QueryId::batch(1),
601            prompt: "generate".into(),
602            system: None,
603            max_tokens: 100,
604            grounded: false,
605            underspecified: false,
606        };
607        let result = FeedResult::Paused {
608            queries: vec![grounded_query, normal_query],
609        };
610        let json = result.to_json("s-batch");
611
612        let qs = json["queries"].as_array().expect("queries should be array");
613        assert_eq!(
614            qs[0]["grounded"], true,
615            "grounded query must have grounded=true"
616        );
617        assert!(
618            qs[1].get("grounded").is_none(),
619            "non-grounded query must omit grounded key"
620        );
621    }
622
623    #[test]
624    fn to_json_paused_multiple_queries_mixed_underspecified() {
625        let underspec_query = LlmQuery {
626            id: QueryId::batch(0),
627            prompt: "clarify intent".into(),
628            system: None,
629            max_tokens: 100,
630            grounded: false,
631            underspecified: true,
632        };
633        let normal_query = LlmQuery {
634            id: QueryId::batch(1),
635            prompt: "generate".into(),
636            system: None,
637            max_tokens: 100,
638            grounded: false,
639            underspecified: false,
640        };
641        let result = FeedResult::Paused {
642            queries: vec![underspec_query, normal_query],
643        };
644        let json = result.to_json("s-batch-us");
645
646        let qs = json["queries"].as_array().expect("queries should be array");
647        assert_eq!(
648            qs[0]["underspecified"], true,
649            "underspecified query must have underspecified=true"
650        );
651        assert!(
652            qs[1].get("underspecified").is_none(),
653            "non-underspecified query must omit underspecified key"
654        );
655    }
656
657    #[test]
658    fn to_json_paused_single_query_no_system() {
659        let query = LlmQuery {
660            id: QueryId::single(),
661            prompt: "hello".into(),
662            system: None,
663            max_tokens: 1024,
664            grounded: false,
665            underspecified: false,
666        };
667        let result = FeedResult::Paused {
668            queries: vec![query],
669        };
670        let json = result.to_json("s-x");
671
672        assert_eq!(json["status"], "needs_response");
673        assert!(json["system"].is_null());
674    }
675
676    #[test]
677    fn to_json_paused_multiple_queries() {
678        let queries = vec![make_query(0), make_query(1), make_query(2)];
679        let result = FeedResult::Paused { queries };
680        let json = result.to_json("s-multi");
681
682        assert_eq!(json["status"], "needs_response");
683        assert_eq!(json["session_id"], "s-multi");
684
685        let qs = json["queries"].as_array().expect("queries should be array");
686        assert_eq!(qs.len(), 3);
687        assert_eq!(qs[0]["id"], "q-0");
688        assert_eq!(qs[0]["prompt"], "prompt-0");
689        assert_eq!(qs[1]["id"], "q-1");
690        assert_eq!(qs[2]["id"], "q-2");
691    }
692
693    #[test]
694    fn to_json_finished_completed() {
695        let result = FeedResult::Finished(ExecutionResult {
696            state: TerminalState::Completed {
697                result: json!({"answer": 42}),
698            },
699            metrics: ExecutionMetrics::new(),
700        });
701        let json = result.to_json("s-done");
702
703        assert_eq!(json["status"], "completed");
704        assert_eq!(json["result"]["answer"], 42);
705        assert!(json.get("stats").is_some());
706    }
707
708    #[test]
709    fn to_json_finished_failed() {
710        let result = FeedResult::Finished(ExecutionResult {
711            state: TerminalState::Failed {
712                error: "lua error: bad argument".into(),
713            },
714            metrics: ExecutionMetrics::new(),
715        });
716        let json = result.to_json("s-err");
717
718        assert_eq!(json["status"], "error");
719        assert_eq!(json["error"], "lua error: bad argument");
720    }
721
722    #[test]
723    fn to_json_finished_cancelled() {
724        let result = FeedResult::Finished(ExecutionResult {
725            state: TerminalState::Cancelled,
726            metrics: ExecutionMetrics::new(),
727        });
728        let json = result.to_json("s-cancel");
729
730        assert_eq!(json["status"], "cancelled");
731        assert!(json.get("stats").is_some());
732    }
733
734    // ─── gen_session_id tests ───
735
736    #[test]
737    fn session_id_starts_with_prefix() {
738        let id = gen_session_id();
739        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
740    }
741
742    #[test]
743    fn session_id_uniqueness() {
744        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
745        let set: std::collections::HashSet<&String> = ids.iter().collect();
746        assert_eq!(set.len(), 10, "10 IDs should all be unique");
747    }
748}