rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
7use super::decoders::sse::from_response as sse_from_response;
8use crate::completion::{CompletionError, CompletionRequest};
9use crate::json_utils::merge_inplace;
10use crate::streaming::{StreamingChoice, StreamingCompletionModel, StreamingResult};
11
12#[derive(Debug, Deserialize)]
13#[serde(tag = "type", rename_all = "snake_case")]
14pub enum StreamingEvent {
15    MessageStart {
16        message: MessageStart,
17    },
18    ContentBlockStart {
19        index: usize,
20        content_block: Content,
21    },
22    ContentBlockDelta {
23        index: usize,
24        delta: ContentDelta,
25    },
26    ContentBlockStop {
27        index: usize,
28    },
29    MessageDelta {
30        delta: MessageDelta,
31        usage: PartialUsage,
32    },
33    MessageStop,
34    Ping,
35    #[serde(other)]
36    Unknown,
37}
38
39#[derive(Debug, Deserialize)]
40pub struct MessageStart {
41    pub id: String,
42    pub role: String,
43    pub content: Vec<Content>,
44    pub model: String,
45    pub stop_reason: Option<String>,
46    pub stop_sequence: Option<String>,
47    pub usage: Usage,
48}
49
50#[derive(Debug, Deserialize)]
51#[serde(tag = "type", rename_all = "snake_case")]
52pub enum ContentDelta {
53    TextDelta { text: String },
54    InputJsonDelta { partial_json: String },
55}
56
57#[derive(Debug, Deserialize)]
58pub struct MessageDelta {
59    pub stop_reason: Option<String>,
60    pub stop_sequence: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64pub struct PartialUsage {
65    pub output_tokens: usize,
66    #[serde(default)]
67    pub input_tokens: Option<usize>,
68}
69
70#[derive(Default)]
71struct ToolCallState {
72    name: String,
73    id: String,
74    input_json: String,
75}
76
77impl StreamingCompletionModel for CompletionModel {
78    async fn stream(
79        &self,
80        completion_request: CompletionRequest,
81    ) -> Result<StreamingResult, CompletionError> {
82        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
83            tokens
84        } else if let Some(tokens) = self.default_max_tokens {
85            tokens
86        } else {
87            return Err(CompletionError::RequestError(
88                "`max_tokens` must be set for Anthropic".into(),
89            ));
90        };
91
92        let mut full_history = vec![];
93        if let Some(docs) = completion_request.normalized_documents() {
94            full_history.push(docs);
95        }
96        full_history.extend(completion_request.chat_history);
97
98        let full_history = full_history
99            .into_iter()
100            .map(Message::try_from)
101            .collect::<Result<Vec<Message>, _>>()?;
102
103        let mut request = json!({
104            "model": self.model,
105            "messages": full_history,
106            "max_tokens": max_tokens,
107            "system": completion_request.preamble.unwrap_or("".to_string()),
108            "stream": true,
109        });
110
111        if let Some(temperature) = completion_request.temperature {
112            merge_inplace(&mut request, json!({ "temperature": temperature }));
113        }
114
115        if !completion_request.tools.is_empty() {
116            merge_inplace(
117                &mut request,
118                json!({
119                    "tools": completion_request
120                        .tools
121                        .into_iter()
122                        .map(|tool| ToolDefinition {
123                            name: tool.name,
124                            description: Some(tool.description),
125                            input_schema: tool.parameters,
126                        })
127                        .collect::<Vec<_>>(),
128                    "tool_choice": ToolChoice::Auto,
129                }),
130            );
131        }
132
133        if let Some(ref params) = completion_request.additional_params {
134            merge_inplace(&mut request, params.clone())
135        }
136
137        let response = self
138            .client
139            .post("/v1/messages")
140            .json(&request)
141            .send()
142            .await?;
143
144        if !response.status().is_success() {
145            return Err(CompletionError::ProviderError(response.text().await?));
146        }
147
148        // Use our SSE decoder to directly handle Server-Sent Events format
149        let sse_stream = sse_from_response(response);
150
151        Ok(Box::pin(stream! {
152            let mut current_tool_call: Option<ToolCallState> = None;
153            let mut sse_stream = Box::pin(sse_stream);
154
155            while let Some(sse_result) = sse_stream.next().await {
156                match sse_result {
157                    Ok(sse) => {
158                        // Parse the SSE data as a StreamingEvent
159                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
160                            Ok(event) => {
161                                if let Some(result) = handle_event(&event, &mut current_tool_call) {
162                                    yield result;
163                                }
164                            },
165                            Err(e) => {
166                                if !sse.data.trim().is_empty() {
167                                    yield Err(CompletionError::ResponseError(
168                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
169                                    ));
170                                }
171                            }
172                        }
173                    },
174                    Err(e) => {
175                        yield Err(CompletionError::ResponseError(format!("SSE Error: {}", e)));
176                        break;
177                    }
178                }
179            }
180        }))
181    }
182}
183
184fn handle_event(
185    event: &StreamingEvent,
186    current_tool_call: &mut Option<ToolCallState>,
187) -> Option<Result<StreamingChoice, CompletionError>> {
188    match event {
189        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
190            ContentDelta::TextDelta { text } => {
191                if current_tool_call.is_none() {
192                    return Some(Ok(StreamingChoice::Message(text.clone())));
193                }
194                None
195            }
196            ContentDelta::InputJsonDelta { partial_json } => {
197                if let Some(ref mut tool_call) = current_tool_call {
198                    tool_call.input_json.push_str(partial_json);
199                }
200                None
201            }
202        },
203        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
204            Content::ToolUse { id, name, .. } => {
205                *current_tool_call = Some(ToolCallState {
206                    name: name.clone(),
207                    id: id.clone(),
208                    input_json: String::new(),
209                });
210                None
211            }
212            // Handle other content types - they don't need special handling
213            _ => None,
214        },
215        StreamingEvent::ContentBlockStop { .. } => {
216            if let Some(tool_call) = current_tool_call.take() {
217                let json_str = if tool_call.input_json.is_empty() {
218                    "{}"
219                } else {
220                    &tool_call.input_json
221                };
222                match serde_json::from_str(json_str) {
223                    Ok(json_value) => Some(Ok(StreamingChoice::ToolCall(
224                        tool_call.name,
225                        tool_call.id,
226                        json_value,
227                    ))),
228                    Err(e) => Some(Err(CompletionError::from(e))),
229                }
230            } else {
231                None
232            }
233        }
234        // Ignore other event types or handle as needed
235        StreamingEvent::MessageStart { .. }
236        | StreamingEvent::MessageDelta { .. }
237        | StreamingEvent::MessageStop
238        | StreamingEvent::Ping
239        | StreamingEvent::Unknown => None,
240    }
241}