Skip to main content

algocline_core/
state.rs

1use std::collections::HashMap;
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5
6use crate::query::{LlmQuery, QueryId};
7
8#[derive(Debug, thiserror::Error)]
9#[error("invalid state transition: expected {expected}, got {actual}")]
10pub struct TransitionError {
11    pub expected: &'static str,
12    pub actual: &'static str,
13}
14
15#[derive(Debug, thiserror::Error)]
16pub enum FeedError {
17    #[error("unknown query_id: {0}")]
18    UnknownQuery(QueryId),
19    #[error("already responded to query_id: {0}")]
20    AlreadyResponded(QueryId),
21    #[error(transparent)]
22    InvalidState(#[from] TransitionError),
23}
24
25/// Join barrier that collects N LLM responses.
26///
27/// Responses can be fed in any order and concurrency.
28/// Becomes complete when all queries have been responded to.
29#[derive(Debug, Serialize, Deserialize)]
30pub struct PendingQueries {
31    /// Issued queries (insertion order preserved via IndexMap).
32    queries: IndexMap<QueryId, LlmQuery>,
33    responses: HashMap<QueryId, String>,
34}
35
36impl PendingQueries {
37    pub fn new(queries: Vec<LlmQuery>) -> Self {
38        let map = queries
39            .into_iter()
40            .map(|q| (q.id.clone(), q))
41            .collect::<IndexMap<_, _>>();
42        Self {
43            queries: map,
44            responses: HashMap::new(),
45        }
46    }
47
48    /// Feed one response. Returns `true` if all queries are now complete.
49    pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
50        if !self.queries.contains_key(id) {
51            return Err(FeedError::UnknownQuery(id.clone()));
52        }
53        if self.responses.contains_key(id) {
54            return Err(FeedError::AlreadyResponded(id.clone()));
55        }
56        self.responses.insert(id.clone(), response);
57        Ok(self.is_complete())
58    }
59
60    pub fn pending_queries(&self) -> Vec<&LlmQuery> {
61        self.queries
62            .values()
63            .filter(|q| !self.responses.contains_key(&q.id))
64            .collect()
65    }
66
67    pub fn remaining(&self) -> usize {
68        self.queries.len() - self.responses.len()
69    }
70
71    pub fn is_complete(&self) -> bool {
72        self.responses.len() == self.queries.len()
73    }
74
75    /// Consume and return responses in query insertion order.
76    /// Corresponds to the Paused → Running transition.
77    pub fn into_ordered_responses(self) -> Vec<String> {
78        self.queries
79            .keys()
80            .map(|id| {
81                // is_complete() guarantees queries and responses share the same key set,
82                // but fall back to empty string if called without checking is_complete()
83                self.responses.get(id).cloned().unwrap_or_default()
84            })
85            .collect()
86    }
87}
88
89pub enum ExecutionState {
90    Running,
91    /// Awaiting 1..N LLM responses.
92    Paused(PendingQueries),
93    Completed {
94        result: serde_json::Value,
95    },
96    Failed {
97        error: String,
98    },
99    /// Explicit cancellation by the host.
100    Cancelled,
101}
102
103impl ExecutionState {
104    pub fn is_terminal(&self) -> bool {
105        matches!(
106            self,
107            Self::Completed { .. } | Self::Failed { .. } | Self::Cancelled
108        )
109    }
110
111    /// Number of pending queries. Returns 0 for non-Paused states.
112    pub fn remaining(&self) -> usize {
113        match self {
114            Self::Paused(pending) => pending.remaining(),
115            _ => 0,
116        }
117    }
118
119    /// Returns the state name (for error messages).
120    pub fn name(&self) -> &'static str {
121        match self {
122            Self::Running => "Running",
123            Self::Paused(_) => "Paused",
124            Self::Completed { .. } => "Completed",
125            Self::Failed { .. } => "Failed",
126            Self::Cancelled => "Cancelled",
127        }
128    }
129
130    /// Feed a response. Only valid in Paused state.
131    /// Returns `Ok(true)` when all queries are complete, `Ok(false)` otherwise.
132    pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
133        match self {
134            Self::Paused(pending) => pending.feed(id, response),
135            other => Err(TransitionError {
136                expected: "Paused",
137                actual: other.name(),
138            }
139            .into()),
140        }
141    }
142
143    /// Extract responses from a complete Paused state.
144    /// Transitions self to Running (preparing for Lua resumption).
145    pub fn take_responses(&mut self) -> Result<Vec<String>, TransitionError> {
146        match std::mem::replace(self, Self::Running) {
147            Self::Paused(pending) if pending.is_complete() => Ok(pending.into_ordered_responses()),
148            prev => {
149                let actual = prev.name();
150                *self = prev;
151                Err(TransitionError {
152                    expected: "Paused(complete)",
153                    actual,
154                })
155            }
156        }
157    }
158
159    /// Running → Completed.
160    pub fn complete(&mut self, result: serde_json::Value) -> Result<(), TransitionError> {
161        match self {
162            Self::Running => {
163                *self = Self::Completed { result };
164                Ok(())
165            }
166            other => Err(TransitionError {
167                expected: "Running",
168                actual: other.name(),
169            }),
170        }
171    }
172
173    /// Running → Failed.
174    pub fn fail(&mut self, error: String) -> Result<(), TransitionError> {
175        match self {
176            Self::Running => {
177                *self = Self::Failed { error };
178                Ok(())
179            }
180            other => Err(TransitionError {
181                expected: "Running",
182                actual: other.name(),
183            }),
184        }
185    }
186
187    /// Running → Paused (triggered by alc.llm() / alc.llm_batch()).
188    pub fn pause(&mut self, queries: Vec<LlmQuery>) -> Result<(), TransitionError> {
189        match self {
190            Self::Running => {
191                *self = Self::Paused(PendingQueries::new(queries));
192                Ok(())
193            }
194            other => Err(TransitionError {
195                expected: "Running",
196                actual: other.name(),
197            }),
198        }
199    }
200
201    /// Running | Paused → Cancelled (explicit host cancellation).
202    pub fn cancel(&mut self) -> Result<(), TransitionError> {
203        match self {
204            Self::Running | Self::Paused(_) => {
205                *self = Self::Cancelled;
206                Ok(())
207            }
208            other => Err(TransitionError {
209                expected: "Running or Paused",
210                actual: other.name(),
211            }),
212        }
213    }
214}
215
216/// Return type of Session.resume(). Never returns Running.
217pub enum ResumeOutcome {
218    /// Lua resumed and paused again at alc.llm().
219    Paused {
220        queries: Vec<LlmQuery>,
221    },
222    Completed {
223        result: serde_json::Value,
224    },
225    Failed {
226        error: String,
227    },
228    /// Cancelled during resume.
229    Cancelled,
230}
231
232/// Terminal execution state. Only Completed, Failed, or Cancelled.
233#[derive(Debug)]
234pub enum TerminalState {
235    Completed { result: serde_json::Value },
236    Failed { error: String },
237    Cancelled,
238}
239
240impl TryFrom<ExecutionState> for TerminalState {
241    type Error = TransitionError;
242
243    fn try_from(state: ExecutionState) -> Result<Self, TransitionError> {
244        match state {
245            ExecutionState::Completed { result } => Ok(Self::Completed { result }),
246            ExecutionState::Failed { error } => Ok(Self::Failed { error }),
247            ExecutionState::Cancelled => Ok(Self::Cancelled),
248            other => Err(TransitionError {
249                expected: "Completed, Failed, or Cancelled",
250                actual: other.name(),
251            }),
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::query::{LlmQuery, QueryId};
260    use serde_json::json;
261
262    fn make_query(index: usize) -> LlmQuery {
263        LlmQuery {
264            id: QueryId::batch(index),
265            prompt: format!("prompt-{index}"),
266            system: None,
267            max_tokens: 100,
268            grounded: false,
269            underspecified: false,
270        }
271    }
272
273    // ─── PendingQueries tests ───
274
275    #[test]
276    fn pending_queries_single_feed() {
277        let mut pq = PendingQueries::new(vec![make_query(0)]);
278        assert_eq!(pq.remaining(), 1);
279        assert!(!pq.is_complete());
280
281        let complete = pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
282        assert!(complete);
283        assert_eq!(pq.remaining(), 0);
284    }
285
286    #[test]
287    fn pending_queries_multi_feed_ordering() {
288        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1), make_query(2)]);
289
290        // feed in reverse order
291        assert!(!pq.feed(&QueryId::batch(2), "resp-2".into()).unwrap());
292        assert!(!pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap());
293        assert!(pq.feed(&QueryId::batch(1), "resp-1".into()).unwrap());
294
295        // into_ordered_responses returns in insertion order
296        let responses = pq.into_ordered_responses();
297        assert_eq!(responses, vec!["resp-0", "resp-1", "resp-2"]);
298    }
299
300    #[test]
301    fn pending_queries_unknown_query_error() {
302        let mut pq = PendingQueries::new(vec![make_query(0)]);
303        let err = pq.feed(&QueryId::batch(99), "resp".into()).unwrap_err();
304        assert!(matches!(err, FeedError::UnknownQuery(_)));
305    }
306
307    #[test]
308    fn pending_queries_double_feed_error() {
309        let mut pq = PendingQueries::new(vec![make_query(0)]);
310        pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
311        let err = pq.feed(&QueryId::batch(0), "resp2".into()).unwrap_err();
312        assert!(matches!(err, FeedError::AlreadyResponded(_)));
313    }
314
315    #[test]
316    fn pending_queries_pending_list() {
317        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
318        assert_eq!(pq.pending_queries().len(), 2);
319
320        pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
321        let pending = pq.pending_queries();
322        assert_eq!(pending.len(), 1);
323        assert_eq!(pending[0].id, QueryId::batch(1));
324    }
325
326    #[test]
327    fn pending_queries_roundtrip_json() {
328        let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
329        pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap();
330
331        let json = serde_json::to_value(&pq).unwrap();
332        let restored: PendingQueries = serde_json::from_value(json).unwrap();
333        assert_eq!(restored.remaining(), 1);
334        assert_eq!(restored.queries.len(), 2);
335    }
336
337    // ─── ExecutionState transition tests ───
338
339    #[test]
340    fn running_to_paused() {
341        let mut state = ExecutionState::Running;
342        state.pause(vec![make_query(0)]).unwrap();
343        assert_eq!(state.name(), "Paused");
344    }
345
346    #[test]
347    fn paused_feed_and_take() {
348        let mut state = ExecutionState::Running;
349        state.pause(vec![make_query(0), make_query(1)]).unwrap();
350
351        assert!(!state.feed(&QueryId::batch(0), "r0".into()).unwrap());
352        assert!(state.feed(&QueryId::batch(1), "r1".into()).unwrap());
353
354        let responses = state.take_responses().unwrap();
355        assert_eq!(responses, vec!["r0", "r1"]);
356        assert_eq!(state.name(), "Running");
357    }
358
359    #[test]
360    fn take_responses_incomplete_fails() {
361        let mut state = ExecutionState::Running;
362        state.pause(vec![make_query(0), make_query(1)]).unwrap();
363        state.feed(&QueryId::batch(0), "r0".into()).unwrap();
364
365        let err = state.take_responses().unwrap_err();
366        assert_eq!(err.actual, "Paused");
367        // state should remain Paused
368        assert_eq!(state.name(), "Paused");
369    }
370
371    #[test]
372    fn running_to_completed() {
373        let mut state = ExecutionState::Running;
374        state.complete(json!({"answer": 42})).unwrap();
375        assert!(state.is_terminal());
376        assert_eq!(state.name(), "Completed");
377    }
378
379    #[test]
380    fn running_to_failed() {
381        let mut state = ExecutionState::Running;
382        state.fail("boom".into()).unwrap();
383        assert!(state.is_terminal());
384        assert_eq!(state.name(), "Failed");
385    }
386
387    #[test]
388    fn cancel_from_running() {
389        let mut state = ExecutionState::Running;
390        state.cancel().unwrap();
391        assert!(state.is_terminal());
392        assert_eq!(state.name(), "Cancelled");
393    }
394
395    #[test]
396    fn cancel_from_paused() {
397        let mut state = ExecutionState::Running;
398        state.pause(vec![make_query(0)]).unwrap();
399        state.cancel().unwrap();
400        assert_eq!(state.name(), "Cancelled");
401    }
402
403    // ─── remaining() tests ───
404
405    #[test]
406    fn remaining_running_is_zero() {
407        let state = ExecutionState::Running;
408        assert_eq!(state.remaining(), 0);
409    }
410
411    #[test]
412    fn remaining_tracks_feeds() {
413        let mut state = ExecutionState::Running;
414        state
415            .pause(vec![make_query(0), make_query(1), make_query(2)])
416            .unwrap();
417        assert_eq!(state.remaining(), 3);
418
419        state.feed(&QueryId::batch(0), "r".into()).unwrap();
420        assert_eq!(state.remaining(), 2);
421
422        state.feed(&QueryId::batch(1), "r".into()).unwrap();
423        assert_eq!(state.remaining(), 1);
424    }
425
426    #[test]
427    fn remaining_terminal_is_zero() {
428        let state = ExecutionState::Completed {
429            result: json!(null),
430        };
431        assert_eq!(state.remaining(), 0);
432    }
433
434    // ─── Invalid transition tests ───
435
436    #[test]
437    fn feed_on_running_fails() {
438        let mut state = ExecutionState::Running;
439        let err = state.feed(&QueryId::single(), "r".into()).unwrap_err();
440        assert!(matches!(err, FeedError::InvalidState(_)));
441    }
442
443    #[test]
444    fn pause_on_paused_fails() {
445        let mut state = ExecutionState::Running;
446        state.pause(vec![make_query(0)]).unwrap();
447        let err = state.pause(vec![make_query(1)]).unwrap_err();
448        assert_eq!(err.expected, "Running");
449    }
450
451    #[test]
452    fn complete_on_paused_fails() {
453        let mut state = ExecutionState::Running;
454        state.pause(vec![make_query(0)]).unwrap();
455        let err = state.complete(json!(null)).unwrap_err();
456        assert_eq!(err.expected, "Running");
457    }
458
459    #[test]
460    fn cancel_on_completed_fails() {
461        let mut state = ExecutionState::Running;
462        state.complete(json!(null)).unwrap();
463        let err = state.cancel().unwrap_err();
464        assert_eq!(err.expected, "Running or Paused");
465    }
466
467    #[test]
468    fn cancel_on_failed_fails() {
469        let mut state = ExecutionState::Running;
470        state.fail("e".into()).unwrap();
471        let err = state.cancel().unwrap_err();
472        assert_eq!(err.expected, "Running or Paused");
473    }
474
475    #[test]
476    fn terminal_state_rejects_non_terminal() {
477        let state = ExecutionState::Running;
478        let err = TerminalState::try_from(state).unwrap_err();
479        assert_eq!(err.actual, "Running");
480    }
481
482    #[test]
483    fn terminal_state_from_completed() {
484        let state = ExecutionState::Completed { result: json!(42) };
485        let terminal = TerminalState::try_from(state).unwrap();
486        assert!(matches!(terminal, TerminalState::Completed { .. }));
487    }
488
489    #[test]
490    fn terminal_state_from_cancelled() {
491        let state = ExecutionState::Cancelled;
492        let terminal = TerminalState::try_from(state).unwrap();
493        assert!(matches!(terminal, TerminalState::Cancelled));
494    }
495}
496
497#[cfg(test)]
498mod proptests {
499    use super::*;
500    use crate::query::{LlmQuery, QueryId};
501    use proptest::prelude::*;
502
503    fn make_query(index: usize) -> LlmQuery {
504        LlmQuery {
505            id: QueryId::batch(index),
506            prompt: format!("prompt-{index}"),
507            system: None,
508            max_tokens: 100,
509            grounded: false,
510            underspecified: false,
511        }
512    }
513
514    proptest! {
515        /// into_ordered_responses returns insertion order regardless of feed order.
516        #[test]
517        fn feed_order_independent(size in 1usize..8) {
518            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
519            let mut pq = PendingQueries::new(queries);
520
521            // feed in reverse order
522            for i in (0..size).rev() {
523                let _ = pq.feed(&QueryId::batch(i), format!("resp-{i}"));
524            }
525
526            let responses = pq.into_ordered_responses();
527            // must return in insertion order (0, 1, 2, ...)
528            for (i, resp) in responses.iter().enumerate() {
529                prop_assert_eq!(resp, &format!("resp-{i}"));
530            }
531        }
532
533        /// Feeding the same query twice returns AlreadyResponded error.
534        #[test]
535        fn double_feed_always_errors(size in 1usize..8, target in 0usize..8) {
536            let target = target % size; // clamp to valid range
537            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
538            let mut pq = PendingQueries::new(queries);
539
540            pq.feed(&QueryId::batch(target), "first".into()).unwrap();
541            let err = pq.feed(&QueryId::batch(target), "second".into()).unwrap_err();
542            prop_assert!(matches!(err, FeedError::AlreadyResponded(_)));
543        }
544
545        /// Feeding a non-existent query_id returns UnknownQuery error.
546        #[test]
547        fn unknown_query_always_errors(size in 1usize..8, bad_id in 100usize..200) {
548            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
549            let mut pq = PendingQueries::new(queries);
550
551            let err = pq.feed(&QueryId::batch(bad_id), "resp".into()).unwrap_err();
552            prop_assert!(matches!(err, FeedError::UnknownQuery(_)));
553        }
554
555        /// remaining() decreases by 1 with each feed.
556        #[test]
557        fn remaining_decreases_monotonically(size in 1usize..10) {
558            let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
559            let mut pq = PendingQueries::new(queries);
560
561            for i in 0..size {
562                prop_assert_eq!(pq.remaining(), size - i);
563                let _ = pq.feed(&QueryId::batch(i), format!("r-{i}"));
564            }
565            prop_assert_eq!(pq.remaining(), 0);
566            prop_assert!(pq.is_complete());
567        }
568    }
569}