rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10    apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{self, RawStreamingChoice, RawStreamingToolCall, StreamingResult};
17use crate::telemetry::SpanCombinator;
18
19#[derive(Debug, Deserialize)]
20#[serde(tag = "type", rename_all = "snake_case")]
21pub enum StreamingEvent {
22    MessageStart {
23        message: MessageStart,
24    },
25    ContentBlockStart {
26        index: usize,
27        content_block: Content,
28    },
29    ContentBlockDelta {
30        index: usize,
31        delta: ContentDelta,
32    },
33    ContentBlockStop {
34        index: usize,
35    },
36    MessageDelta {
37        delta: MessageDelta,
38        usage: PartialUsage,
39    },
40    MessageStop,
41    Ping,
42    #[serde(other)]
43    Unknown,
44}
45
46#[derive(Debug, Deserialize)]
47pub struct MessageStart {
48    pub id: String,
49    pub role: String,
50    pub content: Vec<Content>,
51    pub model: String,
52    pub stop_reason: Option<String>,
53    pub stop_sequence: Option<String>,
54    pub usage: Usage,
55}
56
57#[derive(Debug, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ContentDelta {
60    TextDelta { text: String },
61    InputJsonDelta { partial_json: String },
62    ThinkingDelta { thinking: String },
63    SignatureDelta { signature: String },
64}
65
66#[derive(Debug, Deserialize)]
67pub struct MessageDelta {
68    pub stop_reason: Option<String>,
69    pub stop_sequence: Option<String>,
70}
71
72#[derive(Debug, Deserialize, Clone, Serialize, Default)]
73pub struct PartialUsage {
74    pub output_tokens: usize,
75    #[serde(default)]
76    pub input_tokens: Option<usize>,
77}
78
79impl GetTokenUsage for PartialUsage {
80    fn token_usage(&self) -> Option<crate::completion::Usage> {
81        let mut usage = crate::completion::Usage::new();
82
83        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
84        usage.output_tokens = self.output_tokens as u64;
85        usage.total_tokens = usage.input_tokens + usage.output_tokens;
86        Some(usage)
87    }
88}
89
90#[derive(Default)]
91struct ToolCallState {
92    name: String,
93    id: String,
94    input_json: String,
95}
96
97#[derive(Default)]
98struct ThinkingState {
99    thinking: String,
100    signature: String,
101}
102
103#[derive(Clone, Debug, Deserialize, Serialize)]
104pub struct StreamingCompletionResponse {
105    pub usage: PartialUsage,
106}
107
108impl GetTokenUsage for StreamingCompletionResponse {
109    fn token_usage(&self) -> Option<crate::completion::Usage> {
110        let mut usage = crate::completion::Usage::new();
111        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
112        usage.output_tokens = self.usage.output_tokens as u64;
113        usage.total_tokens =
114            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
115
116        Some(usage)
117    }
118}
119
120impl<T> CompletionModel<T>
121where
122    T: HttpClientExt + Clone + Default + 'static,
123{
124    pub(crate) async fn stream(
125        &self,
126        completion_request: CompletionRequest,
127    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
128    {
129        let span = if tracing::Span::current().is_disabled() {
130            info_span!(
131                target: "rig::completions",
132                "chat_streaming",
133                gen_ai.operation.name = "chat_streaming",
134                gen_ai.provider.name = "anthropic",
135                gen_ai.request.model = self.model,
136                gen_ai.system_instructions = &completion_request.preamble,
137                gen_ai.response.id = tracing::field::Empty,
138                gen_ai.response.model = self.model,
139                gen_ai.usage.output_tokens = tracing::field::Empty,
140                gen_ai.usage.input_tokens = tracing::field::Empty,
141                gen_ai.input.messages = tracing::field::Empty,
142                gen_ai.output.messages = tracing::field::Empty,
143            )
144        } else {
145            tracing::Span::current()
146        };
147        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
148            tokens
149        } else if let Some(tokens) = self.default_max_tokens {
150            tokens
151        } else {
152            return Err(CompletionError::RequestError(
153                "`max_tokens` must be set for Anthropic".into(),
154            ));
155        };
156
157        let mut full_history = vec![];
158        if let Some(docs) = completion_request.normalized_documents() {
159            full_history.push(docs);
160        }
161        full_history.extend(completion_request.chat_history);
162
163        let mut messages = full_history
164            .into_iter()
165            .map(Message::try_from)
166            .collect::<Result<Vec<Message>, _>>()?;
167
168        // Convert system prompt to array format for cache_control support
169        let mut system: Vec<SystemContent> =
170            if let Some(preamble) = completion_request.preamble.as_ref() {
171                if preamble.is_empty() {
172                    vec![]
173                } else {
174                    vec![SystemContent::Text {
175                        text: preamble.clone(),
176                        cache_control: None,
177                    }]
178                }
179            } else {
180                vec![]
181            };
182
183        // Apply cache control breakpoints only if prompt_caching is enabled
184        if self.prompt_caching {
185            apply_cache_control(&mut system, &mut messages);
186        }
187
188        let mut body = json!({
189            "model": self.model,
190            "messages": messages,
191            "max_tokens": max_tokens,
192            "stream": true,
193        });
194
195        // Add system prompt if non-empty
196        if !system.is_empty() {
197            merge_inplace(&mut body, json!({ "system": system }));
198        }
199
200        if let Some(temperature) = completion_request.temperature {
201            merge_inplace(&mut body, json!({ "temperature": temperature }));
202        }
203
204        if !completion_request.tools.is_empty() {
205            merge_inplace(
206                &mut body,
207                json!({
208                    "tools": completion_request
209                        .tools
210                        .into_iter()
211                        .map(|tool| ToolDefinition {
212                            name: tool.name,
213                            description: Some(tool.description),
214                            input_schema: tool.parameters,
215                        })
216                        .collect::<Vec<_>>(),
217                    "tool_choice": ToolChoice::Auto,
218                }),
219            );
220        }
221
222        if let Some(ref params) = completion_request.additional_params {
223            merge_inplace(&mut body, params.clone())
224        }
225
226        if enabled!(Level::TRACE) {
227            tracing::trace!(
228                target: "rig::completions",
229                "Anthropic completion request: {}",
230                serde_json::to_string_pretty(&body)?
231            );
232        }
233
234        let body: Vec<u8> = serde_json::to_vec(&body)?;
235
236        let req = self
237            .client
238            .post("/v1/messages")?
239            .body(body)
240            .map_err(http_client::Error::Protocol)?;
241
242        let stream = GenericEventSource::new(self.client.clone(), req);
243
244        // Use our SSE decoder to directly handle Server-Sent Events format
245        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
246            let mut current_tool_call: Option<ToolCallState> = None;
247            let mut current_thinking: Option<ThinkingState> = None;
248            let mut sse_stream = Box::pin(stream);
249            let mut input_tokens = 0;
250            let mut final_usage = None;
251
252            let mut text_content = String::new();
253
254            while let Some(sse_result) = sse_stream.next().await {
255                match sse_result {
256                    Ok(Event::Open) => {}
257                    Ok(Event::Message(sse)) => {
258                        // Parse the SSE data as a StreamingEvent
259                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
260                            Ok(event) => {
261                                match &event {
262                                    StreamingEvent::MessageStart { message } => {
263                                        input_tokens = message.usage.input_tokens;
264
265                                        let span = tracing::Span::current();
266                                        span.record("gen_ai.response.id", &message.id);
267                                        span.record("gen_ai.response.model_name", &message.model);
268                                    },
269                                    StreamingEvent::MessageDelta { delta, usage } => {
270                                        if delta.stop_reason.is_some() {
271                                            let usage = PartialUsage {
272                                                 output_tokens: usage.output_tokens,
273                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
274                                            };
275
276                                            let span = tracing::Span::current();
277                                            span.record_token_usage(&usage);
278                                            final_usage = Some(usage);
279                                            break;
280                                        }
281                                    }
282                                    _ => {}
283                                }
284
285                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
286                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
287                                        text_content += text;
288                                    }
289                                    yield result;
290                                }
291                            },
292                            Err(e) => {
293                                if !sse.data.trim().is_empty() {
294                                    yield Err(CompletionError::ResponseError(
295                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
296                                    ));
297                                }
298                            }
299                        }
300                    },
301                    Err(e) => {
302                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
303                        break;
304                    }
305                }
306            }
307
308            // Ensure event source is closed when stream ends
309            sse_stream.close();
310
311            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
312                usage: final_usage.unwrap_or_default()
313            }))
314        }.instrument(span));
315
316        Ok(streaming::StreamingCompletionResponse::stream(stream))
317    }
318}
319
320fn handle_event(
321    event: &StreamingEvent,
322    current_tool_call: &mut Option<ToolCallState>,
323    current_thinking: &mut Option<ThinkingState>,
324) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
325    match event {
326        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
327            ContentDelta::TextDelta { text } => {
328                if current_tool_call.is_none() {
329                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
330                }
331                None
332            }
333            ContentDelta::InputJsonDelta { partial_json } => {
334                if let Some(tool_call) = current_tool_call {
335                    tool_call.input_json.push_str(partial_json);
336                    // Emit the delta so UI can show progress
337                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
338                        id: tool_call.id.clone(),
339                        delta: partial_json.clone(),
340                    }));
341                }
342                None
343            }
344            ContentDelta::ThinkingDelta { thinking } => {
345                if current_thinking.is_none() {
346                    *current_thinking = Some(ThinkingState::default());
347                }
348
349                if let Some(state) = current_thinking {
350                    state.thinking.push_str(thinking);
351                }
352
353                Some(Ok(RawStreamingChoice::ReasoningDelta {
354                    id: None,
355                    reasoning: thinking.clone(),
356                }))
357            }
358            ContentDelta::SignatureDelta { signature } => {
359                if current_thinking.is_none() {
360                    *current_thinking = Some(ThinkingState::default());
361                }
362
363                if let Some(state) = current_thinking {
364                    state.signature.push_str(signature);
365                }
366
367                // Don't yield signature chunks, they will be included in the final Reasoning
368                None
369            }
370        },
371        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
372            Content::ToolUse { id, name, .. } => {
373                *current_tool_call = Some(ToolCallState {
374                    name: name.clone(),
375                    id: id.clone(),
376                    input_json: String::new(),
377                });
378                None
379            }
380            Content::Thinking { .. } => {
381                *current_thinking = Some(ThinkingState::default());
382                None
383            }
384            // Handle other content types - they don't need special handling
385            _ => None,
386        },
387        StreamingEvent::ContentBlockStop { .. } => {
388            if let Some(thinking_state) = Option::take(current_thinking)
389                && !thinking_state.thinking.is_empty()
390            {
391                let signature = if thinking_state.signature.is_empty() {
392                    None
393                } else {
394                    Some(thinking_state.signature)
395                };
396
397                return Some(Ok(RawStreamingChoice::Reasoning {
398                    id: None,
399                    reasoning: thinking_state.thinking,
400                    signature,
401                }));
402            }
403
404            if let Some(tool_call) = Option::take(current_tool_call) {
405                let json_str = if tool_call.input_json.is_empty() {
406                    "{}"
407                } else {
408                    &tool_call.input_json
409                };
410                match serde_json::from_str(json_str) {
411                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
412                        RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value),
413                    ))),
414                    Err(e) => Some(Err(CompletionError::from(e))),
415                }
416            } else {
417                None
418            }
419        }
420        // Ignore other event types or handle as needed
421        StreamingEvent::MessageStart { .. }
422        | StreamingEvent::MessageDelta { .. }
423        | StreamingEvent::MessageStop
424        | StreamingEvent::Ping
425        | StreamingEvent::Unknown => None,
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432
433    #[test]
434    fn test_thinking_delta_deserialization() {
435        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
436        let delta: ContentDelta = serde_json::from_str(json).unwrap();
437
438        match delta {
439            ContentDelta::ThinkingDelta { thinking } => {
440                assert_eq!(thinking, "Let me think about this...");
441            }
442            _ => panic!("Expected ThinkingDelta variant"),
443        }
444    }
445
446    #[test]
447    fn test_signature_delta_deserialization() {
448        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
449        let delta: ContentDelta = serde_json::from_str(json).unwrap();
450
451        match delta {
452            ContentDelta::SignatureDelta { signature } => {
453                assert_eq!(signature, "abc123def456");
454            }
455            _ => panic!("Expected SignatureDelta variant"),
456        }
457    }
458
459    #[test]
460    fn test_thinking_delta_streaming_event_deserialization() {
461        let json = r#"{
462            "type": "content_block_delta",
463            "index": 0,
464            "delta": {
465                "type": "thinking_delta",
466                "thinking": "First, I need to understand the problem."
467            }
468        }"#;
469
470        let event: StreamingEvent = serde_json::from_str(json).unwrap();
471
472        match event {
473            StreamingEvent::ContentBlockDelta { index, delta } => {
474                assert_eq!(index, 0);
475                match delta {
476                    ContentDelta::ThinkingDelta { thinking } => {
477                        assert_eq!(thinking, "First, I need to understand the problem.");
478                    }
479                    _ => panic!("Expected ThinkingDelta"),
480                }
481            }
482            _ => panic!("Expected ContentBlockDelta event"),
483        }
484    }
485
486    #[test]
487    fn test_signature_delta_streaming_event_deserialization() {
488        let json = r#"{
489            "type": "content_block_delta",
490            "index": 0,
491            "delta": {
492                "type": "signature_delta",
493                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
494            }
495        }"#;
496
497        let event: StreamingEvent = serde_json::from_str(json).unwrap();
498
499        match event {
500            StreamingEvent::ContentBlockDelta { index, delta } => {
501                assert_eq!(index, 0);
502                match delta {
503                    ContentDelta::SignatureDelta { signature } => {
504                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
505                    }
506                    _ => panic!("Expected SignatureDelta"),
507                }
508            }
509            _ => panic!("Expected ContentBlockDelta event"),
510        }
511    }
512
513    #[test]
514    fn test_handle_thinking_delta_event() {
515        let event = StreamingEvent::ContentBlockDelta {
516            index: 0,
517            delta: ContentDelta::ThinkingDelta {
518                thinking: "Analyzing the request...".to_string(),
519            },
520        };
521
522        let mut tool_call_state = None;
523        let mut thinking_state = None;
524        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
525
526        assert!(result.is_some());
527        let choice = result.unwrap().unwrap();
528
529        match choice {
530            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
531                assert_eq!(id, None);
532                assert_eq!(reasoning, "Analyzing the request...");
533            }
534            _ => panic!("Expected ReasoningDelta choice"),
535        }
536
537        // Verify thinking state was updated
538        assert!(thinking_state.is_some());
539        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
540    }
541
542    #[test]
543    fn test_handle_signature_delta_event() {
544        let event = StreamingEvent::ContentBlockDelta {
545            index: 0,
546            delta: ContentDelta::SignatureDelta {
547                signature: "test_signature".to_string(),
548            },
549        };
550
551        let mut tool_call_state = None;
552        let mut thinking_state = None;
553        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
554
555        // SignatureDelta should not yield anything (returns None)
556        assert!(result.is_none());
557
558        // But signature should be captured in thinking state
559        assert!(thinking_state.is_some());
560        assert_eq!(thinking_state.unwrap().signature, "test_signature");
561    }
562
563    #[test]
564    fn test_handle_text_delta_event() {
565        let event = StreamingEvent::ContentBlockDelta {
566            index: 0,
567            delta: ContentDelta::TextDelta {
568                text: "Hello, world!".to_string(),
569            },
570        };
571
572        let mut tool_call_state = None;
573        let mut thinking_state = None;
574        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
575
576        assert!(result.is_some());
577        let choice = result.unwrap().unwrap();
578
579        match choice {
580            RawStreamingChoice::Message(text) => {
581                assert_eq!(text, "Hello, world!");
582            }
583            _ => panic!("Expected Message choice"),
584        }
585    }
586
587    #[test]
588    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
589        // Thinking deltas should still be processed even if a tool call is in progress
590        let event = StreamingEvent::ContentBlockDelta {
591            index: 0,
592            delta: ContentDelta::ThinkingDelta {
593                thinking: "Thinking while tool is active...".to_string(),
594            },
595        };
596
597        let mut tool_call_state = Some(ToolCallState {
598            name: "test_tool".to_string(),
599            id: "tool_123".to_string(),
600            input_json: String::new(),
601        });
602        let mut thinking_state = None;
603
604        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
605
606        assert!(result.is_some());
607        let choice = result.unwrap().unwrap();
608
609        match choice {
610            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
611                assert_eq!(reasoning, "Thinking while tool is active...");
612            }
613            _ => panic!("Expected ReasoningDelta choice"),
614        }
615
616        // Tool call state should remain unchanged
617        assert!(tool_call_state.is_some());
618    }
619
620    #[test]
621    fn test_handle_input_json_delta_event() {
622        let event = StreamingEvent::ContentBlockDelta {
623            index: 0,
624            delta: ContentDelta::InputJsonDelta {
625                partial_json: "{\"arg\":\"value".to_string(),
626            },
627        };
628
629        let mut tool_call_state = Some(ToolCallState {
630            name: "test_tool".to_string(),
631            id: "tool_123".to_string(),
632            input_json: String::new(),
633        });
634        let mut thinking_state = None;
635
636        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
637
638        // Should emit a ToolCallDelta
639        assert!(result.is_some());
640        let choice = result.unwrap().unwrap();
641
642        match choice {
643            RawStreamingChoice::ToolCallDelta { id, delta } => {
644                assert_eq!(id, "tool_123");
645                assert_eq!(delta, "{\"arg\":\"value");
646            }
647            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
648        }
649
650        // Verify the input_json was accumulated
651        assert!(tool_call_state.is_some());
652        let state = tool_call_state.unwrap();
653        assert_eq!(state.input_json, "{\"arg\":\"value");
654    }
655
656    #[test]
657    fn test_tool_call_accumulation_with_multiple_deltas() {
658        let mut tool_call_state = Some(ToolCallState {
659            name: "test_tool".to_string(),
660            id: "tool_123".to_string(),
661            input_json: String::new(),
662        });
663        let mut thinking_state = None;
664
665        // First delta
666        let event1 = StreamingEvent::ContentBlockDelta {
667            index: 0,
668            delta: ContentDelta::InputJsonDelta {
669                partial_json: "{\"location\":".to_string(),
670            },
671        };
672        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
673        assert!(result1.is_some());
674
675        // Second delta
676        let event2 = StreamingEvent::ContentBlockDelta {
677            index: 0,
678            delta: ContentDelta::InputJsonDelta {
679                partial_json: "\"Paris\",".to_string(),
680            },
681        };
682        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
683        assert!(result2.is_some());
684
685        // Third delta
686        let event3 = StreamingEvent::ContentBlockDelta {
687            index: 0,
688            delta: ContentDelta::InputJsonDelta {
689                partial_json: "\"temp\":\"20C\"}".to_string(),
690            },
691        };
692        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
693        assert!(result3.is_some());
694
695        // Verify accumulated JSON
696        assert!(tool_call_state.is_some());
697        let state = tool_call_state.as_ref().unwrap();
698        assert_eq!(
699            state.input_json,
700            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
701        );
702
703        // Final ContentBlockStop should emit complete tool call
704        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
705        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
706        assert!(final_result.is_some());
707
708        match final_result.unwrap().unwrap() {
709            RawStreamingChoice::ToolCall(RawStreamingToolCall {
710                id,
711                name,
712                arguments,
713                ..
714            }) => {
715                assert_eq!(id, "tool_123");
716                assert_eq!(name, "test_tool");
717                assert_eq!(
718                    arguments.get("location").unwrap().as_str().unwrap(),
719                    "Paris"
720                );
721                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
722            }
723            other => panic!("Expected ToolCall, got {:?}", other),
724        }
725
726        // Tool call state should be taken
727        assert!(tool_call_state.is_none());
728    }
729}