rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use crate::OneOrMany;
10use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20    MessageStart {
21        message: MessageStart,
22    },
23    ContentBlockStart {
24        index: usize,
25        content_block: Content,
26    },
27    ContentBlockDelta {
28        index: usize,
29        delta: ContentDelta,
30    },
31    ContentBlockStop {
32        index: usize,
33    },
34    MessageDelta {
35        delta: MessageDelta,
36        usage: PartialUsage,
37    },
38    MessageStop,
39    Ping,
40    #[serde(other)]
41    Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46    pub id: String,
47    pub role: String,
48    pub content: Vec<Content>,
49    pub model: String,
50    pub stop_reason: Option<String>,
51    pub stop_sequence: Option<String>,
52    pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58    TextDelta { text: String },
59    InputJsonDelta { partial_json: String },
60    ThinkingDelta { thinking: String },
61    SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66    pub stop_reason: Option<String>,
67    pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize, Default)]
71pub struct PartialUsage {
72    pub output_tokens: usize,
73    #[serde(default)]
74    pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78    fn token_usage(&self) -> Option<crate::completion::Usage> {
79        let mut usage = crate::completion::Usage::new();
80
81        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82        usage.output_tokens = self.output_tokens as u64;
83        usage.total_tokens = usage.input_tokens + usage.output_tokens;
84        Some(usage)
85    }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90    name: String,
91    id: String,
92    input_json: String,
93}
94
95#[derive(Default)]
96struct ThinkingState {
97    thinking: String,
98    signature: String,
99}
100
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct StreamingCompletionResponse {
103    pub usage: PartialUsage,
104}
105
106impl GetTokenUsage for StreamingCompletionResponse {
107    fn token_usage(&self) -> Option<crate::completion::Usage> {
108        let mut usage = crate::completion::Usage::new();
109        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
110        usage.output_tokens = self.usage.output_tokens as u64;
111        usage.total_tokens =
112            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
113
114        Some(usage)
115    }
116}
117
118impl<T> CompletionModel<T>
119where
120    T: HttpClientExt + Clone + Default + 'static,
121{
122    pub(crate) async fn stream(
123        &self,
124        completion_request: CompletionRequest,
125    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
126    {
127        let span = if tracing::Span::current().is_disabled() {
128            info_span!(
129                target: "rig::completions",
130                "chat_streaming",
131                gen_ai.operation.name = "chat_streaming",
132                gen_ai.provider.name = "anthropic",
133                gen_ai.request.model = self.model,
134                gen_ai.system_instructions = &completion_request.preamble,
135                gen_ai.response.id = tracing::field::Empty,
136                gen_ai.response.model = self.model,
137                gen_ai.usage.output_tokens = tracing::field::Empty,
138                gen_ai.usage.input_tokens = tracing::field::Empty,
139                gen_ai.input.messages = tracing::field::Empty,
140                gen_ai.output.messages = tracing::field::Empty,
141            )
142        } else {
143            tracing::Span::current()
144        };
145        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
146            tokens
147        } else if let Some(tokens) = self.default_max_tokens {
148            tokens
149        } else {
150            return Err(CompletionError::RequestError(
151                "`max_tokens` must be set for Anthropic".into(),
152            ));
153        };
154
155        let mut full_history = vec![];
156        if let Some(docs) = completion_request.normalized_documents() {
157            full_history.push(docs);
158        }
159        full_history.extend(completion_request.chat_history);
160        span.record_model_input(&full_history);
161
162        let full_history = full_history
163            .into_iter()
164            .map(Message::try_from)
165            .collect::<Result<Vec<Message>, _>>()?;
166
167        let mut body = json!({
168            "model": self.model,
169            "messages": full_history,
170            "max_tokens": max_tokens,
171            "system": completion_request.preamble.unwrap_or("".to_string()),
172            "stream": true,
173        });
174
175        if let Some(temperature) = completion_request.temperature {
176            merge_inplace(&mut body, json!({ "temperature": temperature }));
177        }
178
179        if !completion_request.tools.is_empty() {
180            merge_inplace(
181                &mut body,
182                json!({
183                    "tools": completion_request
184                        .tools
185                        .into_iter()
186                        .map(|tool| ToolDefinition {
187                            name: tool.name,
188                            description: Some(tool.description),
189                            input_schema: tool.parameters,
190                        })
191                        .collect::<Vec<_>>(),
192                    "tool_choice": ToolChoice::Auto,
193                }),
194            );
195        }
196
197        if let Some(ref params) = completion_request.additional_params {
198            merge_inplace(&mut body, params.clone())
199        }
200
201        let body: Vec<u8> = serde_json::to_vec(&body)?;
202
203        let req = self
204            .client
205            .post("/v1/messages")?
206            .body(body)
207            .map_err(http_client::Error::Protocol)?;
208
209        let stream = GenericEventSource::new(self.client.http_client().clone(), req);
210
211        // Use our SSE decoder to directly handle Server-Sent Events format
212        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
213            let mut current_tool_call: Option<ToolCallState> = None;
214            let mut current_thinking: Option<ThinkingState> = None;
215            let mut sse_stream = Box::pin(stream);
216            let mut input_tokens = 0;
217            let mut final_usage = None;
218
219            let mut text_content = String::new();
220
221            while let Some(sse_result) = sse_stream.next().await {
222                match sse_result {
223                    Ok(Event::Open) => {}
224                    Ok(Event::Message(sse)) => {
225                        // Parse the SSE data as a StreamingEvent
226                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
227                            Ok(event) => {
228                                match &event {
229                                    StreamingEvent::MessageStart { message } => {
230                                        input_tokens = message.usage.input_tokens;
231
232                                        let span = tracing::Span::current();
233                                        span.record("gen_ai.response.id", &message.id);
234                                        span.record("gen_ai.response.model_name", &message.model);
235                                    },
236                                    StreamingEvent::MessageDelta { delta, usage } => {
237                                        if delta.stop_reason.is_some() {
238                                            let usage = PartialUsage {
239                                                 output_tokens: usage.output_tokens,
240                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
241                                            };
242
243                                            let span = tracing::Span::current();
244                                            span.record_token_usage(&usage);
245                                            span.record_model_output(&Message {
246                                                role: super::completion::Role::Assistant,
247                                                content: OneOrMany::one(Content::Text { text: text_content.clone() })}
248                                            );
249
250                                            final_usage = Some(usage);
251                                            break;
252                                        }
253                                    }
254                                    _ => {}
255                                }
256
257                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
258                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
259                                        text_content += text;
260                                    }
261                                    yield result;
262                                }
263                            },
264                            Err(e) => {
265                                if !sse.data.trim().is_empty() {
266                                    yield Err(CompletionError::ResponseError(
267                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
268                                    ));
269                                }
270                            }
271                        }
272                    },
273                    Err(e) => {
274                        yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
275                        break;
276                    }
277                }
278            }
279
280            // Ensure event source is closed when stream ends
281            sse_stream.close();
282
283            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
284                usage: final_usage.unwrap_or_default()
285            }))
286        }.instrument(span));
287
288        Ok(streaming::StreamingCompletionResponse::stream(stream))
289    }
290}
291
292fn handle_event(
293    event: &StreamingEvent,
294    current_tool_call: &mut Option<ToolCallState>,
295    current_thinking: &mut Option<ThinkingState>,
296) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
297    match event {
298        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
299            ContentDelta::TextDelta { text } => {
300                if current_tool_call.is_none() {
301                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
302                }
303                None
304            }
305            ContentDelta::InputJsonDelta { partial_json } => {
306                if let Some(tool_call) = current_tool_call {
307                    tool_call.input_json.push_str(partial_json);
308                    // Emit the delta so UI can show progress
309                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
310                        id: tool_call.id.clone(),
311                        delta: partial_json.clone(),
312                    }));
313                }
314                None
315            }
316            ContentDelta::ThinkingDelta { thinking } => {
317                if current_thinking.is_none() {
318                    *current_thinking = Some(ThinkingState::default());
319                }
320
321                if let Some(state) = current_thinking {
322                    state.thinking.push_str(thinking);
323                }
324
325                Some(Ok(RawStreamingChoice::Reasoning {
326                    id: None,
327                    reasoning: thinking.clone(),
328                    signature: None,
329                }))
330            }
331            ContentDelta::SignatureDelta { signature } => {
332                if current_thinking.is_none() {
333                    *current_thinking = Some(ThinkingState::default());
334                }
335
336                if let Some(state) = current_thinking {
337                    state.signature.push_str(signature);
338                }
339
340                // Don't yield signature chunks, they will be included in the final Reasoning
341                None
342            }
343        },
344        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
345            Content::ToolUse { id, name, .. } => {
346                *current_tool_call = Some(ToolCallState {
347                    name: name.clone(),
348                    id: id.clone(),
349                    input_json: String::new(),
350                });
351                None
352            }
353            Content::Thinking { .. } => {
354                *current_thinking = Some(ThinkingState::default());
355                None
356            }
357            // Handle other content types - they don't need special handling
358            _ => None,
359        },
360        StreamingEvent::ContentBlockStop { .. } => {
361            if let Some(thinking_state) = Option::take(current_thinking)
362                && !thinking_state.thinking.is_empty()
363            {
364                let signature = if thinking_state.signature.is_empty() {
365                    None
366                } else {
367                    Some(thinking_state.signature)
368                };
369
370                return Some(Ok(RawStreamingChoice::Reasoning {
371                    id: None,
372                    reasoning: thinking_state.thinking,
373                    signature,
374                }));
375            }
376
377            if let Some(tool_call) = Option::take(current_tool_call) {
378                let json_str = if tool_call.input_json.is_empty() {
379                    "{}"
380                } else {
381                    &tool_call.input_json
382                };
383                match serde_json::from_str(json_str) {
384                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
385                        name: tool_call.name,
386                        id: tool_call.id,
387                        arguments: json_value,
388                        call_id: None,
389                    })),
390                    Err(e) => Some(Err(CompletionError::from(e))),
391                }
392            } else {
393                None
394            }
395        }
396        // Ignore other event types or handle as needed
397        StreamingEvent::MessageStart { .. }
398        | StreamingEvent::MessageDelta { .. }
399        | StreamingEvent::MessageStop
400        | StreamingEvent::Ping
401        | StreamingEvent::Unknown => None,
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_thinking_delta_deserialization() {
411        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
412        let delta: ContentDelta = serde_json::from_str(json).unwrap();
413
414        match delta {
415            ContentDelta::ThinkingDelta { thinking } => {
416                assert_eq!(thinking, "Let me think about this...");
417            }
418            _ => panic!("Expected ThinkingDelta variant"),
419        }
420    }
421
422    #[test]
423    fn test_signature_delta_deserialization() {
424        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
425        let delta: ContentDelta = serde_json::from_str(json).unwrap();
426
427        match delta {
428            ContentDelta::SignatureDelta { signature } => {
429                assert_eq!(signature, "abc123def456");
430            }
431            _ => panic!("Expected SignatureDelta variant"),
432        }
433    }
434
435    #[test]
436    fn test_thinking_delta_streaming_event_deserialization() {
437        let json = r#"{
438            "type": "content_block_delta",
439            "index": 0,
440            "delta": {
441                "type": "thinking_delta",
442                "thinking": "First, I need to understand the problem."
443            }
444        }"#;
445
446        let event: StreamingEvent = serde_json::from_str(json).unwrap();
447
448        match event {
449            StreamingEvent::ContentBlockDelta { index, delta } => {
450                assert_eq!(index, 0);
451                match delta {
452                    ContentDelta::ThinkingDelta { thinking } => {
453                        assert_eq!(thinking, "First, I need to understand the problem.");
454                    }
455                    _ => panic!("Expected ThinkingDelta"),
456                }
457            }
458            _ => panic!("Expected ContentBlockDelta event"),
459        }
460    }
461
462    #[test]
463    fn test_signature_delta_streaming_event_deserialization() {
464        let json = r#"{
465            "type": "content_block_delta",
466            "index": 0,
467            "delta": {
468                "type": "signature_delta",
469                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
470            }
471        }"#;
472
473        let event: StreamingEvent = serde_json::from_str(json).unwrap();
474
475        match event {
476            StreamingEvent::ContentBlockDelta { index, delta } => {
477                assert_eq!(index, 0);
478                match delta {
479                    ContentDelta::SignatureDelta { signature } => {
480                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
481                    }
482                    _ => panic!("Expected SignatureDelta"),
483                }
484            }
485            _ => panic!("Expected ContentBlockDelta event"),
486        }
487    }
488
489    #[test]
490    fn test_handle_thinking_delta_event() {
491        let event = StreamingEvent::ContentBlockDelta {
492            index: 0,
493            delta: ContentDelta::ThinkingDelta {
494                thinking: "Analyzing the request...".to_string(),
495            },
496        };
497
498        let mut tool_call_state = None;
499        let mut thinking_state = None;
500        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
501
502        assert!(result.is_some());
503        let choice = result.unwrap().unwrap();
504
505        match choice {
506            RawStreamingChoice::Reasoning { id, reasoning, .. } => {
507                assert_eq!(id, None);
508                assert_eq!(reasoning, "Analyzing the request...");
509            }
510            _ => panic!("Expected Reasoning choice"),
511        }
512
513        // Verify thinking state was updated
514        assert!(thinking_state.is_some());
515        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
516    }
517
518    #[test]
519    fn test_handle_signature_delta_event() {
520        let event = StreamingEvent::ContentBlockDelta {
521            index: 0,
522            delta: ContentDelta::SignatureDelta {
523                signature: "test_signature".to_string(),
524            },
525        };
526
527        let mut tool_call_state = None;
528        let mut thinking_state = None;
529        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
530
531        // SignatureDelta should not yield anything (returns None)
532        assert!(result.is_none());
533
534        // But signature should be captured in thinking state
535        assert!(thinking_state.is_some());
536        assert_eq!(thinking_state.unwrap().signature, "test_signature");
537    }
538
539    #[test]
540    fn test_handle_text_delta_event() {
541        let event = StreamingEvent::ContentBlockDelta {
542            index: 0,
543            delta: ContentDelta::TextDelta {
544                text: "Hello, world!".to_string(),
545            },
546        };
547
548        let mut tool_call_state = None;
549        let mut thinking_state = None;
550        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
551
552        assert!(result.is_some());
553        let choice = result.unwrap().unwrap();
554
555        match choice {
556            RawStreamingChoice::Message(text) => {
557                assert_eq!(text, "Hello, world!");
558            }
559            _ => panic!("Expected Message choice"),
560        }
561    }
562
563    #[test]
564    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
565        // Thinking deltas should still be processed even if a tool call is in progress
566        let event = StreamingEvent::ContentBlockDelta {
567            index: 0,
568            delta: ContentDelta::ThinkingDelta {
569                thinking: "Thinking while tool is active...".to_string(),
570            },
571        };
572
573        let mut tool_call_state = Some(ToolCallState {
574            name: "test_tool".to_string(),
575            id: "tool_123".to_string(),
576            input_json: String::new(),
577        });
578        let mut thinking_state = None;
579
580        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
581
582        assert!(result.is_some());
583        let choice = result.unwrap().unwrap();
584
585        match choice {
586            RawStreamingChoice::Reasoning { reasoning, .. } => {
587                assert_eq!(reasoning, "Thinking while tool is active...");
588            }
589            _ => panic!("Expected Reasoning choice"),
590        }
591
592        // Tool call state should remain unchanged
593        assert!(tool_call_state.is_some());
594    }
595
596    #[test]
597    fn test_handle_input_json_delta_event() {
598        let event = StreamingEvent::ContentBlockDelta {
599            index: 0,
600            delta: ContentDelta::InputJsonDelta {
601                partial_json: "{\"arg\":\"value".to_string(),
602            },
603        };
604
605        let mut tool_call_state = Some(ToolCallState {
606            name: "test_tool".to_string(),
607            id: "tool_123".to_string(),
608            input_json: String::new(),
609        });
610        let mut thinking_state = None;
611
612        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
613
614        // Should emit a ToolCallDelta
615        assert!(result.is_some());
616        let choice = result.unwrap().unwrap();
617
618        match choice {
619            RawStreamingChoice::ToolCallDelta { id, delta } => {
620                assert_eq!(id, "tool_123");
621                assert_eq!(delta, "{\"arg\":\"value");
622            }
623            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
624        }
625
626        // Verify the input_json was accumulated
627        assert!(tool_call_state.is_some());
628        let state = tool_call_state.unwrap();
629        assert_eq!(state.input_json, "{\"arg\":\"value");
630    }
631
632    #[test]
633    fn test_tool_call_accumulation_with_multiple_deltas() {
634        let mut tool_call_state = Some(ToolCallState {
635            name: "test_tool".to_string(),
636            id: "tool_123".to_string(),
637            input_json: String::new(),
638        });
639        let mut thinking_state = None;
640
641        // First delta
642        let event1 = StreamingEvent::ContentBlockDelta {
643            index: 0,
644            delta: ContentDelta::InputJsonDelta {
645                partial_json: "{\"location\":".to_string(),
646            },
647        };
648        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
649        assert!(result1.is_some());
650
651        // Second delta
652        let event2 = StreamingEvent::ContentBlockDelta {
653            index: 0,
654            delta: ContentDelta::InputJsonDelta {
655                partial_json: "\"Paris\",".to_string(),
656            },
657        };
658        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
659        assert!(result2.is_some());
660
661        // Third delta
662        let event3 = StreamingEvent::ContentBlockDelta {
663            index: 0,
664            delta: ContentDelta::InputJsonDelta {
665                partial_json: "\"temp\":\"20C\"}".to_string(),
666            },
667        };
668        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
669        assert!(result3.is_some());
670
671        // Verify accumulated JSON
672        assert!(tool_call_state.is_some());
673        let state = tool_call_state.as_ref().unwrap();
674        assert_eq!(
675            state.input_json,
676            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
677        );
678
679        // Final ContentBlockStop should emit complete tool call
680        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
681        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
682        assert!(final_result.is_some());
683
684        match final_result.unwrap().unwrap() {
685            RawStreamingChoice::ToolCall {
686                id,
687                name,
688                arguments,
689                ..
690            } => {
691                assert_eq!(id, "tool_123");
692                assert_eq!(name, "test_tool");
693                assert_eq!(
694                    arguments.get("location").unwrap().as_str().unwrap(),
695                    "Paris"
696                );
697                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
698            }
699            other => panic!("Expected ToolCall, got {:?}", other),
700        }
701
702        // Tool call state should be taken
703        assert!(tool_call_state.is_none());
704    }
705}