rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use super::decoders::sse::from_response as sse_from_response;
10use crate::OneOrMany;
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20    MessageStart {
21        message: MessageStart,
22    },
23    ContentBlockStart {
24        index: usize,
25        content_block: Content,
26    },
27    ContentBlockDelta {
28        index: usize,
29        delta: ContentDelta,
30    },
31    ContentBlockStop {
32        index: usize,
33    },
34    MessageDelta {
35        delta: MessageDelta,
36        usage: PartialUsage,
37    },
38    MessageStop,
39    Ping,
40    #[serde(other)]
41    Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46    pub id: String,
47    pub role: String,
48    pub content: Vec<Content>,
49    pub model: String,
50    pub stop_reason: Option<String>,
51    pub stop_sequence: Option<String>,
52    pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58    TextDelta { text: String },
59    InputJsonDelta { partial_json: String },
60    ThinkingDelta { thinking: String },
61    SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66    pub stop_reason: Option<String>,
67    pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct PartialUsage {
72    pub output_tokens: usize,
73    #[serde(default)]
74    pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78    fn token_usage(&self) -> Option<crate::completion::Usage> {
79        let mut usage = crate::completion::Usage::new();
80
81        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82        usage.output_tokens = self.output_tokens as u64;
83        usage.total_tokens = usage.input_tokens + usage.output_tokens;
84        Some(usage)
85    }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90    name: String,
91    id: String,
92    input_json: String,
93}
94
95#[derive(Clone, Deserialize, Serialize)]
96pub struct StreamingCompletionResponse {
97    pub usage: PartialUsage,
98}
99
100impl GetTokenUsage for StreamingCompletionResponse {
101    fn token_usage(&self) -> Option<crate::completion::Usage> {
102        let mut usage = crate::completion::Usage::new();
103        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
104        usage.output_tokens = self.usage.output_tokens as u64;
105        usage.total_tokens =
106            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
107
108        Some(usage)
109    }
110}
111
112impl<T> CompletionModel<T>
113where
114    T: HttpClientExt + Clone + Default,
115{
116    pub(crate) async fn stream(
117        &self,
118        completion_request: CompletionRequest,
119    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
120    {
121        let span = if tracing::Span::current().is_disabled() {
122            info_span!(
123                target: "rig::completions",
124                "chat_streaming",
125                gen_ai.operation.name = "chat_streaming",
126                gen_ai.provider.name = "anthropic",
127                gen_ai.request.model = self.model,
128                gen_ai.system_instructions = &completion_request.preamble,
129                gen_ai.response.id = tracing::field::Empty,
130                gen_ai.response.model = self.model,
131                gen_ai.usage.output_tokens = tracing::field::Empty,
132                gen_ai.usage.input_tokens = tracing::field::Empty,
133                gen_ai.input.messages = tracing::field::Empty,
134                gen_ai.output.messages = tracing::field::Empty,
135            )
136        } else {
137            tracing::Span::current()
138        };
139        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
140            tokens
141        } else if let Some(tokens) = self.default_max_tokens {
142            tokens
143        } else {
144            return Err(CompletionError::RequestError(
145                "`max_tokens` must be set for Anthropic".into(),
146            ));
147        };
148
149        let mut full_history = vec![];
150        if let Some(docs) = completion_request.normalized_documents() {
151            full_history.push(docs);
152        }
153        full_history.extend(completion_request.chat_history);
154        span.record_model_input(&full_history);
155
156        let full_history = full_history
157            .into_iter()
158            .map(Message::try_from)
159            .collect::<Result<Vec<Message>, _>>()?;
160
161        let mut body = json!({
162            "model": self.model,
163            "messages": full_history,
164            "max_tokens": max_tokens,
165            "system": completion_request.preamble.unwrap_or("".to_string()),
166            "stream": true,
167        });
168
169        if let Some(temperature) = completion_request.temperature {
170            merge_inplace(&mut body, json!({ "temperature": temperature }));
171        }
172
173        if !completion_request.tools.is_empty() {
174            merge_inplace(
175                &mut body,
176                json!({
177                    "tools": completion_request
178                        .tools
179                        .into_iter()
180                        .map(|tool| ToolDefinition {
181                            name: tool.name,
182                            description: Some(tool.description),
183                            input_schema: tool.parameters,
184                        })
185                        .collect::<Vec<_>>(),
186                    "tool_choice": ToolChoice::Auto,
187                }),
188            );
189        }
190
191        if let Some(ref params) = completion_request.additional_params {
192            merge_inplace(&mut body, params.clone())
193        }
194
195        let body: Vec<u8> = serde_json::to_vec(&body)?;
196
197        let req = self
198            .client
199            .post("/v1/messages")
200            .header("Content-Type", "application/json")
201            .body(body)
202            .map_err(http_client::Error::Protocol)?;
203
204        let response = self.client.send_streaming(req).await?;
205
206        if !response.status().is_success() {
207            let mut stream = response.into_body();
208            let mut text = String::with_capacity(1024);
209            loop {
210                let Some(chunk) = stream.next().await else {
211                    break;
212                };
213
214                let chunk: Vec<u8> = chunk?.into();
215
216                let str = String::from_utf8_lossy(&chunk);
217
218                text.push_str(&str)
219            }
220            return Err(CompletionError::ProviderError(text));
221        }
222
223        let stream = sse_from_response(response.into_body());
224
225        // Use our SSE decoder to directly handle Server-Sent Events format
226        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
227            let mut current_tool_call: Option<ToolCallState> = None;
228            let mut sse_stream = Box::pin(stream);
229            let mut input_tokens = 0;
230
231            let mut text_content = String::new();
232
233            while let Some(sse_result) = sse_stream.next().await {
234                match sse_result {
235                    Ok(sse) => {
236                        // Parse the SSE data as a StreamingEvent
237                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
238                            Ok(event) => {
239                                match &event {
240                                    StreamingEvent::MessageStart { message } => {
241                                        input_tokens = message.usage.input_tokens;
242
243                                        let span = tracing::Span::current();
244                                        span.record("gen_ai.response.id", &message.id);
245                                        span.record("gen_ai.response.model_name", &message.model);
246                                    },
247                                    StreamingEvent::MessageDelta { delta, usage } => {
248                                        if delta.stop_reason.is_some() {
249                                            let usage = PartialUsage {
250                                                 output_tokens: usage.output_tokens,
251                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
252                                            };
253
254                                            let span = tracing::Span::current();
255                                            span.record_token_usage(&usage);
256                                            span.record_model_output(&Message {
257                                                role: super::completion::Role::Assistant,
258                                                content: OneOrMany::one(Content::Text { text: text_content.clone() })}
259                                            );
260
261                                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
262                                                usage
263                                            }))
264                                        }
265                                    }
266                                    _ => {}
267                                }
268
269                                if let Some(result) = handle_event(&event, &mut current_tool_call) {
270                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
271                                        text_content += text;
272                                    }
273                                    yield result;
274                                }
275                            },
276                            Err(e) => {
277                                if !sse.data.trim().is_empty() {
278                                    yield Err(CompletionError::ResponseError(
279                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
280                                    ));
281                                }
282                            }
283                        }
284                    },
285                    Err(e) => {
286                        yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
287                        break;
288                    }
289                }
290            }
291        }.instrument(span));
292
293        Ok(streaming::StreamingCompletionResponse::stream(stream))
294    }
295}
296
297fn handle_event(
298    event: &StreamingEvent,
299    current_tool_call: &mut Option<ToolCallState>,
300) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
301    match event {
302        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
303            ContentDelta::TextDelta { text } => {
304                if current_tool_call.is_none() {
305                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
306                }
307                None
308            }
309            ContentDelta::InputJsonDelta { partial_json } => {
310                if let Some(tool_call) = current_tool_call {
311                    tool_call.input_json.push_str(partial_json);
312                }
313                None
314            }
315            ContentDelta::ThinkingDelta { thinking } => Some(Ok(RawStreamingChoice::Reasoning {
316                id: None,
317                reasoning: thinking.clone(),
318            })),
319            ContentDelta::SignatureDelta { .. } => {
320                // Signature is used for verification of thinking blocks, we can ignore it
321                None
322            }
323        },
324        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
325            Content::ToolUse { id, name, .. } => {
326                *current_tool_call = Some(ToolCallState {
327                    name: name.clone(),
328                    id: id.clone(),
329                    input_json: String::new(),
330                });
331                None
332            }
333            // Handle other content types - they don't need special handling
334            _ => None,
335        },
336        StreamingEvent::ContentBlockStop { .. } => {
337            if let Some(tool_call) = Option::take(current_tool_call) {
338                let json_str = if tool_call.input_json.is_empty() {
339                    "{}"
340                } else {
341                    &tool_call.input_json
342                };
343                match serde_json::from_str(json_str) {
344                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
345                        name: tool_call.name,
346                        id: tool_call.id,
347                        arguments: json_value,
348                        call_id: None,
349                    })),
350                    Err(e) => Some(Err(CompletionError::from(e))),
351                }
352            } else {
353                None
354            }
355        }
356        // Ignore other event types or handle as needed
357        StreamingEvent::MessageStart { .. }
358        | StreamingEvent::MessageDelta { .. }
359        | StreamingEvent::MessageStop
360        | StreamingEvent::Ping
361        | StreamingEvent::Unknown => None,
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    #[test]
370    fn test_thinking_delta_deserialization() {
371        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
372        let delta: ContentDelta = serde_json::from_str(json).unwrap();
373
374        match delta {
375            ContentDelta::ThinkingDelta { thinking } => {
376                assert_eq!(thinking, "Let me think about this...");
377            }
378            _ => panic!("Expected ThinkingDelta variant"),
379        }
380    }
381
382    #[test]
383    fn test_signature_delta_deserialization() {
384        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
385        let delta: ContentDelta = serde_json::from_str(json).unwrap();
386
387        match delta {
388            ContentDelta::SignatureDelta { signature } => {
389                assert_eq!(signature, "abc123def456");
390            }
391            _ => panic!("Expected SignatureDelta variant"),
392        }
393    }
394
395    #[test]
396    fn test_thinking_delta_streaming_event_deserialization() {
397        let json = r#"{
398            "type": "content_block_delta",
399            "index": 0,
400            "delta": {
401                "type": "thinking_delta",
402                "thinking": "First, I need to understand the problem."
403            }
404        }"#;
405
406        let event: StreamingEvent = serde_json::from_str(json).unwrap();
407
408        match event {
409            StreamingEvent::ContentBlockDelta { index, delta } => {
410                assert_eq!(index, 0);
411                match delta {
412                    ContentDelta::ThinkingDelta { thinking } => {
413                        assert_eq!(thinking, "First, I need to understand the problem.");
414                    }
415                    _ => panic!("Expected ThinkingDelta"),
416                }
417            }
418            _ => panic!("Expected ContentBlockDelta event"),
419        }
420    }
421
422    #[test]
423    fn test_signature_delta_streaming_event_deserialization() {
424        let json = r#"{
425            "type": "content_block_delta",
426            "index": 0,
427            "delta": {
428                "type": "signature_delta",
429                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
430            }
431        }"#;
432
433        let event: StreamingEvent = serde_json::from_str(json).unwrap();
434
435        match event {
436            StreamingEvent::ContentBlockDelta { index, delta } => {
437                assert_eq!(index, 0);
438                match delta {
439                    ContentDelta::SignatureDelta { signature } => {
440                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
441                    }
442                    _ => panic!("Expected SignatureDelta"),
443                }
444            }
445            _ => panic!("Expected ContentBlockDelta event"),
446        }
447    }
448
449    #[test]
450    fn test_handle_thinking_delta_event() {
451        let event = StreamingEvent::ContentBlockDelta {
452            index: 0,
453            delta: ContentDelta::ThinkingDelta {
454                thinking: "Analyzing the request...".to_string(),
455            },
456        };
457
458        let mut tool_call_state = None;
459        let result = handle_event(&event, &mut tool_call_state);
460
461        assert!(result.is_some());
462        let choice = result.unwrap().unwrap();
463
464        match choice {
465            RawStreamingChoice::Reasoning { id, reasoning } => {
466                assert_eq!(id, None);
467                assert_eq!(reasoning, "Analyzing the request...");
468            }
469            _ => panic!("Expected Reasoning choice"),
470        }
471    }
472
473    #[test]
474    fn test_handle_signature_delta_event() {
475        let event = StreamingEvent::ContentBlockDelta {
476            index: 0,
477            delta: ContentDelta::SignatureDelta {
478                signature: "test_signature".to_string(),
479            },
480        };
481
482        let mut tool_call_state = None;
483        let result = handle_event(&event, &mut tool_call_state);
484
485        // SignatureDelta should be ignored (returns None)
486        assert!(result.is_none());
487    }
488
489    #[test]
490    fn test_handle_text_delta_event() {
491        let event = StreamingEvent::ContentBlockDelta {
492            index: 0,
493            delta: ContentDelta::TextDelta {
494                text: "Hello, world!".to_string(),
495            },
496        };
497
498        let mut tool_call_state = None;
499        let result = handle_event(&event, &mut tool_call_state);
500
501        assert!(result.is_some());
502        let choice = result.unwrap().unwrap();
503
504        match choice {
505            RawStreamingChoice::Message(text) => {
506                assert_eq!(text, "Hello, world!");
507            }
508            _ => panic!("Expected Message choice"),
509        }
510    }
511
512    #[test]
513    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
514        // Thinking deltas should still be processed even if a tool call is in progress
515        let event = StreamingEvent::ContentBlockDelta {
516            index: 0,
517            delta: ContentDelta::ThinkingDelta {
518                thinking: "Thinking while tool is active...".to_string(),
519            },
520        };
521
522        let mut tool_call_state = Some(ToolCallState {
523            name: "test_tool".to_string(),
524            id: "tool_123".to_string(),
525            input_json: String::new(),
526        });
527
528        let result = handle_event(&event, &mut tool_call_state);
529
530        assert!(result.is_some());
531        let choice = result.unwrap().unwrap();
532
533        match choice {
534            RawStreamingChoice::Reasoning { reasoning, .. } => {
535                assert_eq!(reasoning, "Thinking while tool is active...");
536            }
537            _ => panic!("Expected Reasoning choice"),
538        }
539
540        // Tool call state should remain unchanged
541        assert!(tool_call_state.is_some());
542    }
543}