rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9    CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10    apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{
17    self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
18};
19use crate::telemetry::SpanCombinator;
20
21#[derive(Debug, Deserialize)]
22#[serde(tag = "type", rename_all = "snake_case")]
23pub enum StreamingEvent {
24    MessageStart {
25        message: MessageStart,
26    },
27    ContentBlockStart {
28        index: usize,
29        content_block: Content,
30    },
31    ContentBlockDelta {
32        index: usize,
33        delta: ContentDelta,
34    },
35    ContentBlockStop {
36        index: usize,
37    },
38    MessageDelta {
39        delta: MessageDelta,
40        usage: PartialUsage,
41    },
42    MessageStop,
43    Ping,
44    #[serde(other)]
45    Unknown,
46}
47
48#[derive(Debug, Deserialize)]
49pub struct MessageStart {
50    pub id: String,
51    pub role: String,
52    pub content: Vec<Content>,
53    pub model: String,
54    pub stop_reason: Option<String>,
55    pub stop_sequence: Option<String>,
56    pub usage: Usage,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum ContentDelta {
62    TextDelta { text: String },
63    InputJsonDelta { partial_json: String },
64    ThinkingDelta { thinking: String },
65    SignatureDelta { signature: String },
66}
67
68#[derive(Debug, Deserialize)]
69pub struct MessageDelta {
70    pub stop_reason: Option<String>,
71    pub stop_sequence: Option<String>,
72}
73
74#[derive(Debug, Deserialize, Clone, Serialize, Default)]
75pub struct PartialUsage {
76    pub output_tokens: usize,
77    #[serde(default)]
78    pub input_tokens: Option<usize>,
79}
80
81impl GetTokenUsage for PartialUsage {
82    fn token_usage(&self) -> Option<crate::completion::Usage> {
83        let mut usage = crate::completion::Usage::new();
84
85        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
86        usage.output_tokens = self.output_tokens as u64;
87        usage.total_tokens = usage.input_tokens + usage.output_tokens;
88        Some(usage)
89    }
90}
91
92#[derive(Default)]
93struct ToolCallState {
94    name: String,
95    id: String,
96    input_json: String,
97}
98
99#[derive(Default)]
100struct ThinkingState {
101    thinking: String,
102    signature: String,
103}
104
105#[derive(Clone, Debug, Deserialize, Serialize)]
106pub struct StreamingCompletionResponse {
107    pub usage: PartialUsage,
108}
109
110impl GetTokenUsage for StreamingCompletionResponse {
111    fn token_usage(&self) -> Option<crate::completion::Usage> {
112        let mut usage = crate::completion::Usage::new();
113        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
114        usage.output_tokens = self.usage.output_tokens as u64;
115        usage.total_tokens =
116            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
117
118        Some(usage)
119    }
120}
121
122impl<T> CompletionModel<T>
123where
124    T: HttpClientExt + Clone + Default + 'static,
125{
126    pub(crate) async fn stream(
127        &self,
128        completion_request: CompletionRequest,
129    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
130    {
131        let span = if tracing::Span::current().is_disabled() {
132            info_span!(
133                target: "rig::completions",
134                "chat_streaming",
135                gen_ai.operation.name = "chat_streaming",
136                gen_ai.provider.name = "anthropic",
137                gen_ai.request.model = self.model,
138                gen_ai.system_instructions = &completion_request.preamble,
139                gen_ai.response.id = tracing::field::Empty,
140                gen_ai.response.model = self.model,
141                gen_ai.usage.output_tokens = tracing::field::Empty,
142                gen_ai.usage.input_tokens = tracing::field::Empty,
143                gen_ai.input.messages = tracing::field::Empty,
144                gen_ai.output.messages = tracing::field::Empty,
145            )
146        } else {
147            tracing::Span::current()
148        };
149        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
150            tokens
151        } else if let Some(tokens) = self.default_max_tokens {
152            tokens
153        } else {
154            return Err(CompletionError::RequestError(
155                "`max_tokens` must be set for Anthropic".into(),
156            ));
157        };
158
159        let mut full_history = vec![];
160        if let Some(docs) = completion_request.normalized_documents() {
161            full_history.push(docs);
162        }
163        full_history.extend(completion_request.chat_history);
164
165        let mut messages = full_history
166            .into_iter()
167            .map(Message::try_from)
168            .collect::<Result<Vec<Message>, _>>()?;
169
170        // Convert system prompt to array format for cache_control support
171        let mut system: Vec<SystemContent> =
172            if let Some(preamble) = completion_request.preamble.as_ref() {
173                if preamble.is_empty() {
174                    vec![]
175                } else {
176                    vec![SystemContent::Text {
177                        text: preamble.clone(),
178                        cache_control: None,
179                    }]
180                }
181            } else {
182                vec![]
183            };
184
185        // Apply cache control breakpoints only if prompt_caching is enabled
186        if self.prompt_caching {
187            apply_cache_control(&mut system, &mut messages);
188        }
189
190        let mut body = json!({
191            "model": self.model,
192            "messages": messages,
193            "max_tokens": max_tokens,
194            "stream": true,
195        });
196
197        // Add system prompt if non-empty
198        if !system.is_empty() {
199            merge_inplace(&mut body, json!({ "system": system }));
200        }
201
202        if let Some(temperature) = completion_request.temperature {
203            merge_inplace(&mut body, json!({ "temperature": temperature }));
204        }
205
206        if !completion_request.tools.is_empty() {
207            merge_inplace(
208                &mut body,
209                json!({
210                    "tools": completion_request
211                        .tools
212                        .into_iter()
213                        .map(|tool| ToolDefinition {
214                            name: tool.name,
215                            description: Some(tool.description),
216                            input_schema: tool.parameters,
217                        })
218                        .collect::<Vec<_>>(),
219                    "tool_choice": ToolChoice::Auto,
220                }),
221            );
222        }
223
224        if let Some(ref params) = completion_request.additional_params {
225            merge_inplace(&mut body, params.clone())
226        }
227
228        if enabled!(Level::TRACE) {
229            tracing::trace!(
230                target: "rig::completions",
231                "Anthropic completion request: {}",
232                serde_json::to_string_pretty(&body)?
233            );
234        }
235
236        let body: Vec<u8> = serde_json::to_vec(&body)?;
237
238        let req = self
239            .client
240            .post("/v1/messages")?
241            .body(body)
242            .map_err(http_client::Error::Protocol)?;
243
244        let stream = GenericEventSource::new(self.client.clone(), req);
245
246        // Use our SSE decoder to directly handle Server-Sent Events format
247        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
248            let mut current_tool_call: Option<ToolCallState> = None;
249            let mut current_thinking: Option<ThinkingState> = None;
250            let mut sse_stream = Box::pin(stream);
251            let mut input_tokens = 0;
252            let mut final_usage = None;
253
254            let mut text_content = String::new();
255
256            while let Some(sse_result) = sse_stream.next().await {
257                match sse_result {
258                    Ok(Event::Open) => {}
259                    Ok(Event::Message(sse)) => {
260                        // Parse the SSE data as a StreamingEvent
261                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
262                            Ok(event) => {
263                                match &event {
264                                    StreamingEvent::MessageStart { message } => {
265                                        input_tokens = message.usage.input_tokens;
266
267                                        let span = tracing::Span::current();
268                                        span.record("gen_ai.response.id", &message.id);
269                                        span.record("gen_ai.response.model_name", &message.model);
270                                    },
271                                    StreamingEvent::MessageDelta { delta, usage } => {
272                                        if delta.stop_reason.is_some() {
273                                            let usage = PartialUsage {
274                                                 output_tokens: usage.output_tokens,
275                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
276                                            };
277
278                                            let span = tracing::Span::current();
279                                            span.record_token_usage(&usage);
280                                            final_usage = Some(usage);
281                                            break;
282                                        }
283                                    }
284                                    _ => {}
285                                }
286
287                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
288                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
289                                        text_content += text;
290                                    }
291                                    yield result;
292                                }
293                            },
294                            Err(e) => {
295                                if !sse.data.trim().is_empty() {
296                                    yield Err(CompletionError::ResponseError(
297                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
298                                    ));
299                                }
300                            }
301                        }
302                    },
303                    Err(e) => {
304                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
305                        break;
306                    }
307                }
308            }
309
310            // Ensure event source is closed when stream ends
311            sse_stream.close();
312
313            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
314                usage: final_usage.unwrap_or_default()
315            }))
316        }.instrument(span));
317
318        Ok(streaming::StreamingCompletionResponse::stream(stream))
319    }
320}
321
322fn handle_event(
323    event: &StreamingEvent,
324    current_tool_call: &mut Option<ToolCallState>,
325    current_thinking: &mut Option<ThinkingState>,
326) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
327    match event {
328        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
329            ContentDelta::TextDelta { text } => {
330                if current_tool_call.is_none() {
331                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
332                }
333                None
334            }
335            ContentDelta::InputJsonDelta { partial_json } => {
336                if let Some(tool_call) = current_tool_call {
337                    tool_call.input_json.push_str(partial_json);
338                    // Emit the delta so UI can show progress
339                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
340                        id: tool_call.id.clone(),
341                        content: ToolCallDeltaContent::Delta(partial_json.clone()),
342                    }));
343                }
344                None
345            }
346            ContentDelta::ThinkingDelta { thinking } => {
347                if current_thinking.is_none() {
348                    *current_thinking = Some(ThinkingState::default());
349                }
350
351                if let Some(state) = current_thinking {
352                    state.thinking.push_str(thinking);
353                }
354
355                Some(Ok(RawStreamingChoice::ReasoningDelta {
356                    id: None,
357                    reasoning: thinking.clone(),
358                }))
359            }
360            ContentDelta::SignatureDelta { signature } => {
361                if current_thinking.is_none() {
362                    *current_thinking = Some(ThinkingState::default());
363                }
364
365                if let Some(state) = current_thinking {
366                    state.signature.push_str(signature);
367                }
368
369                // Don't yield signature chunks, they will be included in the final Reasoning
370                None
371            }
372        },
373        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
374            Content::ToolUse { id, name, .. } => {
375                *current_tool_call = Some(ToolCallState {
376                    name: name.clone(),
377                    id: id.clone(),
378                    input_json: String::new(),
379                });
380                Some(Ok(RawStreamingChoice::ToolCallDelta {
381                    id: id.clone(),
382                    content: ToolCallDeltaContent::Name(name.clone()),
383                }))
384            }
385            Content::Thinking { .. } => {
386                *current_thinking = Some(ThinkingState::default());
387                None
388            }
389            // Handle other content types - they don't need special handling
390            _ => None,
391        },
392        StreamingEvent::ContentBlockStop { .. } => {
393            if let Some(thinking_state) = Option::take(current_thinking)
394                && !thinking_state.thinking.is_empty()
395            {
396                let signature = if thinking_state.signature.is_empty() {
397                    None
398                } else {
399                    Some(thinking_state.signature)
400                };
401
402                return Some(Ok(RawStreamingChoice::Reasoning {
403                    id: None,
404                    reasoning: thinking_state.thinking,
405                    signature,
406                }));
407            }
408
409            if let Some(tool_call) = Option::take(current_tool_call) {
410                let json_str = if tool_call.input_json.is_empty() {
411                    "{}"
412                } else {
413                    &tool_call.input_json
414                };
415                match serde_json::from_str(json_str) {
416                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
417                        RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value),
418                    ))),
419                    Err(e) => Some(Err(CompletionError::from(e))),
420                }
421            } else {
422                None
423            }
424        }
425        // Ignore other event types or handle as needed
426        StreamingEvent::MessageStart { .. }
427        | StreamingEvent::MessageDelta { .. }
428        | StreamingEvent::MessageStop
429        | StreamingEvent::Ping
430        | StreamingEvent::Unknown => None,
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_thinking_delta_deserialization() {
440        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
441        let delta: ContentDelta = serde_json::from_str(json).unwrap();
442
443        match delta {
444            ContentDelta::ThinkingDelta { thinking } => {
445                assert_eq!(thinking, "Let me think about this...");
446            }
447            _ => panic!("Expected ThinkingDelta variant"),
448        }
449    }
450
451    #[test]
452    fn test_signature_delta_deserialization() {
453        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
454        let delta: ContentDelta = serde_json::from_str(json).unwrap();
455
456        match delta {
457            ContentDelta::SignatureDelta { signature } => {
458                assert_eq!(signature, "abc123def456");
459            }
460            _ => panic!("Expected SignatureDelta variant"),
461        }
462    }
463
464    #[test]
465    fn test_thinking_delta_streaming_event_deserialization() {
466        let json = r#"{
467            "type": "content_block_delta",
468            "index": 0,
469            "delta": {
470                "type": "thinking_delta",
471                "thinking": "First, I need to understand the problem."
472            }
473        }"#;
474
475        let event: StreamingEvent = serde_json::from_str(json).unwrap();
476
477        match event {
478            StreamingEvent::ContentBlockDelta { index, delta } => {
479                assert_eq!(index, 0);
480                match delta {
481                    ContentDelta::ThinkingDelta { thinking } => {
482                        assert_eq!(thinking, "First, I need to understand the problem.");
483                    }
484                    _ => panic!("Expected ThinkingDelta"),
485                }
486            }
487            _ => panic!("Expected ContentBlockDelta event"),
488        }
489    }
490
491    #[test]
492    fn test_signature_delta_streaming_event_deserialization() {
493        let json = r#"{
494            "type": "content_block_delta",
495            "index": 0,
496            "delta": {
497                "type": "signature_delta",
498                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
499            }
500        }"#;
501
502        let event: StreamingEvent = serde_json::from_str(json).unwrap();
503
504        match event {
505            StreamingEvent::ContentBlockDelta { index, delta } => {
506                assert_eq!(index, 0);
507                match delta {
508                    ContentDelta::SignatureDelta { signature } => {
509                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
510                    }
511                    _ => panic!("Expected SignatureDelta"),
512                }
513            }
514            _ => panic!("Expected ContentBlockDelta event"),
515        }
516    }
517
518    #[test]
519    fn test_handle_thinking_delta_event() {
520        let event = StreamingEvent::ContentBlockDelta {
521            index: 0,
522            delta: ContentDelta::ThinkingDelta {
523                thinking: "Analyzing the request...".to_string(),
524            },
525        };
526
527        let mut tool_call_state = None;
528        let mut thinking_state = None;
529        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
530
531        assert!(result.is_some());
532        let choice = result.unwrap().unwrap();
533
534        match choice {
535            RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
536                assert_eq!(id, None);
537                assert_eq!(reasoning, "Analyzing the request...");
538            }
539            _ => panic!("Expected ReasoningDelta choice"),
540        }
541
542        // Verify thinking state was updated
543        assert!(thinking_state.is_some());
544        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
545    }
546
547    #[test]
548    fn test_handle_signature_delta_event() {
549        let event = StreamingEvent::ContentBlockDelta {
550            index: 0,
551            delta: ContentDelta::SignatureDelta {
552                signature: "test_signature".to_string(),
553            },
554        };
555
556        let mut tool_call_state = None;
557        let mut thinking_state = None;
558        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
559
560        // SignatureDelta should not yield anything (returns None)
561        assert!(result.is_none());
562
563        // But signature should be captured in thinking state
564        assert!(thinking_state.is_some());
565        assert_eq!(thinking_state.unwrap().signature, "test_signature");
566    }
567
568    #[test]
569    fn test_handle_text_delta_event() {
570        let event = StreamingEvent::ContentBlockDelta {
571            index: 0,
572            delta: ContentDelta::TextDelta {
573                text: "Hello, world!".to_string(),
574            },
575        };
576
577        let mut tool_call_state = None;
578        let mut thinking_state = None;
579        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
580
581        assert!(result.is_some());
582        let choice = result.unwrap().unwrap();
583
584        match choice {
585            RawStreamingChoice::Message(text) => {
586                assert_eq!(text, "Hello, world!");
587            }
588            _ => panic!("Expected Message choice"),
589        }
590    }
591
592    #[test]
593    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
594        // Thinking deltas should still be processed even if a tool call is in progress
595        let event = StreamingEvent::ContentBlockDelta {
596            index: 0,
597            delta: ContentDelta::ThinkingDelta {
598                thinking: "Thinking while tool is active...".to_string(),
599            },
600        };
601
602        let mut tool_call_state = Some(ToolCallState {
603            name: "test_tool".to_string(),
604            id: "tool_123".to_string(),
605            input_json: String::new(),
606        });
607        let mut thinking_state = None;
608
609        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
610
611        assert!(result.is_some());
612        let choice = result.unwrap().unwrap();
613
614        match choice {
615            RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
616                assert_eq!(reasoning, "Thinking while tool is active...");
617            }
618            _ => panic!("Expected ReasoningDelta choice"),
619        }
620
621        // Tool call state should remain unchanged
622        assert!(tool_call_state.is_some());
623    }
624
625    #[test]
626    fn test_handle_input_json_delta_event() {
627        let event = StreamingEvent::ContentBlockDelta {
628            index: 0,
629            delta: ContentDelta::InputJsonDelta {
630                partial_json: "{\"arg\":\"value".to_string(),
631            },
632        };
633
634        let mut tool_call_state = Some(ToolCallState {
635            name: "test_tool".to_string(),
636            id: "tool_123".to_string(),
637            input_json: String::new(),
638        });
639        let mut thinking_state = None;
640
641        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
642
643        // Should emit a ToolCallDelta
644        assert!(result.is_some());
645        let choice = result.unwrap().unwrap();
646
647        match choice {
648            RawStreamingChoice::ToolCallDelta { id, content } => {
649                assert_eq!(id, "tool_123");
650                match content {
651                    ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
652                    _ => panic!("Expected Delta content"),
653                }
654            }
655            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
656        }
657
658        // Verify the input_json was accumulated
659        assert!(tool_call_state.is_some());
660        let state = tool_call_state.unwrap();
661        assert_eq!(state.input_json, "{\"arg\":\"value");
662    }
663
664    #[test]
665    fn test_tool_call_accumulation_with_multiple_deltas() {
666        let mut tool_call_state = Some(ToolCallState {
667            name: "test_tool".to_string(),
668            id: "tool_123".to_string(),
669            input_json: String::new(),
670        });
671        let mut thinking_state = None;
672
673        // First delta
674        let event1 = StreamingEvent::ContentBlockDelta {
675            index: 0,
676            delta: ContentDelta::InputJsonDelta {
677                partial_json: "{\"location\":".to_string(),
678            },
679        };
680        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
681        assert!(result1.is_some());
682
683        // Second delta
684        let event2 = StreamingEvent::ContentBlockDelta {
685            index: 0,
686            delta: ContentDelta::InputJsonDelta {
687                partial_json: "\"Paris\",".to_string(),
688            },
689        };
690        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
691        assert!(result2.is_some());
692
693        // Third delta
694        let event3 = StreamingEvent::ContentBlockDelta {
695            index: 0,
696            delta: ContentDelta::InputJsonDelta {
697                partial_json: "\"temp\":\"20C\"}".to_string(),
698            },
699        };
700        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
701        assert!(result3.is_some());
702
703        // Verify accumulated JSON
704        assert!(tool_call_state.is_some());
705        let state = tool_call_state.as_ref().unwrap();
706        assert_eq!(
707            state.input_json,
708            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
709        );
710
711        // Final ContentBlockStop should emit complete tool call
712        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
713        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
714        assert!(final_result.is_some());
715
716        match final_result.unwrap().unwrap() {
717            RawStreamingChoice::ToolCall(RawStreamingToolCall {
718                id,
719                name,
720                arguments,
721                ..
722            }) => {
723                assert_eq!(id, "tool_123");
724                assert_eq!(name, "test_tool");
725                assert_eq!(
726                    arguments.get("location").unwrap().as_str().unwrap(),
727                    "Paris"
728                );
729                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
730            }
731            other => panic!("Expected ToolCall, got {:?}", other),
732        }
733
734        // Tool call state should be taken
735        assert!(tool_call_state.is_none());
736    }
737}