Skip to main content

algocline_core/
state.rs

1use std::collections::HashMap;
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5
6use crate::query::{LlmQuery, QueryId};
7
8#[derive(Debug, thiserror::Error)]
9#[error("invalid state transition: expected {expected}, got {actual}")]
10pub struct TransitionError {
11    pub expected: &'static str,
12    pub actual: &'static str,
13}
14
15#[derive(Debug, thiserror::Error)]
16pub enum FeedError {
17    #[error("unknown query_id: {0}")]
18    UnknownQuery(QueryId),
19    #[error("already responded to query_id: {0}")]
20    AlreadyResponded(QueryId),
21    #[error(transparent)]
22    InvalidState(#[from] TransitionError),
23}
24
25/// Join barrier that collects N LLM responses.
26///
27/// Responses can be fed in any order and concurrency.
28/// Becomes complete when all queries have been responded to.
29#[derive(Debug, Serialize, Deserialize)]
30pub struct PendingQueries {
31    /// Issued queries (insertion order preserved via IndexMap).
32    queries: IndexMap<QueryId, LlmQuery>,
33    responses: HashMap<QueryId, String>,
34}
35
36impl PendingQueries {
37    pub fn new(queries: Vec<LlmQuery>) -> Self {
38        let map = queries
39            .into_iter()
40            .map(|q| (q.id.clone(), q))
41            .collect::<IndexMap<_, _>>();
42        Self {
43            queries: map,
44            responses: HashMap::new(),
45        }
46    }
47
48    /// Feed one response. Returns `true` if all queries are now complete.
49    pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
50        if !self.queries.contains_key(id) {
51            return Err(FeedError::UnknownQuery(id.clone()));
52        }
53        if self.responses.contains_key(id) {
54            return Err(FeedError::AlreadyResponded(id.clone()));
55        }
56        self.responses.insert(id.clone(), response);
57        Ok(self.is_complete())
58    }
59
60    pub fn pending_queries(&self) -> Vec<&LlmQuery> {
61        self.queries
62            .values()
63            .filter(|q| !self.responses.contains_key(&q.id))
64            .collect()
65    }
66
67    pub fn remaining(&self) -> usize {
68        self.queries.len() - self.responses.len()
69    }
70
71    pub fn is_complete(&self) -> bool {
72        self.responses.len() == self.queries.len()
73    }
74
75    /// Consume and return responses in query insertion order.
76    /// Corresponds to the Paused → Running transition.
77    pub fn into_ordered_responses(self) -> Vec<String> {
78        self.queries
79            .keys()
80            .map(|id| {
81                // is_complete() guarantees queries and responses share the same key set,
82                // but fall back to empty string if called without checking is_complete()
83                self.responses.get(id).cloned().unwrap_or_default()
84            })
85            .collect()
86    }
87}
88
89pub enum ExecutionState {
90    Running,
91    /// Awaiting 1..N LLM responses.
92    Paused(PendingQueries),
93    Completed {
94        result: serde_json::Value,
95    },
96    Failed {
97        error: String,
98    },
99    /// Explicit cancellation by the host.
100    Cancelled,
101}
102
103impl ExecutionState {
104    pub fn is_terminal(&self) -> bool {
105        matches!(
106            self,
107            Self::Completed { .. } | Self::Failed { .. } | Self::Cancelled
108        )
109    }
110
111    /// Number of pending queries. Returns 0 for non-Paused states.
112    pub fn remaining(&self) -> usize {
113        match self {
114            Self::Paused(pending) => pending.remaining(),
115            _ => 0,
116        }
117    }
118
119    /// Returns the state name (for error messages).
120    pub fn name(&self) -> &'static str {
121        match self {
122            Self::Running => "Running",
123            Self::Paused(_) => "Paused",
124            Self::Completed { .. } => "Completed",
125            Self::Failed { .. } => "Failed",
126            Self::Cancelled => "Cancelled",
127        }
128    }
129
130    /// Feed a response. Only valid in Paused state.
131    /// Returns `Ok(true)` when all queries are complete, `Ok(false)` otherwise.
132    pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
133        match self {
134            Self::Paused(pending) => pending.feed(id, response),
135            other => Err(TransitionError {
136                expected: "Paused",
137                actual: other.name(),
138            }
139            .into()),
140        }
141    }
142
143    /// Extract responses from a complete Paused state.
144    /// Transitions self to Running (preparing for Lua resumption).
145    pub fn take_responses(&mut self) -> Result<Vec<String>, TransitionError> {
146        match std::mem::replace(self, Self::Running) {
147            Self::Paused(pending) if pending.is_complete() => Ok(pending.into_ordered_responses()),
148            prev => {
149                let actual = prev.name();
150                *self = prev;
151                Err(TransitionError {
152                    expected: "Paused(complete)",
153                    actual,
154                })
155            }
156        }
157    }
158
159    /// Running → Completed.
160    pub fn complete(&mut self, result: serde_json::Value) -> Result<(), TransitionError> {
161        match self {
162            Self::Running => {
163                *self = Self::Completed { result };
164                Ok(())
165            }
166            other => Err(TransitionError {
167                expected: "Running",
168                actual: other.name(),
169            }),
170        }
171    }
172
173    /// Running → Failed.
174    pub fn fail(&mut self, error: String) -> Result<(), TransitionError> {
175        match self {
176            Self::Running => {
177                *self = Self::Failed { error };
178                Ok(())
179            }
180            other => Err(TransitionError {
181                expected: "Running",
182                actual: other.name(),
183            }),
184        }
185    }
186
187    /// Running → Paused (triggered by alc.llm() / alc.llm_batch()).
188    pub fn pause(&mut self, queries: Vec<LlmQuery>) -> Result<(), TransitionError> {
189        match self {
190            Self::Running => {
191                *self = Self::Paused(PendingQueries::new(queries));
192                Ok(())
193            }
194            other => Err(TransitionError {
195                expected: "Running",
196                actual: other.name(),
197            }),
198        }
199    }
200
201    /// Running | Paused → Cancelled (explicit host cancellation).
202    pub fn cancel(&mut self) -> Result<(), TransitionError> {
203        match self {
204            Self::Running | Self::Paused(_) => {
205                *self = Self::Cancelled;
206                Ok(())
207            }
208            other => Err(TransitionError {
209                expected: "Running or Paused",
210                actual: other.name(),
211            }),
212        }
213    }
214}
215
216/// Return type of Session.resume(). Never returns Running.
217pub enum ResumeOutcome {
218    /// Lua resumed and paused again at alc.llm().
219    Paused {
220        queries: Vec<LlmQuery>,
221    },
222    Completed {
223        result: serde_json::Value,
224    },
225    Failed {
226        error: String,
227    },
228    /// Cancelled during resume.
229    Cancelled,
230}
231
232/// Terminal execution state. Only Completed, Failed, or Cancelled.
233#[derive(Debug)]
234pub enum TerminalState {
235    Completed { result: serde_json::Value },
236    Failed { error: String },
237    Cancelled,
238}
239
240impl TryFrom<ExecutionState> for TerminalState {
241    type Error = TransitionError;
242
243    fn try_from(state: ExecutionState) -> Result<Self, TransitionError> {
244        match state {
245            ExecutionState::Completed { result } => Ok(Self::Completed { result }),
246            ExecutionState::Failed { error } => Ok(Self::Failed { error }),
247            ExecutionState::Cancelled => Ok(Self::Cancelled),
248            other => Err(TransitionError {
249                expected: "Completed, Failed, or Cancelled",
250                actual: other.name(),
251            }),
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::query::{LlmQuery, QueryId};
260    use serde_json::json;
261
262    fn make_query(index: usize) -> LlmQuery {
263        LlmQuery {
264            id: QueryId::batch(index),
265            prompt: format!("prompt-{index}"),
266            system: None,
267            max_tokens: 100,
268        }
269    }
270
271    // ─── PendingQueries tests ───
272
273    #[test]
274    fn pending_queries_single_feed() {
275        let mut pq = PendingQueries::new(vec![make_query(0)]);
276        assert_eq!(pq.remaining(), 1);
277        assert!(!pq.is_complete());
278
279        let complete = pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
280        assert!(complete);
281        assert_eq!(pq.remaining(), 0);
282    }
283
284    #[test]
285    fn pending_queries_multi_feed_ordering() {
286        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1), make_query(2)]);
287
288        // feed in reverse order
289        assert!(!pq.feed(&QueryId::batch(2), "resp-2".into()).unwrap());
290        assert!(!pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap());
291        assert!(pq.feed(&QueryId::batch(1), "resp-1".into()).unwrap());
292
293        // into_ordered_responses returns in insertion order
294        let responses = pq.into_ordered_responses();
295        assert_eq!(responses, vec!["resp-0", "resp-1", "resp-2"]);
296    }
297
298    #[test]
299    fn pending_queries_unknown_query_error() {
300        let mut pq = PendingQueries::new(vec![make_query(0)]);
301        let err = pq.feed(&QueryId::batch(99), "resp".into()).unwrap_err();
302        assert!(matches!(err, FeedError::UnknownQuery(_)));
303    }
304
305    #[test]
306    fn pending_queries_double_feed_error() {
307        let mut pq = PendingQueries::new(vec![make_query(0)]);
308        pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
309        let err = pq.feed(&QueryId::batch(0), "resp2".into()).unwrap_err();
310        assert!(matches!(err, FeedError::AlreadyResponded(_)));
311    }
312
313    #[test]
314    fn pending_queries_pending_list() {
315        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
316        assert_eq!(pq.pending_queries().len(), 2);
317
318        pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
319        let pending = pq.pending_queries();
320        assert_eq!(pending.len(), 1);
321        assert_eq!(pending[0].id, QueryId::batch(1));
322    }
323
324    #[test]
325    fn pending_queries_roundtrip_json() {
326        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
327        pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap();
328
329        let json = serde_json::to_value(&pq).unwrap();
330        let restored: PendingQueries = serde_json::from_value(json).unwrap();
331        assert_eq!(restored.remaining(), 1);
332        assert_eq!(restored.queries.len(), 2);
333    }
334
335    // ─── ExecutionState transition tests ───
336
337    #[test]
338    fn running_to_paused() {
339        let mut state = ExecutionState::Running;
340        state.pause(vec![make_query(0)]).unwrap();
341        assert_eq!(state.name(), "Paused");
342    }
343
344    #[test]
345    fn paused_feed_and_take() {
346        let mut state = ExecutionState::Running;
347        state.pause(vec![make_query(0), make_query(1)]).unwrap();
348
349        assert!(!state.feed(&QueryId::batch(0), "r0".into()).unwrap());
350        assert!(state.feed(&QueryId::batch(1), "r1".into()).unwrap());
351
352        let responses = state.take_responses().unwrap();
353        assert_eq!(responses, vec!["r0", "r1"]);
354        assert_eq!(state.name(), "Running");
355    }
356
357    #[test]
358    fn take_responses_incomplete_fails() {
359        let mut state = ExecutionState::Running;
360        state.pause(vec![make_query(0), make_query(1)]).unwrap();
361        state.feed(&QueryId::batch(0), "r0".into()).unwrap();
362
363        let err = state.take_responses().unwrap_err();
364        assert_eq!(err.actual, "Paused");
365        // state should remain Paused
366        assert_eq!(state.name(), "Paused");
367    }
368
369    #[test]
370    fn running_to_completed() {
371        let mut state = ExecutionState::Running;
372        state.complete(json!({"answer": 42})).unwrap();
373        assert!(state.is_terminal());
374        assert_eq!(state.name(), "Completed");
375    }
376
377    #[test]
378    fn running_to_failed() {
379        let mut state = ExecutionState::Running;
380        state.fail("boom".into()).unwrap();
381        assert!(state.is_terminal());
382        assert_eq!(state.name(), "Failed");
383    }
384
385    #[test]
386    fn cancel_from_running() {
387        let mut state = ExecutionState::Running;
388        state.cancel().unwrap();
389        assert!(state.is_terminal());
390        assert_eq!(state.name(), "Cancelled");
391    }
392
393    #[test]
394    fn cancel_from_paused() {
395        let mut state = ExecutionState::Running;
396        state.pause(vec![make_query(0)]).unwrap();
397        state.cancel().unwrap();
398        assert_eq!(state.name(), "Cancelled");
399    }
400
401    // ─── remaining() tests ───
402
403    #[test]
404    fn remaining_running_is_zero() {
405        let state = ExecutionState::Running;
406        assert_eq!(state.remaining(), 0);
407    }
408
409    #[test]
410    fn remaining_tracks_feeds() {
411        let mut state = ExecutionState::Running;
412        state
413            .pause(vec![make_query(0), make_query(1), make_query(2)])
414            .unwrap();
415        assert_eq!(state.remaining(), 3);
416
417        state.feed(&QueryId::batch(0), "r".into()).unwrap();
418        assert_eq!(state.remaining(), 2);
419
420        state.feed(&QueryId::batch(1), "r".into()).unwrap();
421        assert_eq!(state.remaining(), 1);
422    }
423
424    #[test]
425    fn remaining_terminal_is_zero() {
426        let state = ExecutionState::Completed {
427            result: json!(null),
428        };
429        assert_eq!(state.remaining(), 0);
430    }
431
432    // ─── Invalid transition tests ───
433
434    #[test]
435    fn feed_on_running_fails() {
436        let mut state = ExecutionState::Running;
437        let err = state.feed(&QueryId::single(), "r".into()).unwrap_err();
438        assert!(matches!(err, FeedError::InvalidState(_)));
439    }
440
441    #[test]
442    fn pause_on_paused_fails() {
443        let mut state = ExecutionState::Running;
444        state.pause(vec![make_query(0)]).unwrap();
445        let err = state.pause(vec![make_query(1)]).unwrap_err();
446        assert_eq!(err.expected, "Running");
447    }
448
449    #[test]
450    fn complete_on_paused_fails() {
451        let mut state = ExecutionState::Running;
452        state.pause(vec![make_query(0)]).unwrap();
453        let err = state.complete(json!(null)).unwrap_err();
454        assert_eq!(err.expected, "Running");
455    }
456
457    #[test]
458    fn cancel_on_completed_fails() {
459        let mut state = ExecutionState::Running;
460        state.complete(json!(null)).unwrap();
461        let err = state.cancel().unwrap_err();
462        assert_eq!(err.expected, "Running or Paused");
463    }
464
465    #[test]
466    fn cancel_on_failed_fails() {
467        let mut state = ExecutionState::Running;
468        state.fail("e".into()).unwrap();
469        let err = state.cancel().unwrap_err();
470        assert_eq!(err.expected, "Running or Paused");
471    }
472
473    #[test]
474    fn terminal_state_rejects_non_terminal() {
475        let state = ExecutionState::Running;
476        let err = TerminalState::try_from(state).unwrap_err();
477        assert_eq!(err.actual, "Running");
478    }
479
480    #[test]
481    fn terminal_state_from_completed() {
482        let state = ExecutionState::Completed { result: json!(42) };
483        let terminal = TerminalState::try_from(state).unwrap();
484        assert!(matches!(terminal, TerminalState::Completed { .. }));
485    }
486
487    #[test]
488    fn terminal_state_from_cancelled() {
489        let state = ExecutionState::Cancelled;
490        let terminal = TerminalState::try_from(state).unwrap();
491        assert!(matches!(terminal, TerminalState::Cancelled));
492    }
493}
494
495#[cfg(test)]
496mod proptests {
497    use super::*;
498    use crate::query::{LlmQuery, QueryId};
499    use proptest::prelude::*;
500
501    fn make_query(index: usize) -> LlmQuery {
502        LlmQuery {
503            id: QueryId::batch(index),
504            prompt: format!("prompt-{index}"),
505            system: None,
506            max_tokens: 100,
507        }
508    }
509
510    proptest! {
511        /// into_ordered_responses returns insertion order regardless of feed order.
512        #[test]
513        fn feed_order_independent(size in 1usize..8) {
514            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
515            let mut pq = PendingQueries::new(queries);
516
517            // feed in reverse order
518            for i in (0..size).rev() {
519                let _ = pq.feed(&QueryId::batch(i), format!("resp-{i}"));
520            }
521
522            let responses = pq.into_ordered_responses();
523            // must return in insertion order (0, 1, 2, ...)
524            for (i, resp) in responses.iter().enumerate() {
525                prop_assert_eq!(resp, &format!("resp-{i}"));
526            }
527        }
528
529        /// Feeding the same query twice returns AlreadyResponded error.
530        #[test]
531        fn double_feed_always_errors(size in 1usize..8, target in 0usize..8) {
532            let target = target % size; // clamp to valid range
533            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
534            let mut pq = PendingQueries::new(queries);
535
536            pq.feed(&QueryId::batch(target), "first".into()).unwrap();
537            let err = pq.feed(&QueryId::batch(target), "second".into()).unwrap_err();
538            prop_assert!(matches!(err, FeedError::AlreadyResponded(_)));
539        }
540
541        /// Feeding a non-existent query_id returns UnknownQuery error.
542        #[test]
543        fn unknown_query_always_errors(size in 1usize..8, bad_id in 100usize..200) {
544            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
545            let mut pq = PendingQueries::new(queries);
546
547            let err = pq.feed(&QueryId::batch(bad_id), "resp".into()).unwrap_err();
548            prop_assert!(matches!(err, FeedError::UnknownQuery(_)));
549        }
550
551        /// remaining() decreases by 1 with each feed.
552        #[test]
553        fn remaining_decreases_monotonically(size in 1usize..10) {
554            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
555            let mut pq = PendingQueries::new(queries);
556
557            for i in 0..size {
558                prop_assert_eq!(pq.remaining(), size - i);
559                let _ = pq.feed(&QueryId::batch(i), format!("r-{i}"));
560            }
561            prop_assert_eq!(pq.remaining(), 0);
562            prop_assert!(pq.is_complete());
563        }
564    }
565}