1use std::time::Duration;
4
5use super::content::ContentBlock;
6use super::inference::{InferenceOverride, StreamResult};
7use super::message::{Message, ToolCall};
8use super::tool::ToolDescriptor;
9use async_trait::async_trait;
10use thiserror::Error;
11
12#[derive(Debug, Clone)]
14pub struct InferenceRequest {
15 pub upstream_model: String,
17 pub messages: Vec<Message>,
19 pub tools: Vec<ToolDescriptor>,
21 pub system: Vec<ContentBlock>,
23 pub overrides: Option<InferenceOverride>,
26 pub enable_prompt_cache: bool,
28}
29
30#[derive(Debug, Clone)]
32pub enum InterruptCause {
33 ConnectionReset,
35 IdleStall,
37 GoAway,
39 Provider5xxMidStream(u16),
41 ResumedFromCheckpoint,
45}
46
47impl std::fmt::Display for InterruptCause {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 Self::ConnectionReset => f.write_str("connection reset"),
51 Self::IdleStall => f.write_str("idle stall"),
52 Self::GoAway => f.write_str("goaway"),
53 Self::Provider5xxMidStream(s) => write!(f, "provider {s} mid-stream"),
54 Self::ResumedFromCheckpoint => f.write_str("resumed from checkpoint"),
55 }
56 }
57}
58
59#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct InFlightTool {
62 pub id: String,
63 pub name: String,
64 pub partial_args: String,
66}
67
68#[derive(Debug, Clone)]
72pub struct InterruptSnapshot {
73 pub text: Option<String>,
76 pub completed_tool_calls: Vec<ToolCall>,
78 pub in_flight_tool: Option<InFlightTool>,
80 pub bytes_received: usize,
82}
83
84#[derive(Debug, Clone)]
87pub enum RecoveryPlan {
88 ContinueText { assistant_prefix: String },
92 SynthesizeToolUse {
97 completed: Vec<ToolCall>,
98 cancelled_tool_hint: Option<InFlightTool>,
99 },
100 TruncateBeforeTool {
104 assistant_prefix: String,
105 cancelled_tool_id: String,
106 cancelled_tool_name: String,
107 },
108 WholeRestart,
110}
111
112impl InterruptSnapshot {
113 pub fn from_partials<I>(text: Option<String>, partials: I, bytes_received: usize) -> Self
123 where
124 I: IntoIterator<Item = (String, String, String)>,
125 {
126 let mut completed: Vec<ToolCall> = Vec::new();
127 let mut in_flight: Option<InFlightTool> = None;
128
129 for (id, name, args_json) in partials {
130 if name.is_empty() {
131 in_flight = Some(InFlightTool {
132 id,
133 name: String::new(),
134 partial_args: args_json,
135 });
136 continue;
137 }
138 match serde_json::from_str::<serde_json::Value>(&args_json) {
139 Ok(arguments) if !(arguments.is_null() && !args_json.is_empty()) => {
140 completed.push(ToolCall::new(id, name, arguments));
141 }
142 _ => {
143 in_flight = Some(InFlightTool {
144 id,
145 name,
146 partial_args: args_json,
147 });
148 }
149 }
150 }
151
152 Self {
153 text,
154 completed_tool_calls: completed,
155 in_flight_tool: in_flight,
156 bytes_received,
157 }
158 }
159
160 pub fn plan(&self) -> RecoveryPlan {
162 let text = self.text.as_deref().unwrap_or("");
163 let has_text = !text.is_empty();
164 let has_completed = !self.completed_tool_calls.is_empty();
165
166 if has_completed {
168 return RecoveryPlan::SynthesizeToolUse {
169 completed: self.completed_tool_calls.clone(),
170 cancelled_tool_hint: self.in_flight_tool.clone(),
171 };
172 }
173
174 if has_text {
176 if let Some(p) = &self.in_flight_tool {
177 return RecoveryPlan::TruncateBeforeTool {
178 assistant_prefix: text.to_string(),
179 cancelled_tool_id: p.id.clone(),
180 cancelled_tool_name: p.name.clone(),
181 };
182 }
183 return RecoveryPlan::ContinueText {
185 assistant_prefix: text.to_string(),
186 };
187 }
188
189 RecoveryPlan::WholeRestart
191 }
192}
193
194#[derive(Debug, Clone, Error)]
208#[non_exhaustive]
209pub enum InferenceExecutionError {
210 #[error("provider error: {0}")]
211 Provider(String),
212 #[error("rate limited: {message}")]
213 RateLimited {
214 message: String,
215 retry_after: Option<Duration>,
217 },
218 #[error("provider overloaded: {message}")]
219 Overloaded {
220 message: String,
221 retry_after: Option<Duration>,
222 },
223 #[error("timeout: {0}")]
224 Timeout(String),
225 #[error("stream interrupted ({cause})")]
226 StreamInterrupted {
227 cause: InterruptCause,
228 snapshot: Box<InterruptSnapshot>,
229 },
230 #[error("context overflow: {0}")]
231 ContextOverflow(String),
232 #[error("invalid request: {0}")]
233 InvalidRequest(String),
234 #[error("unauthorized: {0}")]
235 Unauthorized(String),
236 #[error("model not found: {0}")]
237 ModelNotFound(String),
238 #[error("content filtered: {0}")]
239 ContentFiltered(String),
240 #[error("all models unavailable (circuit breakers open)")]
241 AllModelsUnavailable,
242 #[error("cancelled")]
243 Cancelled,
244}
245
246impl InferenceExecutionError {
247 pub fn rate_limited(message: impl Into<String>) -> Self {
249 Self::RateLimited {
250 message: message.into(),
251 retry_after: None,
252 }
253 }
254
255 pub fn overloaded(message: impl Into<String>) -> Self {
257 Self::Overloaded {
258 message: message.into(),
259 retry_after: None,
260 }
261 }
262
263 pub fn is_retryable(&self) -> bool {
268 matches!(
269 self,
270 Self::Provider(_)
271 | Self::RateLimited { .. }
272 | Self::Overloaded { .. }
273 | Self::Timeout(_)
274 | Self::StreamInterrupted { .. }
275 )
276 }
277
278 pub fn counts_toward_circuit_breaker(&self) -> bool {
283 self.is_retryable()
284 }
285
286 pub fn retry_after(&self) -> Option<Duration> {
288 match self {
289 Self::RateLimited { retry_after, .. } | Self::Overloaded { retry_after, .. } => {
290 *retry_after
291 }
292 _ => None,
293 }
294 }
295}
296
297#[derive(Debug, Clone)]
299pub enum LlmStreamEvent {
300 TextDelta(String),
302 ReasoningDelta(String),
304 ToolCallStart { id: String, name: String },
306 ToolCallDelta { id: String, args_delta: String },
308 ContentBlockStop,
310 Usage(super::inference::TokenUsage),
312 Stop(super::inference::StopReason),
314}
315
316pub type InferenceStream = std::pin::Pin<
322 Box<dyn futures::Stream<Item = Result<LlmStreamEvent, InferenceExecutionError>> + Send>,
323>;
324
325#[async_trait]
330pub trait LlmExecutor: Send + Sync {
331 async fn execute(
333 &self,
334 request: InferenceRequest,
335 ) -> Result<StreamResult, InferenceExecutionError>;
336
337 fn execute_stream(
342 &self,
343 request: InferenceRequest,
344 ) -> std::pin::Pin<
345 Box<
346 dyn std::future::Future<Output = Result<InferenceStream, InferenceExecutionError>>
347 + Send
348 + '_,
349 >,
350 > {
351 Box::pin(async move {
352 let result = self.execute(request).await?;
353 let events = collected_to_stream_events(result);
354 Ok(Box::pin(futures::stream::iter(events)) as InferenceStream)
355 })
356 }
357
358 fn name(&self) -> &str;
360}
361
362pub fn collected_to_stream_events(
364 result: StreamResult,
365) -> Vec<Result<LlmStreamEvent, InferenceExecutionError>> {
366 use super::content::ContentBlock;
367 let mut events = Vec::new();
368
369 for block in &result.content {
371 match block {
372 ContentBlock::Text { text } if !text.is_empty() => {
373 events.push(Ok(LlmStreamEvent::TextDelta(text.clone())));
374 }
375 ContentBlock::Thinking { thinking } if !thinking.is_empty() => {
376 events.push(Ok(LlmStreamEvent::ReasoningDelta(thinking.clone())));
377 }
378 _ => {}
379 }
380 }
381
382 for call in &result.tool_calls {
384 events.push(Ok(LlmStreamEvent::ToolCallStart {
385 id: call.id.clone(),
386 name: call.name.clone(),
387 }));
388 let args = serde_json::to_string(&call.arguments).unwrap_or_default();
389 if !args.is_empty() {
390 events.push(Ok(LlmStreamEvent::ToolCallDelta {
391 id: call.id.clone(),
392 args_delta: args,
393 }));
394 }
395 }
396
397 if let Some(usage) = result.usage {
399 events.push(Ok(LlmStreamEvent::Usage(usage)));
400 }
401
402 if let Some(stop) = result.stop_reason {
404 events.push(Ok(LlmStreamEvent::Stop(stop)));
405 }
406
407 events
408}
409
410#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
412pub enum ToolExecutionMode {
413 #[default]
415 Sequential,
416 ParallelBatchApproval,
418 ParallelStreaming,
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::contract::inference::{StopReason, TokenUsage};
426 use crate::contract::message::ToolCall;
427 use crate::contract::tool::ToolDescriptor;
428 use serde_json::json;
429
430 struct MockLlm {
432 response_text: String,
433 tool_calls: Vec<ToolCall>,
434 }
435
436 #[async_trait]
437 impl LlmExecutor for MockLlm {
438 async fn execute(
439 &self,
440 _request: InferenceRequest,
441 ) -> Result<StreamResult, InferenceExecutionError> {
442 Ok(StreamResult {
443 content: if self.response_text.is_empty() {
444 vec![]
445 } else {
446 vec![ContentBlock::text(self.response_text.clone())]
447 },
448 tool_calls: self.tool_calls.clone(),
449 usage: Some(TokenUsage {
450 prompt_tokens: Some(100),
451 completion_tokens: Some(50),
452 total_tokens: Some(150),
453 ..Default::default()
454 }),
455 stop_reason: if self.tool_calls.is_empty() {
456 Some(StopReason::EndTurn)
457 } else {
458 Some(StopReason::ToolUse)
459 },
460 has_incomplete_tool_calls: false,
461 })
462 }
463
464 fn name(&self) -> &str {
465 "mock"
466 }
467 }
468
469 #[tokio::test]
470 async fn mock_llm_returns_text() {
471 let llm = MockLlm {
472 response_text: "Hello!".into(),
473 tool_calls: vec![],
474 };
475 let request = InferenceRequest {
476 upstream_model: "test-model".into(),
477 messages: vec![Message::user("hi")],
478 tools: vec![],
479 system: vec![],
480 overrides: None,
481 enable_prompt_cache: false,
482 };
483 let result = llm.execute(request).await.unwrap();
484 assert_eq!(result.text(), "Hello!");
485 assert!(!result.needs_tools());
486 assert_eq!(result.stop_reason, Some(StopReason::EndTurn));
487 }
488
489 #[tokio::test]
490 async fn mock_llm_returns_tool_calls() {
491 let llm = MockLlm {
492 response_text: String::new(),
493 tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
494 };
495 let request = InferenceRequest {
496 upstream_model: "test-model".into(),
497 messages: vec![Message::user("search for rust")],
498 tools: vec![ToolDescriptor::new("search", "search", "Web search")],
499 system: vec![ContentBlock::text("You are helpful.")],
500 overrides: None,
501 enable_prompt_cache: false,
502 };
503 let result = llm.execute(request).await.unwrap();
504 assert!(result.needs_tools());
505 assert_eq!(result.tool_calls.len(), 1);
506 assert_eq!(result.tool_calls[0].name, "search");
507 assert_eq!(result.stop_reason, Some(StopReason::ToolUse));
508 }
509
510 #[tokio::test]
511 async fn mock_llm_with_overrides() {
512 let llm = MockLlm {
513 response_text: "ok".into(),
514 tool_calls: vec![],
515 };
516 let request = InferenceRequest {
517 upstream_model: "base-model".into(),
518 messages: vec![],
519 tools: vec![],
520 system: vec![],
521 overrides: Some(InferenceOverride {
522 temperature: Some(0.7),
523 ..Default::default()
524 }),
525 enable_prompt_cache: false,
526 };
527 let result = llm.execute(request).await.unwrap();
528 assert_eq!(result.text(), "ok");
529 }
530
531 #[test]
532 fn llm_executor_name_is_exposed() {
533 let llm = MockLlm {
534 response_text: String::new(),
535 tool_calls: vec![],
536 };
537
538 assert_eq!(llm.name(), "mock");
539 }
540
541 #[test]
542 fn tool_execution_mode_default_is_sequential() {
543 assert_eq!(ToolExecutionMode::default(), ToolExecutionMode::Sequential);
544 }
545
546 #[test]
547 fn inference_execution_error_display_strings_are_stable() {
548 assert_eq!(
549 InferenceExecutionError::Provider("provider failed".into()).to_string(),
550 "provider error: provider failed"
551 );
552 assert_eq!(
553 InferenceExecutionError::rate_limited("too many requests").to_string(),
554 "rate limited: too many requests"
555 );
556 assert_eq!(
557 InferenceExecutionError::overloaded("server overloaded").to_string(),
558 "provider overloaded: server overloaded"
559 );
560 assert_eq!(
561 InferenceExecutionError::Timeout("slow backend".into()).to_string(),
562 "timeout: slow backend"
563 );
564 assert_eq!(
565 InferenceExecutionError::ContextOverflow("prompt too long".into()).to_string(),
566 "context overflow: prompt too long"
567 );
568 assert_eq!(
569 InferenceExecutionError::InvalidRequest("bad schema".into()).to_string(),
570 "invalid request: bad schema"
571 );
572 assert_eq!(
573 InferenceExecutionError::Unauthorized("bad key".into()).to_string(),
574 "unauthorized: bad key"
575 );
576 assert_eq!(
577 InferenceExecutionError::ModelNotFound("no such model".into()).to_string(),
578 "model not found: no such model"
579 );
580 assert_eq!(
581 InferenceExecutionError::AllModelsUnavailable.to_string(),
582 "all models unavailable (circuit breakers open)"
583 );
584 assert_eq!(InferenceExecutionError::Cancelled.to_string(), "cancelled");
585
586 let stream_err = InferenceExecutionError::StreamInterrupted {
587 cause: InterruptCause::ConnectionReset,
588 snapshot: Box::new(InterruptSnapshot {
589 text: None,
590 completed_tool_calls: vec![],
591 in_flight_tool: None,
592 bytes_received: 0,
593 }),
594 };
595 assert_eq!(
596 stream_err.to_string(),
597 "stream interrupted (connection reset)"
598 );
599 }
600
601 #[test]
602 fn is_retryable_partitions_variants() {
603 use InferenceExecutionError::*;
604 let partial_snapshot = || {
605 Box::new(InterruptSnapshot {
606 text: None,
607 completed_tool_calls: vec![],
608 in_flight_tool: None,
609 bytes_received: 0,
610 })
611 };
612
613 assert!(Provider("x".into()).is_retryable());
615 assert!(InferenceExecutionError::rate_limited("x").is_retryable());
616 assert!(InferenceExecutionError::overloaded("x").is_retryable());
617 assert!(Timeout("x".into()).is_retryable());
618 assert!(
619 StreamInterrupted {
620 cause: InterruptCause::ConnectionReset,
621 snapshot: partial_snapshot(),
622 }
623 .is_retryable()
624 );
625
626 assert!(!ContextOverflow("x".into()).is_retryable());
628 assert!(!InvalidRequest("x".into()).is_retryable());
629 assert!(!Unauthorized("x".into()).is_retryable());
630 assert!(!ModelNotFound("x".into()).is_retryable());
631 assert!(!ContentFiltered("x".into()).is_retryable());
632
633 assert!(!AllModelsUnavailable.is_retryable());
635 assert!(!Cancelled.is_retryable());
636 }
637
638 #[test]
639 fn retry_after_is_only_exposed_for_rate_limit_variants() {
640 use std::time::Duration;
641
642 let rl = InferenceExecutionError::RateLimited {
643 message: "429".into(),
644 retry_after: Some(Duration::from_secs(5)),
645 };
646 assert_eq!(rl.retry_after(), Some(Duration::from_secs(5)));
647
648 let ov = InferenceExecutionError::Overloaded {
649 message: "529".into(),
650 retry_after: Some(Duration::from_secs(10)),
651 };
652 assert_eq!(ov.retry_after(), Some(Duration::from_secs(10)));
653
654 assert_eq!(
655 InferenceExecutionError::Timeout("slow".into()).retry_after(),
656 None
657 );
658 }
659
660 #[test]
661 fn plan_returns_continue_text_when_only_text_present() {
662 let snap = InterruptSnapshot {
663 text: Some("hello".into()),
664 completed_tool_calls: vec![],
665 in_flight_tool: None,
666 bytes_received: 5,
667 };
668 match snap.plan() {
669 RecoveryPlan::ContinueText { assistant_prefix } => {
670 assert_eq!(assistant_prefix, "hello");
671 }
672 other => panic!("expected ContinueText, got {other:?}"),
673 }
674 }
675
676 #[test]
677 fn plan_returns_synthesize_tool_use_when_completed_tool_present() {
678 use serde_json::json;
679 let snap = InterruptSnapshot {
680 text: Some("I'll search.".into()),
681 completed_tool_calls: vec![ToolCall::new("c1", "search", json!({"q": "rust"}))],
682 in_flight_tool: Some(InFlightTool {
683 id: "c2".into(),
684 name: "fetch".into(),
685 partial_args: r#"{"url":"#.into(),
686 }),
687 bytes_received: 64,
688 };
689 match snap.plan() {
690 RecoveryPlan::SynthesizeToolUse {
691 completed,
692 cancelled_tool_hint,
693 } => {
694 assert_eq!(completed.len(), 1);
695 assert_eq!(completed[0].name, "search");
696 let hint = cancelled_tool_hint.expect("in-flight tool becomes hint");
697 assert_eq!(hint.name, "fetch");
698 }
699 other => panic!("expected SynthesizeToolUse, got {other:?}"),
700 }
701 }
702
703 #[test]
704 fn plan_returns_truncate_before_tool_when_text_and_in_flight_only() {
705 let snap = InterruptSnapshot {
706 text: Some("let me think".into()),
707 completed_tool_calls: vec![],
708 in_flight_tool: Some(InFlightTool {
709 id: "c1".into(),
710 name: "calc".into(),
711 partial_args: r#"{"expr":"#.into(),
712 }),
713 bytes_received: 24,
714 };
715 match snap.plan() {
716 RecoveryPlan::TruncateBeforeTool {
717 assistant_prefix,
718 cancelled_tool_id,
719 cancelled_tool_name,
720 } => {
721 assert_eq!(assistant_prefix, "let me think");
722 assert_eq!(cancelled_tool_id, "c1");
723 assert_eq!(cancelled_tool_name, "calc");
724 }
725 other => panic!("expected TruncateBeforeTool, got {other:?}"),
726 }
727 }
728
729 #[test]
730 fn plan_returns_whole_restart_when_nothing_salvageable() {
731 let snap = InterruptSnapshot {
732 text: None,
733 completed_tool_calls: vec![],
734 in_flight_tool: None,
735 bytes_received: 0,
736 };
737 assert!(matches!(snap.plan(), RecoveryPlan::WholeRestart));
738
739 let snap2 = InterruptSnapshot {
741 text: None,
742 completed_tool_calls: vec![],
743 in_flight_tool: Some(InFlightTool {
744 id: "c1".into(),
745 name: "x".into(),
746 partial_args: "{".into(),
747 }),
748 bytes_received: 1,
749 };
750 assert!(matches!(snap2.plan(), RecoveryPlan::WholeRestart));
751 }
752}