Skip to main content

awaken_contract/contract/
executor.rs

1//! LLM executor trait and tool execution strategy.
2
3use std::time::Duration;
4
5use super::content::ContentBlock;
6use super::inference::{InferenceOverride, StreamResult};
7use super::message::{Message, ToolCall};
8use super::tool::ToolDescriptor;
9use async_trait::async_trait;
10use thiserror::Error;
11
12/// A provider-neutral LLM inference request.
13#[derive(Debug, Clone)]
14pub struct InferenceRequest {
15    /// Effective upstream model name sent to the resolved provider executor.
16    pub upstream_model: String,
17    /// Messages to send.
18    pub messages: Vec<Message>,
19    /// Available tools.
20    pub tools: Vec<ToolDescriptor>,
21    /// System prompt content blocks. Empty means no system prompt.
22    pub system: Vec<ContentBlock>,
23    /// Per-inference overrides that remain after runtime routing is applied
24    /// (temperature, max_tokens, fallback upstream models, etc).
25    pub overrides: Option<InferenceOverride>,
26    /// Whether to apply prompt cache hints (e.g. `CacheControl::Ephemeral`) to system messages.
27    pub enable_prompt_cache: bool,
28}
29
30/// Cause of a mid-stream interruption.
31#[derive(Debug, Clone)]
32pub enum InterruptCause {
33    /// Underlying socket reset (TCP RST, ECONNRESET) while receiving events.
34    ConnectionReset,
35    /// No delta received within the configured idle window.
36    IdleStall,
37    /// HTTP/2 GOAWAY or equivalent server-initiated disconnect.
38    GoAway,
39    /// Provider returned a 5xx status after headers had been sent.
40    Provider5xxMidStream(u16),
41    /// Synthetic cause used when a stream is being resumed from a
42    /// persisted checkpoint (no real interruption happened in this
43    /// process — the previous process crashed or restarted).
44    ResumedFromCheckpoint,
45}
46
47impl std::fmt::Display for InterruptCause {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        match self {
50            Self::ConnectionReset => f.write_str("connection reset"),
51            Self::IdleStall => f.write_str("idle stall"),
52            Self::GoAway => f.write_str("goaway"),
53            Self::Provider5xxMidStream(s) => write!(f, "provider {s} mid-stream"),
54            Self::ResumedFromCheckpoint => f.write_str("resumed from checkpoint"),
55        }
56    }
57}
58
59/// A tool_use block observed mid-stream whose argument JSON did not close.
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct InFlightTool {
62    pub id: String,
63    pub name: String,
64    /// Raw accumulated argument JSON fragment (unparseable as-is).
65    pub partial_args: String,
66}
67
68/// Snapshot of everything a `StreamCollector` had accumulated at the moment
69/// the stream was interrupted. Used by the loop runner to pick a
70/// [`RecoveryPlan`](crate::contract::executor::RecoveryPlan).
71#[derive(Debug, Clone)]
72pub struct InterruptSnapshot {
73    /// Assistant text accumulated before the interruption. `None` if no text
74    /// was received.
75    pub text: Option<String>,
76    /// Tool calls whose argument JSON parsed successfully before interruption.
77    pub completed_tool_calls: Vec<ToolCall>,
78    /// The tool_use block (if any) that was open but not yet closed.
79    pub in_flight_tool: Option<InFlightTool>,
80    /// Total bytes of events processed (telemetry).
81    pub bytes_received: usize,
82}
83
84/// Chosen recovery path for a mid-stream interruption. Computed from
85/// [`InterruptSnapshot::plan`].
86#[derive(Debug, Clone)]
87pub enum RecoveryPlan {
88    /// Only text accumulated. Retry the whole request with the accumulated
89    /// text injected as an assistant prefix followed by a continuation
90    /// prompt.
91    ContinueText { assistant_prefix: String },
92    /// At least one tool_use arrived intact. Synthesize a
93    /// `StopReason::ToolUse` terminal state so the loop runner executes the
94    /// completed tools. Any in-flight tool is surfaced as a hint for the
95    /// next user message.
96    SynthesizeToolUse {
97        completed: Vec<ToolCall>,
98        cancelled_tool_hint: Option<InFlightTool>,
99    },
100    /// There was text plus a single unclosed tool_use. Truncate before the
101    /// tool, emit a cancel event for consumers, then continue with the text
102    /// prefix.
103    TruncateBeforeTool {
104        assistant_prefix: String,
105        cancelled_tool_id: String,
106        cancelled_tool_name: String,
107    },
108    /// Nothing salvageable: retry the entire request fresh.
109    WholeRestart,
110}
111
112impl InterruptSnapshot {
113    /// Build an `InterruptSnapshot` from a stream of `(id, name, args_json)`
114    /// triples in declaration order, plus the accumulated text.
115    ///
116    /// Tools whose `name` is empty or whose `args_json` does not parse as
117    /// JSON become the `in_flight_tool` (last-write-wins if multiple);
118    /// the rest land in `completed_tool_calls`. This is the single source
119    /// of truth for partials → snapshot translation; the multiple
120    /// stream-collector implementations across the runtime delegate to
121    /// it instead of reimplementing.
122    pub fn from_partials<I>(text: Option<String>, partials: I, bytes_received: usize) -> Self
123    where
124        I: IntoIterator<Item = (String, String, String)>,
125    {
126        let mut completed: Vec<ToolCall> = Vec::new();
127        let mut in_flight: Option<InFlightTool> = None;
128
129        for (id, name, args_json) in partials {
130            if name.is_empty() {
131                in_flight = Some(InFlightTool {
132                    id,
133                    name: String::new(),
134                    partial_args: args_json,
135                });
136                continue;
137            }
138            match serde_json::from_str::<serde_json::Value>(&args_json) {
139                Ok(arguments) if !(arguments.is_null() && !args_json.is_empty()) => {
140                    completed.push(ToolCall::new(id, name, arguments));
141                }
142                _ => {
143                    in_flight = Some(InFlightTool {
144                        id,
145                        name,
146                        partial_args: args_json,
147                    });
148                }
149            }
150        }
151
152        Self {
153            text,
154            completed_tool_calls: completed,
155            in_flight_tool: in_flight,
156            bytes_received,
157        }
158    }
159
160    /// Decide which recovery plan applies to this snapshot.
161    pub fn plan(&self) -> RecoveryPlan {
162        let text = self.text.as_deref().unwrap_or("");
163        let has_text = !text.is_empty();
164        let has_completed = !self.completed_tool_calls.is_empty();
165
166        // R2: any completed tool → synthesize ToolUse regardless of text/in-flight.
167        if has_completed {
168            return RecoveryPlan::SynthesizeToolUse {
169                completed: self.completed_tool_calls.clone(),
170                cancelled_tool_hint: self.in_flight_tool.clone(),
171            };
172        }
173
174        // R3: text with an in-flight tool → truncate to the text prefix.
175        if has_text {
176            if let Some(p) = &self.in_flight_tool {
177                return RecoveryPlan::TruncateBeforeTool {
178                    assistant_prefix: text.to_string(),
179                    cancelled_tool_id: p.id.clone(),
180                    cancelled_tool_name: p.name.clone(),
181                };
182            }
183            // R1: text only.
184            return RecoveryPlan::ContinueText {
185                assistant_prefix: text.to_string(),
186            };
187        }
188
189        // R4: nothing usable (no text and no completed tools).
190        RecoveryPlan::WholeRestart
191    }
192}
193
194/// Errors from LLM inference.
195///
196/// Variants split into three recoverability classes:
197/// - **Transient** (retryable, count toward circuit breaker): `RateLimited`,
198///   `Overloaded`, `Timeout`, `Provider`, `StreamInterrupted`.
199/// - **Permanent** (not retryable, do NOT count toward circuit breaker):
200///   `ContextOverflow`, `InvalidRequest`, `Unauthorized`, `ModelNotFound`,
201///   `ContentFiltered`.
202/// - **Fail-fast**: `AllModelsUnavailable`, `Cancelled`.
203///
204/// Use [`InferenceExecutionError::is_retryable`] and
205/// [`InferenceExecutionError::counts_toward_circuit_breaker`] for policy
206/// decisions instead of pattern-matching variants directly where possible.
207#[derive(Debug, Clone, Error)]
208#[non_exhaustive]
209pub enum InferenceExecutionError {
210    #[error("provider error: {0}")]
211    Provider(String),
212    #[error("rate limited: {message}")]
213    RateLimited {
214        message: String,
215        /// Duration from the provider's `Retry-After` header, if any.
216        retry_after: Option<Duration>,
217    },
218    #[error("provider overloaded: {message}")]
219    Overloaded {
220        message: String,
221        retry_after: Option<Duration>,
222    },
223    #[error("timeout: {0}")]
224    Timeout(String),
225    #[error("stream interrupted ({cause})")]
226    StreamInterrupted {
227        cause: InterruptCause,
228        snapshot: Box<InterruptSnapshot>,
229    },
230    #[error("context overflow: {0}")]
231    ContextOverflow(String),
232    #[error("invalid request: {0}")]
233    InvalidRequest(String),
234    #[error("unauthorized: {0}")]
235    Unauthorized(String),
236    #[error("model not found: {0}")]
237    ModelNotFound(String),
238    #[error("content filtered: {0}")]
239    ContentFiltered(String),
240    #[error("all models unavailable (circuit breakers open)")]
241    AllModelsUnavailable,
242    #[error("cancelled")]
243    Cancelled,
244}
245
246impl InferenceExecutionError {
247    /// Short constructor for a rate-limit error with no `Retry-After`.
248    pub fn rate_limited(message: impl Into<String>) -> Self {
249        Self::RateLimited {
250            message: message.into(),
251            retry_after: None,
252        }
253    }
254
255    /// Short constructor for an overloaded error with no `Retry-After`.
256    pub fn overloaded(message: impl Into<String>) -> Self {
257        Self::Overloaded {
258            message: message.into(),
259            retry_after: None,
260        }
261    }
262
263    /// Whether the retry subsystem should try this request again.
264    ///
265    /// Transient errors return `true`; permanent and fail-fast errors
266    /// (including `Cancelled`) return `false`.
267    pub fn is_retryable(&self) -> bool {
268        matches!(
269            self,
270            Self::Provider(_)
271                | Self::RateLimited { .. }
272                | Self::Overloaded { .. }
273                | Self::Timeout(_)
274                | Self::StreamInterrupted { .. }
275        )
276    }
277
278    /// Whether this failure should increment the per-model circuit-breaker
279    /// failure counter. Permanent errors (bad auth, bad schema, context
280    /// overflow) must not trip the breaker — they would have failed with the
281    /// same error on any model.
282    pub fn counts_toward_circuit_breaker(&self) -> bool {
283        self.is_retryable()
284    }
285
286    /// If this error carries a `Retry-After` hint from the provider, return it.
287    pub fn retry_after(&self) -> Option<Duration> {
288        match self {
289            Self::RateLimited { retry_after, .. } | Self::Overloaded { retry_after, .. } => {
290                *retry_after
291            }
292            _ => None,
293        }
294    }
295}
296
297/// A token-level streaming event from the LLM.
298#[derive(Debug, Clone)]
299pub enum LlmStreamEvent {
300    /// Incremental text content.
301    TextDelta(String),
302    /// Incremental reasoning/thinking content.
303    ReasoningDelta(String),
304    /// A tool use block started.
305    ToolCallStart { id: String, name: String },
306    /// Incremental tool call argument JSON.
307    ToolCallDelta { id: String, args_delta: String },
308    /// A content block finished.
309    ContentBlockStop,
310    /// Token usage data (typically sent once at the end).
311    Usage(super::inference::TokenUsage),
312    /// Stop reason (end of stream).
313    Stop(super::inference::StopReason),
314}
315
316/// A boxed stream of `LlmStreamEvent`s.
317///
318/// Implementors wrap their provider-specific streaming response into this type.
319/// The loop runner consumes events, emits deltas via `EventSink`, and collects
320/// the final `StreamResult`.
321pub type InferenceStream = std::pin::Pin<
322    Box<dyn futures::Stream<Item = Result<LlmStreamEvent, InferenceExecutionError>> + Send>,
323>;
324
325/// Abstraction over LLM inference backends.
326///
327/// Providers implement `execute` (collected) and optionally `execute_stream` (streaming).
328/// The loop runner prefers `execute_stream` when available.
329#[async_trait]
330pub trait LlmExecutor: Send + Sync {
331    /// Execute a chat completion and return the collected result.
332    async fn execute(
333        &self,
334        request: InferenceRequest,
335    ) -> Result<StreamResult, InferenceExecutionError>;
336
337    /// Execute a chat completion as a token stream.
338    ///
339    /// Default implementation calls `execute()` and wraps the result as a single-event stream.
340    /// Override to provide true token-level streaming from the LLM provider.
341    fn execute_stream(
342        &self,
343        request: InferenceRequest,
344    ) -> std::pin::Pin<
345        Box<
346            dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
347                + Send
348                + '_,
349        >,
350    > {
351        Box::pin(async move {
352            let result = self.execute(request).await?;
353            let events = collected_to_stream_events(result);
354            Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
355        })
356    }
357
358    /// Provider name for logging/debugging.
359    fn name(&self) -> &str;
360}
361
362/// Convert a collected `StreamResult` into a sequence of `LlmStreamEvent`s.
363pub fn collected_to_stream_events(
364    result: StreamResult,
365) -> Vec<Result<LlmStreamEvent, InferenceExecutionError>> {
366    use super::content::ContentBlock;
367    let mut events = Vec::new();
368
369    // Emit text/thinking deltas from content blocks
370    for block in &result.content {
371        match block {
372            ContentBlock::Text { text } if !text.is_empty() => {
373                events.push(Ok(LlmStreamEvent::TextDelta(text.clone())));
374            }
375            ContentBlock::Thinking { thinking } if !thinking.is_empty() => {
376                events.push(Ok(LlmStreamEvent::ReasoningDelta(thinking.clone())));
377            }
378            _ => {}
379        }
380    }
381
382    // Emit tool calls
383    for call in &result.tool_calls {
384        events.push(Ok(LlmStreamEvent::ToolCallStart {
385            id: call.id.clone(),
386            name: call.name.clone(),
387        }));
388        let args = serde_json::to_string(&call.arguments).unwrap_or_default();
389        if !args.is_empty() {
390            events.push(Ok(LlmStreamEvent::ToolCallDelta {
391                id: call.id.clone(),
392                args_delta: args,
393            }));
394        }
395    }
396
397    // Emit usage
398    if let Some(usage) = result.usage {
399        events.push(Ok(LlmStreamEvent::Usage(usage)));
400    }
401
402    // Emit stop reason
403    if let Some(stop) = result.stop_reason {
404        events.push(Ok(LlmStreamEvent::Stop(stop)));
405    }
406
407    events
408}
409
410/// Tool execution strategy.
411#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
412pub enum ToolExecutionMode {
413    /// Execute tool calls one at a time.
414    #[default]
415    Sequential,
416    /// Execute all tool calls concurrently, batch approval gate.
417    ParallelBatchApproval,
418    /// Execute all tool calls concurrently, streaming results.
419    ParallelStreaming,
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::contract::inference::{StopReason, TokenUsage};
426    use crate::contract::message::ToolCall;
427    use crate::contract::tool::ToolDescriptor;
428    use serde_json::json;
429
430    /// A mock LLM executor for testing.
431    struct MockLlm {
432        response_text: String,
433        tool_calls: Vec<ToolCall>,
434    }
435
436    #[async_trait]
437    impl LlmExecutor for MockLlm {
438        async fn execute(
439            &self,
440            _request: InferenceRequest,
441        ) -> Result<StreamResult, InferenceExecutionError> {
442            Ok(StreamResult {
443                content: if self.response_text.is_empty() {
444                    vec![]
445                } else {
446                    vec![ContentBlock::text(self.response_text.clone())]
447                },
448                tool_calls: self.tool_calls.clone(),
449                usage: Some(TokenUsage {
450                    prompt_tokens: Some(100),
451                    completion_tokens: Some(50),
452                    total_tokens: Some(150),
453                    ..Default::default()
454                }),
455                stop_reason: if self.tool_calls.is_empty() {
456                    Some(StopReason::EndTurn)
457                } else {
458                    Some(StopReason::ToolUse)
459                },
460                has_incomplete_tool_calls: false,
461            })
462        }
463
464        fn name(&self) -> &str {
465            "mock"
466        }
467    }
468
469    #[tokio::test]
470    async fn mock_llm_returns_text() {
471        let llm = MockLlm {
472            response_text: "Hello!".into(),
473            tool_calls: vec![],
474        };
475        let request = InferenceRequest {
476            upstream_model: "test-model".into(),
477            messages: vec![Message::user("hi")],
478            tools: vec![],
479            system: vec![],
480            overrides: None,
481            enable_prompt_cache: false,
482        };
483        let result = llm.execute(request).await.unwrap();
484        assert_eq!(result.text(), "Hello!");
485        assert!(!result.needs_tools());
486        assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
487    }
488
489    #[tokio::test]
490    async fn mock_llm_returns_tool_calls() {
491        let llm = MockLlm {
492            response_text: String::new(),
493            tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
494        };
495        let request = InferenceRequest {
496            upstream_model: "test-model".into(),
497            messages: vec![Message::user("search for rust")],
498            tools: vec![ToolDescriptor::new("search", "search", "Web search")],
499            system: vec![ContentBlock::text("You are helpful.")],
500            overrides: None,
501            enable_prompt_cache: false,
502        };
503        let result = llm.execute(request).await.unwrap();
504        assert!(result.needs_tools());
505        assert_eq!(result.tool_calls.len(), 1);
506        assert_eq!(result.tool_calls[0].name, "search");
507        assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
508    }
509
510    #[tokio::test]
511    async fn mock_llm_with_overrides() {
512        let llm = MockLlm {
513            response_text: "ok".into(),
514            tool_calls: vec![],
515        };
516        let request = InferenceRequest {
517            upstream_model: "base-model".into(),
518            messages: vec![],
519            tools: vec![],
520            system: vec![],
521            overrides: Some(InferenceOverride {
522                temperature: Some(0.7),
523                ..Default::default()
524            }),
525            enable_prompt_cache: false,
526        };
527        let result = llm.execute(request).await.unwrap();
528        assert_eq!(result.text(), "ok");
529    }
530
531    #[test]
532    fn llm_executor_name_is_exposed() {
533        let llm = MockLlm {
534            response_text: String::new(),
535            tool_calls: vec![],
536        };
537
538        assert_eq!(llm.name(), "mock");
539    }
540
541    #[test]
542    fn tool_execution_mode_default_is_sequential() {
543        assert_eq!(ToolExecutionMode::default(), ToolExecutionMode::Sequential);
544    }
545
546    #[test]
547    fn inference_execution_error_display_strings_are_stable() {
548        assert_eq!(
549            InferenceExecutionError::Provider("provider failed".into()).to_string(),
550            "provider error: provider failed"
551        );
552        assert_eq!(
553            InferenceExecutionError::rate_limited("too many requests").to_string(),
554            "rate limited: too many requests"
555        );
556        assert_eq!(
557            InferenceExecutionError::overloaded("server overloaded").to_string(),
558            "provider overloaded: server overloaded"
559        );
560        assert_eq!(
561            InferenceExecutionError::Timeout("slow backend".into()).to_string(),
562            "timeout: slow backend"
563        );
564        assert_eq!(
565            InferenceExecutionError::ContextOverflow("prompt too long".into()).to_string(),
566            "context overflow: prompt too long"
567        );
568        assert_eq!(
569            InferenceExecutionError::InvalidRequest("bad schema".into()).to_string(),
570            "invalid request: bad schema"
571        );
572        assert_eq!(
573            InferenceExecutionError::Unauthorized("bad key".into()).to_string(),
574            "unauthorized: bad key"
575        );
576        assert_eq!(
577            InferenceExecutionError::ModelNotFound("no such model".into()).to_string(),
578            "model not found: no such model"
579        );
580        assert_eq!(
581            InferenceExecutionError::AllModelsUnavailable.to_string(),
582            "all models unavailable (circuit breakers open)"
583        );
584        assert_eq!(InferenceExecutionError::Cancelled.to_string(), "cancelled");
585
586        let stream_err = InferenceExecutionError::StreamInterrupted {
587            cause: InterruptCause::ConnectionReset,
588            snapshot: Box::new(InterruptSnapshot {
589                text: None,
590                completed_tool_calls: vec![],
591                in_flight_tool: None,
592                bytes_received: 0,
593            }),
594        };
595        assert_eq!(
596            stream_err.to_string(),
597            "stream interrupted (connection reset)"
598        );
599    }
600
601    #[test]
602    fn is_retryable_partitions_variants() {
603        use InferenceExecutionError::*;
604        let partial_snapshot = || {
605            Box::new(InterruptSnapshot {
606                text: None,
607                completed_tool_calls: vec![],
608                in_flight_tool: None,
609                bytes_received: 0,
610            })
611        };
612
613        // Retryable
614        assert!(Provider("x".into()).is_retryable());
615        assert!(InferenceExecutionError::rate_limited("x").is_retryable());
616        assert!(InferenceExecutionError::overloaded("x").is_retryable());
617        assert!(Timeout("x".into()).is_retryable());
618        assert!(
619            StreamInterrupted {
620                cause: InterruptCause::ConnectionReset,
621                snapshot: partial_snapshot(),
622            }
623            .is_retryable()
624        );
625
626        // Permanent
627        assert!(!ContextOverflow("x".into()).is_retryable());
628        assert!(!InvalidRequest("x".into()).is_retryable());
629        assert!(!Unauthorized("x".into()).is_retryable());
630        assert!(!ModelNotFound("x".into()).is_retryable());
631        assert!(!ContentFiltered("x".into()).is_retryable());
632
633        // Fail-fast / lifecycle
634        assert!(!AllModelsUnavailable.is_retryable());
635        assert!(!Cancelled.is_retryable());
636    }
637
638    #[test]
639    fn retry_after_is_only_exposed_for_rate_limit_variants() {
640        use std::time::Duration;
641
642        let rl = InferenceExecutionError::RateLimited {
643            message: "429".into(),
644            retry_after: Some(Duration::from_secs(5)),
645        };
646        assert_eq!(rl.retry_after(), Some(Duration::from_secs(5)));
647
648        let ov = InferenceExecutionError::Overloaded {
649            message: "529".into(),
650            retry_after: Some(Duration::from_secs(10)),
651        };
652        assert_eq!(ov.retry_after(), Some(Duration::from_secs(10)));
653
654        assert_eq!(
655            InferenceExecutionError::Timeout("slow".into()).retry_after(),
656            None
657        );
658    }
659
660    #[test]
661    fn plan_returns_continue_text_when_only_text_present() {
662        let snap = InterruptSnapshot {
663            text: Some("hello".into()),
664            completed_tool_calls: vec![],
665            in_flight_tool: None,
666            bytes_received: 5,
667        };
668        match snap.plan() {
669            RecoveryPlan::ContinueText { assistant_prefix } => {
670                assert_eq!(assistant_prefix, "hello");
671            }
672            other => panic!("expected ContinueText, got {other:?}"),
673        }
674    }
675
676    #[test]
677    fn plan_returns_synthesize_tool_use_when_completed_tool_present() {
678        use serde_json::json;
679        let snap = InterruptSnapshot {
680            text: Some("I'll search.".into()),
681            completed_tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
682            in_flight_tool: Some(InFlightTool {
683                id: "c2".into(),
684                name: "fetch".into(),
685                partial_args: r#"{"url":"#.into(),
686            }),
687            bytes_received: 64,
688        };
689        match snap.plan() {
690            RecoveryPlan::SynthesizeToolUse {
691                completed,
692                cancelled_tool_hint,
693            } => {
694                assert_eq!(completed.len(), 1);
695                assert_eq!(completed[0].name, "search");
696                let hint = cancelled_tool_hint.expect("in-flight tool becomes hint");
697                assert_eq!(hint.name, "fetch");
698            }
699            other => panic!("expected SynthesizeToolUse, got {other:?}"),
700        }
701    }
702
703    #[test]
704    fn plan_returns_truncate_before_tool_when_text_and_in_flight_only() {
705        let snap = InterruptSnapshot {
706            text: Some("let me think".into()),
707            completed_tool_calls: vec![],
708            in_flight_tool: Some(InFlightTool {
709                id: "c1".into(),
710                name: "calc".into(),
711                partial_args: r#"{"expr":"#.into(),
712            }),
713            bytes_received: 24,
714        };
715        match snap.plan() {
716            RecoveryPlan::TruncateBeforeTool {
717                assistant_prefix,
718                cancelled_tool_id,
719                cancelled_tool_name,
720            } => {
721                assert_eq!(assistant_prefix, "let me think");
722                assert_eq!(cancelled_tool_id, "c1");
723                assert_eq!(cancelled_tool_name, "calc");
724            }
725            other => panic!("expected TruncateBeforeTool, got {other:?}"),
726        }
727    }
728
729    #[test]
730    fn plan_returns_whole_restart_when_nothing_salvageable() {
731        let snap = InterruptSnapshot {
732            text: None,
733            completed_tool_calls: vec![],
734            in_flight_tool: None,
735            bytes_received: 0,
736        };
737        assert!(matches!(snap.plan(), RecoveryPlan::WholeRestart));
738
739        // Also: only an in-flight tool, no text → whole restart.
740        let snap2 = InterruptSnapshot {
741            text: None,
742            completed_tool_calls: vec![],
743            in_flight_tool: Some(InFlightTool {
744                id: "c1".into(),
745                name: "x".into(),
746                partial_args: "{".into(),
747            }),
748            bytes_received: 1,
749        };
750        assert!(matches!(snap2.plan(), RecoveryPlan::WholeRestart));
751    }
752}