1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{self, RawStreamingChoice, RawStreamingToolCall, StreamingResult};
17use crate::telemetry::SpanCombinator;
18
19#[derive(Debug, Deserialize)]
20#[serde(tag = "type", rename_all = "snake_case")]
21pub enum StreamingEvent {
22 MessageStart {
23 message: MessageStart,
24 },
25 ContentBlockStart {
26 index: usize,
27 content_block: Content,
28 },
29 ContentBlockDelta {
30 index: usize,
31 delta: ContentDelta,
32 },
33 ContentBlockStop {
34 index: usize,
35 },
36 MessageDelta {
37 delta: MessageDelta,
38 usage: PartialUsage,
39 },
40 MessageStop,
41 Ping,
42 #[serde(other)]
43 Unknown,
44}
45
46#[derive(Debug, Deserialize)]
47pub struct MessageStart {
48 pub id: String,
49 pub role: String,
50 pub content: Vec<Content>,
51 pub model: String,
52 pub stop_reason: Option<String>,
53 pub stop_sequence: Option<String>,
54 pub usage: Usage,
55}
56
57#[derive(Debug, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ContentDelta {
60 TextDelta { text: String },
61 InputJsonDelta { partial_json: String },
62 ThinkingDelta { thinking: String },
63 SignatureDelta { signature: String },
64}
65
66#[derive(Debug, Deserialize)]
67pub struct MessageDelta {
68 pub stop_reason: Option<String>,
69 pub stop_sequence: Option<String>,
70}
71
72#[derive(Debug, Deserialize, Clone, Serialize, Default)]
73pub struct PartialUsage {
74 pub output_tokens: usize,
75 #[serde(default)]
76 pub input_tokens: Option<usize>,
77}
78
79impl GetTokenUsage for PartialUsage {
80 fn token_usage(&self) -> Option<crate::completion::Usage> {
81 let mut usage = crate::completion::Usage::new();
82
83 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
84 usage.output_tokens = self.output_tokens as u64;
85 usage.total_tokens = usage.input_tokens + usage.output_tokens;
86 Some(usage)
87 }
88}
89
90#[derive(Default)]
91struct ToolCallState {
92 name: String,
93 id: String,
94 input_json: String,
95}
96
97#[derive(Default)]
98struct ThinkingState {
99 thinking: String,
100 signature: String,
101}
102
103#[derive(Clone, Debug, Deserialize, Serialize)]
104pub struct StreamingCompletionResponse {
105 pub usage: PartialUsage,
106}
107
108impl GetTokenUsage for StreamingCompletionResponse {
109 fn token_usage(&self) -> Option<crate::completion::Usage> {
110 let mut usage = crate::completion::Usage::new();
111 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
112 usage.output_tokens = self.usage.output_tokens as u64;
113 usage.total_tokens =
114 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
115
116 Some(usage)
117 }
118}
119
120impl<T> CompletionModel<T>
121where
122 T: HttpClientExt + Clone + Default + 'static,
123{
124 pub(crate) async fn stream(
125 &self,
126 completion_request: CompletionRequest,
127 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
128 {
129 let span = if tracing::Span::current().is_disabled() {
130 info_span!(
131 target: "rig::completions",
132 "chat_streaming",
133 gen_ai.operation.name = "chat_streaming",
134 gen_ai.provider.name = "anthropic",
135 gen_ai.request.model = self.model,
136 gen_ai.system_instructions = &completion_request.preamble,
137 gen_ai.response.id = tracing::field::Empty,
138 gen_ai.response.model = self.model,
139 gen_ai.usage.output_tokens = tracing::field::Empty,
140 gen_ai.usage.input_tokens = tracing::field::Empty,
141 gen_ai.input.messages = tracing::field::Empty,
142 gen_ai.output.messages = tracing::field::Empty,
143 )
144 } else {
145 tracing::Span::current()
146 };
147 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
148 tokens
149 } else if let Some(tokens) = self.default_max_tokens {
150 tokens
151 } else {
152 return Err(CompletionError::RequestError(
153 "`max_tokens` must be set for Anthropic".into(),
154 ));
155 };
156
157 let mut full_history = vec![];
158 if let Some(docs) = completion_request.normalized_documents() {
159 full_history.push(docs);
160 }
161 full_history.extend(completion_request.chat_history);
162
163 let mut messages = full_history
164 .into_iter()
165 .map(Message::try_from)
166 .collect::<Result<Vec<Message>, _>>()?;
167
168 let mut system: Vec<SystemContent> =
170 if let Some(preamble) = completion_request.preamble.as_ref() {
171 if preamble.is_empty() {
172 vec![]
173 } else {
174 vec![SystemContent::Text {
175 text: preamble.clone(),
176 cache_control: None,
177 }]
178 }
179 } else {
180 vec![]
181 };
182
183 if self.prompt_caching {
185 apply_cache_control(&mut system, &mut messages);
186 }
187
188 let mut body = json!({
189 "model": self.model,
190 "messages": messages,
191 "max_tokens": max_tokens,
192 "stream": true,
193 });
194
195 if !system.is_empty() {
197 merge_inplace(&mut body, json!({ "system": system }));
198 }
199
200 if let Some(temperature) = completion_request.temperature {
201 merge_inplace(&mut body, json!({ "temperature": temperature }));
202 }
203
204 if !completion_request.tools.is_empty() {
205 merge_inplace(
206 &mut body,
207 json!({
208 "tools": completion_request
209 .tools
210 .into_iter()
211 .map(|tool| ToolDefinition {
212 name: tool.name,
213 description: Some(tool.description),
214 input_schema: tool.parameters,
215 })
216 .collect::<Vec<_>>(),
217 "tool_choice": ToolChoice::Auto,
218 }),
219 );
220 }
221
222 if let Some(ref params) = completion_request.additional_params {
223 merge_inplace(&mut body, params.clone())
224 }
225
226 if enabled!(Level::TRACE) {
227 tracing::trace!(
228 target: "rig::completions",
229 "Anthropic completion request: {}",
230 serde_json::to_string_pretty(&body)?
231 );
232 }
233
234 let body: Vec<u8> = serde_json::to_vec(&body)?;
235
236 let req = self
237 .client
238 .post("/v1/messages")?
239 .body(body)
240 .map_err(http_client::Error::Protocol)?;
241
242 let stream = GenericEventSource::new(self.client.clone(), req);
243
244 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
246 let mut current_tool_call: Option<ToolCallState> = None;
247 let mut current_thinking: Option<ThinkingState> = None;
248 let mut sse_stream = Box::pin(stream);
249 let mut input_tokens = 0;
250 let mut final_usage = None;
251
252 let mut text_content = String::new();
253
254 while let Some(sse_result) = sse_stream.next().await {
255 match sse_result {
256 Ok(Event::Open) => {}
257 Ok(Event::Message(sse)) => {
258 match serde_json::from_str::<StreamingEvent>(&sse.data) {
260 Ok(event) => {
261 match &event {
262 StreamingEvent::MessageStart { message } => {
263 input_tokens = message.usage.input_tokens;
264
265 let span = tracing::Span::current();
266 span.record("gen_ai.response.id", &message.id);
267 span.record("gen_ai.response.model_name", &message.model);
268 },
269 StreamingEvent::MessageDelta { delta, usage } => {
270 if delta.stop_reason.is_some() {
271 let usage = PartialUsage {
272 output_tokens: usage.output_tokens,
273 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
274 };
275
276 let span = tracing::Span::current();
277 span.record_token_usage(&usage);
278 final_usage = Some(usage);
279 break;
280 }
281 }
282 _ => {}
283 }
284
285 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
286 if let Ok(RawStreamingChoice::Message(ref text)) = result {
287 text_content += text;
288 }
289 yield result;
290 }
291 },
292 Err(e) => {
293 if !sse.data.trim().is_empty() {
294 yield Err(CompletionError::ResponseError(
295 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
296 ));
297 }
298 }
299 }
300 },
301 Err(e) => {
302 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
303 break;
304 }
305 }
306 }
307
308 sse_stream.close();
310
311 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
312 usage: final_usage.unwrap_or_default()
313 }))
314 }.instrument(span));
315
316 Ok(streaming::StreamingCompletionResponse::stream(stream))
317 }
318}
319
320fn handle_event(
321 event: &StreamingEvent,
322 current_tool_call: &mut Option<ToolCallState>,
323 current_thinking: &mut Option<ThinkingState>,
324) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
325 match event {
326 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
327 ContentDelta::TextDelta { text } => {
328 if current_tool_call.is_none() {
329 return Some(Ok(RawStreamingChoice::Message(text.clone())));
330 }
331 None
332 }
333 ContentDelta::InputJsonDelta { partial_json } => {
334 if let Some(tool_call) = current_tool_call {
335 tool_call.input_json.push_str(partial_json);
336 return Some(Ok(RawStreamingChoice::ToolCallDelta {
338 id: tool_call.id.clone(),
339 delta: partial_json.clone(),
340 }));
341 }
342 None
343 }
344 ContentDelta::ThinkingDelta { thinking } => {
345 if current_thinking.is_none() {
346 *current_thinking = Some(ThinkingState::default());
347 }
348
349 if let Some(state) = current_thinking {
350 state.thinking.push_str(thinking);
351 }
352
353 Some(Ok(RawStreamingChoice::ReasoningDelta {
354 id: None,
355 reasoning: thinking.clone(),
356 }))
357 }
358 ContentDelta::SignatureDelta { signature } => {
359 if current_thinking.is_none() {
360 *current_thinking = Some(ThinkingState::default());
361 }
362
363 if let Some(state) = current_thinking {
364 state.signature.push_str(signature);
365 }
366
367 None
369 }
370 },
371 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
372 Content::ToolUse { id, name, .. } => {
373 *current_tool_call = Some(ToolCallState {
374 name: name.clone(),
375 id: id.clone(),
376 input_json: String::new(),
377 });
378 None
379 }
380 Content::Thinking { .. } => {
381 *current_thinking = Some(ThinkingState::default());
382 None
383 }
384 _ => None,
386 },
387 StreamingEvent::ContentBlockStop { .. } => {
388 if let Some(thinking_state) = Option::take(current_thinking)
389 && !thinking_state.thinking.is_empty()
390 {
391 let signature = if thinking_state.signature.is_empty() {
392 None
393 } else {
394 Some(thinking_state.signature)
395 };
396
397 return Some(Ok(RawStreamingChoice::Reasoning {
398 id: None,
399 reasoning: thinking_state.thinking,
400 signature,
401 }));
402 }
403
404 if let Some(tool_call) = Option::take(current_tool_call) {
405 let json_str = if tool_call.input_json.is_empty() {
406 "{}"
407 } else {
408 &tool_call.input_json
409 };
410 match serde_json::from_str(json_str) {
411 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall(
412 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value),
413 ))),
414 Err(e) => Some(Err(CompletionError::from(e))),
415 }
416 } else {
417 None
418 }
419 }
420 StreamingEvent::MessageStart { .. }
422 | StreamingEvent::MessageDelta { .. }
423 | StreamingEvent::MessageStop
424 | StreamingEvent::Ping
425 | StreamingEvent::Unknown => None,
426 }
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432
433 #[test]
434 fn test_thinking_delta_deserialization() {
435 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
436 let delta: ContentDelta = serde_json::from_str(json).unwrap();
437
438 match delta {
439 ContentDelta::ThinkingDelta { thinking } => {
440 assert_eq!(thinking, "Let me think about this...");
441 }
442 _ => panic!("Expected ThinkingDelta variant"),
443 }
444 }
445
446 #[test]
447 fn test_signature_delta_deserialization() {
448 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
449 let delta: ContentDelta = serde_json::from_str(json).unwrap();
450
451 match delta {
452 ContentDelta::SignatureDelta { signature } => {
453 assert_eq!(signature, "abc123def456");
454 }
455 _ => panic!("Expected SignatureDelta variant"),
456 }
457 }
458
459 #[test]
460 fn test_thinking_delta_streaming_event_deserialization() {
461 let json = r#"{
462 "type": "content_block_delta",
463 "index": 0,
464 "delta": {
465 "type": "thinking_delta",
466 "thinking": "First, I need to understand the problem."
467 }
468 }"#;
469
470 let event: StreamingEvent = serde_json::from_str(json).unwrap();
471
472 match event {
473 StreamingEvent::ContentBlockDelta { index, delta } => {
474 assert_eq!(index, 0);
475 match delta {
476 ContentDelta::ThinkingDelta { thinking } => {
477 assert_eq!(thinking, "First, I need to understand the problem.");
478 }
479 _ => panic!("Expected ThinkingDelta"),
480 }
481 }
482 _ => panic!("Expected ContentBlockDelta event"),
483 }
484 }
485
486 #[test]
487 fn test_signature_delta_streaming_event_deserialization() {
488 let json = r#"{
489 "type": "content_block_delta",
490 "index": 0,
491 "delta": {
492 "type": "signature_delta",
493 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
494 }
495 }"#;
496
497 let event: StreamingEvent = serde_json::from_str(json).unwrap();
498
499 match event {
500 StreamingEvent::ContentBlockDelta { index, delta } => {
501 assert_eq!(index, 0);
502 match delta {
503 ContentDelta::SignatureDelta { signature } => {
504 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
505 }
506 _ => panic!("Expected SignatureDelta"),
507 }
508 }
509 _ => panic!("Expected ContentBlockDelta event"),
510 }
511 }
512
513 #[test]
514 fn test_handle_thinking_delta_event() {
515 let event = StreamingEvent::ContentBlockDelta {
516 index: 0,
517 delta: ContentDelta::ThinkingDelta {
518 thinking: "Analyzing the request...".to_string(),
519 },
520 };
521
522 let mut tool_call_state = None;
523 let mut thinking_state = None;
524 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
525
526 assert!(result.is_some());
527 let choice = result.unwrap().unwrap();
528
529 match choice {
530 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
531 assert_eq!(id, None);
532 assert_eq!(reasoning, "Analyzing the request...");
533 }
534 _ => panic!("Expected ReasoningDelta choice"),
535 }
536
537 assert!(thinking_state.is_some());
539 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
540 }
541
542 #[test]
543 fn test_handle_signature_delta_event() {
544 let event = StreamingEvent::ContentBlockDelta {
545 index: 0,
546 delta: ContentDelta::SignatureDelta {
547 signature: "test_signature".to_string(),
548 },
549 };
550
551 let mut tool_call_state = None;
552 let mut thinking_state = None;
553 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
554
555 assert!(result.is_none());
557
558 assert!(thinking_state.is_some());
560 assert_eq!(thinking_state.unwrap().signature, "test_signature");
561 }
562
563 #[test]
564 fn test_handle_text_delta_event() {
565 let event = StreamingEvent::ContentBlockDelta {
566 index: 0,
567 delta: ContentDelta::TextDelta {
568 text: "Hello, world!".to_string(),
569 },
570 };
571
572 let mut tool_call_state = None;
573 let mut thinking_state = None;
574 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
575
576 assert!(result.is_some());
577 let choice = result.unwrap().unwrap();
578
579 match choice {
580 RawStreamingChoice::Message(text) => {
581 assert_eq!(text, "Hello, world!");
582 }
583 _ => panic!("Expected Message choice"),
584 }
585 }
586
587 #[test]
588 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
589 let event = StreamingEvent::ContentBlockDelta {
591 index: 0,
592 delta: ContentDelta::ThinkingDelta {
593 thinking: "Thinking while tool is active...".to_string(),
594 },
595 };
596
597 let mut tool_call_state = Some(ToolCallState {
598 name: "test_tool".to_string(),
599 id: "tool_123".to_string(),
600 input_json: String::new(),
601 });
602 let mut thinking_state = None;
603
604 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
605
606 assert!(result.is_some());
607 let choice = result.unwrap().unwrap();
608
609 match choice {
610 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
611 assert_eq!(reasoning, "Thinking while tool is active...");
612 }
613 _ => panic!("Expected ReasoningDelta choice"),
614 }
615
616 assert!(tool_call_state.is_some());
618 }
619
620 #[test]
621 fn test_handle_input_json_delta_event() {
622 let event = StreamingEvent::ContentBlockDelta {
623 index: 0,
624 delta: ContentDelta::InputJsonDelta {
625 partial_json: "{\"arg\":\"value".to_string(),
626 },
627 };
628
629 let mut tool_call_state = Some(ToolCallState {
630 name: "test_tool".to_string(),
631 id: "tool_123".to_string(),
632 input_json: String::new(),
633 });
634 let mut thinking_state = None;
635
636 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
637
638 assert!(result.is_some());
640 let choice = result.unwrap().unwrap();
641
642 match choice {
643 RawStreamingChoice::ToolCallDelta { id, delta } => {
644 assert_eq!(id, "tool_123");
645 assert_eq!(delta, "{\"arg\":\"value");
646 }
647 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
648 }
649
650 assert!(tool_call_state.is_some());
652 let state = tool_call_state.unwrap();
653 assert_eq!(state.input_json, "{\"arg\":\"value");
654 }
655
656 #[test]
657 fn test_tool_call_accumulation_with_multiple_deltas() {
658 let mut tool_call_state = Some(ToolCallState {
659 name: "test_tool".to_string(),
660 id: "tool_123".to_string(),
661 input_json: String::new(),
662 });
663 let mut thinking_state = None;
664
665 let event1 = StreamingEvent::ContentBlockDelta {
667 index: 0,
668 delta: ContentDelta::InputJsonDelta {
669 partial_json: "{\"location\":".to_string(),
670 },
671 };
672 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
673 assert!(result1.is_some());
674
675 let event2 = StreamingEvent::ContentBlockDelta {
677 index: 0,
678 delta: ContentDelta::InputJsonDelta {
679 partial_json: "\"Paris\",".to_string(),
680 },
681 };
682 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
683 assert!(result2.is_some());
684
685 let event3 = StreamingEvent::ContentBlockDelta {
687 index: 0,
688 delta: ContentDelta::InputJsonDelta {
689 partial_json: "\"temp\":\"20C\"}".to_string(),
690 },
691 };
692 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
693 assert!(result3.is_some());
694
695 assert!(tool_call_state.is_some());
697 let state = tool_call_state.as_ref().unwrap();
698 assert_eq!(
699 state.input_json,
700 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
701 );
702
703 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
705 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
706 assert!(final_result.is_some());
707
708 match final_result.unwrap().unwrap() {
709 RawStreamingChoice::ToolCall(RawStreamingToolCall {
710 id,
711 name,
712 arguments,
713 ..
714 }) => {
715 assert_eq!(id, "tool_123");
716 assert_eq!(name, "test_tool");
717 assert_eq!(
718 arguments.get("location").unwrap().as_str().unwrap(),
719 "Paris"
720 );
721 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
722 }
723 other => panic!("Expected ToolCall, got {:?}", other),
724 }
725
726 assert!(tool_call_state.is_none());
728 }
729}