Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34pub struct ExecutionResult {
35    pub state: TerminalState,
36    pub metrics: ExecutionMetrics,
37}
38
39/// Result of a session interaction (start or feed).
40pub enum FeedResult {
41    /// Partial feed accepted, still waiting for more responses.
42    Accepted { remaining: usize },
43    /// All queries answered, Lua re-paused with new queries.
44    Paused { queries: Vec<LlmQuery> },
45    /// Execution completed (success, failure, or cancellation).
46    Finished(ExecutionResult),
47}
48
49impl FeedResult {
50    /// Convert to JSON for MCP tool response.
51    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52        match self {
53            Self::Accepted { remaining } => json!({
54                "status": "accepted",
55                "remaining": remaining,
56            }),
57            Self::Paused { queries } => {
58                if queries.len() == 1 {
59                    let q = &queries[0];
60                    json!({
61                        "status": "needs_response",
62                        "session_id": session_id,
63                        "prompt": q.prompt,
64                        "system": q.system,
65                        "max_tokens": q.max_tokens,
66                    })
67                } else {
68                    let qs: Vec<_> = queries
69                        .iter()
70                        .map(|q| {
71                            json!({
72                                "id": q.id.as_str(),
73                                "prompt": q.prompt,
74                                "system": q.system,
75                                "max_tokens": q.max_tokens,
76                            })
77                        })
78                        .collect();
79                    json!({
80                        "status": "needs_response",
81                        "session_id": session_id,
82                        "queries": qs,
83                    })
84                }
85            }
86            Self::Finished(result) => match &result.state {
87                TerminalState::Completed { result: val } => json!({
88                    "status": "completed",
89                    "result": val,
90                    "stats": result.metrics.to_json(),
91                }),
92                TerminalState::Failed { error } => json!({
93                    "status": "error",
94                    "error": error,
95                }),
96                TerminalState::Cancelled => json!({
97                    "status": "cancelled",
98                    "stats": result.metrics.to_json(),
99                }),
100            },
101        }
102    }
103}
104
105// ─── Session ─────────────────────────────────────────────────
106
107/// A Lua execution session with domain state tracking.
108pub struct Session {
109    state: ExecutionState,
110    metrics: ExecutionMetrics,
111    observer: MetricsObserver,
112    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
113    exec_task: AsyncTask,
114    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
115    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
116}
117
118impl Session {
119    pub fn new(
120        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
121        exec_task: AsyncTask,
122        metrics: ExecutionMetrics,
123    ) -> Self {
124        let observer = metrics.create_observer();
125        Self {
126            state: ExecutionState::Running,
127            metrics,
128            observer,
129            llm_rx,
130            exec_task,
131            resp_txs: HashMap::new(),
132        }
133    }
134
135    /// Wait for the next event from Lua execution.
136    ///
137    /// Called after initial start or after feeding all responses.
138    /// State must be Running when called.
139    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
140        tokio::select! {
141            result = &mut self.exec_task => {
142                match result {
143                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
144                        Ok(v) => {
145                            self.state.complete(v.clone()).map_err(|e| {
146                                SessionError::InvalidTransition(e.to_string())
147                            })?;
148                            self.observer.on_completed(&v);
149                            Ok(FeedResult::Finished(ExecutionResult {
150                                state: TerminalState::Completed { result: v },
151                                metrics: self.take_metrics(),
152                            }))
153                        }
154                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
155                    },
156                    Err(e) => self.fail_with(e.to_string()),
157                }
158            }
159            Some(req) = self.llm_rx.recv() => {
160                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
161                    id: qr.id.clone(),
162                    prompt: qr.prompt.clone(),
163                    system: qr.system.clone(),
164                    max_tokens: qr.max_tokens,
165                }).collect();
166
167                for qr in req.queries {
168                    self.resp_txs.insert(qr.id, qr.resp_tx);
169                }
170
171                self.state.pause(queries.clone()).map_err(|e| {
172                    SessionError::InvalidTransition(e.to_string())
173                })?;
174                self.observer.on_paused(&queries);
175                Ok(FeedResult::Paused { queries })
176            }
177        }
178    }
179
180    /// Feed one response by query_id.
181    ///
182    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
183    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
184        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
185        if let Some(tx) = self.resp_txs.remove(query_id) {
186            let _ = tx.send(Ok(response.clone()));
187        }
188
189        // Domain: record in state machine
190        let complete = self
191            .state
192            .feed(query_id, response)
193            .map_err(SessionError::Feed)?;
194
195        if complete {
196            // Domain: transition Paused(complete) → Running
197            self.state
198                .take_responses()
199                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
200            self.observer.on_resumed();
201        } else {
202            self.observer
203                .on_partial_feed(query_id, self.state.remaining());
204        }
205
206        Ok(complete)
207    }
208
209    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
210        self.state
211            .fail(msg.clone())
212            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
213        self.observer.on_failed(&msg);
214        Ok(FeedResult::Finished(ExecutionResult {
215            state: TerminalState::Failed { error: msg },
216            metrics: self.take_metrics(),
217        }))
218    }
219
220    fn take_metrics(&mut self) -> ExecutionMetrics {
221        std::mem::take(&mut self.metrics)
222    }
223}
224
225// ─── Registry ────────────────────────────────────────────────
226
227/// Manages active sessions. Replaces the raw SessionMap.
228pub struct SessionRegistry {
229    sessions: Arc<Mutex<HashMap<String, Session>>>,
230}
231
232impl Default for SessionRegistry {
233    fn default() -> Self {
234        Self::new()
235    }
236}
237
238impl SessionRegistry {
239    pub fn new() -> Self {
240        Self {
241            sessions: Arc::new(Mutex::new(HashMap::new())),
242        }
243    }
244
245    /// Start execution and wait for first event (pause or completion).
246    pub async fn start_execution(
247        &self,
248        mut session: Session,
249    ) -> Result<(String, FeedResult), SessionError> {
250        let session_id = gen_session_id();
251        let result = session.wait_event().await?;
252
253        if matches!(result, FeedResult::Paused { .. }) {
254            self.sessions
255                .lock()
256                .await
257                .insert(session_id.clone(), session);
258        }
259
260        Ok((session_id, result))
261    }
262
263    /// Feed one response to a paused session by query_id.
264    ///
265    /// If this completes all pending queries, the session resumes and
266    /// returns the next event (Paused or Finished).
267    /// If queries remain, returns Accepted { remaining }.
268    pub async fn feed_response(
269        &self,
270        session_id: &str,
271        query_id: &QueryId,
272        response: String,
273    ) -> Result<FeedResult, SessionError> {
274        // 1. Feed under lock
275        let complete = {
276            let mut map = self.sessions.lock().await;
277            let session = map
278                .get_mut(session_id)
279                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
280
281            let complete = session.feed_one(query_id, response)?;
282
283            if !complete {
284                return Ok(FeedResult::Accepted {
285                    remaining: session.state.remaining(),
286                });
287            }
288
289            complete
290        };
291
292        // 2. All complete → take session out for async resume
293        debug_assert!(complete);
294        let mut session = {
295            let mut map = self.sessions.lock().await;
296            map.remove(session_id)
297                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
298        };
299
300        let result = session.wait_event().await?;
301
302        if matches!(result, FeedResult::Paused { .. }) {
303            self.sessions
304                .lock()
305                .await
306                .insert(session_id.into(), session);
307        }
308
309        Ok(result)
310    }
311}
312
313/// Generate a non-deterministic session ID.
314///
315/// MCP spec requires "secure, non-deterministic session IDs" to prevent
316/// session hijacking. Uses timestamp + random bytes for uniqueness and
317/// unpredictability.
318fn gen_session_id() -> String {
319    use std::time::{SystemTime, UNIX_EPOCH};
320    let ts = SystemTime::now()
321        .duration_since(UNIX_EPOCH)
322        .unwrap_or_default()
323        .as_nanos();
324    // 8 random bytes → 16 hex chars of entropy
325    let random: u64 = {
326        use std::collections::hash_map::RandomState;
327        use std::hash::{BuildHasher, Hasher};
328        let s = RandomState::new();
329        let mut h = s.build_hasher();
330        h.write_u128(ts);
331        h.finish()
332    };
333    format!("s-{ts:x}-{random:016x}")
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
340    use serde_json::json;
341
342    fn make_query(index: usize) -> LlmQuery {
343        LlmQuery {
344            id: QueryId::batch(index),
345            prompt: format!("prompt-{index}"),
346            system: None,
347            max_tokens: 100,
348        }
349    }
350
351    // ─── FeedResult::to_json tests ───
352
353    #[test]
354    fn to_json_accepted() {
355        let result = FeedResult::Accepted { remaining: 3 };
356        let json = result.to_json("s-123");
357        assert_eq!(json["status"], "accepted");
358        assert_eq!(json["remaining"], 3);
359    }
360
361    #[test]
362    fn to_json_paused_single_query() {
363        let query = LlmQuery {
364            id: QueryId::single(),
365            prompt: "What is 2+2?".into(),
366            system: Some("You are a calculator.".into()),
367            max_tokens: 50,
368        };
369        let result = FeedResult::Paused {
370            queries: vec![query],
371        };
372        let json = result.to_json("s-abc");
373
374        assert_eq!(json["status"], "needs_response");
375        assert_eq!(json["session_id"], "s-abc");
376        assert_eq!(json["prompt"], "What is 2+2?");
377        assert_eq!(json["system"], "You are a calculator.");
378        assert_eq!(json["max_tokens"], 50);
379        // single query mode: no "queries" array
380        assert!(json.get("queries").is_none());
381    }
382
383    #[test]
384    fn to_json_paused_single_query_no_system() {
385        let query = LlmQuery {
386            id: QueryId::single(),
387            prompt: "hello".into(),
388            system: None,
389            max_tokens: 1024,
390        };
391        let result = FeedResult::Paused {
392            queries: vec![query],
393        };
394        let json = result.to_json("s-x");
395
396        assert_eq!(json["status"], "needs_response");
397        assert!(json["system"].is_null());
398    }
399
400    #[test]
401    fn to_json_paused_multiple_queries() {
402        let queries = vec![make_query(0), make_query(1), make_query(2)];
403        let result = FeedResult::Paused { queries };
404        let json = result.to_json("s-multi");
405
406        assert_eq!(json["status"], "needs_response");
407        assert_eq!(json["session_id"], "s-multi");
408
409        let qs = json["queries"].as_array().expect("queries should be array");
410        assert_eq!(qs.len(), 3);
411        assert_eq!(qs[0]["id"], "q-0");
412        assert_eq!(qs[0]["prompt"], "prompt-0");
413        assert_eq!(qs[1]["id"], "q-1");
414        assert_eq!(qs[2]["id"], "q-2");
415    }
416
417    #[test]
418    fn to_json_finished_completed() {
419        let result = FeedResult::Finished(ExecutionResult {
420            state: TerminalState::Completed {
421                result: json!({"answer": 42}),
422            },
423            metrics: ExecutionMetrics::new(),
424        });
425        let json = result.to_json("s-done");
426
427        assert_eq!(json["status"], "completed");
428        assert_eq!(json["result"]["answer"], 42);
429        assert!(json.get("stats").is_some());
430    }
431
432    #[test]
433    fn to_json_finished_failed() {
434        let result = FeedResult::Finished(ExecutionResult {
435            state: TerminalState::Failed {
436                error: "lua error: bad argument".into(),
437            },
438            metrics: ExecutionMetrics::new(),
439        });
440        let json = result.to_json("s-err");
441
442        assert_eq!(json["status"], "error");
443        assert_eq!(json["error"], "lua error: bad argument");
444    }
445
446    #[test]
447    fn to_json_finished_cancelled() {
448        let result = FeedResult::Finished(ExecutionResult {
449            state: TerminalState::Cancelled,
450            metrics: ExecutionMetrics::new(),
451        });
452        let json = result.to_json("s-cancel");
453
454        assert_eq!(json["status"], "cancelled");
455        assert!(json.get("stats").is_some());
456    }
457
458    // ─── gen_session_id tests ───
459
460    #[test]
461    fn session_id_starts_with_prefix() {
462        let id = gen_session_id();
463        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
464    }
465
466    #[test]
467    fn session_id_uniqueness() {
468        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
469        let set: std::collections::HashSet<&String> = ids.iter().collect();
470        assert_eq!(set.len(), 10, "10 IDs should all be unique");
471    }
472}