1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::message::ReasoningContent;
17use crate::streaming::{
18 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
19};
20use crate::telemetry::SpanCombinator;
21
22#[derive(Debug, Deserialize)]
23#[serde(tag = "type", rename_all = "snake_case")]
24pub enum StreamingEvent {
25 MessageStart {
26 message: MessageStart,
27 },
28 ContentBlockStart {
29 index: usize,
30 content_block: Content,
31 },
32 ContentBlockDelta {
33 index: usize,
34 delta: ContentDelta,
35 },
36 ContentBlockStop {
37 index: usize,
38 },
39 MessageDelta {
40 delta: MessageDelta,
41 usage: PartialUsage,
42 },
43 MessageStop,
44 Ping,
45 #[serde(other)]
46 Unknown,
47}
48
49#[derive(Debug, Deserialize)]
50pub struct MessageStart {
51 pub id: String,
52 pub role: String,
53 pub content: Vec<Content>,
54 pub model: String,
55 pub stop_reason: Option<String>,
56 pub stop_sequence: Option<String>,
57 pub usage: Usage,
58}
59
60#[derive(Debug, Deserialize)]
61#[serde(tag = "type", rename_all = "snake_case")]
62pub enum ContentDelta {
63 TextDelta { text: String },
64 InputJsonDelta { partial_json: String },
65 ThinkingDelta { thinking: String },
66 SignatureDelta { signature: String },
67}
68
69#[derive(Debug, Deserialize)]
70pub struct MessageDelta {
71 pub stop_reason: Option<String>,
72 pub stop_sequence: Option<String>,
73}
74
75#[derive(Debug, Deserialize, Clone, Serialize, Default)]
76pub struct PartialUsage {
77 pub output_tokens: usize,
78 #[serde(default)]
79 pub input_tokens: Option<usize>,
80}
81
82impl GetTokenUsage for PartialUsage {
83 fn token_usage(&self) -> Option<crate::completion::Usage> {
84 let mut usage = crate::completion::Usage::new();
85
86 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
87 usage.output_tokens = self.output_tokens as u64;
88 usage.total_tokens = usage.input_tokens + usage.output_tokens;
89 Some(usage)
90 }
91}
92
93#[derive(Default)]
94struct ToolCallState {
95 name: String,
96 id: String,
97 internal_call_id: String,
98 input_json: String,
99}
100
101#[derive(Default)]
102struct ThinkingState {
103 thinking: String,
104 signature: String,
105}
106
107#[derive(Clone, Debug, Deserialize, Serialize)]
108pub struct StreamingCompletionResponse {
109 pub usage: PartialUsage,
110}
111
112impl GetTokenUsage for StreamingCompletionResponse {
113 fn token_usage(&self) -> Option<crate::completion::Usage> {
114 let mut usage = crate::completion::Usage::new();
115 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
116 usage.output_tokens = self.usage.output_tokens as u64;
117 usage.total_tokens =
118 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
119
120 Some(usage)
121 }
122}
123
124impl<T> CompletionModel<T>
125where
126 T: HttpClientExt + Clone + Default + 'static,
127{
128 pub(crate) async fn stream(
129 &self,
130 completion_request: CompletionRequest,
131 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
132 {
133 let request_model = completion_request
134 .model
135 .clone()
136 .unwrap_or_else(|| self.model.clone());
137 let span = if tracing::Span::current().is_disabled() {
138 info_span!(
139 target: "rig::completions",
140 "chat_streaming",
141 gen_ai.operation.name = "chat_streaming",
142 gen_ai.provider.name = "anthropic",
143 gen_ai.request.model = &request_model,
144 gen_ai.system_instructions = &completion_request.preamble,
145 gen_ai.response.id = tracing::field::Empty,
146 gen_ai.response.model = &request_model,
147 gen_ai.usage.output_tokens = tracing::field::Empty,
148 gen_ai.usage.input_tokens = tracing::field::Empty,
149 gen_ai.input.messages = tracing::field::Empty,
150 gen_ai.output.messages = tracing::field::Empty,
151 )
152 } else {
153 tracing::Span::current()
154 };
155 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
156 tokens
157 } else if let Some(tokens) = self.default_max_tokens {
158 tokens
159 } else {
160 return Err(CompletionError::RequestError(
161 "`max_tokens` must be set for Anthropic".into(),
162 ));
163 };
164
165 let mut full_history = vec![];
166 if let Some(docs) = completion_request.normalized_documents() {
167 full_history.push(docs);
168 }
169 full_history.extend(completion_request.chat_history);
170
171 let mut messages = full_history
172 .into_iter()
173 .map(Message::try_from)
174 .collect::<Result<Vec<Message>, _>>()?;
175
176 let mut system: Vec<SystemContent> =
178 if let Some(preamble) = completion_request.preamble.as_ref() {
179 if preamble.is_empty() {
180 vec![]
181 } else {
182 vec![SystemContent::Text {
183 text: preamble.clone(),
184 cache_control: None,
185 }]
186 }
187 } else {
188 vec![]
189 };
190
191 if self.prompt_caching {
193 apply_cache_control(&mut system, &mut messages);
194 }
195
196 let mut body = json!({
197 "model": request_model,
198 "messages": messages,
199 "max_tokens": max_tokens,
200 "stream": true,
201 });
202
203 if !system.is_empty() {
205 merge_inplace(&mut body, json!({ "system": system }));
206 }
207
208 if let Some(temperature) = completion_request.temperature {
209 merge_inplace(&mut body, json!({ "temperature": temperature }));
210 }
211
212 if !completion_request.tools.is_empty() {
213 merge_inplace(
214 &mut body,
215 json!({
216 "tools": completion_request
217 .tools
218 .into_iter()
219 .map(|tool| ToolDefinition {
220 name: tool.name,
221 description: Some(tool.description),
222 input_schema: tool.parameters,
223 })
224 .collect::<Vec<_>>(),
225 "tool_choice": ToolChoice::Auto,
226 }),
227 );
228 }
229
230 if let Some(ref params) = completion_request.additional_params {
231 merge_inplace(&mut body, params.clone())
232 }
233
234 if enabled!(Level::TRACE) {
235 tracing::trace!(
236 target: "rig::completions",
237 "Anthropic completion request: {}",
238 serde_json::to_string_pretty(&body)?
239 );
240 }
241
242 let body: Vec<u8> = serde_json::to_vec(&body)?;
243
244 let req = self
245 .client
246 .post("/v1/messages")?
247 .body(body)
248 .map_err(http_client::Error::Protocol)?;
249
250 let stream = GenericEventSource::new(self.client.clone(), req);
251
252 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
254 let mut current_tool_call: Option<ToolCallState> = None;
255 let mut current_thinking: Option<ThinkingState> = None;
256 let mut sse_stream = Box::pin(stream);
257 let mut input_tokens = 0;
258 let mut final_usage = None;
259
260 let mut text_content = String::new();
261
262 while let Some(sse_result) = sse_stream.next().await {
263 match sse_result {
264 Ok(Event::Open) => {}
265 Ok(Event::Message(sse)) => {
266 match serde_json::from_str::<StreamingEvent>(&sse.data) {
268 Ok(event) => {
269 match &event {
270 StreamingEvent::MessageStart { message } => {
271 input_tokens = message.usage.input_tokens;
272
273 let span = tracing::Span::current();
274 span.record("gen_ai.response.id", &message.id);
275 span.record("gen_ai.response.model_name", &message.model);
276 },
277 StreamingEvent::MessageDelta { delta, usage } => {
278 if delta.stop_reason.is_some() {
279 let usage = PartialUsage {
280 output_tokens: usage.output_tokens,
281 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
282 };
283
284 let span = tracing::Span::current();
285 span.record_token_usage(&usage);
286 final_usage = Some(usage);
287 break;
288 }
289 }
290 _ => {}
291 }
292
293 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
294 if let Ok(RawStreamingChoice::Message(ref text)) = result {
295 text_content += text;
296 }
297 yield result;
298 }
299 },
300 Err(e) => {
301 if !sse.data.trim().is_empty() {
302 yield Err(CompletionError::ResponseError(
303 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
304 ));
305 }
306 }
307 }
308 },
309 Err(e) => {
310 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
311 break;
312 }
313 }
314 }
315
316 sse_stream.close();
318
319 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
320 usage: final_usage.unwrap_or_default()
321 }))
322 }.instrument(span));
323
324 Ok(streaming::StreamingCompletionResponse::stream(stream))
325 }
326}
327
328fn handle_event(
329 event: &StreamingEvent,
330 current_tool_call: &mut Option<ToolCallState>,
331 current_thinking: &mut Option<ThinkingState>,
332) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
333 match event {
334 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
335 ContentDelta::TextDelta { text } => {
336 if current_tool_call.is_none() {
337 return Some(Ok(RawStreamingChoice::Message(text.clone())));
338 }
339 None
340 }
341 ContentDelta::InputJsonDelta { partial_json } => {
342 if let Some(tool_call) = current_tool_call {
343 tool_call.input_json.push_str(partial_json);
344 return Some(Ok(RawStreamingChoice::ToolCallDelta {
346 id: tool_call.id.clone(),
347 internal_call_id: tool_call.internal_call_id.clone(),
348 content: ToolCallDeltaContent::Delta(partial_json.clone()),
349 }));
350 }
351 None
352 }
353 ContentDelta::ThinkingDelta { thinking } => {
354 current_thinking
355 .get_or_insert_with(ThinkingState::default)
356 .thinking
357 .push_str(thinking);
358
359 Some(Ok(RawStreamingChoice::ReasoningDelta {
360 id: None,
361 reasoning: thinking.clone(),
362 }))
363 }
364 ContentDelta::SignatureDelta { signature } => {
365 current_thinking
366 .get_or_insert_with(ThinkingState::default)
367 .signature
368 .push_str(signature);
369
370 None
372 }
373 },
374 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
375 Content::ToolUse { id, name, .. } => {
376 let internal_call_id = nanoid::nanoid!();
377 *current_tool_call = Some(ToolCallState {
378 name: name.clone(),
379 id: id.clone(),
380 internal_call_id: internal_call_id.clone(),
381 input_json: String::new(),
382 });
383 Some(Ok(RawStreamingChoice::ToolCallDelta {
384 id: id.clone(),
385 internal_call_id,
386 content: ToolCallDeltaContent::Name(name.clone()),
387 }))
388 }
389 Content::Thinking { .. } => {
390 *current_thinking = Some(ThinkingState::default());
391 None
392 }
393 Content::RedactedThinking { data } => Some(Ok(RawStreamingChoice::Reasoning {
394 id: None,
395 content: ReasoningContent::Redacted { data: data.clone() },
396 })),
397 _ => None,
399 },
400 StreamingEvent::ContentBlockStop { .. } => {
401 if let Some(thinking_state) = Option::take(current_thinking)
402 && !thinking_state.thinking.is_empty()
403 {
404 let signature = if thinking_state.signature.is_empty() {
405 None
406 } else {
407 Some(thinking_state.signature)
408 };
409
410 return Some(Ok(RawStreamingChoice::Reasoning {
411 id: None,
412 content: ReasoningContent::Text {
413 text: thinking_state.thinking,
414 signature,
415 },
416 }));
417 }
418
419 if let Some(tool_call) = Option::take(current_tool_call) {
420 let json_str = if tool_call.input_json.is_empty() {
421 "{}"
422 } else {
423 &tool_call.input_json
424 };
425 match serde_json::from_str(json_str) {
426 Ok(json_value) => {
427 let raw_tool_call =
428 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
429 .with_internal_call_id(tool_call.internal_call_id);
430 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
431 }
432 Err(e) => Some(Err(CompletionError::from(e))),
433 }
434 } else {
435 None
436 }
437 }
438 StreamingEvent::MessageStart { .. }
440 | StreamingEvent::MessageDelta { .. }
441 | StreamingEvent::MessageStop
442 | StreamingEvent::Ping
443 | StreamingEvent::Unknown => None,
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450
451 #[test]
452 fn test_thinking_delta_deserialization() {
453 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
454 let delta: ContentDelta = serde_json::from_str(json).unwrap();
455
456 match delta {
457 ContentDelta::ThinkingDelta { thinking } => {
458 assert_eq!(thinking, "Let me think about this...");
459 }
460 _ => panic!("Expected ThinkingDelta variant"),
461 }
462 }
463
464 #[test]
465 fn test_signature_delta_deserialization() {
466 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
467 let delta: ContentDelta = serde_json::from_str(json).unwrap();
468
469 match delta {
470 ContentDelta::SignatureDelta { signature } => {
471 assert_eq!(signature, "abc123def456");
472 }
473 _ => panic!("Expected SignatureDelta variant"),
474 }
475 }
476
477 #[test]
478 fn test_thinking_delta_streaming_event_deserialization() {
479 let json = r#"{
480 "type": "content_block_delta",
481 "index": 0,
482 "delta": {
483 "type": "thinking_delta",
484 "thinking": "First, I need to understand the problem."
485 }
486 }"#;
487
488 let event: StreamingEvent = serde_json::from_str(json).unwrap();
489
490 match event {
491 StreamingEvent::ContentBlockDelta { index, delta } => {
492 assert_eq!(index, 0);
493 match delta {
494 ContentDelta::ThinkingDelta { thinking } => {
495 assert_eq!(thinking, "First, I need to understand the problem.");
496 }
497 _ => panic!("Expected ThinkingDelta"),
498 }
499 }
500 _ => panic!("Expected ContentBlockDelta event"),
501 }
502 }
503
504 #[test]
505 fn test_signature_delta_streaming_event_deserialization() {
506 let json = r#"{
507 "type": "content_block_delta",
508 "index": 0,
509 "delta": {
510 "type": "signature_delta",
511 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
512 }
513 }"#;
514
515 let event: StreamingEvent = serde_json::from_str(json).unwrap();
516
517 match event {
518 StreamingEvent::ContentBlockDelta { index, delta } => {
519 assert_eq!(index, 0);
520 match delta {
521 ContentDelta::SignatureDelta { signature } => {
522 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
523 }
524 _ => panic!("Expected SignatureDelta"),
525 }
526 }
527 _ => panic!("Expected ContentBlockDelta event"),
528 }
529 }
530
531 #[test]
532 fn test_handle_thinking_delta_event() {
533 let event = StreamingEvent::ContentBlockDelta {
534 index: 0,
535 delta: ContentDelta::ThinkingDelta {
536 thinking: "Analyzing the request...".to_string(),
537 },
538 };
539
540 let mut tool_call_state = None;
541 let mut thinking_state = None;
542 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
543
544 assert!(result.is_some());
545 let choice = result.unwrap().unwrap();
546
547 match choice {
548 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
549 assert_eq!(id, None);
550 assert_eq!(reasoning, "Analyzing the request...");
551 }
552 _ => panic!("Expected ReasoningDelta choice"),
553 }
554
555 assert!(thinking_state.is_some());
557 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
558 }
559
560 #[test]
561 fn test_handle_signature_delta_event() {
562 let event = StreamingEvent::ContentBlockDelta {
563 index: 0,
564 delta: ContentDelta::SignatureDelta {
565 signature: "test_signature".to_string(),
566 },
567 };
568
569 let mut tool_call_state = None;
570 let mut thinking_state = None;
571 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
572
573 assert!(result.is_none());
575
576 assert!(thinking_state.is_some());
578 assert_eq!(thinking_state.unwrap().signature, "test_signature");
579 }
580
581 #[test]
582 fn test_handle_redacted_thinking_content_block_start_event() {
583 let event = StreamingEvent::ContentBlockStart {
584 index: 0,
585 content_block: Content::RedactedThinking {
586 data: "redacted_blob".to_string(),
587 },
588 };
589 let mut tool_call_state = None;
590 let mut thinking_state = None;
591 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
592
593 assert!(result.is_some());
594 match result.unwrap().unwrap() {
595 RawStreamingChoice::Reasoning {
596 content: ReasoningContent::Redacted { data },
597 ..
598 } => {
599 assert_eq!(data, "redacted_blob");
600 }
601 _ => panic!("Expected Redacted reasoning chunk"),
602 }
603 }
604
605 #[test]
606 fn test_handle_text_delta_event() {
607 let event = StreamingEvent::ContentBlockDelta {
608 index: 0,
609 delta: ContentDelta::TextDelta {
610 text: "Hello, world!".to_string(),
611 },
612 };
613
614 let mut tool_call_state = None;
615 let mut thinking_state = None;
616 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
617
618 assert!(result.is_some());
619 let choice = result.unwrap().unwrap();
620
621 match choice {
622 RawStreamingChoice::Message(text) => {
623 assert_eq!(text, "Hello, world!");
624 }
625 _ => panic!("Expected Message choice"),
626 }
627 }
628
629 #[test]
630 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
631 let event = StreamingEvent::ContentBlockDelta {
633 index: 0,
634 delta: ContentDelta::ThinkingDelta {
635 thinking: "Thinking while tool is active...".to_string(),
636 },
637 };
638
639 let mut tool_call_state = Some(ToolCallState {
640 name: "test_tool".to_string(),
641 id: "tool_123".to_string(),
642 internal_call_id: nanoid::nanoid!(),
643 input_json: String::new(),
644 });
645 let mut thinking_state = None;
646
647 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
648
649 assert!(result.is_some());
650 let choice = result.unwrap().unwrap();
651
652 match choice {
653 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
654 assert_eq!(reasoning, "Thinking while tool is active...");
655 }
656 _ => panic!("Expected ReasoningDelta choice"),
657 }
658
659 assert!(tool_call_state.is_some());
661 }
662
663 #[test]
664 fn test_handle_input_json_delta_event() {
665 let event = StreamingEvent::ContentBlockDelta {
666 index: 0,
667 delta: ContentDelta::InputJsonDelta {
668 partial_json: "{\"arg\":\"value".to_string(),
669 },
670 };
671
672 let mut tool_call_state = Some(ToolCallState {
673 name: "test_tool".to_string(),
674 id: "tool_123".to_string(),
675 internal_call_id: nanoid::nanoid!(),
676 input_json: String::new(),
677 });
678 let mut thinking_state = None;
679
680 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
681
682 assert!(result.is_some());
684 let choice = result.unwrap().unwrap();
685
686 match choice {
687 RawStreamingChoice::ToolCallDelta {
688 id,
689 internal_call_id: _,
690 content,
691 } => {
692 assert_eq!(id, "tool_123");
693 match content {
694 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
695 _ => panic!("Expected Delta content"),
696 }
697 }
698 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
699 }
700
701 assert!(tool_call_state.is_some());
703 let state = tool_call_state.unwrap();
704 assert_eq!(state.input_json, "{\"arg\":\"value");
705 }
706
707 #[test]
708 fn test_tool_call_accumulation_with_multiple_deltas() {
709 let mut tool_call_state = Some(ToolCallState {
710 name: "test_tool".to_string(),
711 id: "tool_123".to_string(),
712 internal_call_id: nanoid::nanoid!(),
713 input_json: String::new(),
714 });
715 let mut thinking_state = None;
716
717 let event1 = StreamingEvent::ContentBlockDelta {
719 index: 0,
720 delta: ContentDelta::InputJsonDelta {
721 partial_json: "{\"location\":".to_string(),
722 },
723 };
724 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
725 assert!(result1.is_some());
726
727 let event2 = StreamingEvent::ContentBlockDelta {
729 index: 0,
730 delta: ContentDelta::InputJsonDelta {
731 partial_json: "\"Paris\",".to_string(),
732 },
733 };
734 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
735 assert!(result2.is_some());
736
737 let event3 = StreamingEvent::ContentBlockDelta {
739 index: 0,
740 delta: ContentDelta::InputJsonDelta {
741 partial_json: "\"temp\":\"20C\"}".to_string(),
742 },
743 };
744 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
745 assert!(result3.is_some());
746
747 assert!(tool_call_state.is_some());
749 let state = tool_call_state.as_ref().unwrap();
750 assert_eq!(
751 state.input_json,
752 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
753 );
754
755 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
757 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
758 assert!(final_result.is_some());
759
760 match final_result.unwrap().unwrap() {
761 RawStreamingChoice::ToolCall(RawStreamingToolCall {
762 id,
763 name,
764 arguments,
765 ..
766 }) => {
767 assert_eq!(id, "tool_123");
768 assert_eq!(name, "test_tool");
769 assert_eq!(
770 arguments.get("location").unwrap().as_str().unwrap(),
771 "Paris"
772 );
773 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
774 }
775 other => panic!("Expected ToolCall, got {:?}", other),
776 }
777
778 assert!(tool_call_state.is_none());
780 }
781}