rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
9use crate::json_utils::merge_inplace;
10use crate::streaming;
11use crate::streaming::{RawStreamingChoice, StreamingResult};
12
13#[derive(Debug, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum StreamingEvent {
16    MessageStart {
17        message: MessageStart,
18    },
19    ContentBlockStart {
20        index: usize,
21        content_block: Content,
22    },
23    ContentBlockDelta {
24        index: usize,
25        delta: ContentDelta,
26    },
27    ContentBlockStop {
28        index: usize,
29    },
30    MessageDelta {
31        delta: MessageDelta,
32        usage: PartialUsage,
33    },
34    MessageStop,
35    Ping,
36    #[serde(other)]
37    Unknown,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct MessageStart {
42    pub id: String,
43    pub role: String,
44    pub content: Vec<Content>,
45    pub model: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ContentDelta {
54    TextDelta { text: String },
55    InputJsonDelta { partial_json: String },
56}
57
58#[derive(Debug, Deserialize)]
59pub struct MessageDelta {
60    pub stop_reason: Option<String>,
61    pub stop_sequence: Option<String>,
62}
63
64#[derive(Debug, Deserialize, Clone, Serialize)]
65pub struct PartialUsage {
66    pub output_tokens: usize,
67    #[serde(default)]
68    pub input_tokens: Option<usize>,
69}
70
71#[derive(Default)]
72struct ToolCallState {
73    name: String,
74    id: String,
75    input_json: String,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79pub struct StreamingCompletionResponse {
80    pub usage: PartialUsage,
81}
82
83impl GetTokenUsage for StreamingCompletionResponse {
84    fn token_usage(&self) -> Option<crate::completion::Usage> {
85        let mut usage = crate::completion::Usage::new();
86        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
87        usage.output_tokens = self.usage.output_tokens as u64;
88        usage.total_tokens =
89            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
90
91        Some(usage)
92    }
93}
94
95impl CompletionModel {
96    pub(crate) async fn stream(
97        &self,
98        completion_request: CompletionRequest,
99    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
100    {
101        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
102            tokens
103        } else if let Some(tokens) = self.default_max_tokens {
104            tokens
105        } else {
106            return Err(CompletionError::RequestError(
107                "`max_tokens` must be set for Anthropic".into(),
108            ));
109        };
110
111        let mut full_history = vec![];
112        if let Some(docs) = completion_request.normalized_documents() {
113            full_history.push(docs);
114        }
115        full_history.extend(completion_request.chat_history);
116
117        let full_history = full_history
118            .into_iter()
119            .map(Message::try_from)
120            .collect::<Result<Vec<Message>, _>>()?;
121
122        let mut request = json!({
123            "model": self.model,
124            "messages": full_history,
125            "max_tokens": max_tokens,
126            "system": completion_request.preamble.unwrap_or("".to_string()),
127            "stream": true,
128        });
129
130        if let Some(temperature) = completion_request.temperature {
131            merge_inplace(&mut request, json!({ "temperature": temperature }));
132        }
133
134        if !completion_request.tools.is_empty() {
135            merge_inplace(
136                &mut request,
137                json!({
138                    "tools": completion_request
139                        .tools
140                        .into_iter()
141                        .map(|tool| ToolDefinition {
142                            name: tool.name,
143                            description: Some(tool.description),
144                            input_schema: tool.parameters,
145                        })
146                        .collect::<Vec<_>>(),
147                    "tool_choice": ToolChoice::Auto,
148                }),
149            );
150        }
151
152        if let Some(ref params) = completion_request.additional_params {
153            merge_inplace(&mut request, params.clone())
154        }
155
156        let response = self
157            .client
158            .post("/v1/messages")
159            .json(&request)
160            .send()
161            .await?;
162
163        if !response.status().is_success() {
164            return Err(CompletionError::ProviderError(response.text().await?));
165        }
166
167        // Use our SSE decoder to directly handle Server-Sent Events format
168        let sse_stream = sse_from_response(response);
169
170        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
171            let mut current_tool_call: Option<ToolCallState> = None;
172            let mut sse_stream = Box::pin(sse_stream);
173            let mut input_tokens = 0;
174
175            while let Some(sse_result) = sse_stream.next().await {
176                match sse_result {
177                    Ok(sse) => {
178                        // Parse the SSE data as a StreamingEvent
179                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
180                            Ok(event) => {
181                                match &event {
182                                    StreamingEvent::MessageStart { message } => {
183                                        input_tokens = message.usage.input_tokens;
184                                    },
185                                    StreamingEvent::MessageDelta { delta, usage } => {
186                                        if delta.stop_reason.is_some() {
187
188                                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
189                                                usage: PartialUsage {
190                                                    output_tokens: usage.output_tokens,
191                                                    input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
192                                                }
193                                            }))
194                                        }
195                                    }
196                                    _ => {}
197                                }
198
199                                if let Some(result) = handle_event(&event, &mut current_tool_call) {
200                                    yield result;
201                                }
202                            },
203                            Err(e) => {
204                                if !sse.data.trim().is_empty() {
205                                    yield Err(CompletionError::ResponseError(
206                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
207                                    ));
208                                }
209                            }
210                        }
211                    },
212                    Err(e) => {
213                        yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
214                        break;
215                    }
216                }
217            }
218        });
219
220        Ok(streaming::StreamingCompletionResponse::stream(stream))
221    }
222}
223
224fn handle_event(
225    event: &StreamingEvent,
226    current_tool_call: &mut Option<ToolCallState>,
227) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
228    match event {
229        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
230            ContentDelta::TextDelta { text } => {
231                if current_tool_call.is_none() {
232                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
233                }
234                None
235            }
236            ContentDelta::InputJsonDelta { partial_json } => {
237                if let Some(tool_call) = current_tool_call {
238                    tool_call.input_json.push_str(partial_json);
239                }
240                None
241            }
242        },
243        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
244            Content::ToolUse { id, name, .. } => {
245                *current_tool_call = Some(ToolCallState {
246                    name: name.clone(),
247                    id: id.clone(),
248                    input_json: String::new(),
249                });
250                None
251            }
252            // Handle other content types - they don't need special handling
253            _ => None,
254        },
255        StreamingEvent::ContentBlockStop { .. } => {
256            if let Some(tool_call) = current_tool_call.take() {
257                let json_str = if tool_call.input_json.is_empty() {
258                    "{}"
259                } else {
260                    &tool_call.input_json
261                };
262                match serde_json::from_str(json_str) {
263                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
264                        name: tool_call.name,
265                        id: tool_call.id,
266                        arguments: json_value,
267                        call_id: None,
268                    })),
269                    Err(e) => Some(Err(CompletionError::from(e))),
270                }
271            } else {
272                None
273            }
274        }
275        // Ignore other event types or handle as needed
276        StreamingEvent::MessageStart { .. }
277        | StreamingEvent::MessageDelta { .. }
278        | StreamingEvent::MessageStop
279        | StreamingEvent::Ping
280        | StreamingEvent::Unknown => None,
281    }
282}