1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{
17 self, RawStreamingChoice, RawStreamingToolCall, StreamingResult, ToolCallDeltaContent,
18};
19use crate::telemetry::SpanCombinator;
20
21#[derive(Debug, Deserialize)]
22#[serde(tag = "type", rename_all = "snake_case")]
23pub enum StreamingEvent {
24 MessageStart {
25 message: MessageStart,
26 },
27 ContentBlockStart {
28 index: usize,
29 content_block: Content,
30 },
31 ContentBlockDelta {
32 index: usize,
33 delta: ContentDelta,
34 },
35 ContentBlockStop {
36 index: usize,
37 },
38 MessageDelta {
39 delta: MessageDelta,
40 usage: PartialUsage,
41 },
42 MessageStop,
43 Ping,
44 #[serde(other)]
45 Unknown,
46}
47
48#[derive(Debug, Deserialize)]
49pub struct MessageStart {
50 pub id: String,
51 pub role: String,
52 pub content: Vec<Content>,
53 pub model: String,
54 pub stop_reason: Option<String>,
55 pub stop_sequence: Option<String>,
56 pub usage: Usage,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(tag = "type", rename_all = "snake_case")]
61pub enum ContentDelta {
62 TextDelta { text: String },
63 InputJsonDelta { partial_json: String },
64 ThinkingDelta { thinking: String },
65 SignatureDelta { signature: String },
66}
67
68#[derive(Debug, Deserialize)]
69pub struct MessageDelta {
70 pub stop_reason: Option<String>,
71 pub stop_sequence: Option<String>,
72}
73
74#[derive(Debug, Deserialize, Clone, Serialize, Default)]
75pub struct PartialUsage {
76 pub output_tokens: usize,
77 #[serde(default)]
78 pub input_tokens: Option<usize>,
79}
80
81impl GetTokenUsage for PartialUsage {
82 fn token_usage(&self) -> Option<crate::completion::Usage> {
83 let mut usage = crate::completion::Usage::new();
84
85 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
86 usage.output_tokens = self.output_tokens as u64;
87 usage.total_tokens = usage.input_tokens + usage.output_tokens;
88 Some(usage)
89 }
90}
91
92#[derive(Default)]
93struct ToolCallState {
94 name: String,
95 id: String,
96 internal_call_id: String,
97 input_json: String,
98}
99
100#[derive(Default)]
101struct ThinkingState {
102 thinking: String,
103 signature: String,
104}
105
106#[derive(Clone, Debug, Deserialize, Serialize)]
107pub struct StreamingCompletionResponse {
108 pub usage: PartialUsage,
109}
110
111impl GetTokenUsage for StreamingCompletionResponse {
112 fn token_usage(&self) -> Option<crate::completion::Usage> {
113 let mut usage = crate::completion::Usage::new();
114 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
115 usage.output_tokens = self.usage.output_tokens as u64;
116 usage.total_tokens =
117 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
118
119 Some(usage)
120 }
121}
122
123impl<T> CompletionModel<T>
124where
125 T: HttpClientExt + Clone + Default + 'static,
126{
127 pub(crate) async fn stream(
128 &self,
129 completion_request: CompletionRequest,
130 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
131 {
132 let span = if tracing::Span::current().is_disabled() {
133 info_span!(
134 target: "rig::completions",
135 "chat_streaming",
136 gen_ai.operation.name = "chat_streaming",
137 gen_ai.provider.name = "anthropic",
138 gen_ai.request.model = self.model,
139 gen_ai.system_instructions = &completion_request.preamble,
140 gen_ai.response.id = tracing::field::Empty,
141 gen_ai.response.model = self.model,
142 gen_ai.usage.output_tokens = tracing::field::Empty,
143 gen_ai.usage.input_tokens = tracing::field::Empty,
144 gen_ai.input.messages = tracing::field::Empty,
145 gen_ai.output.messages = tracing::field::Empty,
146 )
147 } else {
148 tracing::Span::current()
149 };
150 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
151 tokens
152 } else if let Some(tokens) = self.default_max_tokens {
153 tokens
154 } else {
155 return Err(CompletionError::RequestError(
156 "`max_tokens` must be set for Anthropic".into(),
157 ));
158 };
159
160 let mut full_history = vec![];
161 if let Some(docs) = completion_request.normalized_documents() {
162 full_history.push(docs);
163 }
164 full_history.extend(completion_request.chat_history);
165
166 let mut messages = full_history
167 .into_iter()
168 .map(Message::try_from)
169 .collect::<Result<Vec<Message>, _>>()?;
170
171 let mut system: Vec<SystemContent> =
173 if let Some(preamble) = completion_request.preamble.as_ref() {
174 if preamble.is_empty() {
175 vec![]
176 } else {
177 vec![SystemContent::Text {
178 text: preamble.clone(),
179 cache_control: None,
180 }]
181 }
182 } else {
183 vec![]
184 };
185
186 if self.prompt_caching {
188 apply_cache_control(&mut system, &mut messages);
189 }
190
191 let mut body = json!({
192 "model": self.model,
193 "messages": messages,
194 "max_tokens": max_tokens,
195 "stream": true,
196 });
197
198 if !system.is_empty() {
200 merge_inplace(&mut body, json!({ "system": system }));
201 }
202
203 if let Some(temperature) = completion_request.temperature {
204 merge_inplace(&mut body, json!({ "temperature": temperature }));
205 }
206
207 if !completion_request.tools.is_empty() {
208 merge_inplace(
209 &mut body,
210 json!({
211 "tools": completion_request
212 .tools
213 .into_iter()
214 .map(|tool| ToolDefinition {
215 name: tool.name,
216 description: Some(tool.description),
217 input_schema: tool.parameters,
218 })
219 .collect::<Vec<_>>(),
220 "tool_choice": ToolChoice::Auto,
221 }),
222 );
223 }
224
225 if let Some(ref params) = completion_request.additional_params {
226 merge_inplace(&mut body, params.clone())
227 }
228
229 if enabled!(Level::TRACE) {
230 tracing::trace!(
231 target: "rig::completions",
232 "Anthropic completion request: {}",
233 serde_json::to_string_pretty(&body)?
234 );
235 }
236
237 let body: Vec<u8> = serde_json::to_vec(&body)?;
238
239 let req = self
240 .client
241 .post("/v1/messages")?
242 .body(body)
243 .map_err(http_client::Error::Protocol)?;
244
245 let stream = GenericEventSource::new(self.client.clone(), req);
246
247 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
249 let mut current_tool_call: Option<ToolCallState> = None;
250 let mut current_thinking: Option<ThinkingState> = None;
251 let mut sse_stream = Box::pin(stream);
252 let mut input_tokens = 0;
253 let mut final_usage = None;
254
255 let mut text_content = String::new();
256
257 while let Some(sse_result) = sse_stream.next().await {
258 match sse_result {
259 Ok(Event::Open) => {}
260 Ok(Event::Message(sse)) => {
261 match serde_json::from_str::<StreamingEvent>(&sse.data) {
263 Ok(event) => {
264 match &event {
265 StreamingEvent::MessageStart { message } => {
266 input_tokens = message.usage.input_tokens;
267
268 let span = tracing::Span::current();
269 span.record("gen_ai.response.id", &message.id);
270 span.record("gen_ai.response.model_name", &message.model);
271 },
272 StreamingEvent::MessageDelta { delta, usage } => {
273 if delta.stop_reason.is_some() {
274 let usage = PartialUsage {
275 output_tokens: usage.output_tokens,
276 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
277 };
278
279 let span = tracing::Span::current();
280 span.record_token_usage(&usage);
281 final_usage = Some(usage);
282 break;
283 }
284 }
285 _ => {}
286 }
287
288 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
289 if let Ok(RawStreamingChoice::Message(ref text)) = result {
290 text_content += text;
291 }
292 yield result;
293 }
294 },
295 Err(e) => {
296 if !sse.data.trim().is_empty() {
297 yield Err(CompletionError::ResponseError(
298 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
299 ));
300 }
301 }
302 }
303 },
304 Err(e) => {
305 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
306 break;
307 }
308 }
309 }
310
311 sse_stream.close();
313
314 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
315 usage: final_usage.unwrap_or_default()
316 }))
317 }.instrument(span));
318
319 Ok(streaming::StreamingCompletionResponse::stream(stream))
320 }
321}
322
323fn handle_event(
324 event: &StreamingEvent,
325 current_tool_call: &mut Option<ToolCallState>,
326 current_thinking: &mut Option<ThinkingState>,
327) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
328 match event {
329 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
330 ContentDelta::TextDelta { text } => {
331 if current_tool_call.is_none() {
332 return Some(Ok(RawStreamingChoice::Message(text.clone())));
333 }
334 None
335 }
336 ContentDelta::InputJsonDelta { partial_json } => {
337 if let Some(tool_call) = current_tool_call {
338 tool_call.input_json.push_str(partial_json);
339 return Some(Ok(RawStreamingChoice::ToolCallDelta {
341 id: tool_call.id.clone(),
342 internal_call_id: tool_call.internal_call_id.clone(),
343 content: ToolCallDeltaContent::Delta(partial_json.clone()),
344 }));
345 }
346 None
347 }
348 ContentDelta::ThinkingDelta { thinking } => {
349 if current_thinking.is_none() {
350 *current_thinking = Some(ThinkingState::default());
351 }
352
353 if let Some(state) = current_thinking {
354 state.thinking.push_str(thinking);
355 }
356
357 Some(Ok(RawStreamingChoice::ReasoningDelta {
358 id: None,
359 reasoning: thinking.clone(),
360 }))
361 }
362 ContentDelta::SignatureDelta { signature } => {
363 if current_thinking.is_none() {
364 *current_thinking = Some(ThinkingState::default());
365 }
366
367 if let Some(state) = current_thinking {
368 state.signature.push_str(signature);
369 }
370
371 None
373 }
374 },
375 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
376 Content::ToolUse { id, name, .. } => {
377 let internal_call_id = nanoid::nanoid!();
378 *current_tool_call = Some(ToolCallState {
379 name: name.clone(),
380 id: id.clone(),
381 internal_call_id: internal_call_id.clone(),
382 input_json: String::new(),
383 });
384 Some(Ok(RawStreamingChoice::ToolCallDelta {
385 id: id.clone(),
386 internal_call_id,
387 content: ToolCallDeltaContent::Name(name.clone()),
388 }))
389 }
390 Content::Thinking { .. } => {
391 *current_thinking = Some(ThinkingState::default());
392 None
393 }
394 _ => None,
396 },
397 StreamingEvent::ContentBlockStop { .. } => {
398 if let Some(thinking_state) = Option::take(current_thinking)
399 && !thinking_state.thinking.is_empty()
400 {
401 let signature = if thinking_state.signature.is_empty() {
402 None
403 } else {
404 Some(thinking_state.signature)
405 };
406
407 return Some(Ok(RawStreamingChoice::Reasoning {
408 id: None,
409 reasoning: thinking_state.thinking,
410 signature,
411 }));
412 }
413
414 if let Some(tool_call) = Option::take(current_tool_call) {
415 let json_str = if tool_call.input_json.is_empty() {
416 "{}"
417 } else {
418 &tool_call.input_json
419 };
420 match serde_json::from_str(json_str) {
421 Ok(json_value) => {
422 let raw_tool_call =
423 RawStreamingToolCall::new(tool_call.id, tool_call.name, json_value)
424 .with_internal_call_id(tool_call.internal_call_id);
425 Some(Ok(RawStreamingChoice::ToolCall(raw_tool_call)))
426 }
427 Err(e) => Some(Err(CompletionError::from(e))),
428 }
429 } else {
430 None
431 }
432 }
433 StreamingEvent::MessageStart { .. }
435 | StreamingEvent::MessageDelta { .. }
436 | StreamingEvent::MessageStop
437 | StreamingEvent::Ping
438 | StreamingEvent::Unknown => None,
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_thinking_delta_deserialization() {
448 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
449 let delta: ContentDelta = serde_json::from_str(json).unwrap();
450
451 match delta {
452 ContentDelta::ThinkingDelta { thinking } => {
453 assert_eq!(thinking, "Let me think about this...");
454 }
455 _ => panic!("Expected ThinkingDelta variant"),
456 }
457 }
458
459 #[test]
460 fn test_signature_delta_deserialization() {
461 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
462 let delta: ContentDelta = serde_json::from_str(json).unwrap();
463
464 match delta {
465 ContentDelta::SignatureDelta { signature } => {
466 assert_eq!(signature, "abc123def456");
467 }
468 _ => panic!("Expected SignatureDelta variant"),
469 }
470 }
471
472 #[test]
473 fn test_thinking_delta_streaming_event_deserialization() {
474 let json = r#"{
475 "type": "content_block_delta",
476 "index": 0,
477 "delta": {
478 "type": "thinking_delta",
479 "thinking": "First, I need to understand the problem."
480 }
481 }"#;
482
483 let event: StreamingEvent = serde_json::from_str(json).unwrap();
484
485 match event {
486 StreamingEvent::ContentBlockDelta { index, delta } => {
487 assert_eq!(index, 0);
488 match delta {
489 ContentDelta::ThinkingDelta { thinking } => {
490 assert_eq!(thinking, "First, I need to understand the problem.");
491 }
492 _ => panic!("Expected ThinkingDelta"),
493 }
494 }
495 _ => panic!("Expected ContentBlockDelta event"),
496 }
497 }
498
499 #[test]
500 fn test_signature_delta_streaming_event_deserialization() {
501 let json = r#"{
502 "type": "content_block_delta",
503 "index": 0,
504 "delta": {
505 "type": "signature_delta",
506 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
507 }
508 }"#;
509
510 let event: StreamingEvent = serde_json::from_str(json).unwrap();
511
512 match event {
513 StreamingEvent::ContentBlockDelta { index, delta } => {
514 assert_eq!(index, 0);
515 match delta {
516 ContentDelta::SignatureDelta { signature } => {
517 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
518 }
519 _ => panic!("Expected SignatureDelta"),
520 }
521 }
522 _ => panic!("Expected ContentBlockDelta event"),
523 }
524 }
525
526 #[test]
527 fn test_handle_thinking_delta_event() {
528 let event = StreamingEvent::ContentBlockDelta {
529 index: 0,
530 delta: ContentDelta::ThinkingDelta {
531 thinking: "Analyzing the request...".to_string(),
532 },
533 };
534
535 let mut tool_call_state = None;
536 let mut thinking_state = None;
537 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
538
539 assert!(result.is_some());
540 let choice = result.unwrap().unwrap();
541
542 match choice {
543 RawStreamingChoice::ReasoningDelta { id, reasoning, .. } => {
544 assert_eq!(id, None);
545 assert_eq!(reasoning, "Analyzing the request...");
546 }
547 _ => panic!("Expected ReasoningDelta choice"),
548 }
549
550 assert!(thinking_state.is_some());
552 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
553 }
554
555 #[test]
556 fn test_handle_signature_delta_event() {
557 let event = StreamingEvent::ContentBlockDelta {
558 index: 0,
559 delta: ContentDelta::SignatureDelta {
560 signature: "test_signature".to_string(),
561 },
562 };
563
564 let mut tool_call_state = None;
565 let mut thinking_state = None;
566 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
567
568 assert!(result.is_none());
570
571 assert!(thinking_state.is_some());
573 assert_eq!(thinking_state.unwrap().signature, "test_signature");
574 }
575
576 #[test]
577 fn test_handle_text_delta_event() {
578 let event = StreamingEvent::ContentBlockDelta {
579 index: 0,
580 delta: ContentDelta::TextDelta {
581 text: "Hello, world!".to_string(),
582 },
583 };
584
585 let mut tool_call_state = None;
586 let mut thinking_state = None;
587 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
588
589 assert!(result.is_some());
590 let choice = result.unwrap().unwrap();
591
592 match choice {
593 RawStreamingChoice::Message(text) => {
594 assert_eq!(text, "Hello, world!");
595 }
596 _ => panic!("Expected Message choice"),
597 }
598 }
599
600 #[test]
601 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
602 let event = StreamingEvent::ContentBlockDelta {
604 index: 0,
605 delta: ContentDelta::ThinkingDelta {
606 thinking: "Thinking while tool is active...".to_string(),
607 },
608 };
609
610 let mut tool_call_state = Some(ToolCallState {
611 name: "test_tool".to_string(),
612 id: "tool_123".to_string(),
613 internal_call_id: nanoid::nanoid!(),
614 input_json: String::new(),
615 });
616 let mut thinking_state = None;
617
618 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
619
620 assert!(result.is_some());
621 let choice = result.unwrap().unwrap();
622
623 match choice {
624 RawStreamingChoice::ReasoningDelta { reasoning, .. } => {
625 assert_eq!(reasoning, "Thinking while tool is active...");
626 }
627 _ => panic!("Expected ReasoningDelta choice"),
628 }
629
630 assert!(tool_call_state.is_some());
632 }
633
634 #[test]
635 fn test_handle_input_json_delta_event() {
636 let event = StreamingEvent::ContentBlockDelta {
637 index: 0,
638 delta: ContentDelta::InputJsonDelta {
639 partial_json: "{\"arg\":\"value".to_string(),
640 },
641 };
642
643 let mut tool_call_state = Some(ToolCallState {
644 name: "test_tool".to_string(),
645 id: "tool_123".to_string(),
646 internal_call_id: nanoid::nanoid!(),
647 input_json: String::new(),
648 });
649 let mut thinking_state = None;
650
651 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
652
653 assert!(result.is_some());
655 let choice = result.unwrap().unwrap();
656
657 match choice {
658 RawStreamingChoice::ToolCallDelta {
659 id,
660 internal_call_id: _,
661 content,
662 } => {
663 assert_eq!(id, "tool_123");
664 match content {
665 ToolCallDeltaContent::Delta(delta) => assert_eq!(delta, "{\"arg\":\"value"),
666 _ => panic!("Expected Delta content"),
667 }
668 }
669 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
670 }
671
672 assert!(tool_call_state.is_some());
674 let state = tool_call_state.unwrap();
675 assert_eq!(state.input_json, "{\"arg\":\"value");
676 }
677
678 #[test]
679 fn test_tool_call_accumulation_with_multiple_deltas() {
680 let mut tool_call_state = Some(ToolCallState {
681 name: "test_tool".to_string(),
682 id: "tool_123".to_string(),
683 internal_call_id: nanoid::nanoid!(),
684 input_json: String::new(),
685 });
686 let mut thinking_state = None;
687
688 let event1 = StreamingEvent::ContentBlockDelta {
690 index: 0,
691 delta: ContentDelta::InputJsonDelta {
692 partial_json: "{\"location\":".to_string(),
693 },
694 };
695 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
696 assert!(result1.is_some());
697
698 let event2 = StreamingEvent::ContentBlockDelta {
700 index: 0,
701 delta: ContentDelta::InputJsonDelta {
702 partial_json: "\"Paris\",".to_string(),
703 },
704 };
705 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
706 assert!(result2.is_some());
707
708 let event3 = StreamingEvent::ContentBlockDelta {
710 index: 0,
711 delta: ContentDelta::InputJsonDelta {
712 partial_json: "\"temp\":\"20C\"}".to_string(),
713 },
714 };
715 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
716 assert!(result3.is_some());
717
718 assert!(tool_call_state.is_some());
720 let state = tool_call_state.as_ref().unwrap();
721 assert_eq!(
722 state.input_json,
723 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
724 );
725
726 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
728 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
729 assert!(final_result.is_some());
730
731 match final_result.unwrap().unwrap() {
732 RawStreamingChoice::ToolCall(RawStreamingToolCall {
733 id,
734 name,
735 arguments,
736 ..
737 }) => {
738 assert_eq!(id, "tool_123");
739 assert_eq!(name, "test_tool");
740 assert_eq!(
741 arguments.get("location").unwrap().as_str().unwrap(),
742 "Paris"
743 );
744 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
745 }
746 other => panic!("Expected ToolCall, got {:?}", other),
747 }
748
749 assert!(tool_call_state.is_none());
751 }
752}