rig/providers/anthropic/
streaming.rs

1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use crate::OneOrMany;
10use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20    MessageStart {
21        message: MessageStart,
22    },
23    ContentBlockStart {
24        index: usize,
25        content_block: Content,
26    },
27    ContentBlockDelta {
28        index: usize,
29        delta: ContentDelta,
30    },
31    ContentBlockStop {
32        index: usize,
33    },
34    MessageDelta {
35        delta: MessageDelta,
36        usage: PartialUsage,
37    },
38    MessageStop,
39    Ping,
40    #[serde(other)]
41    Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46    pub id: String,
47    pub role: String,
48    pub content: Vec<Content>,
49    pub model: String,
50    pub stop_reason: Option<String>,
51    pub stop_sequence: Option<String>,
52    pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58    TextDelta { text: String },
59    InputJsonDelta { partial_json: String },
60    ThinkingDelta { thinking: String },
61    SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66    pub stop_reason: Option<String>,
67    pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct PartialUsage {
72    pub output_tokens: usize,
73    #[serde(default)]
74    pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78    fn token_usage(&self) -> Option<crate::completion::Usage> {
79        let mut usage = crate::completion::Usage::new();
80
81        usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82        usage.output_tokens = self.output_tokens as u64;
83        usage.total_tokens = usage.input_tokens + usage.output_tokens;
84        Some(usage)
85    }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90    name: String,
91    id: String,
92    input_json: String,
93}
94
95#[derive(Default)]
96struct ThinkingState {
97    thinking: String,
98    signature: String,
99}
100
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct StreamingCompletionResponse {
103    pub usage: PartialUsage,
104}
105
106impl GetTokenUsage for StreamingCompletionResponse {
107    fn token_usage(&self) -> Option<crate::completion::Usage> {
108        let mut usage = crate::completion::Usage::new();
109        usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
110        usage.output_tokens = self.usage.output_tokens as u64;
111        usage.total_tokens =
112            self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
113
114        Some(usage)
115    }
116}
117
118impl<T> CompletionModel<T>
119where
120    T: HttpClientExt + Clone + Default + 'static,
121{
122    pub(crate) async fn stream(
123        &self,
124        completion_request: CompletionRequest,
125    ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
126    {
127        let span = if tracing::Span::current().is_disabled() {
128            info_span!(
129                target: "rig::completions",
130                "chat_streaming",
131                gen_ai.operation.name = "chat_streaming",
132                gen_ai.provider.name = "anthropic",
133                gen_ai.request.model = self.model,
134                gen_ai.system_instructions = &completion_request.preamble,
135                gen_ai.response.id = tracing::field::Empty,
136                gen_ai.response.model = self.model,
137                gen_ai.usage.output_tokens = tracing::field::Empty,
138                gen_ai.usage.input_tokens = tracing::field::Empty,
139                gen_ai.input.messages = tracing::field::Empty,
140                gen_ai.output.messages = tracing::field::Empty,
141            )
142        } else {
143            tracing::Span::current()
144        };
145        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
146            tokens
147        } else if let Some(tokens) = self.default_max_tokens {
148            tokens
149        } else {
150            return Err(CompletionError::RequestError(
151                "`max_tokens` must be set for Anthropic".into(),
152            ));
153        };
154
155        let mut full_history = vec![];
156        if let Some(docs) = completion_request.normalized_documents() {
157            full_history.push(docs);
158        }
159        full_history.extend(completion_request.chat_history);
160        span.record_model_input(&full_history);
161
162        let full_history = full_history
163            .into_iter()
164            .map(Message::try_from)
165            .collect::<Result<Vec<Message>, _>>()?;
166
167        let mut body = json!({
168            "model": self.model,
169            "messages": full_history,
170            "max_tokens": max_tokens,
171            "system": completion_request.preamble.unwrap_or("".to_string()),
172            "stream": true,
173        });
174
175        if let Some(temperature) = completion_request.temperature {
176            merge_inplace(&mut body, json!({ "temperature": temperature }));
177        }
178
179        if !completion_request.tools.is_empty() {
180            merge_inplace(
181                &mut body,
182                json!({
183                    "tools": completion_request
184                        .tools
185                        .into_iter()
186                        .map(|tool| ToolDefinition {
187                            name: tool.name,
188                            description: Some(tool.description),
189                            input_schema: tool.parameters,
190                        })
191                        .collect::<Vec<_>>(),
192                    "tool_choice": ToolChoice::Auto,
193                }),
194            );
195        }
196
197        if let Some(ref params) = completion_request.additional_params {
198            merge_inplace(&mut body, params.clone())
199        }
200
201        let body: Vec<u8> = serde_json::to_vec(&body)?;
202
203        let req = self
204            .client
205            .post("/v1/messages")
206            .header("Content-Type", "application/json")
207            .body(body)
208            .map_err(http_client::Error::Protocol)?;
209
210        let stream = GenericEventSource::new(self.client.http_client.clone(), req);
211
212        // Use our SSE decoder to directly handle Server-Sent Events format
213        let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
214            let mut current_tool_call: Option<ToolCallState> = None;
215            let mut current_thinking: Option<ThinkingState> = None;
216            let mut sse_stream = Box::pin(stream);
217            let mut input_tokens = 0;
218
219            let mut text_content = String::new();
220
221            while let Some(sse_result) = sse_stream.next().await {
222                match sse_result {
223                    Ok(Event::Open) => {}
224                    Ok(Event::Message(sse)) => {
225                        // Parse the SSE data as a StreamingEvent
226                        match serde_json::from_str::<StreamingEvent>(&sse.data) {
227                            Ok(event) => {
228                                match &event {
229                                    StreamingEvent::MessageStart { message } => {
230                                        input_tokens = message.usage.input_tokens;
231
232                                        let span = tracing::Span::current();
233                                        span.record("gen_ai.response.id", &message.id);
234                                        span.record("gen_ai.response.model_name", &message.model);
235                                    },
236                                    StreamingEvent::MessageDelta { delta, usage } => {
237                                        if delta.stop_reason.is_some() {
238                                            let usage = PartialUsage {
239                                                 output_tokens: usage.output_tokens,
240                                                 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
241                                            };
242
243                                            let span = tracing::Span::current();
244                                            span.record_token_usage(&usage);
245                                            span.record_model_output(&Message {
246                                                role: super::completion::Role::Assistant,
247                                                content: OneOrMany::one(Content::Text { text: text_content.clone() })}
248                                            );
249
250                                            yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
251                                                usage
252                                            }))
253                                        }
254                                    }
255                                    _ => {}
256                                }
257
258                                if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
259                                    if let Ok(RawStreamingChoice::Message(ref text)) = result {
260                                        text_content += text;
261                                    }
262                                    yield result;
263                                }
264                            },
265                            Err(e) => {
266                                if !sse.data.trim().is_empty() {
267                                    yield Err(CompletionError::ResponseError(
268                                        format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
269                                    ));
270                                }
271                            }
272                        }
273                    },
274                    Err(e) => {
275                        yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
276                        break;
277                    }
278                }
279            }
280        }.instrument(span));
281
282        Ok(streaming::StreamingCompletionResponse::stream(stream))
283    }
284}
285
286fn handle_event(
287    event: &StreamingEvent,
288    current_tool_call: &mut Option<ToolCallState>,
289    current_thinking: &mut Option<ThinkingState>,
290) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
291    match event {
292        StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
293            ContentDelta::TextDelta { text } => {
294                if current_tool_call.is_none() {
295                    return Some(Ok(RawStreamingChoice::Message(text.clone())));
296                }
297                None
298            }
299            ContentDelta::InputJsonDelta { partial_json } => {
300                if let Some(tool_call) = current_tool_call {
301                    tool_call.input_json.push_str(partial_json);
302                    // Emit the delta so UI can show progress
303                    return Some(Ok(RawStreamingChoice::ToolCallDelta {
304                        id: tool_call.id.clone(),
305                        delta: partial_json.clone(),
306                    }));
307                }
308                None
309            }
310            ContentDelta::ThinkingDelta { thinking } => {
311                if current_thinking.is_none() {
312                    *current_thinking = Some(ThinkingState::default());
313                }
314
315                if let Some(state) = current_thinking {
316                    state.thinking.push_str(thinking);
317                }
318
319                Some(Ok(RawStreamingChoice::Reasoning {
320                    id: None,
321                    reasoning: thinking.clone(),
322                    signature: None,
323                }))
324            }
325            ContentDelta::SignatureDelta { signature } => {
326                if current_thinking.is_none() {
327                    *current_thinking = Some(ThinkingState::default());
328                }
329
330                if let Some(state) = current_thinking {
331                    state.signature.push_str(signature);
332                }
333
334                // Don't yield signature chunks, they will be included in the final Reasoning
335                None
336            }
337        },
338        StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
339            Content::ToolUse { id, name, .. } => {
340                *current_tool_call = Some(ToolCallState {
341                    name: name.clone(),
342                    id: id.clone(),
343                    input_json: String::new(),
344                });
345                None
346            }
347            Content::Thinking { .. } => {
348                *current_thinking = Some(ThinkingState::default());
349                None
350            }
351            // Handle other content types - they don't need special handling
352            _ => None,
353        },
354        StreamingEvent::ContentBlockStop { .. } => {
355            if let Some(thinking_state) = Option::take(current_thinking)
356                && !thinking_state.thinking.is_empty()
357            {
358                let signature = if thinking_state.signature.is_empty() {
359                    None
360                } else {
361                    Some(thinking_state.signature)
362                };
363
364                return Some(Ok(RawStreamingChoice::Reasoning {
365                    id: None,
366                    reasoning: thinking_state.thinking,
367                    signature,
368                }));
369            }
370
371            if let Some(tool_call) = Option::take(current_tool_call) {
372                let json_str = if tool_call.input_json.is_empty() {
373                    "{}"
374                } else {
375                    &tool_call.input_json
376                };
377                match serde_json::from_str(json_str) {
378                    Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
379                        name: tool_call.name,
380                        id: tool_call.id,
381                        arguments: json_value,
382                        call_id: None,
383                    })),
384                    Err(e) => Some(Err(CompletionError::from(e))),
385                }
386            } else {
387                None
388            }
389        }
390        // Ignore other event types or handle as needed
391        StreamingEvent::MessageStart { .. }
392        | StreamingEvent::MessageDelta { .. }
393        | StreamingEvent::MessageStop
394        | StreamingEvent::Ping
395        | StreamingEvent::Unknown => None,
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    #[test]
404    fn test_thinking_delta_deserialization() {
405        let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
406        let delta: ContentDelta = serde_json::from_str(json).unwrap();
407
408        match delta {
409            ContentDelta::ThinkingDelta { thinking } => {
410                assert_eq!(thinking, "Let me think about this...");
411            }
412            _ => panic!("Expected ThinkingDelta variant"),
413        }
414    }
415
416    #[test]
417    fn test_signature_delta_deserialization() {
418        let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
419        let delta: ContentDelta = serde_json::from_str(json).unwrap();
420
421        match delta {
422            ContentDelta::SignatureDelta { signature } => {
423                assert_eq!(signature, "abc123def456");
424            }
425            _ => panic!("Expected SignatureDelta variant"),
426        }
427    }
428
429    #[test]
430    fn test_thinking_delta_streaming_event_deserialization() {
431        let json = r#"{
432            "type": "content_block_delta",
433            "index": 0,
434            "delta": {
435                "type": "thinking_delta",
436                "thinking": "First, I need to understand the problem."
437            }
438        }"#;
439
440        let event: StreamingEvent = serde_json::from_str(json).unwrap();
441
442        match event {
443            StreamingEvent::ContentBlockDelta { index, delta } => {
444                assert_eq!(index, 0);
445                match delta {
446                    ContentDelta::ThinkingDelta { thinking } => {
447                        assert_eq!(thinking, "First, I need to understand the problem.");
448                    }
449                    _ => panic!("Expected ThinkingDelta"),
450                }
451            }
452            _ => panic!("Expected ContentBlockDelta event"),
453        }
454    }
455
456    #[test]
457    fn test_signature_delta_streaming_event_deserialization() {
458        let json = r#"{
459            "type": "content_block_delta",
460            "index": 0,
461            "delta": {
462                "type": "signature_delta",
463                "signature": "ErUBCkYICBgCIkCaGbqC85F4"
464            }
465        }"#;
466
467        let event: StreamingEvent = serde_json::from_str(json).unwrap();
468
469        match event {
470            StreamingEvent::ContentBlockDelta { index, delta } => {
471                assert_eq!(index, 0);
472                match delta {
473                    ContentDelta::SignatureDelta { signature } => {
474                        assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
475                    }
476                    _ => panic!("Expected SignatureDelta"),
477                }
478            }
479            _ => panic!("Expected ContentBlockDelta event"),
480        }
481    }
482
483    #[test]
484    fn test_handle_thinking_delta_event() {
485        let event = StreamingEvent::ContentBlockDelta {
486            index: 0,
487            delta: ContentDelta::ThinkingDelta {
488                thinking: "Analyzing the request...".to_string(),
489            },
490        };
491
492        let mut tool_call_state = None;
493        let mut thinking_state = None;
494        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
495
496        assert!(result.is_some());
497        let choice = result.unwrap().unwrap();
498
499        match choice {
500            RawStreamingChoice::Reasoning { id, reasoning, .. } => {
501                assert_eq!(id, None);
502                assert_eq!(reasoning, "Analyzing the request...");
503            }
504            _ => panic!("Expected Reasoning choice"),
505        }
506
507        // Verify thinking state was updated
508        assert!(thinking_state.is_some());
509        assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
510    }
511
512    #[test]
513    fn test_handle_signature_delta_event() {
514        let event = StreamingEvent::ContentBlockDelta {
515            index: 0,
516            delta: ContentDelta::SignatureDelta {
517                signature: "test_signature".to_string(),
518            },
519        };
520
521        let mut tool_call_state = None;
522        let mut thinking_state = None;
523        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
524
525        // SignatureDelta should not yield anything (returns None)
526        assert!(result.is_none());
527
528        // But signature should be captured in thinking state
529        assert!(thinking_state.is_some());
530        assert_eq!(thinking_state.unwrap().signature, "test_signature");
531    }
532
533    #[test]
534    fn test_handle_text_delta_event() {
535        let event = StreamingEvent::ContentBlockDelta {
536            index: 0,
537            delta: ContentDelta::TextDelta {
538                text: "Hello, world!".to_string(),
539            },
540        };
541
542        let mut tool_call_state = None;
543        let mut thinking_state = None;
544        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
545
546        assert!(result.is_some());
547        let choice = result.unwrap().unwrap();
548
549        match choice {
550            RawStreamingChoice::Message(text) => {
551                assert_eq!(text, "Hello, world!");
552            }
553            _ => panic!("Expected Message choice"),
554        }
555    }
556
557    #[test]
558    fn test_thinking_delta_does_not_interfere_with_tool_calls() {
559        // Thinking deltas should still be processed even if a tool call is in progress
560        let event = StreamingEvent::ContentBlockDelta {
561            index: 0,
562            delta: ContentDelta::ThinkingDelta {
563                thinking: "Thinking while tool is active...".to_string(),
564            },
565        };
566
567        let mut tool_call_state = Some(ToolCallState {
568            name: "test_tool".to_string(),
569            id: "tool_123".to_string(),
570            input_json: String::new(),
571        });
572        let mut thinking_state = None;
573
574        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
575
576        assert!(result.is_some());
577        let choice = result.unwrap().unwrap();
578
579        match choice {
580            RawStreamingChoice::Reasoning { reasoning, .. } => {
581                assert_eq!(reasoning, "Thinking while tool is active...");
582            }
583            _ => panic!("Expected Reasoning choice"),
584        }
585
586        // Tool call state should remain unchanged
587        assert!(tool_call_state.is_some());
588    }
589
590    #[test]
591    fn test_handle_input_json_delta_event() {
592        let event = StreamingEvent::ContentBlockDelta {
593            index: 0,
594            delta: ContentDelta::InputJsonDelta {
595                partial_json: "{\"arg\":\"value".to_string(),
596            },
597        };
598
599        let mut tool_call_state = Some(ToolCallState {
600            name: "test_tool".to_string(),
601            id: "tool_123".to_string(),
602            input_json: String::new(),
603        });
604        let mut thinking_state = None;
605
606        let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
607
608        // Should emit a ToolCallDelta
609        assert!(result.is_some());
610        let choice = result.unwrap().unwrap();
611
612        match choice {
613            RawStreamingChoice::ToolCallDelta { id, delta } => {
614                assert_eq!(id, "tool_123");
615                assert_eq!(delta, "{\"arg\":\"value");
616            }
617            _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
618        }
619
620        // Verify the input_json was accumulated
621        assert!(tool_call_state.is_some());
622        let state = tool_call_state.unwrap();
623        assert_eq!(state.input_json, "{\"arg\":\"value");
624    }
625
626    #[test]
627    fn test_tool_call_accumulation_with_multiple_deltas() {
628        let mut tool_call_state = Some(ToolCallState {
629            name: "test_tool".to_string(),
630            id: "tool_123".to_string(),
631            input_json: String::new(),
632        });
633        let mut thinking_state = None;
634
635        // First delta
636        let event1 = StreamingEvent::ContentBlockDelta {
637            index: 0,
638            delta: ContentDelta::InputJsonDelta {
639                partial_json: "{\"location\":".to_string(),
640            },
641        };
642        let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
643        assert!(result1.is_some());
644
645        // Second delta
646        let event2 = StreamingEvent::ContentBlockDelta {
647            index: 0,
648            delta: ContentDelta::InputJsonDelta {
649                partial_json: "\"Paris\",".to_string(),
650            },
651        };
652        let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
653        assert!(result2.is_some());
654
655        // Third delta
656        let event3 = StreamingEvent::ContentBlockDelta {
657            index: 0,
658            delta: ContentDelta::InputJsonDelta {
659                partial_json: "\"temp\":\"20C\"}".to_string(),
660            },
661        };
662        let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
663        assert!(result3.is_some());
664
665        // Verify accumulated JSON
666        assert!(tool_call_state.is_some());
667        let state = tool_call_state.as_ref().unwrap();
668        assert_eq!(
669            state.input_json,
670            "{\"location\":\"Paris\",\"temp\":\"20C\"}"
671        );
672
673        // Final ContentBlockStop should emit complete tool call
674        let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
675        let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
676        assert!(final_result.is_some());
677
678        match final_result.unwrap().unwrap() {
679            RawStreamingChoice::ToolCall {
680                id,
681                name,
682                arguments,
683                ..
684            } => {
685                assert_eq!(id, "tool_123");
686                assert_eq!(name, "test_tool");
687                assert_eq!(
688                    arguments.get("location").unwrap().as_str().unwrap(),
689                    "Paris"
690                );
691                assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
692            }
693            other => panic!("Expected ToolCall, got {:?}", other),
694        }
695
696        // Tool call state should be taken
697        assert!(tool_call_state.is_none());
698    }
699}