1use std::collections::HashMap;
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5
6use crate::query::{LlmQuery, QueryId};
7
8#[derive(Debug, thiserror::Error)]
9#[error("invalid state transition: expected {expected}, got {actual}")]
10pub struct TransitionError {
11 pub expected: &'static str,
12 pub actual: &'static str,
13}
14
15#[derive(Debug, thiserror::Error)]
16pub enum FeedError {
17 #[error("unknown query_id: {0}")]
18 UnknownQuery(QueryId),
19 #[error("already responded to query_id: {0}")]
20 AlreadyResponded(QueryId),
21 #[error(transparent)]
22 InvalidState(#[from] TransitionError),
23}
24
25#[derive(Debug, Serialize, Deserialize)]
30pub struct PendingQueries {
31 queries: IndexMap<QueryId, LlmQuery>,
33 responses: HashMap<QueryId, String>,
34}
35
36impl PendingQueries {
37 pub fn new(queries: Vec<LlmQuery>) -> Self {
38 let map = queries
39 .into_iter()
40 .map(|q| (q.id.clone(), q))
41 .collect::<IndexMap<_, _>>();
42 Self {
43 queries: map,
44 responses: HashMap::new(),
45 }
46 }
47
48 pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
50 if !self.queries.contains_key(id) {
51 return Err(FeedError::UnknownQuery(id.clone()));
52 }
53 if self.responses.contains_key(id) {
54 return Err(FeedError::AlreadyResponded(id.clone()));
55 }
56 self.responses.insert(id.clone(), response);
57 Ok(self.is_complete())
58 }
59
60 pub fn pending_queries(&self) -> Vec<&LlmQuery> {
61 self.queries
62 .values()
63 .filter(|q| !self.responses.contains_key(&q.id))
64 .collect()
65 }
66
67 pub fn remaining(&self) -> usize {
68 self.queries.len() - self.responses.len()
69 }
70
71 pub fn is_complete(&self) -> bool {
72 self.responses.len() == self.queries.len()
73 }
74
75 pub fn into_ordered_responses(self) -> Vec<String> {
78 self.queries
79 .keys()
80 .map(|id| {
81 self.responses.get(id).cloned().unwrap_or_default()
84 })
85 .collect()
86 }
87}
88
89pub enum ExecutionState {
90 Running,
91 Paused(PendingQueries),
93 Completed {
94 result: serde_json::Value,
95 },
96 Failed {
97 error: String,
98 },
99 Cancelled,
101}
102
103impl ExecutionState {
104 pub fn is_terminal(&self) -> bool {
105 matches!(
106 self,
107 Self::Completed { .. } | Self::Failed { .. } | Self::Cancelled
108 )
109 }
110
111 pub fn remaining(&self) -> usize {
113 match self {
114 Self::Paused(pending) => pending.remaining(),
115 _ => 0,
116 }
117 }
118
119 pub fn name(&self) -> &'static str {
121 match self {
122 Self::Running => "Running",
123 Self::Paused(_) => "Paused",
124 Self::Completed { .. } => "Completed",
125 Self::Failed { .. } => "Failed",
126 Self::Cancelled => "Cancelled",
127 }
128 }
129
130 pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
133 match self {
134 Self::Paused(pending) => pending.feed(id, response),
135 other => Err(TransitionError {
136 expected: "Paused",
137 actual: other.name(),
138 }
139 .into()),
140 }
141 }
142
143 pub fn take_responses(&mut self) -> Result<Vec<String>, TransitionError> {
146 match std::mem::replace(self, Self::Running) {
147 Self::Paused(pending) if pending.is_complete() => Ok(pending.into_ordered_responses()),
148 prev => {
149 let actual = prev.name();
150 *self = prev;
151 Err(TransitionError {
152 expected: "Paused(complete)",
153 actual,
154 })
155 }
156 }
157 }
158
159 pub fn complete(&mut self, result: serde_json::Value) -> Result<(), TransitionError> {
161 match self {
162 Self::Running => {
163 *self = Self::Completed { result };
164 Ok(())
165 }
166 other => Err(TransitionError {
167 expected: "Running",
168 actual: other.name(),
169 }),
170 }
171 }
172
173 pub fn fail(&mut self, error: String) -> Result<(), TransitionError> {
175 match self {
176 Self::Running => {
177 *self = Self::Failed { error };
178 Ok(())
179 }
180 other => Err(TransitionError {
181 expected: "Running",
182 actual: other.name(),
183 }),
184 }
185 }
186
187 pub fn pause(&mut self, queries: Vec<LlmQuery>) -> Result<(), TransitionError> {
189 match self {
190 Self::Running => {
191 *self = Self::Paused(PendingQueries::new(queries));
192 Ok(())
193 }
194 other => Err(TransitionError {
195 expected: "Running",
196 actual: other.name(),
197 }),
198 }
199 }
200
201 pub fn cancel(&mut self) -> Result<(), TransitionError> {
203 match self {
204 Self::Running | Self::Paused(_) => {
205 *self = Self::Cancelled;
206 Ok(())
207 }
208 other => Err(TransitionError {
209 expected: "Running or Paused",
210 actual: other.name(),
211 }),
212 }
213 }
214}
215
216pub enum ResumeOutcome {
218 Paused {
220 queries: Vec<LlmQuery>,
221 },
222 Completed {
223 result: serde_json::Value,
224 },
225 Failed {
226 error: String,
227 },
228 Cancelled,
230}
231
232#[derive(Debug)]
234pub enum TerminalState {
235 Completed { result: serde_json::Value },
236 Failed { error: String },
237 Cancelled,
238}
239
240impl TryFrom<ExecutionState> for TerminalState {
241 type Error = TransitionError;
242
243 fn try_from(state: ExecutionState) -> Result<Self, TransitionError> {
244 match state {
245 ExecutionState::Completed { result } => Ok(Self::Completed { result }),
246 ExecutionState::Failed { error } => Ok(Self::Failed { error }),
247 ExecutionState::Cancelled => Ok(Self::Cancelled),
248 other => Err(TransitionError {
249 expected: "Completed, Failed, or Cancelled",
250 actual: other.name(),
251 }),
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::query::{LlmQuery, QueryId};
260 use serde_json::json;
261
262 fn make_query(index: usize) -> LlmQuery {
263 LlmQuery {
264 id: QueryId::batch(index),
265 prompt: format!("prompt-{index}"),
266 system: None,
267 max_tokens: 100,
268 }
269 }
270
271 #[test]
274 fn pending_queries_single_feed() {
275 let mut pq = PendingQueries::new(vec![make_query(0)]);
276 assert_eq!(pq.remaining(), 1);
277 assert!(!pq.is_complete());
278
279 let complete = pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
280 assert!(complete);
281 assert_eq!(pq.remaining(), 0);
282 }
283
284 #[test]
285 fn pending_queries_multi_feed_ordering() {
286 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1), make_query(2)]);
287
288 assert!(!pq.feed(&QueryId::batch(2), "resp-2".into()).unwrap());
290 assert!(!pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap());
291 assert!(pq.feed(&QueryId::batch(1), "resp-1".into()).unwrap());
292
293 let responses = pq.into_ordered_responses();
295 assert_eq!(responses, vec!["resp-0", "resp-1", "resp-2"]);
296 }
297
298 #[test]
299 fn pending_queries_unknown_query_error() {
300 let mut pq = PendingQueries::new(vec![make_query(0)]);
301 let err = pq.feed(&QueryId::batch(99), "resp".into()).unwrap_err();
302 assert!(matches!(err, FeedError::UnknownQuery(_)));
303 }
304
305 #[test]
306 fn pending_queries_double_feed_error() {
307 let mut pq = PendingQueries::new(vec![make_query(0)]);
308 pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
309 let err = pq.feed(&QueryId::batch(0), "resp2".into()).unwrap_err();
310 assert!(matches!(err, FeedError::AlreadyResponded(_)));
311 }
312
313 #[test]
314 fn pending_queries_pending_list() {
315 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
316 assert_eq!(pq.pending_queries().len(), 2);
317
318 pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
319 let pending = pq.pending_queries();
320 assert_eq!(pending.len(), 1);
321 assert_eq!(pending[0].id, QueryId::batch(1));
322 }
323
324 #[test]
325 fn pending_queries_roundtrip_json() {
326 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
327 pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap();
328
329 let json = serde_json::to_value(&pq).unwrap();
330 let restored: PendingQueries = serde_json::from_value(json).unwrap();
331 assert_eq!(restored.remaining(), 1);
332 assert_eq!(restored.queries.len(), 2);
333 }
334
335 #[test]
338 fn running_to_paused() {
339 let mut state = ExecutionState::Running;
340 state.pause(vec![make_query(0)]).unwrap();
341 assert_eq!(state.name(), "Paused");
342 }
343
344 #[test]
345 fn paused_feed_and_take() {
346 let mut state = ExecutionState::Running;
347 state.pause(vec![make_query(0), make_query(1)]).unwrap();
348
349 assert!(!state.feed(&QueryId::batch(0), "r0".into()).unwrap());
350 assert!(state.feed(&QueryId::batch(1), "r1".into()).unwrap());
351
352 let responses = state.take_responses().unwrap();
353 assert_eq!(responses, vec!["r0", "r1"]);
354 assert_eq!(state.name(), "Running");
355 }
356
357 #[test]
358 fn take_responses_incomplete_fails() {
359 let mut state = ExecutionState::Running;
360 state.pause(vec![make_query(0), make_query(1)]).unwrap();
361 state.feed(&QueryId::batch(0), "r0".into()).unwrap();
362
363 let err = state.take_responses().unwrap_err();
364 assert_eq!(err.actual, "Paused");
365 assert_eq!(state.name(), "Paused");
367 }
368
369 #[test]
370 fn running_to_completed() {
371 let mut state = ExecutionState::Running;
372 state.complete(json!({"answer": 42})).unwrap();
373 assert!(state.is_terminal());
374 assert_eq!(state.name(), "Completed");
375 }
376
377 #[test]
378 fn running_to_failed() {
379 let mut state = ExecutionState::Running;
380 state.fail("boom".into()).unwrap();
381 assert!(state.is_terminal());
382 assert_eq!(state.name(), "Failed");
383 }
384
385 #[test]
386 fn cancel_from_running() {
387 let mut state = ExecutionState::Running;
388 state.cancel().unwrap();
389 assert!(state.is_terminal());
390 assert_eq!(state.name(), "Cancelled");
391 }
392
393 #[test]
394 fn cancel_from_paused() {
395 let mut state = ExecutionState::Running;
396 state.pause(vec![make_query(0)]).unwrap();
397 state.cancel().unwrap();
398 assert_eq!(state.name(), "Cancelled");
399 }
400
401 #[test]
404 fn remaining_running_is_zero() {
405 let state = ExecutionState::Running;
406 assert_eq!(state.remaining(), 0);
407 }
408
409 #[test]
410 fn remaining_tracks_feeds() {
411 let mut state = ExecutionState::Running;
412 state
413 .pause(vec![make_query(0), make_query(1), make_query(2)])
414 .unwrap();
415 assert_eq!(state.remaining(), 3);
416
417 state.feed(&QueryId::batch(0), "r".into()).unwrap();
418 assert_eq!(state.remaining(), 2);
419
420 state.feed(&QueryId::batch(1), "r".into()).unwrap();
421 assert_eq!(state.remaining(), 1);
422 }
423
424 #[test]
425 fn remaining_terminal_is_zero() {
426 let state = ExecutionState::Completed {
427 result: json!(null),
428 };
429 assert_eq!(state.remaining(), 0);
430 }
431
432 #[test]
435 fn feed_on_running_fails() {
436 let mut state = ExecutionState::Running;
437 let err = state.feed(&QueryId::single(), "r".into()).unwrap_err();
438 assert!(matches!(err, FeedError::InvalidState(_)));
439 }
440
441 #[test]
442 fn pause_on_paused_fails() {
443 let mut state = ExecutionState::Running;
444 state.pause(vec![make_query(0)]).unwrap();
445 let err = state.pause(vec![make_query(1)]).unwrap_err();
446 assert_eq!(err.expected, "Running");
447 }
448
449 #[test]
450 fn complete_on_paused_fails() {
451 let mut state = ExecutionState::Running;
452 state.pause(vec![make_query(0)]).unwrap();
453 let err = state.complete(json!(null)).unwrap_err();
454 assert_eq!(err.expected, "Running");
455 }
456
457 #[test]
458 fn cancel_on_completed_fails() {
459 let mut state = ExecutionState::Running;
460 state.complete(json!(null)).unwrap();
461 let err = state.cancel().unwrap_err();
462 assert_eq!(err.expected, "Running or Paused");
463 }
464
465 #[test]
466 fn cancel_on_failed_fails() {
467 let mut state = ExecutionState::Running;
468 state.fail("e".into()).unwrap();
469 let err = state.cancel().unwrap_err();
470 assert_eq!(err.expected, "Running or Paused");
471 }
472
473 #[test]
474 fn terminal_state_rejects_non_terminal() {
475 let state = ExecutionState::Running;
476 let err = TerminalState::try_from(state).unwrap_err();
477 assert_eq!(err.actual, "Running");
478 }
479
480 #[test]
481 fn terminal_state_from_completed() {
482 let state = ExecutionState::Completed { result: json!(42) };
483 let terminal = TerminalState::try_from(state).unwrap();
484 assert!(matches!(terminal, TerminalState::Completed { .. }));
485 }
486
487 #[test]
488 fn terminal_state_from_cancelled() {
489 let state = ExecutionState::Cancelled;
490 let terminal = TerminalState::try_from(state).unwrap();
491 assert!(matches!(terminal, TerminalState::Cancelled));
492 }
493}
494
495#[cfg(test)]
496mod proptests {
497 use super::*;
498 use crate::query::{LlmQuery, QueryId};
499 use proptest::prelude::*;
500
501 fn make_query(index: usize) -> LlmQuery {
502 LlmQuery {
503 id: QueryId::batch(index),
504 prompt: format!("prompt-{index}"),
505 system: None,
506 max_tokens: 100,
507 }
508 }
509
510 proptest! {
511 #[test]
513 fn feed_order_independent(size in 1usize..8) {
514 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
515 let mut pq = PendingQueries::new(queries);
516
517 for i in (0..size).rev() {
519 let _ = pq.feed(&QueryId::batch(i), format!("resp-{i}"));
520 }
521
522 let responses = pq.into_ordered_responses();
523 for (i, resp) in responses.iter().enumerate() {
525 prop_assert_eq!(resp, &format!("resp-{i}"));
526 }
527 }
528
529 #[test]
531 fn double_feed_always_errors(size in 1usize..8, target in 0usize..8) {
532 let target = target % size; let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
534 let mut pq = PendingQueries::new(queries);
535
536 pq.feed(&QueryId::batch(target), "first".into()).unwrap();
537 let err = pq.feed(&QueryId::batch(target), "second".into()).unwrap_err();
538 prop_assert!(matches!(err, FeedError::AlreadyResponded(_)));
539 }
540
541 #[test]
543 fn unknown_query_always_errors(size in 1usize..8, bad_id in 100usize..200) {
544 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
545 let mut pq = PendingQueries::new(queries);
546
547 let err = pq.feed(&QueryId::batch(bad_id), "resp".into()).unwrap_err();
548 prop_assert!(matches!(err, FeedError::UnknownQuery(_)));
549 }
550
551 #[test]
553 fn remaining_decreases_monotonically(size in 1usize..10) {
554 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
555 let mut pq = PendingQueries::new(queries);
556
557 for i in 0..size {
558 prop_assert_eq!(pq.remaining(), size - i);
559 let _ = pq.feed(&QueryId::batch(i), format!("r-{i}"));
560 }
561 prop_assert_eq!(pq.remaining(), 0);
562 prop_assert!(pq.is_complete());
563 }
564 }
565}