rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest};
9use crate::json_utils::merge_inplace;
10use crate::streaming;
11use crate::streaming::{RawStreamingChoice, StreamingResult};
12
13#[derive(Debug, Deserialize)]
14#[serde(tag = "type", rename_all = "snake_case")]
15pub enum StreamingEvent {
16    MessageStart {
17        message: MessageStart,
18    },
19    ContentBlockStart {
20        index: usize,
21        content_block: Content,
22    },
23    ContentBlockDelta {
24        index: usize,
25        delta: ContentDelta,
26    },
27    ContentBlockStop {
28        index: usize,
29    },
30    MessageDelta {
31        delta: MessageDelta,
32        usage: PartialUsage,
33    },
34    MessageStop,
35    Ping,
36    #[serde(other)]
37    Unknown,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct MessageStart {
42    pub id: String,
43    pub role: String,
44    pub content: Vec<Content>,
45    pub model: String,
46    pub stop_reason: Option<String>,
47    pub stop_sequence: Option<String>,
48    pub usage: Usage,
49}
50
51#[derive(Debug, Deserialize)]
52#[serde(tag = "type", rename_all = "snake_case")]
53pub enum ContentDelta {
54    TextDelta { text: String },
55    InputJsonDelta { partial_json: String },
56}
57
58#[derive(Debug, Deserialize)]
59pub struct MessageDelta {
60    pub stop_reason: Option<String>,
61    pub stop_sequence: Option<String>,
62}
63
64#[derive(Debug, Deserialize, Clone, Serialize)]
65pub struct PartialUsage {
66    pub output_tokens: usize,
67    #[serde(default)]
68    pub input_tokens: Option<usize>,
69}
70
71#[derive(Default)]
72struct ToolCallState {
73    name: String,
74    id: String,
75    input_json: String,
76}
77
78#[derive(Clone, Deserialize, Serialize)]
79pub struct StreamingCompletionResponse {
80    pub usage: PartialUsage,
81}
82
83impl CompletionModel {
84    pub(crate) async fn stream(
85        &self,
86        completion_request: CompletionRequest,
87    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
88    {
89        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
90            tokens
91        } else if let Some(tokens) = self.default_max_tokens {
92            tokens
93        } else {
94            return Err(CompletionError::RequestError(
95                "`max_tokens` must be set for Anthropic".into(),
96            ));
97        };
98
99        let mut full_history = vec![];
100        if let Some(docs) = completion_request.normalized_documents() {
101            full_history.push(docs);
102        }
103        full_history.extend(completion_request.chat_history);
104
105        let full_history = full_history
106            .into_iter()
107            .map(Message::try_from)
108            .collect::<Result<Vec<Message>, _>>()?;
109
110        let mut request = json!({
111            "model": self.model,
112            "messages": full_history,
113            "max_tokens": max_tokens,
114            "system": completion_request.preamble.unwrap_or("".to_string()),
115            "stream": true,
116        });
117
118        if let Some(temperature) = completion_request.temperature {
119            merge_inplace(&mut request, json!({ "temperature": temperature }));
120        }
121
122        if !completion_request.tools.is_empty() {
123            merge_inplace(
124                &mut request,
125                json!({
126                    "tools": completion_request
127                        .tools
128                        .into_iter()
129                        .map(|tool| ToolDefinition {
130                            name: tool.name,
131                            description: Some(tool.description),
132                            input_schema: tool.parameters,
133                        })
134                        .collect::<Vec<_>>(),
135                    "tool_choice": ToolChoice::Auto,
136                }),
137            );
138        }
139
140        if let Some(ref params) = completion_request.additional_params {
141            merge_inplace(&mut request, params.clone())
142        }
143
144        let response = self
145            .client
146            .post("/v1/messages")
147            .json(&request)
148            .send()
149            .await?;
150
151        if !response.status().is_success() {
152            return Err(CompletionError::ProviderError(response.text().await?));
153        }
154
155        // Use our SSE decoder to directly handle Server-Sent Events format
156        let sse_stream = sse_from_response(response);
157
158        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
159            let mut current_tool_call: Option<ToolCallState> = None;
160            let mut sse_stream = Box::pin(sse_stream);
161            let mut input_tokens = 0;
162
163            while let Some(sse_result) = sse_stream.next().await {
164                match sse_result {
165                    Ok(sse) => {
166                        // Parse the SSE data as a StreamingEvent
167                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
168                            Ok(event) => {
169                                match &event {
170                                    StreamingEvent::MessageStart { message } => {
171                                        input_tokens = message.usage.input_tokens;
172                                    },
173                                    StreamingEvent::MessageDelta { delta, usage } => {
174                                        if delta.stop_reason.is_some() {
175
176                                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
177                                                usage: PartialUsage {
178                                                    output_tokens: usage.output_tokens,
179                                                    input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
180                                                }
181                                            }))
182                                        }
183                                    }
184                                    _ => {}
185                                }
186
187                                if let Some(result) = handle_event(&event, &mut current_tool_call) {
188                                    yield result;
189                                }
190                            },
191                            Err(e) => {
192                                if !sse.data.trim().is_empty() {
193                                    yield Err(CompletionError::ResponseError(
194                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
195                                    ));
196                                }
197                            }
198                        }
199                    },
200                    Err(e) => {
201                        yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
202                        break;
203                    }
204                }
205            }
206        });
207
208        Ok(streaming::StreamingCompletionResponse::stream(stream))
209    }
210}
211
212fn handle_event(
213    event: &StreamingEvent,
214    current_tool_call: &mut Option<ToolCallState>,
215) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
216    match event {
217        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
218            ContentDelta::TextDelta { text } => {
219                if current_tool_call.is_none() {
220                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
221                }
222                None
223            }
224            ContentDelta::InputJsonDelta { partial_json } => {
225                if let Some(tool_call) = current_tool_call {
226                    tool_call.input_json.push_str(partial_json);
227                }
228                None
229            }
230        },
231        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
232            Content::ToolUse { id, name, .. } => {
233                *current_tool_call = Some(ToolCallState {
234                    name: name.clone(),
235                    id: id.clone(),
236                    input_json: String::new(),
237                });
238                None
239            }
240            // Handle other content types - they don't need special handling
241            _ => None,
242        },
243        StreamingEvent::ContentBlockStop { .. } => {
244            if let Some(tool_call) = current_tool_call.take() {
245                let json_str = if tool_call.input_json.is_empty() {
246                    "{}"
247                } else {
248                    &tool_call.input_json
249                };
250                match serde_json::from_str(json_str) {
251                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
252                        name: tool_call.name,
253                        id: tool_call.id,
254                        arguments: json_value,
255                        call_id: None,
256                    })),
257                    Err(e) => Some(Err(CompletionError::from(e))),
258                }
259            } else {
260                None
261            }
262        }
263        // Ignore other event types or handle as needed
264        StreamingEvent::MessageStart { .. }
265        | StreamingEvent::MessageDelta { .. }
266        | StreamingEvent::MessageStop
267        | StreamingEvent::Ping
268        | StreamingEvent::Unknown => None,
269    }
270}