1use std::collections::HashMap;
2
3use indexmap::IndexMap;
4use serde::{Deserialize, Serialize};
5
6use crate::query::{LlmQuery, QueryId};
7
8#[derive(Debug, thiserror::Error)]
9#[error("invalid state transition: expected {expected}, got {actual}")]
10pub struct TransitionError {
11 pub expected: &'static str,
12 pub actual: &'static str,
13}
14
15#[derive(Debug, thiserror::Error)]
16pub enum FeedError {
17 #[error("unknown query_id: {0}")]
18 UnknownQuery(QueryId),
19 #[error("already responded to query_id: {0}")]
20 AlreadyResponded(QueryId),
21 #[error(transparent)]
22 InvalidState(#[from] TransitionError),
23}
24
25#[derive(Debug, Serialize, Deserialize)]
30pub struct PendingQueries {
31 queries: IndexMap<QueryId, LlmQuery>,
33 responses: HashMap<QueryId, String>,
34}
35
36impl PendingQueries {
37 pub fn new(queries: Vec<LlmQuery>) -> Self {
38 let map = queries
39 .into_iter()
40 .map(|q| (q.id.clone(), q))
41 .collect::<IndexMap<_, _>>();
42 Self {
43 queries: map,
44 responses: HashMap::new(),
45 }
46 }
47
48 pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
50 if !self.queries.contains_key(id) {
51 return Err(FeedError::UnknownQuery(id.clone()));
52 }
53 if self.responses.contains_key(id) {
54 return Err(FeedError::AlreadyResponded(id.clone()));
55 }
56 self.responses.insert(id.clone(), response);
57 Ok(self.is_complete())
58 }
59
60 pub fn pending_queries(&self) -> Vec<&LlmQuery> {
61 self.queries
62 .values()
63 .filter(|q| !self.responses.contains_key(&q.id))
64 .collect()
65 }
66
67 pub fn remaining(&self) -> usize {
68 self.queries.len() - self.responses.len()
69 }
70
71 pub fn is_complete(&self) -> bool {
72 self.responses.len() == self.queries.len()
73 }
74
75 pub fn into_ordered_responses(self) -> Vec<String> {
78 self.queries
79 .keys()
80 .map(|id| {
81 self.responses.get(id).cloned().unwrap_or_default()
84 })
85 .collect()
86 }
87}
88
89pub enum ExecutionState {
90 Running,
91 Paused(PendingQueries),
93 Completed {
94 result: serde_json::Value,
95 },
96 Failed {
97 error: String,
98 },
99 Cancelled,
101}
102
103impl ExecutionState {
104 pub fn is_terminal(&self) -> bool {
105 matches!(
106 self,
107 Self::Completed { .. } | Self::Failed { .. } | Self::Cancelled
108 )
109 }
110
111 pub fn remaining(&self) -> usize {
113 match self {
114 Self::Paused(pending) => pending.remaining(),
115 _ => 0,
116 }
117 }
118
119 pub fn name(&self) -> &'static str {
121 match self {
122 Self::Running => "Running",
123 Self::Paused(_) => "Paused",
124 Self::Completed { .. } => "Completed",
125 Self::Failed { .. } => "Failed",
126 Self::Cancelled => "Cancelled",
127 }
128 }
129
130 pub fn feed(&mut self, id: &QueryId, response: String) -> Result<bool, FeedError> {
133 match self {
134 Self::Paused(pending) => pending.feed(id, response),
135 other => Err(TransitionError {
136 expected: "Paused",
137 actual: other.name(),
138 }
139 .into()),
140 }
141 }
142
143 pub fn take_responses(&mut self) -> Result<Vec<String>, TransitionError> {
146 match std::mem::replace(self, Self::Running) {
147 Self::Paused(pending) if pending.is_complete() => Ok(pending.into_ordered_responses()),
148 prev => {
149 let actual = prev.name();
150 *self = prev;
151 Err(TransitionError {
152 expected: "Paused(complete)",
153 actual,
154 })
155 }
156 }
157 }
158
159 pub fn complete(&mut self, result: serde_json::Value) -> Result<(), TransitionError> {
161 match self {
162 Self::Running => {
163 *self = Self::Completed { result };
164 Ok(())
165 }
166 other => Err(TransitionError {
167 expected: "Running",
168 actual: other.name(),
169 }),
170 }
171 }
172
173 pub fn fail(&mut self, error: String) -> Result<(), TransitionError> {
175 match self {
176 Self::Running => {
177 *self = Self::Failed { error };
178 Ok(())
179 }
180 other => Err(TransitionError {
181 expected: "Running",
182 actual: other.name(),
183 }),
184 }
185 }
186
187 pub fn pause(&mut self, queries: Vec<LlmQuery>) -> Result<(), TransitionError> {
189 match self {
190 Self::Running => {
191 *self = Self::Paused(PendingQueries::new(queries));
192 Ok(())
193 }
194 other => Err(TransitionError {
195 expected: "Running",
196 actual: other.name(),
197 }),
198 }
199 }
200
201 pub fn cancel(&mut self) -> Result<(), TransitionError> {
203 match self {
204 Self::Running | Self::Paused(_) => {
205 *self = Self::Cancelled;
206 Ok(())
207 }
208 other => Err(TransitionError {
209 expected: "Running or Paused",
210 actual: other.name(),
211 }),
212 }
213 }
214}
215
216pub enum ResumeOutcome {
218 Paused {
220 queries: Vec<LlmQuery>,
221 },
222 Completed {
223 result: serde_json::Value,
224 },
225 Failed {
226 error: String,
227 },
228 Cancelled,
230}
231
232#[derive(Debug)]
234pub enum TerminalState {
235 Completed { result: serde_json::Value },
236 Failed { error: String },
237 Cancelled,
238}
239
240impl TryFrom<ExecutionState> for TerminalState {
241 type Error = TransitionError;
242
243 fn try_from(state: ExecutionState) -> Result<Self, TransitionError> {
244 match state {
245 ExecutionState::Completed { result } => Ok(Self::Completed { result }),
246 ExecutionState::Failed { error } => Ok(Self::Failed { error }),
247 ExecutionState::Cancelled => Ok(Self::Cancelled),
248 other => Err(TransitionError {
249 expected: "Completed, Failed, or Cancelled",
250 actual: other.name(),
251 }),
252 }
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use crate::query::{LlmQuery, QueryId};
260 use serde_json::json;
261
262 fn make_query(index: usize) -> LlmQuery {
263 LlmQuery {
264 id: QueryId::batch(index),
265 prompt: format!("prompt-{index}"),
266 system: None,
267 max_tokens: 100,
268 grounded: false,
269 underspecified: false,
270 }
271 }
272
273 #[test]
276 fn pending_queries_single_feed() {
277 let mut pq = PendingQueries::new(vec![make_query(0)]);
278 assert_eq!(pq.remaining(), 1);
279 assert!(!pq.is_complete());
280
281 let complete = pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
282 assert!(complete);
283 assert_eq!(pq.remaining(), 0);
284 }
285
286 #[test]
287 fn pending_queries_multi_feed_ordering() {
288 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1), make_query(2)]);
289
290 assert!(!pq.feed(&QueryId::batch(2), "resp-2".into()).unwrap());
292 assert!(!pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap());
293 assert!(pq.feed(&QueryId::batch(1), "resp-1".into()).unwrap());
294
295 let responses = pq.into_ordered_responses();
297 assert_eq!(responses, vec!["resp-0", "resp-1", "resp-2"]);
298 }
299
300 #[test]
301 fn pending_queries_unknown_query_error() {
302 let mut pq = PendingQueries::new(vec![make_query(0)]);
303 let err = pq.feed(&QueryId::batch(99), "resp".into()).unwrap_err();
304 assert!(matches!(err, FeedError::UnknownQuery(_)));
305 }
306
307 #[test]
308 fn pending_queries_double_feed_error() {
309 let mut pq = PendingQueries::new(vec![make_query(0)]);
310 pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
311 let err = pq.feed(&QueryId::batch(0), "resp2".into()).unwrap_err();
312 assert!(matches!(err, FeedError::AlreadyResponded(_)));
313 }
314
315 #[test]
316 fn pending_queries_pending_list() {
317 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
318 assert_eq!(pq.pending_queries().len(), 2);
319
320 pq.feed(&QueryId::batch(0), "resp".into()).unwrap();
321 let pending = pq.pending_queries();
322 assert_eq!(pending.len(), 1);
323 assert_eq!(pending[0].id, QueryId::batch(1));
324 }
325
326 #[test]
327 fn pending_queries_roundtrip_json() {
328 let mut pq = PendingQueries::new(vec![make_query(0), make_query(1)]);
329 pq.feed(&QueryId::batch(0), "resp-0".into()).unwrap();
330
331 let json = serde_json::to_value(&pq).unwrap();
332 let restored: PendingQueries = serde_json::from_value(json).unwrap();
333 assert_eq!(restored.remaining(), 1);
334 assert_eq!(restored.queries.len(), 2);
335 }
336
337 #[test]
340 fn running_to_paused() {
341 let mut state = ExecutionState::Running;
342 state.pause(vec![make_query(0)]).unwrap();
343 assert_eq!(state.name(), "Paused");
344 }
345
346 #[test]
347 fn paused_feed_and_take() {
348 let mut state = ExecutionState::Running;
349 state.pause(vec![make_query(0), make_query(1)]).unwrap();
350
351 assert!(!state.feed(&QueryId::batch(0), "r0".into()).unwrap());
352 assert!(state.feed(&QueryId::batch(1), "r1".into()).unwrap());
353
354 let responses = state.take_responses().unwrap();
355 assert_eq!(responses, vec!["r0", "r1"]);
356 assert_eq!(state.name(), "Running");
357 }
358
359 #[test]
360 fn take_responses_incomplete_fails() {
361 let mut state = ExecutionState::Running;
362 state.pause(vec![make_query(0), make_query(1)]).unwrap();
363 state.feed(&QueryId::batch(0), "r0".into()).unwrap();
364
365 let err = state.take_responses().unwrap_err();
366 assert_eq!(err.actual, "Paused");
367 assert_eq!(state.name(), "Paused");
369 }
370
371 #[test]
372 fn running_to_completed() {
373 let mut state = ExecutionState::Running;
374 state.complete(json!({"answer": 42})).unwrap();
375 assert!(state.is_terminal());
376 assert_eq!(state.name(), "Completed");
377 }
378
379 #[test]
380 fn running_to_failed() {
381 let mut state = ExecutionState::Running;
382 state.fail("boom".into()).unwrap();
383 assert!(state.is_terminal());
384 assert_eq!(state.name(), "Failed");
385 }
386
387 #[test]
388 fn cancel_from_running() {
389 let mut state = ExecutionState::Running;
390 state.cancel().unwrap();
391 assert!(state.is_terminal());
392 assert_eq!(state.name(), "Cancelled");
393 }
394
395 #[test]
396 fn cancel_from_paused() {
397 let mut state = ExecutionState::Running;
398 state.pause(vec![make_query(0)]).unwrap();
399 state.cancel().unwrap();
400 assert_eq!(state.name(), "Cancelled");
401 }
402
403 #[test]
406 fn remaining_running_is_zero() {
407 let state = ExecutionState::Running;
408 assert_eq!(state.remaining(), 0);
409 }
410
411 #[test]
412 fn remaining_tracks_feeds() {
413 let mut state = ExecutionState::Running;
414 state
415 .pause(vec![make_query(0), make_query(1), make_query(2)])
416 .unwrap();
417 assert_eq!(state.remaining(), 3);
418
419 state.feed(&QueryId::batch(0), "r".into()).unwrap();
420 assert_eq!(state.remaining(), 2);
421
422 state.feed(&QueryId::batch(1), "r".into()).unwrap();
423 assert_eq!(state.remaining(), 1);
424 }
425
426 #[test]
427 fn remaining_terminal_is_zero() {
428 let state = ExecutionState::Completed {
429 result: json!(null),
430 };
431 assert_eq!(state.remaining(), 0);
432 }
433
434 #[test]
437 fn feed_on_running_fails() {
438 let mut state = ExecutionState::Running;
439 let err = state.feed(&QueryId::single(), "r".into()).unwrap_err();
440 assert!(matches!(err, FeedError::InvalidState(_)));
441 }
442
443 #[test]
444 fn pause_on_paused_fails() {
445 let mut state = ExecutionState::Running;
446 state.pause(vec![make_query(0)]).unwrap();
447 let err = state.pause(vec![make_query(1)]).unwrap_err();
448 assert_eq!(err.expected, "Running");
449 }
450
451 #[test]
452 fn complete_on_paused_fails() {
453 let mut state = ExecutionState::Running;
454 state.pause(vec![make_query(0)]).unwrap();
455 let err = state.complete(json!(null)).unwrap_err();
456 assert_eq!(err.expected, "Running");
457 }
458
459 #[test]
460 fn cancel_on_completed_fails() {
461 let mut state = ExecutionState::Running;
462 state.complete(json!(null)).unwrap();
463 let err = state.cancel().unwrap_err();
464 assert_eq!(err.expected, "Running or Paused");
465 }
466
467 #[test]
468 fn cancel_on_failed_fails() {
469 let mut state = ExecutionState::Running;
470 state.fail("e".into()).unwrap();
471 let err = state.cancel().unwrap_err();
472 assert_eq!(err.expected, "Running or Paused");
473 }
474
475 #[test]
476 fn terminal_state_rejects_non_terminal() {
477 let state = ExecutionState::Running;
478 let err = TerminalState::try_from(state).unwrap_err();
479 assert_eq!(err.actual, "Running");
480 }
481
482 #[test]
483 fn terminal_state_from_completed() {
484 let state = ExecutionState::Completed { result: json!(42) };
485 let terminal = TerminalState::try_from(state).unwrap();
486 assert!(matches!(terminal, TerminalState::Completed { .. }));
487 }
488
489 #[test]
490 fn terminal_state_from_cancelled() {
491 let state = ExecutionState::Cancelled;
492 let terminal = TerminalState::try_from(state).unwrap();
493 assert!(matches!(terminal, TerminalState::Cancelled));
494 }
495}
496
497#[cfg(test)]
498mod proptests {
499 use super::*;
500 use crate::query::{LlmQuery, QueryId};
501 use proptest::prelude::*;
502
503 fn make_query(index: usize) -> LlmQuery {
504 LlmQuery {
505 id: QueryId::batch(index),
506 prompt: format!("prompt-{index}"),
507 system: None,
508 max_tokens: 100,
509 grounded: false,
510 underspecified: false,
511 }
512 }
513
514 proptest! {
515 #[test]
517 fn feed_order_independent(size in 1usize..8) {
518 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
519 let mut pq = PendingQueries::new(queries);
520
521 for i in (0..size).rev() {
523 let _ = pq.feed(&QueryId::batch(i), format!("resp-{i}"));
524 }
525
526 let responses = pq.into_ordered_responses();
527 for (i, resp) in responses.iter().enumerate() {
529 prop_assert_eq!(resp, &format!("resp-{i}"));
530 }
531 }
532
533 #[test]
535 fn double_feed_always_errors(size in 1usize..8, target in 0usize..8) {
536 let target = target % size; let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
538 let mut pq = PendingQueries::new(queries);
539
540 pq.feed(&QueryId::batch(target), "first".into()).unwrap();
541 let err = pq.feed(&QueryId::batch(target), "second".into()).unwrap_err();
542 prop_assert!(matches!(err, FeedError::AlreadyResponded(_)));
543 }
544
545 #[test]
547 fn unknown_query_always_errors(size in 1usize..8, bad_id in 100usize..200) {
548 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
549 let mut pq = PendingQueries::new(queries);
550
551 let err = pq.feed(&QueryId::batch(bad_id), "resp".into()).unwrap_err();
552 prop_assert!(matches!(err, FeedError::UnknownQuery(_)));
553 }
554
555 #[test]
557 fn remaining_decreases_monotonically(size in 1usize..10) {
558 let queries: Vec<LlmQuery> = (0..size).map(make_query).collect();
559 let mut pq = PendingQueries::new(queries);
560
561 for i in 0..size {
562 prop_assert_eq!(pq.remaining(), size - i);
563 let _ = pq.feed(&QueryId::batch(i), format!("r-{i}"));
564 }
565 prop_assert_eq!(pq.remaining(), 0);
566 prop_assert!(pq.is_complete());
567 }
568 }
569}