Skip to main content

rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10    apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{
17    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
18};
19use crate::telemetry::SpanCombinator;
20
21#[derive(Debug, Deserialize)]
22#[serde(tag = "type", rename_all = "snake_case")]
23pub enum StreamingEvent {
24    MessageStart {
25        message: MessageStart,
26    },
27    ContentBlockStart {
28        index: usize,
29        content_block: Content,
30    },
31    ContentBlockDelta {
32        index: usize,
33        delta: ContentDelta,
34    },
35    ContentBlockStop {
36        index: usize,
37    },
38    MessageDelta {
39        delta: MessageDelta,
40        usage: PartialUsage,
41    },
42    MessageStop,
43    Ping,
44    #[serde(other)]
45    Unknown,
46}
47
48#[derive(Debug, Deserialize)]
49pub struct MessageStart {
50    pub id: String,
51    pub role: String,
52    pub content: Vec<Content>,
53    pub model: String,
54    pub stop_reason: Option<String>,
55    pub stop_sequence: Option<String>,
56    pub usage: Usage,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum ContentDelta {
62    TextDelta { text: String },
63    InputJsonDelta { partial_json: String },
64    ThinkingDelta { thinking: String },
65    SignatureDelta { signature: String },
66}
67
68#[derive(Debug, Deserialize)]
69pub struct MessageDelta {
70    pub stop_reason: Option<String>,
71    pub stop_sequence: Option<String>,
72}
73
74#[derive(Debug, Deserialize, Clone, Serialize, Default)]
75pub struct PartialUsage {
76    pub output_tokens: usize,
77    #[serde(default)]
78    pub input_tokens: Option<usize>,
79}
80
81impl GetTokenUsage for PartialUsage {
82    fn token_usage(&self) -> Option<crate::completion::Usage> {
83        let mut usage = crate::completion::Usage::new();
84
85        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
86        usage.output_tokens = self.output_tokens as u64;
87        usage.total_tokens = usage.input_tokens + usage.output_tokens;
88        Some(usage)
89    }
90}
91
92#[derive(Default)]
93struct ToolCallState {
94    name: String,
95    id: String,
96    internal_call_id: String,
97    input_json: String,
98}
99
100#[derive(Default)]
101struct ThinkingState {
102    thinking: String,
103    signature: String,
104}
105
106#[derive(Clone, Debug, Deserialize, Serialize)]
107pub struct StreamingCompletionResponse {
108    pub usage: PartialUsage,
109}
110
111impl GetTokenUsage for StreamingCompletionResponse {
112    fn token_usage(&self) -> Option<crate::completion::Usage> {
113        let mut usage = crate::completion::Usage::new();
114        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
115        usage.output_tokens = self.usage.output_tokens as u64;
116        usage.total_tokens =
117            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
118
119        Some(usage)
120    }
121}
122
123impl<T> CompletionModel<T>
124where
125    T: HttpClientExt + Clone + Default + 'static,
126{
127    pub(crate) async fn stream(
128        &self,
129        completion_request: CompletionRequest,
130    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
131    {
132        let span = if tracing::Span::current().is_disabled() {
133            info_span!(
134                target: "rig::completions",
135                "chat_streaming",
136                gen_ai.operation.name = "chat_streaming",
137                gen_ai.provider.name = "anthropic",
138                gen_ai.request.model = self.model,
139                gen_ai.system_instructions = &completion_request.preamble,
140                gen_ai.response.id = tracing::field::Empty,
141                gen_ai.response.model = self.model,
142                gen_ai.usage.output_tokens = tracing::field::Empty,
143                gen_ai.usage.input_tokens = tracing::field::Empty,
144                gen_ai.input.messages = tracing::field::Empty,
145                gen_ai.output.messages = tracing::field::Empty,
146            )
147        } else {
148            tracing::Span::current()
149        };
150        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
151            tokens
152        } else if let Some(tokens) = self.default_max_tokens {
153            tokens
154        } else {
155            return Err(CompletionError::RequestError(
156                "`max_tokens` must be set for Anthropic".into(),
157            ));
158        };
159
160        let mut full_history = vec![];
161        if let Some(docs) = completion_request.normalized_documents() {
162            full_history.push(docs);
163        }
164        full_history.extend(completion_request.chat_history);
165
166        let mut messages = full_history
167            .into_iter()
168            .map(Message::try_from)
169            .collect::<Result<Vec<Message>, _>>()?;
170
171        // Convert system prompt to array format for cache_control support
172        let mut system: Vec<SystemContent> =
173            if let Some(preamble) = completion_request.preamble.as_ref() {
174                if preamble.is_empty() {
175                    vec![]
176                } else {
177                    vec![SystemContent::Text {
178                        text: preamble.clone(),
179                        cache_control: None,
180                    }]
181                }
182            } else {
183                vec![]
184            };
185
186        // Apply cache control breakpoints only if prompt_caching is enabled
187        if self.prompt_caching {
188            apply_cache_control(&mut system, &mut messages);
189        }
190
191        let mut body = json!({
192            "model": self.model,
193            "messages": messages,
194            "max_tokens": max_tokens,
195            "stream": true,
196        });
197
198        // Add system prompt if non-empty
199        if !system.is_empty() {
200            merge_inplace(&mut body, json!({ "system": system }));
201        }
202
203        if let Some(temperature) = completion_request.temperature {
204            merge_inplace(&mut body, json!({ "temperature": temperature }));
205        }
206
207        if !completion_request.tools.is_empty() {
208            merge_inplace(
209                &mut body,
210                json!({
211                    "tools": completion_request
212                        .tools
213                        .into_iter()
214                        .map(|tool| ToolDefinition {
215                            name: tool.name,
216                            description: Some(tool.description),
217                            input_schema: tool.parameters,
218                        })
219                        .collect::<Vec<_>>(),
220                    "tool_choice": ToolChoice::Auto,
221                }),
222            );
223        }
224
225        if let Some(ref params) = completion_request.additional_params {
226            merge_inplace(&mut body, params.clone())
227        }
228
229        if enabled!(Level::TRACE) {
230            tracing::trace!(
231                target: "rig::completions",
232                "Anthropic completion request: {}",
233                serde_json::to_string_pretty(&body)?
234            );
235        }
236
237        let body: Vec<u8> = serde_json::to_vec(&body)?;
238
239        let req = self
240            .client
241            .post("/v1/messages")?
242            .body(body)
243            .map_err(http_client::Error::Protocol)?;
244
245        let stream = GenericEventSource::new(self.client.clone(), req);
246
247        // Use our SSE decoder to directly handle Server-Sent Events format
248        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
249            let mut current_tool_call: Option<ToolCallState> = None;
250            let mut current_thinking: Option<ThinkingState> = None;
251            let mut sse_stream = Box::pin(stream);
252            let mut input_tokens = 0;
253            let mut final_usage = None;
254
255            let mut text_content = String::new();
256
257            while let Some(sse_result) = sse_stream.next().await {
258                match sse_result {
259                    Ok(Event::Open) => {}
260                    Ok(Event::Message(sse)) => {
261                        // Parse the SSE data as a StreamingEvent
262                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
263                            Ok(event) => {
264                                match &event {
265                                    StreamingEvent::MessageStart { message } => {
266                                        input_tokens = message.usage.input_tokens;
267
268                                        let span = tracing::Span::current();
269                                        span.record("gen_ai.response.id", &message.id);
270                                        span.record("gen_ai.response.model_name", &message.model);
271                                    },
272                                    StreamingEvent::MessageDelta { delta, usage } => {
273                                        if delta.stop_reason.is_some() {
274                                            let usage = PartialUsage {
275                                                 output_tokens: usage.output_tokens,
276                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
277                                            };
278
279                                            let span = tracing::Span::current();
280                                            span.record_token_usage(&usage);
281                                            final_usage = Some(usage);
282                                            break;
283                                        }
284                                    }
285                                    _ => {}
286                                }
287
288                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
289                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
290                                        text_content += text;
291                                    }
292                                    yield result;
293                                }
294                            },
295                            Err(e) => {
296                                if !sse.data.trim().is_empty() {
297                                    yield Err(CompletionError::ResponseError(
298                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
299                                    ));
300                                }
301                            }
302                        }
303                    },
304                    Err(e) => {
305                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
306                        break;
307                    }
308                }
309            }
310
311            // Ensure event source is closed when stream ends
312            sse_stream.close();
313
314            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
315                usage: final_usage.unwrap_or_default()
316            }))
317        }.instrument(span));
318
319        Ok(streaming::StreamingCompletionResponse::stream(stream))
320    }
321}
322
323fn handle_event(
324    event: &StreamingEvent,
325    current_tool_call: &mut Option<ToolCallState>,
326    current_thinking: &mut Option<ThinkingState>,
327) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
328    match event {
329        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
330            ContentDelta::TextDelta { text } => {
331                if current_tool_call.is_none() {
332                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
333                }
334                None
335            }
336            ContentDelta::InputJsonDelta { partial_json } => {
337                if let Some(tool_call) = current_tool_call {
338                    tool_call.input_json.push_str(partial_json);
339                    // Emit the delta so UI can show progress
340                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
341                        id: tool_call.id.clone(),
342                        internal_call_id: tool_call.internal_call_id.clone(),
343                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
344                    }));
345                }
346                None
347            }
348            ContentDelta::ThinkingDelta { thinking } => {
349                if current_thinking.is_none() {
350                    *current_thinking = Some(ThinkingState::default());
351                }
352
353                if let Some(state) = current_thinking {
354                    state.thinking.push_str(thinking);
355                }
356
357                Some(Ok(RawStreamingChoice::ReasoningDelta {
358                    id: None,
359                    reasoning: thinking.clone(),
360                }))
361            }
362            ContentDelta::SignatureDelta { signature } => {
363                if current_thinking.is_none() {
364                    *current_thinking = Some(ThinkingState::default());
365                }
366
367                if let Some(state) = current_thinking {
368                    state.signature.push_str(signature);
369                }
370
371                // Don't yield signature chunks, they will be included in the final Reasoning
372                None
373            }
374        },
375        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
376            Content::ToolUse { id, name, .. } => {
377                let internal_call_id = nanoid::nanoid!();
378                *current_tool_call = Some(ToolCallState {
379                    name: name.clone(),
380                    id: id.clone(),
381                    internal_call_id: internal_call_id.clone(),
382                    input_json: String::new(),
383                });
384                Some(Ok(RawStreamingChoice::ToolCallDelta {
385                    id: id.clone(),
386                    internal_call_id,
387                    content: ToolCallDeltaContent::Name(name.clone()),
388                }))
389            }
390            Content::Thinking { .. } => {
391                *current_thinking = Some(ThinkingState::default());
392                None
393            }
394            // Handle other content types - they don't need special handling
395            _ => None,
396        },
397        StreamingEvent::ContentBlockStop { .. } => {
398            if let Some(thinking_state) = Option::take(current_thinking)
399                && !thinking_state.thinking.is_empty()
400            {
401                let signature = if thinking_state.signature.is_empty() {
402                    None
403                } else {
404                    Some(thinking_state.signature)
405                };
406
407                return Some(Ok(RawStreamingChoice::Reasoning {
408                    id: None,
409                    reasoning: thinking_state.thinking,
410                    signature,
411                }));
412            }
413
414            if let Some(tool_call) = Option::take(current_tool_call) {
415                let json_str = if tool_call.input_json.is_empty() {
416                    "{}"
417                } else {
418                    &tool_call.input_json
419                };
420                match serde_json::from_str(json_str) {
421                    Ok(json_value) => {
422                        let raw_tool_call =
423                            RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
424                                .with_internal_call_id(tool_call.internal_call_id);
425                        Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
426                    }
427                    Err(e) => Some(Err(CompletionError::from(e))),
428                }
429            } else {
430                None
431            }
432        }
433        // Ignore other event types or handle as needed
434        StreamingEvent::MessageStart { .. }
435        | StreamingEvent::MessageDelta { .. }
436        | StreamingEvent::MessageStop
437        | StreamingEvent::Ping
438        | StreamingEvent::Unknown => None,
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_thinking_delta_deserialization() {
448        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
449        let delta: ContentDelta = serde_json::from_str(json).unwrap();
450
451        match delta {
452            ContentDelta::ThinkingDelta { thinking } => {
453                assert_eq!(thinking, "Let me think about this...");
454            }
455            _ => panic!("Expected ThinkingDelta variant"),
456        }
457    }
458
459    #[test]
460    fn test_signature_delta_deserialization() {
461        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
462        let delta: ContentDelta = serde_json::from_str(json).unwrap();
463
464        match delta {
465            ContentDelta::SignatureDelta { signature } => {
466                assert_eq!(signature, "abc123def456");
467            }
468            _ => panic!("Expected SignatureDelta variant"),
469        }
470    }
471
472    #[test]
473    fn test_thinking_delta_streaming_event_deserialization() {
474        let json = r#"{
475            "type": "content_block_delta",
476            "index": 0,
477            "delta": {
478                "type": "thinking_delta",
479                "thinking": "First, I need to understand the problem."
480            }
481        }"#;
482
483        let event: StreamingEvent = serde_json::from_str(json).unwrap();
484
485        match event {
486            StreamingEvent::ContentBlockDelta { index, delta } => {
487                assert_eq!(index, 0);
488                match delta {
489                    ContentDelta::ThinkingDelta { thinking } => {
490                        assert_eq!(thinking, "First, I need to understand the problem.");
491                    }
492                    _ => panic!("Expected ThinkingDelta"),
493                }
494            }
495            _ => panic!("Expected ContentBlockDelta event"),
496        }
497    }
498
499    #[test]
500    fn test_signature_delta_streaming_event_deserialization() {
501        let json = r#"{
502            "type": "content_block_delta",
503            "index": 0,
504            "delta": {
505                "type": "signature_delta",
506                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
507            }
508        }"#;
509
510        let event: StreamingEvent = serde_json::from_str(json).unwrap();
511
512        match event {
513            StreamingEvent::ContentBlockDelta { index, delta } => {
514                assert_eq!(index, 0);
515                match delta {
516                    ContentDelta::SignatureDelta { signature } => {
517                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
518                    }
519                    _ => panic!("Expected SignatureDelta"),
520                }
521            }
522            _ => panic!("Expected ContentBlockDelta event"),
523        }
524    }
525
526    #[test]
527    fn test_handle_thinking_delta_event() {
528        let event = StreamingEvent::ContentBlockDelta {
529            index: 0,
530            delta: ContentDelta::ThinkingDelta {
531                thinking: "Analyzing the request...".to_string(),
532            },
533        };
534
535        let mut tool_call_state = None;
536        let mut thinking_state = None;
537        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
538
539        assert!(result.is_some());
540        let choice = result.unwrap().unwrap();
541
542        match choice {
543            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
544                assert_eq!(id, None);
545                assert_eq!(reasoning, "Analyzing the request...");
546            }
547            _ => panic!("Expected ReasoningDelta choice"),
548        }
549
550        // Verify thinking state was updated
551        assert!(thinking_state.is_some());
552        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
553    }
554
555    #[test]
556    fn test_handle_signature_delta_event() {
557        let event = StreamingEvent::ContentBlockDelta {
558            index: 0,
559            delta: ContentDelta::SignatureDelta {
560                signature: "test_signature".to_string(),
561            },
562        };
563
564        let mut tool_call_state = None;
565        let mut thinking_state = None;
566        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
567
568        // SignatureDelta should not yield anything (returns None)
569        assert!(result.is_none());
570
571        // But signature should be captured in thinking state
572        assert!(thinking_state.is_some());
573        assert_eq!(thinking_state.unwrap().signature, "test_signature");
574    }
575
576    #[test]
577    fn test_handle_text_delta_event() {
578        let event = StreamingEvent::ContentBlockDelta {
579            index: 0,
580            delta: ContentDelta::TextDelta {
581                text: "Hello, world!".to_string(),
582            },
583        };
584
585        let mut tool_call_state = None;
586        let mut thinking_state = None;
587        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
588
589        assert!(result.is_some());
590        let choice = result.unwrap().unwrap();
591
592        match choice {
593            RawStreamingChoice::Message(text) => {
594                assert_eq!(text, "Hello, world!");
595            }
596            _ => panic!("Expected Message choice"),
597        }
598    }
599
600    #[test]
601    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
602        // Thinking deltas should still be processed even if a tool call is in progress
603        let event = StreamingEvent::ContentBlockDelta {
604            index: 0,
605            delta: ContentDelta::ThinkingDelta {
606                thinking: "Thinking while tool is active...".to_string(),
607            },
608        };
609
610        let mut tool_call_state = Some(ToolCallState {
611            name: "test_tool".to_string(),
612            id: "tool_123".to_string(),
613            internal_call_id: nanoid::nanoid!(),
614            input_json: String::new(),
615        });
616        let mut thinking_state = None;
617
618        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
619
620        assert!(result.is_some());
621        let choice = result.unwrap().unwrap();
622
623        match choice {
624            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
625                assert_eq!(reasoning, "Thinking while tool is active...");
626            }
627            _ => panic!("Expected ReasoningDelta choice"),
628        }
629
630        // Tool call state should remain unchanged
631        assert!(tool_call_state.is_some());
632    }
633
634    #[test]
635    fn test_handle_input_json_delta_event() {
636        let event = StreamingEvent::ContentBlockDelta {
637            index: 0,
638            delta: ContentDelta::InputJsonDelta {
639                partial_json: "{\"arg\":\"value".to_string(),
640            },
641        };
642
643        let mut tool_call_state = Some(ToolCallState {
644            name: "test_tool".to_string(),
645            id: "tool_123".to_string(),
646            internal_call_id: nanoid::nanoid!(),
647            input_json: String::new(),
648        });
649        let mut thinking_state = None;
650
651        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
652
653        // Should emit a ToolCallDelta
654        assert!(result.is_some());
655        let choice = result.unwrap().unwrap();
656
657        match choice {
658            RawStreamingChoice::ToolCallDelta {
659                id,
660                internal_call_id: _,
661                content,
662            } => {
663                assert_eq!(id, "tool_123");
664                match content {
665                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
666                    _ => panic!("Expected Delta content"),
667                }
668            }
669            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
670        }
671
672        // Verify the input_json was accumulated
673        assert!(tool_call_state.is_some());
674        let state = tool_call_state.unwrap();
675        assert_eq!(state.input_json, "{\"arg\":\"value");
676    }
677
678    #[test]
679    fn test_tool_call_accumulation_with_multiple_deltas() {
680        let mut tool_call_state = Some(ToolCallState {
681            name: "test_tool".to_string(),
682            id: "tool_123".to_string(),
683            internal_call_id: nanoid::nanoid!(),
684            input_json: String::new(),
685        });
686        let mut thinking_state = None;
687
688        // First delta
689        let event1 = StreamingEvent::ContentBlockDelta {
690            index: 0,
691            delta: ContentDelta::InputJsonDelta {
692                partial_json: "{\"location\":".to_string(),
693            },
694        };
695        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
696        assert!(result1.is_some());
697
698        // Second delta
699        let event2 = StreamingEvent::ContentBlockDelta {
700            index: 0,
701            delta: ContentDelta::InputJsonDelta {
702                partial_json: "\"Paris\",".to_string(),
703            },
704        };
705        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
706        assert!(result2.is_some());
707
708        // Third delta
709        let event3 = StreamingEvent::ContentBlockDelta {
710            index: 0,
711            delta: ContentDelta::InputJsonDelta {
712                partial_json: "\"temp\":\"20C\"}".to_string(),
713            },
714        };
715        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
716        assert!(result3.is_some());
717
718        // Verify accumulated JSON
719        assert!(tool_call_state.is_some());
720        let state = tool_call_state.as_ref().unwrap();
721        assert_eq!(
722            state.input_json,
723            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
724        );
725
726        // Final ContentBlockStop should emit complete tool call
727        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
728        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
729        assert!(final_result.is_some());
730
731        match final_result.unwrap().unwrap() {
732            RawStreamingChoice::ToolCall(RawStreamingToolCall {
733                id,
734                name,
735                arguments,
736                ..
737            }) => {
738                assert_eq!(id, "tool_123");
739                assert_eq!(name, "test_tool");
740                assert_eq!(
741                    arguments.get("location").unwrap().as_str().unwrap(),
742                    "Paris"
743                );
744                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
745            }
746            other => panic!("Expected ToolCall, got {:?}", other),
747        }
748
749        // Tool call state should be taken
750        assert!(tool_call_state.is_none());
751    }
752}