agent_sdk/
llm.rs

1pub mod router;
2pub mod streaming;
3pub mod types;
4
5pub use router::{ModelRouter, ModelTier, TaskComplexity};
6pub use streaming::{StreamAccumulator, StreamBox, StreamDelta};
7pub use types::*;
8
9use anyhow::Result;
10use async_trait::async_trait;
11use futures::StreamExt;
12
13#[async_trait]
14pub trait LlmProvider: Send + Sync {
15    /// Non-streaming chat completion.
16    async fn chat(&self, request: ChatRequest) -> Result<ChatOutcome>;
17
18    /// Streaming chat completion.
19    ///
20    /// Returns a stream of [`StreamDelta`] events. The default implementation
21    /// calls [`chat()`](Self::chat) and converts the result to a single-chunk stream.
22    ///
23    /// Providers should override this method to provide true streaming support.
24    fn chat_stream(&self, request: ChatRequest) -> StreamBox<'_> {
25        Box::pin(async_stream::stream! {
26            match self.chat(request).await {
27                Ok(outcome) => match outcome {
28                    ChatOutcome::Success(response) => {
29                        // Emit content as deltas
30                        for (idx, block) in response.content.iter().enumerate() {
31                            match block {
32                                ContentBlock::Text { text } => {
33                                    yield Ok(StreamDelta::TextDelta {
34                                        delta: text.clone(),
35                                        block_index: idx,
36                                    });
37                                }
38                                ContentBlock::ToolUse { id, name, input, .. } => {
39                                    yield Ok(StreamDelta::ToolUseStart {
40                                        id: id.clone(),
41                                        name: name.clone(),
42                                        block_index: idx,
43                                    });
44                                    yield Ok(StreamDelta::ToolInputDelta {
45                                        id: id.clone(),
46                                        delta: serde_json::to_string(input).unwrap_or_default(),
47                                        block_index: idx,
48                                    });
49                                }
50                                ContentBlock::ToolResult { .. } => {
51                                    // Tool results are not emitted in streaming responses
52                                }
53                            }
54                        }
55                        yield Ok(StreamDelta::Usage(response.usage));
56                        yield Ok(StreamDelta::Done {
57                            stop_reason: response.stop_reason,
58                        });
59                    }
60                    ChatOutcome::RateLimited => {
61                        yield Ok(StreamDelta::Error {
62                            message: "Rate limited".to_string(),
63                            recoverable: true,
64                        });
65                    }
66                    ChatOutcome::InvalidRequest(msg) => {
67                        yield Ok(StreamDelta::Error {
68                            message: msg,
69                            recoverable: false,
70                        });
71                    }
72                    ChatOutcome::ServerError(msg) => {
73                        yield Ok(StreamDelta::Error {
74                            message: msg,
75                            recoverable: true,
76                        });
77                    }
78                },
79                Err(e) => yield Err(e),
80            }
81        })
82    }
83
84    fn model(&self) -> &str;
85    fn provider(&self) -> &'static str;
86}
87
88/// Helper function to consume a stream and collect it into a `ChatResponse`.
89///
90/// This is useful for providers that want to test their streaming implementation
91/// or for cases where you need the full response after streaming.
92///
93/// # Errors
94///
95/// Returns an error if the stream yields an error result.
96pub async fn collect_stream(mut stream: StreamBox<'_>, model: String) -> Result<ChatOutcome> {
97    let mut accumulator = StreamAccumulator::new();
98    let mut last_error: Option<(String, bool)> = None;
99
100    while let Some(result) = stream.next().await {
101        match result {
102            Ok(delta) => {
103                if let StreamDelta::Error {
104                    message,
105                    recoverable,
106                } = &delta
107                {
108                    last_error = Some((message.clone(), *recoverable));
109                }
110                accumulator.apply(&delta);
111            }
112            Err(e) => return Err(e),
113        }
114    }
115
116    // If we encountered an error during streaming, return it
117    if let Some((message, recoverable)) = last_error {
118        if !recoverable {
119            return Ok(ChatOutcome::InvalidRequest(message));
120        }
121        // Check if it was a rate limit
122        if message.contains("Rate limited") || message.contains("rate limit") {
123            return Ok(ChatOutcome::RateLimited);
124        }
125        return Ok(ChatOutcome::ServerError(message));
126    }
127
128    // Extract usage and stop_reason before consuming the accumulator
129    let usage = accumulator.take_usage().unwrap_or(Usage {
130        input_tokens: 0,
131        output_tokens: 0,
132    });
133    let stop_reason = accumulator.take_stop_reason();
134    let content = accumulator.into_content_blocks();
135
136    Ok(ChatOutcome::Success(ChatResponse {
137        id: String::new(),
138        content,
139        model,
140        stop_reason,
141        usage,
142    }))
143}