1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{
17 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
18};
19use crate::telemetry::SpanCombinator;
20
21#[derive(Debug, Deserialize)]
22#[serde(tag = "type", rename_all = "snake_case")]
23pub enum StreamingEvent {
24 MessageStart {
25 message: MessageStart,
26 },
27 ContentBlockStart {
28 index: usize,
29 content_block: Content,
30 },
31 ContentBlockDelta {
32 index: usize,
33 delta: ContentDelta,
34 },
35 ContentBlockStop {
36 index: usize,
37 },
38 MessageDelta {
39 delta: MessageDelta,
40 usage: PartialUsage,
41 },
42 MessageStop,
43 Ping,
44 #[serde(other)]
45 Unknown,
46}
47
48#[derive(Debug, Deserialize)]
49pub struct MessageStart {
50 pub id: String,
51 pub role: String,
52 pub content: Vec<Content>,
53 pub model: String,
54 pub stop_reason: Option<String>,
55 pub stop_sequence: Option<String>,
56 pub usage: Usage,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum ContentDelta {
62 TextDelta { text: String },
63 InputJsonDelta { partial_json: String },
64 ThinkingDelta { thinking: String },
65 SignatureDelta { signature: String },
66}
67
68#[derive(Debug, Deserialize)]
69pub struct MessageDelta {
70 pub stop_reason: Option<String>,
71 pub stop_sequence: Option<String>,
72}
73
74#[derive(Debug, Deserialize, Clone, Serialize, Default)]
75pub struct PartialUsage {
76 pub output_tokens: usize,
77 #[serde(default)]
78 pub input_tokens: Option<usize>,
79}
80
81impl GetTokenUsage for PartialUsage {
82 fn token_usage(&self) -> Option<crate::completion::Usage> {
83 let mut usage = crate::completion::Usage::new();
84
85 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
86 usage.output_tokens = self.output_tokens as u64;
87 usage.total_tokens = usage.input_tokens + usage.output_tokens;
88 Some(usage)
89 }
90}
91
92#[derive(Default)]
93struct ToolCallState {
94 name: String,
95 id: String,
96 input_json: String,
97}
98
99#[derive(Default)]
100struct ThinkingState {
101 thinking: String,
102 signature: String,
103}
104
105#[derive(Clone, Debug, Deserialize, Serialize)]
106pub struct StreamingCompletionResponse {
107 pub usage: PartialUsage,
108}
109
110impl GetTokenUsage for StreamingCompletionResponse {
111 fn token_usage(&self) -> Option<crate::completion::Usage> {
112 let mut usage = crate::completion::Usage::new();
113 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
114 usage.output_tokens = self.usage.output_tokens as u64;
115 usage.total_tokens =
116 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
117
118 Some(usage)
119 }
120}
121
122impl<T> CompletionModel<T>
123where
124 T: HttpClientExt + Clone + Default + 'static,
125{
126 pub(crate) async fn stream(
127 &self,
128 completion_request: CompletionRequest,
129 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
130 {
131 let span = if tracing::Span::current().is_disabled() {
132 info_span!(
133 target: "rig::completions",
134 "chat_streaming",
135 gen_ai.operation.name = "chat_streaming",
136 gen_ai.provider.name = "anthropic",
137 gen_ai.request.model = self.model,
138 gen_ai.system_instructions = &completion_request.preamble,
139 gen_ai.response.id = tracing::field::Empty,
140 gen_ai.response.model = self.model,
141 gen_ai.usage.output_tokens = tracing::field::Empty,
142 gen_ai.usage.input_tokens = tracing::field::Empty,
143 gen_ai.input.messages = tracing::field::Empty,
144 gen_ai.output.messages = tracing::field::Empty,
145 )
146 } else {
147 tracing::Span::current()
148 };
149 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
150 tokens
151 } else if let Some(tokens) = self.default_max_tokens {
152 tokens
153 } else {
154 return Err(CompletionError::RequestError(
155 "`max_tokens` must be set for Anthropic".into(),
156 ));
157 };
158
159 let mut full_history = vec![];
160 if let Some(docs) = completion_request.normalized_documents() {
161 full_history.push(docs);
162 }
163 full_history.extend(completion_request.chat_history);
164
165 let mut messages = full_history
166 .into_iter()
167 .map(Message::try_from)
168 .collect::<Result<Vec<Message>, _>>()?;
169
170 let mut system: Vec<SystemContent> =
172 if let Some(preamble) = completion_request.preamble.as_ref() {
173 if preamble.is_empty() {
174 vec![]
175 } else {
176 vec![SystemContent::Text {
177 text: preamble.clone(),
178 cache_control: None,
179 }]
180 }
181 } else {
182 vec![]
183 };
184
185 if self.prompt_caching {
187 apply_cache_control(&mut system, &mut messages);
188 }
189
190 let mut body = json!({
191 "model": self.model,
192 "messages": messages,
193 "max_tokens": max_tokens,
194 "stream": true,
195 });
196
197 if !system.is_empty() {
199 merge_inplace(&mut body, json!({ "system": system }));
200 }
201
202 if let Some(temperature) = completion_request.temperature {
203 merge_inplace(&mut body, json!({ "temperature": temperature }));
204 }
205
206 if !completion_request.tools.is_empty() {
207 merge_inplace(
208 &mut body,
209 json!({
210 "tools": completion_request
211 .tools
212 .into_iter()
213 .map(|tool| ToolDefinition {
214 name: tool.name,
215 description: Some(tool.description),
216 input_schema: tool.parameters,
217 })
218 .collect::<Vec<_>>(),
219 "tool_choice": ToolChoice::Auto,
220 }),
221 );
222 }
223
224 if let Some(ref params) = completion_request.additional_params {
225 merge_inplace(&mut body, params.clone())
226 }
227
228 if enabled!(Level::TRACE) {
229 tracing::trace!(
230 target: "rig::completions",
231 "Anthropic completion request: {}",
232 serde_json::to_string_pretty(&body)?
233 );
234 }
235
236 let body: Vec<u8> = serde_json::to_vec(&body)?;
237
238 let req = self
239 .client
240 .post("/v1/messages")?
241 .body(body)
242 .map_err(http_client::Error::Protocol)?;
243
244 let stream = GenericEventSource::new(self.client.clone(), req);
245
246 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
248 let mut current_tool_call: Option<ToolCallState> = None;
249 let mut current_thinking: Option<ThinkingState> = None;
250 let mut sse_stream = Box::pin(stream);
251 let mut input_tokens = 0;
252 let mut final_usage = None;
253
254 let mut text_content = String::new();
255
256 while let Some(sse_result) = sse_stream.next().await {
257 match sse_result {
258 Ok(Event::Open) => {}
259 Ok(Event::Message(sse)) => {
260 match serde_json::from_str::<StreamingEvent>(&sse.data) {
262 Ok(event) => {
263 match &event {
264 StreamingEvent::MessageStart { message } => {
265 input_tokens = message.usage.input_tokens;
266
267 let span = tracing::Span::current();
268 span.record("gen_ai.response.id", &message.id);
269 span.record("gen_ai.response.model_name", &message.model);
270 },
271 StreamingEvent::MessageDelta { delta, usage } => {
272 if delta.stop_reason.is_some() {
273 let usage = PartialUsage {
274 output_tokens: usage.output_tokens,
275 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
276 };
277
278 let span = tracing::Span::current();
279 span.record_token_usage(&usage);
280 final_usage = Some(usage);
281 break;
282 }
283 }
284 _ => {}
285 }
286
287 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
288 if let Ok(RawStreamingChoice::Message(ref text)) = result {
289 text_content += text;
290 }
291 yield result;
292 }
293 },
294 Err(e) => {
295 if !sse.data.trim().is_empty() {
296 yield Err(CompletionError::ResponseError(
297 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
298 ));
299 }
300 }
301 }
302 },
303 Err(e) => {
304 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
305 break;
306 }
307 }
308 }
309
310 sse_stream.close();
312
313 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
314 usage: final_usage.unwrap_or_default()
315 }))
316 }.instrument(span));
317
318 Ok(streaming::StreamingCompletionResponse::stream(stream))
319 }
320}
321
322fn handle_event(
323 event: &StreamingEvent,
324 current_tool_call: &mut Option<ToolCallState>,
325 current_thinking: &mut Option<ThinkingState>,
326) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
327 match event {
328 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
329 ContentDelta::TextDelta { text } => {
330 if current_tool_call.is_none() {
331 return Some(Ok(RawStreamingChoice::Message(text.clone())));
332 }
333 None
334 }
335 ContentDelta::InputJsonDelta { partial_json } => {
336 if let Some(tool_call) = current_tool_call {
337 tool_call.input_json.push_str(partial_json);
338 return Some(Ok(RawStreamingChoice::ToolCallDelta {
340 id: tool_call.id.clone(),
341 content: ToolCallDeltaContent::Delta(partial_json.clone()),
342 }));
343 }
344 None
345 }
346 ContentDelta::ThinkingDelta { thinking } => {
347 if current_thinking.is_none() {
348 *current_thinking = Some(ThinkingState::default());
349 }
350
351 if let Some(state) = current_thinking {
352 state.thinking.push_str(thinking);
353 }
354
355 Some(Ok(RawStreamingChoice::ReasoningDelta {
356 id: None,
357 reasoning: thinking.clone(),
358 }))
359 }
360 ContentDelta::SignatureDelta { signature } => {
361 if current_thinking.is_none() {
362 *current_thinking = Some(ThinkingState::default());
363 }
364
365 if let Some(state) = current_thinking {
366 state.signature.push_str(signature);
367 }
368
369 None
371 }
372 },
373 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
374 Content::ToolUse { id, name, .. } => {
375 *current_tool_call = Some(ToolCallState {
376 name: name.clone(),
377 id: id.clone(),
378 input_json: String::new(),
379 });
380 Some(Ok(RawStreamingChoice::ToolCallDelta {
381 id: id.clone(),
382 content: ToolCallDeltaContent::Name(name.clone()),
383 }))
384 }
385 Content::Thinking { .. } => {
386 *current_thinking = Some(ThinkingState::default());
387 None
388 }
389 _ => None,
391 },
392 StreamingEvent::ContentBlockStop { .. } => {
393 if let Some(thinking_state) = Option::take(current_thinking)
394 && !thinking_state.thinking.is_empty()
395 {
396 let signature = if thinking_state.signature.is_empty() {
397 None
398 } else {
399 Some(thinking_state.signature)
400 };
401
402 return Some(Ok(RawStreamingChoice::Reasoning {
403 id: None,
404 reasoning: thinking_state.thinking,
405 signature,
406 }));
407 }
408
409 if let Some(tool_call) = Option::take(current_tool_call) {
410 let json_str = if tool_call.input_json.is_empty() {
411 "{}"
412 } else {
413 &tool_call.input_json
414 };
415 match serde_json::from_str(json_str) {
416 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
417 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value),
418 ))),
419 Err(e) => Some(Err(CompletionError::from(e))),
420 }
421 } else {
422 None
423 }
424 }
425 StreamingEvent::MessageStart { .. }
427 | StreamingEvent::MessageDelta { .. }
428 | StreamingEvent::MessageStop
429 | StreamingEvent::Ping
430 | StreamingEvent::Unknown => None,
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_thinking_delta_deserialization() {
440 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
441 let delta: ContentDelta = serde_json::from_str(json).unwrap();
442
443 match delta {
444 ContentDelta::ThinkingDelta { thinking } => {
445 assert_eq!(thinking, "Let me think about this...");
446 }
447 _ => panic!("Expected ThinkingDelta variant"),
448 }
449 }
450
451 #[test]
452 fn test_signature_delta_deserialization() {
453 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
454 let delta: ContentDelta = serde_json::from_str(json).unwrap();
455
456 match delta {
457 ContentDelta::SignatureDelta { signature } => {
458 assert_eq!(signature, "abc123def456");
459 }
460 _ => panic!("Expected SignatureDelta variant"),
461 }
462 }
463
464 #[test]
465 fn test_thinking_delta_streaming_event_deserialization() {
466 let json = r#"{
467 "type": "content_block_delta",
468 "index": 0,
469 "delta": {
470 "type": "thinking_delta",
471 "thinking": "First, I need to understand the problem."
472 }
473 }"#;
474
475 let event: StreamingEvent = serde_json::from_str(json).unwrap();
476
477 match event {
478 StreamingEvent::ContentBlockDelta { index, delta } => {
479 assert_eq!(index, 0);
480 match delta {
481 ContentDelta::ThinkingDelta { thinking } => {
482 assert_eq!(thinking, "First, I need to understand the problem.");
483 }
484 _ => panic!("Expected ThinkingDelta"),
485 }
486 }
487 _ => panic!("Expected ContentBlockDelta event"),
488 }
489 }
490
491 #[test]
492 fn test_signature_delta_streaming_event_deserialization() {
493 let json = r#"{
494 "type": "content_block_delta",
495 "index": 0,
496 "delta": {
497 "type": "signature_delta",
498 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
499 }
500 }"#;
501
502 let event: StreamingEvent = serde_json::from_str(json).unwrap();
503
504 match event {
505 StreamingEvent::ContentBlockDelta { index, delta } => {
506 assert_eq!(index, 0);
507 match delta {
508 ContentDelta::SignatureDelta { signature } => {
509 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
510 }
511 _ => panic!("Expected SignatureDelta"),
512 }
513 }
514 _ => panic!("Expected ContentBlockDelta event"),
515 }
516 }
517
518 #[test]
519 fn test_handle_thinking_delta_event() {
520 let event = StreamingEvent::ContentBlockDelta {
521 index: 0,
522 delta: ContentDelta::ThinkingDelta {
523 thinking: "Analyzing the request...".to_string(),
524 },
525 };
526
527 let mut tool_call_state = None;
528 let mut thinking_state = None;
529 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
530
531 assert!(result.is_some());
532 let choice = result.unwrap().unwrap();
533
534 match choice {
535 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
536 assert_eq!(id, None);
537 assert_eq!(reasoning, "Analyzing the request...");
538 }
539 _ => panic!("Expected ReasoningDelta choice"),
540 }
541
542 assert!(thinking_state.is_some());
544 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
545 }
546
547 #[test]
548 fn test_handle_signature_delta_event() {
549 let event = StreamingEvent::ContentBlockDelta {
550 index: 0,
551 delta: ContentDelta::SignatureDelta {
552 signature: "test_signature".to_string(),
553 },
554 };
555
556 let mut tool_call_state = None;
557 let mut thinking_state = None;
558 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
559
560 assert!(result.is_none());
562
563 assert!(thinking_state.is_some());
565 assert_eq!(thinking_state.unwrap().signature, "test_signature");
566 }
567
568 #[test]
569 fn test_handle_text_delta_event() {
570 let event = StreamingEvent::ContentBlockDelta {
571 index: 0,
572 delta: ContentDelta::TextDelta {
573 text: "Hello, world!".to_string(),
574 },
575 };
576
577 let mut tool_call_state = None;
578 let mut thinking_state = None;
579 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
580
581 assert!(result.is_some());
582 let choice = result.unwrap().unwrap();
583
584 match choice {
585 RawStreamingChoice::Message(text) => {
586 assert_eq!(text, "Hello, world!");
587 }
588 _ => panic!("Expected Message choice"),
589 }
590 }
591
592 #[test]
593 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
594 let event = StreamingEvent::ContentBlockDelta {
596 index: 0,
597 delta: ContentDelta::ThinkingDelta {
598 thinking: "Thinking while tool is active...".to_string(),
599 },
600 };
601
602 let mut tool_call_state = Some(ToolCallState {
603 name: "test_tool".to_string(),
604 id: "tool_123".to_string(),
605 input_json: String::new(),
606 });
607 let mut thinking_state = None;
608
609 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
610
611 assert!(result.is_some());
612 let choice = result.unwrap().unwrap();
613
614 match choice {
615 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
616 assert_eq!(reasoning, "Thinking while tool is active...");
617 }
618 _ => panic!("Expected ReasoningDelta choice"),
619 }
620
621 assert!(tool_call_state.is_some());
623 }
624
625 #[test]
626 fn test_handle_input_json_delta_event() {
627 let event = StreamingEvent::ContentBlockDelta {
628 index: 0,
629 delta: ContentDelta::InputJsonDelta {
630 partial_json: "{\"arg\":\"value".to_string(),
631 },
632 };
633
634 let mut tool_call_state = Some(ToolCallState {
635 name: "test_tool".to_string(),
636 id: "tool_123".to_string(),
637 input_json: String::new(),
638 });
639 let mut thinking_state = None;
640
641 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
642
643 assert!(result.is_some());
645 let choice = result.unwrap().unwrap();
646
647 match choice {
648 RawStreamingChoice::ToolCallDelta { id, content } => {
649 assert_eq!(id, "tool_123");
650 match content {
651 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
652 _ => panic!("Expected Delta content"),
653 }
654 }
655 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
656 }
657
658 assert!(tool_call_state.is_some());
660 let state = tool_call_state.unwrap();
661 assert_eq!(state.input_json, "{\"arg\":\"value");
662 }
663
664 #[test]
665 fn test_tool_call_accumulation_with_multiple_deltas() {
666 let mut tool_call_state = Some(ToolCallState {
667 name: "test_tool".to_string(),
668 id: "tool_123".to_string(),
669 input_json: String::new(),
670 });
671 let mut thinking_state = None;
672
673 let event1 = StreamingEvent::ContentBlockDelta {
675 index: 0,
676 delta: ContentDelta::InputJsonDelta {
677 partial_json: "{\"location\":".to_string(),
678 },
679 };
680 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
681 assert!(result1.is_some());
682
683 let event2 = StreamingEvent::ContentBlockDelta {
685 index: 0,
686 delta: ContentDelta::InputJsonDelta {
687 partial_json: "\"Paris\",".to_string(),
688 },
689 };
690 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
691 assert!(result2.is_some());
692
693 let event3 = StreamingEvent::ContentBlockDelta {
695 index: 0,
696 delta: ContentDelta::InputJsonDelta {
697 partial_json: "\"temp\":\"20C\"}".to_string(),
698 },
699 };
700 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
701 assert!(result3.is_some());
702
703 assert!(tool_call_state.is_some());
705 let state = tool_call_state.as_ref().unwrap();
706 assert_eq!(
707 state.input_json,
708 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
709 );
710
711 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
713 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
714 assert!(final_result.is_some());
715
716 match final_result.unwrap().unwrap() {
717 RawStreamingChoice::ToolCall(RawStreamingToolCall {
718 id,
719 name,
720 arguments,
721 ..
722 }) => {
723 assert_eq!(id, "tool_123");
724 assert_eq!(name, "test_tool");
725 assert_eq!(
726 arguments.get("location").unwrap().as_str().unwrap(),
727 "Paris"
728 );
729 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
730 }
731 other => panic!("Expected ToolCall, got {:?}", other),
732 }
733
734 assert!(tool_call_state.is_none());
736 }
737}