Skip to main content

algocline_engine/
session.rs

1//! Session-based Lua execution with pause/resume on alc.llm() calls.
2//!
3//! Runtime layer: ties Domain (ExecutionState) and Metrics (ExecutionMetrics)
4//! together with channel-based Lua pause/resume machinery.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use algocline_core::{
10    ExecutionMetrics, ExecutionObserver, ExecutionState, LlmQuery, MetricsObserver, QueryId,
11    TerminalState,
12};
13use mlua_isle::AsyncTask;
14use serde_json::json;
15use tokio::sync::Mutex;
16
17use crate::llm_bridge::LlmRequest;
18
19// ─── Error types (Runtime layer) ─────────────────────────────
20
21#[derive(Debug, thiserror::Error)]
22pub enum SessionError {
23    #[error("session '{0}' not found")]
24    NotFound(String),
25    #[error(transparent)]
26    Feed(#[from] algocline_core::FeedError),
27    #[error("invalid transition: {0}")]
28    InvalidTransition(String),
29}
30
31// ─── Result types (Runtime layer) ────────────────────────────
32
33/// Session completion data: terminal state + metrics.
34pub struct ExecutionResult {
35    pub state: TerminalState,
36    pub metrics: ExecutionMetrics,
37}
38
39/// Result of a session interaction (start or feed).
40pub enum FeedResult {
41    /// Partial feed accepted, still waiting for more responses.
42    Accepted { remaining: usize },
43    /// All queries answered, Lua re-paused with new queries.
44    Paused { queries: Vec<LlmQuery> },
45    /// Execution completed (success, failure, or cancellation).
46    Finished(ExecutionResult),
47}
48
49impl FeedResult {
50    /// Convert to JSON for MCP tool response.
51    pub fn to_json(&self, session_id: &str) -> serde_json::Value {
52        match self {
53            Self::Accepted { remaining } => json!({
54                "status": "accepted",
55                "remaining": remaining,
56            }),
57            Self::Paused { queries } => {
58                if queries.len() == 1 {
59                    let q = &queries[0];
60                    json!({
61                        "status": "needs_response",
62                        "session_id": session_id,
63                        "prompt": q.prompt,
64                        "system": q.system,
65                        "max_tokens": q.max_tokens,
66                    })
67                } else {
68                    let qs: Vec<_> = queries
69                        .iter()
70                        .map(|q| {
71                            json!({
72                                "id": q.id.as_str(),
73                                "prompt": q.prompt,
74                                "system": q.system,
75                                "max_tokens": q.max_tokens,
76                            })
77                        })
78                        .collect();
79                    json!({
80                        "status": "needs_response",
81                        "session_id": session_id,
82                        "queries": qs,
83                    })
84                }
85            }
86            Self::Finished(result) => match &result.state {
87                TerminalState::Completed { result: val } => json!({
88                    "status": "completed",
89                    "result": val,
90                    "stats": result.metrics.to_json(),
91                }),
92                TerminalState::Failed { error } => json!({
93                    "status": "error",
94                    "error": error,
95                }),
96                TerminalState::Cancelled => json!({
97                    "status": "cancelled",
98                    "stats": result.metrics.to_json(),
99                }),
100            },
101        }
102    }
103}
104
105// ─── Session ─────────────────────────────────────────────────
106
107/// A Lua execution session with domain state tracking.
108pub struct Session {
109    state: ExecutionState,
110    metrics: ExecutionMetrics,
111    observer: MetricsObserver,
112    llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
113    exec_task: AsyncTask,
114    /// QueryId → resp_tx. Populated on Paused, cleared on resume.
115    resp_txs: HashMap<QueryId, tokio::sync::oneshot::Sender<Result<String, String>>>,
116}
117
118impl Session {
119    pub fn new(
120        llm_rx: tokio::sync::mpsc::Receiver<LlmRequest>,
121        exec_task: AsyncTask,
122        metrics: ExecutionMetrics,
123    ) -> Self {
124        let observer = metrics.create_observer();
125        Self {
126            state: ExecutionState::Running,
127            metrics,
128            observer,
129            llm_rx,
130            exec_task,
131            resp_txs: HashMap::new(),
132        }
133    }
134
135    /// Wait for the next event from Lua execution.
136    ///
137    /// Called after initial start or after feeding all responses.
138    /// State must be Running when called.
139    async fn wait_event(&mut self) -> Result<FeedResult, SessionError> {
140        tokio::select! {
141            result = &mut self.exec_task => {
142                match result {
143                    Ok(json_str) => match serde_json::from_str::<serde_json::Value>(&json_str) {
144                        Ok(v) => {
145                            self.state.complete(v.clone()).map_err(|e| {
146                                SessionError::InvalidTransition(e.to_string())
147                            })?;
148                            self.observer.on_completed(&v);
149                            Ok(FeedResult::Finished(ExecutionResult {
150                                state: TerminalState::Completed { result: v },
151                                metrics: self.take_metrics(),
152                            }))
153                        }
154                        Err(e) => self.fail_with(format!("JSON parse: {e}")),
155                    },
156                    Err(e) => self.fail_with(e.to_string()),
157                }
158            }
159            Some(req) = self.llm_rx.recv() => {
160                let queries: Vec<LlmQuery> = req.queries.iter().map(|qr| LlmQuery {
161                    id: qr.id.clone(),
162                    prompt: qr.prompt.clone(),
163                    system: qr.system.clone(),
164                    max_tokens: qr.max_tokens,
165                }).collect();
166
167                for qr in req.queries {
168                    self.resp_txs.insert(qr.id, qr.resp_tx);
169                }
170
171                self.state.pause(queries.clone()).map_err(|e| {
172                    SessionError::InvalidTransition(e.to_string())
173                })?;
174                self.observer.on_paused(&queries);
175                Ok(FeedResult::Paused { queries })
176            }
177        }
178    }
179
180    /// Feed one response by query_id.
181    ///
182    /// Returns Ok(true) if all queries are now complete, Ok(false) if still waiting.
183    fn feed_one(&mut self, query_id: &QueryId, response: String) -> Result<bool, SessionError> {
184        // Track response before ownership transfer.
185        self.observer.on_response_fed(query_id, &response);
186
187        // Runtime: send response to Lua thread (unblocks resp_rx.recv())
188        if let Some(tx) = self.resp_txs.remove(query_id) {
189            let _ = tx.send(Ok(response.clone()));
190        }
191
192        // Domain: record in state machine
193        let complete = self
194            .state
195            .feed(query_id, response)
196            .map_err(SessionError::Feed)?;
197
198        if complete {
199            // Domain: transition Paused(complete) → Running
200            self.state
201                .take_responses()
202                .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
203            self.observer.on_resumed();
204        } else {
205            self.observer
206                .on_partial_feed(query_id, self.state.remaining());
207        }
208
209        Ok(complete)
210    }
211
212    fn fail_with(&mut self, msg: String) -> Result<FeedResult, SessionError> {
213        self.state
214            .fail(msg.clone())
215            .map_err(|e| SessionError::InvalidTransition(e.to_string()))?;
216        self.observer.on_failed(&msg);
217        Ok(FeedResult::Finished(ExecutionResult {
218            state: TerminalState::Failed { error: msg },
219            metrics: self.take_metrics(),
220        }))
221    }
222
223    fn take_metrics(&mut self) -> ExecutionMetrics {
224        std::mem::take(&mut self.metrics)
225    }
226}
227
228// ─── Registry ────────────────────────────────────────────────
229
230/// Manages active sessions. Replaces the raw SessionMap.
231pub struct SessionRegistry {
232    sessions: Arc<Mutex<HashMap<String, Session>>>,
233}
234
235impl Default for SessionRegistry {
236    fn default() -> Self {
237        Self::new()
238    }
239}
240
241impl SessionRegistry {
242    pub fn new() -> Self {
243        Self {
244            sessions: Arc::new(Mutex::new(HashMap::new())),
245        }
246    }
247
248    /// Start execution and wait for first event (pause or completion).
249    pub async fn start_execution(
250        &self,
251        mut session: Session,
252    ) -> Result<(String, FeedResult), SessionError> {
253        let session_id = gen_session_id();
254        let result = session.wait_event().await?;
255
256        if matches!(result, FeedResult::Paused { .. }) {
257            self.sessions
258                .lock()
259                .await
260                .insert(session_id.clone(), session);
261        }
262
263        Ok((session_id, result))
264    }
265
266    /// Feed one response to a paused session by query_id.
267    ///
268    /// If this completes all pending queries, the session resumes and
269    /// returns the next event (Paused or Finished).
270    /// If queries remain, returns Accepted { remaining }.
271    pub async fn feed_response(
272        &self,
273        session_id: &str,
274        query_id: &QueryId,
275        response: String,
276    ) -> Result<FeedResult, SessionError> {
277        // 1. Feed under lock
278        let complete = {
279            let mut map = self.sessions.lock().await;
280            let session = map
281                .get_mut(session_id)
282                .ok_or_else(|| SessionError::NotFound(session_id.into()))?;
283
284            let complete = session.feed_one(query_id, response)?;
285
286            if !complete {
287                return Ok(FeedResult::Accepted {
288                    remaining: session.state.remaining(),
289                });
290            }
291
292            complete
293        };
294
295        // 2. All complete → take session out for async resume
296        debug_assert!(complete);
297        let mut session = {
298            let mut map = self.sessions.lock().await;
299            map.remove(session_id)
300                .ok_or_else(|| SessionError::NotFound(session_id.into()))?
301        };
302
303        let result = session.wait_event().await?;
304
305        if matches!(result, FeedResult::Paused { .. }) {
306            self.sessions
307                .lock()
308                .await
309                .insert(session_id.into(), session);
310        }
311
312        Ok(result)
313    }
314}
315
316/// Generate a non-deterministic session ID.
317///
318/// MCP spec requires "secure, non-deterministic session IDs" to prevent
319/// session hijacking. Uses timestamp + random bytes for uniqueness and
320/// unpredictability.
321fn gen_session_id() -> String {
322    use std::time::{SystemTime, UNIX_EPOCH};
323    let ts = SystemTime::now()
324        .duration_since(UNIX_EPOCH)
325        .unwrap_or_default()
326        .as_nanos();
327    // 8 random bytes → 16 hex chars of entropy
328    let random: u64 = {
329        use std::collections::hash_map::RandomState;
330        use std::hash::{BuildHasher, Hasher};
331        let s = RandomState::new();
332        let mut h = s.build_hasher();
333        h.write_u128(ts);
334        h.finish()
335    };
336    format!("s-{ts:x}-{random:016x}")
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use algocline_core::{ExecutionMetrics, LlmQuery, QueryId};
343    use serde_json::json;
344
345    fn make_query(index: usize) -> LlmQuery {
346        LlmQuery {
347            id: QueryId::batch(index),
348            prompt: format!("prompt-{index}"),
349            system: None,
350            max_tokens: 100,
351        }
352    }
353
354    // ─── FeedResult::to_json tests ───
355
356    #[test]
357    fn to_json_accepted() {
358        let result = FeedResult::Accepted { remaining: 3 };
359        let json = result.to_json("s-123");
360        assert_eq!(json["status"], "accepted");
361        assert_eq!(json["remaining"], 3);
362    }
363
364    #[test]
365    fn to_json_paused_single_query() {
366        let query = LlmQuery {
367            id: QueryId::single(),
368            prompt: "What is 2+2?".into(),
369            system: Some("You are a calculator.".into()),
370            max_tokens: 50,
371        };
372        let result = FeedResult::Paused {
373            queries: vec![query],
374        };
375        let json = result.to_json("s-abc");
376
377        assert_eq!(json["status"], "needs_response");
378        assert_eq!(json["session_id"], "s-abc");
379        assert_eq!(json["prompt"], "What is 2+2?");
380        assert_eq!(json["system"], "You are a calculator.");
381        assert_eq!(json["max_tokens"], 50);
382        // single query mode: no "queries" array
383        assert!(json.get("queries").is_none());
384    }
385
386    #[test]
387    fn to_json_paused_single_query_no_system() {
388        let query = LlmQuery {
389            id: QueryId::single(),
390            prompt: "hello".into(),
391            system: None,
392            max_tokens: 1024,
393        };
394        let result = FeedResult::Paused {
395            queries: vec![query],
396        };
397        let json = result.to_json("s-x");
398
399        assert_eq!(json["status"], "needs_response");
400        assert!(json["system"].is_null());
401    }
402
403    #[test]
404    fn to_json_paused_multiple_queries() {
405        let queries = vec![make_query(0), make_query(1), make_query(2)];
406        let result = FeedResult::Paused { queries };
407        let json = result.to_json("s-multi");
408
409        assert_eq!(json["status"], "needs_response");
410        assert_eq!(json["session_id"], "s-multi");
411
412        let qs = json["queries"].as_array().expect("queries should be array");
413        assert_eq!(qs.len(), 3);
414        assert_eq!(qs[0]["id"], "q-0");
415        assert_eq!(qs[0]["prompt"], "prompt-0");
416        assert_eq!(qs[1]["id"], "q-1");
417        assert_eq!(qs[2]["id"], "q-2");
418    }
419
420    #[test]
421    fn to_json_finished_completed() {
422        let result = FeedResult::Finished(ExecutionResult {
423            state: TerminalState::Completed {
424                result: json!({"answer": 42}),
425            },
426            metrics: ExecutionMetrics::new(),
427        });
428        let json = result.to_json("s-done");
429
430        assert_eq!(json["status"], "completed");
431        assert_eq!(json["result"]["answer"], 42);
432        assert!(json.get("stats").is_some());
433    }
434
435    #[test]
436    fn to_json_finished_failed() {
437        let result = FeedResult::Finished(ExecutionResult {
438            state: TerminalState::Failed {
439                error: "lua error: bad argument".into(),
440            },
441            metrics: ExecutionMetrics::new(),
442        });
443        let json = result.to_json("s-err");
444
445        assert_eq!(json["status"], "error");
446        assert_eq!(json["error"], "lua error: bad argument");
447    }
448
449    #[test]
450    fn to_json_finished_cancelled() {
451        let result = FeedResult::Finished(ExecutionResult {
452            state: TerminalState::Cancelled,
453            metrics: ExecutionMetrics::new(),
454        });
455        let json = result.to_json("s-cancel");
456
457        assert_eq!(json["status"], "cancelled");
458        assert!(json.get("stats").is_some());
459    }
460
461    // ─── gen_session_id tests ───
462
463    #[test]
464    fn session_id_starts_with_prefix() {
465        let id = gen_session_id();
466        assert!(id.starts_with("s-"), "id should start with 's-': {id}");
467    }
468
469    #[test]
470    fn session_id_uniqueness() {
471        let ids: Vec<String> = (0..10).map(|_| gen_session_id()).collect();
472        let set: std::collections::HashSet<&String> = ids.iter().collect();
473        assert_eq!(set.len(), 10, "10 IDs should all be unique");
474    }
475}