1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::{Level, enabled, info_span};
6use tracing_futures::Instrument;
7
8use super::completion::{
9 CompletionModel, Content, Message, SystemContent, ToolChoice, ToolDefinition, Usage,
10 apply_cache_control,
11};
12use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
13use crate::http_client::sse::{Event, GenericEventSource};
14use crate::http_client::{self, HttpClientExt};
15use crate::json_utils::merge_inplace;
16use crate::streaming::{self, RawStreamingChoice, StreamingResult};
17use crate::telemetry::SpanCombinator;
18
19#[derive(Debug, Deserialize)]
20#[serde(tag = "type", rename_all = "snake_case")]
21pub enum StreamingEvent {
22 MessageStart {
23 message: MessageStart,
24 },
25 ContentBlockStart {
26 index: usize,
27 content_block: Content,
28 },
29 ContentBlockDelta {
30 index: usize,
31 delta: ContentDelta,
32 },
33 ContentBlockStop {
34 index: usize,
35 },
36 MessageDelta {
37 delta: MessageDelta,
38 usage: PartialUsage,
39 },
40 MessageStop,
41 Ping,
42 #[serde(other)]
43 Unknown,
44}
45
46#[derive(Debug, Deserialize)]
47pub struct MessageStart {
48 pub id: String,
49 pub role: String,
50 pub content: Vec<Content>,
51 pub model: String,
52 pub stop_reason: Option<String>,
53 pub stop_sequence: Option<String>,
54 pub usage: Usage,
55}
56
57#[derive(Debug, Deserialize)]
58#[serde(tag = "type", rename_all = "snake_case")]
59pub enum ContentDelta {
60 TextDelta { text: String },
61 InputJsonDelta { partial_json: String },
62 ThinkingDelta { thinking: String },
63 SignatureDelta { signature: String },
64}
65
66#[derive(Debug, Deserialize)]
67pub struct MessageDelta {
68 pub stop_reason: Option<String>,
69 pub stop_sequence: Option<String>,
70}
71
72#[derive(Debug, Deserialize, Clone, Serialize, Default)]
73pub struct PartialUsage {
74 pub output_tokens: usize,
75 #[serde(default)]
76 pub input_tokens: Option<usize>,
77}
78
79impl GetTokenUsage for PartialUsage {
80 fn token_usage(&self) -> Option<crate::completion::Usage> {
81 let mut usage = crate::completion::Usage::new();
82
83 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
84 usage.output_tokens = self.output_tokens as u64;
85 usage.total_tokens = usage.input_tokens + usage.output_tokens;
86 Some(usage)
87 }
88}
89
90#[derive(Default)]
91struct ToolCallState {
92 name: String,
93 id: String,
94 input_json: String,
95}
96
97#[derive(Default)]
98struct ThinkingState {
99 thinking: String,
100 signature: String,
101}
102
103#[derive(Clone, Debug, Deserialize, Serialize)]
104pub struct StreamingCompletionResponse {
105 pub usage: PartialUsage,
106}
107
108impl GetTokenUsage for StreamingCompletionResponse {
109 fn token_usage(&self) -> Option<crate::completion::Usage> {
110 let mut usage = crate::completion::Usage::new();
111 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
112 usage.output_tokens = self.usage.output_tokens as u64;
113 usage.total_tokens =
114 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
115
116 Some(usage)
117 }
118}
119
120impl<T> CompletionModel<T>
121where
122 T: HttpClientExt + Clone + Default + 'static,
123{
124 pub(crate) async fn stream(
125 &self,
126 completion_request: CompletionRequest,
127 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
128 {
129 let span = if tracing::Span::current().is_disabled() {
130 info_span!(
131 target: "rig::completions",
132 "chat_streaming",
133 gen_ai.operation.name = "chat_streaming",
134 gen_ai.provider.name = "anthropic",
135 gen_ai.request.model = self.model,
136 gen_ai.system_instructions = &completion_request.preamble,
137 gen_ai.response.id = tracing::field::Empty,
138 gen_ai.response.model = self.model,
139 gen_ai.usage.output_tokens = tracing::field::Empty,
140 gen_ai.usage.input_tokens = tracing::field::Empty,
141 gen_ai.input.messages = tracing::field::Empty,
142 gen_ai.output.messages = tracing::field::Empty,
143 )
144 } else {
145 tracing::Span::current()
146 };
147 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
148 tokens
149 } else if let Some(tokens) = self.default_max_tokens {
150 tokens
151 } else {
152 return Err(CompletionError::RequestError(
153 "`max_tokens` must be set for Anthropic".into(),
154 ));
155 };
156
157 let mut full_history = vec![];
158 if let Some(docs) = completion_request.normalized_documents() {
159 full_history.push(docs);
160 }
161 full_history.extend(completion_request.chat_history);
162
163 let mut messages = full_history
164 .into_iter()
165 .map(Message::try_from)
166 .collect::<Result<Vec<Message>, _>>()?;
167
168 let mut system: Vec<SystemContent> =
170 if let Some(preamble) = completion_request.preamble.as_ref() {
171 if preamble.is_empty() {
172 vec![]
173 } else {
174 vec![SystemContent::Text {
175 text: preamble.clone(),
176 cache_control: None,
177 }]
178 }
179 } else {
180 vec![]
181 };
182
183 if self.prompt_caching {
185 apply_cache_control(&mut system, &mut messages);
186 }
187
188 let mut body = json!({
189 "model": self.model,
190 "messages": messages,
191 "max_tokens": max_tokens,
192 "stream": true,
193 });
194
195 if !system.is_empty() {
197 merge_inplace(&mut body, json!({ "system": system }));
198 }
199
200 if let Some(temperature) = completion_request.temperature {
201 merge_inplace(&mut body, json!({ "temperature": temperature }));
202 }
203
204 if !completion_request.tools.is_empty() {
205 merge_inplace(
206 &mut body,
207 json!({
208 "tools": completion_request
209 .tools
210 .into_iter()
211 .map(|tool| ToolDefinition {
212 name: tool.name,
213 description: Some(tool.description),
214 input_schema: tool.parameters,
215 })
216 .collect::<Vec<_>>(),
217 "tool_choice": ToolChoice::Auto,
218 }),
219 );
220 }
221
222 if let Some(ref params) = completion_request.additional_params {
223 merge_inplace(&mut body, params.clone())
224 }
225
226 if enabled!(Level::TRACE) {
227 tracing::trace!(
228 target: "rig::completions",
229 "Anthropic completion request: {}",
230 serde_json::to_string_pretty(&body)?
231 );
232 }
233
234 let body: Vec<u8> = serde_json::to_vec(&body)?;
235
236 let req = self
237 .client
238 .post("/v1/messages")?
239 .body(body)
240 .map_err(http_client::Error::Protocol)?;
241
242 let stream = GenericEventSource::new(self.client.clone(), req);
243
244 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
246 let mut current_tool_call: Option<ToolCallState> = None;
247 let mut current_thinking: Option<ThinkingState> = None;
248 let mut sse_stream = Box::pin(stream);
249 let mut input_tokens = 0;
250 let mut final_usage = None;
251
252 let mut text_content = String::new();
253
254 while let Some(sse_result) = sse_stream.next().await {
255 match sse_result {
256 Ok(Event::Open) => {}
257 Ok(Event::Message(sse)) => {
258 match serde_json::from_str::<StreamingEvent>(&sse.data) {
260 Ok(event) => {
261 match &event {
262 StreamingEvent::MessageStart { message } => {
263 input_tokens = message.usage.input_tokens;
264
265 let span = tracing::Span::current();
266 span.record("gen_ai.response.id", &message.id);
267 span.record("gen_ai.response.model_name", &message.model);
268 },
269 StreamingEvent::MessageDelta { delta, usage } => {
270 if delta.stop_reason.is_some() {
271 let usage = PartialUsage {
272 output_tokens: usage.output_tokens,
273 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
274 };
275
276 let span = tracing::Span::current();
277 span.record_token_usage(&usage);
278 final_usage = Some(usage);
279 break;
280 }
281 }
282 _ => {}
283 }
284
285 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
286 if let Ok(RawStreamingChoice::Message(ref text)) = result {
287 text_content += text;
288 }
289 yield result;
290 }
291 },
292 Err(e) => {
293 if !sse.data.trim().is_empty() {
294 yield Err(CompletionError::ResponseError(
295 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
296 ));
297 }
298 }
299 }
300 },
301 Err(e) => {
302 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
303 break;
304 }
305 }
306 }
307
308 sse_stream.close();
310
311 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
312 usage: final_usage.unwrap_or_default()
313 }))
314 }.instrument(span));
315
316 Ok(streaming::StreamingCompletionResponse::stream(stream))
317 }
318}
319
320fn handle_event(
321 event: &StreamingEvent,
322 current_tool_call: &mut Option<ToolCallState>,
323 current_thinking: &mut Option<ThinkingState>,
324) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
325 match event {
326 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
327 ContentDelta::TextDelta { text } => {
328 if current_tool_call.is_none() {
329 return Some(Ok(RawStreamingChoice::Message(text.clone())));
330 }
331 None
332 }
333 ContentDelta::InputJsonDelta { partial_json } => {
334 if let Some(tool_call) = current_tool_call {
335 tool_call.input_json.push_str(partial_json);
336 return Some(Ok(RawStreamingChoice::ToolCallDelta {
338 id: tool_call.id.clone(),
339 delta: partial_json.clone(),
340 }));
341 }
342 None
343 }
344 ContentDelta::ThinkingDelta { thinking } => {
345 if current_thinking.is_none() {
346 *current_thinking = Some(ThinkingState::default());
347 }
348
349 if let Some(state) = current_thinking {
350 state.thinking.push_str(thinking);
351 }
352
353 Some(Ok(RawStreamingChoice::Reasoning {
354 id: None,
355 reasoning: thinking.clone(),
356 signature: None,
357 }))
358 }
359 ContentDelta::SignatureDelta { signature } => {
360 if current_thinking.is_none() {
361 *current_thinking = Some(ThinkingState::default());
362 }
363
364 if let Some(state) = current_thinking {
365 state.signature.push_str(signature);
366 }
367
368 None
370 }
371 },
372 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
373 Content::ToolUse { id, name, .. } => {
374 *current_tool_call = Some(ToolCallState {
375 name: name.clone(),
376 id: id.clone(),
377 input_json: String::new(),
378 });
379 None
380 }
381 Content::Thinking { .. } => {
382 *current_thinking = Some(ThinkingState::default());
383 None
384 }
385 _ => None,
387 },
388 StreamingEvent::ContentBlockStop { .. } => {
389 if let Some(thinking_state) = Option::take(current_thinking)
390 && !thinking_state.thinking.is_empty()
391 {
392 let signature = if thinking_state.signature.is_empty() {
393 None
394 } else {
395 Some(thinking_state.signature)
396 };
397
398 return Some(Ok(RawStreamingChoice::Reasoning {
399 id: None,
400 reasoning: thinking_state.thinking,
401 signature,
402 }));
403 }
404
405 if let Some(tool_call) = Option::take(current_tool_call) {
406 let json_str = if tool_call.input_json.is_empty() {
407 "{}"
408 } else {
409 &tool_call.input_json
410 };
411 match serde_json::from_str(json_str) {
412 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
413 name: tool_call.name,
414 id: tool_call.id,
415 arguments: json_value,
416 call_id: None,
417 })),
418 Err(e) => Some(Err(CompletionError::from(e))),
419 }
420 } else {
421 None
422 }
423 }
424 StreamingEvent::MessageStart { .. }
426 | StreamingEvent::MessageDelta { .. }
427 | StreamingEvent::MessageStop
428 | StreamingEvent::Ping
429 | StreamingEvent::Unknown => None,
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_thinking_delta_deserialization() {
439 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
440 let delta: ContentDelta = serde_json::from_str(json).unwrap();
441
442 match delta {
443 ContentDelta::ThinkingDelta { thinking } => {
444 assert_eq!(thinking, "Let me think about this...");
445 }
446 _ => panic!("Expected ThinkingDelta variant"),
447 }
448 }
449
450 #[test]
451 fn test_signature_delta_deserialization() {
452 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
453 let delta: ContentDelta = serde_json::from_str(json).unwrap();
454
455 match delta {
456 ContentDelta::SignatureDelta { signature } => {
457 assert_eq!(signature, "abc123def456");
458 }
459 _ => panic!("Expected SignatureDelta variant"),
460 }
461 }
462
463 #[test]
464 fn test_thinking_delta_streaming_event_deserialization() {
465 let json = r#"{
466 "type": "content_block_delta",
467 "index": 0,
468 "delta": {
469 "type": "thinking_delta",
470 "thinking": "First, I need to understand the problem."
471 }
472 }"#;
473
474 let event: StreamingEvent = serde_json::from_str(json).unwrap();
475
476 match event {
477 StreamingEvent::ContentBlockDelta { index, delta } => {
478 assert_eq!(index, 0);
479 match delta {
480 ContentDelta::ThinkingDelta { thinking } => {
481 assert_eq!(thinking, "First, I need to understand the problem.");
482 }
483 _ => panic!("Expected ThinkingDelta"),
484 }
485 }
486 _ => panic!("Expected ContentBlockDelta event"),
487 }
488 }
489
490 #[test]
491 fn test_signature_delta_streaming_event_deserialization() {
492 let json = r#"{
493 "type": "content_block_delta",
494 "index": 0,
495 "delta": {
496 "type": "signature_delta",
497 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
498 }
499 }"#;
500
501 let event: StreamingEvent = serde_json::from_str(json).unwrap();
502
503 match event {
504 StreamingEvent::ContentBlockDelta { index, delta } => {
505 assert_eq!(index, 0);
506 match delta {
507 ContentDelta::SignatureDelta { signature } => {
508 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
509 }
510 _ => panic!("Expected SignatureDelta"),
511 }
512 }
513 _ => panic!("Expected ContentBlockDelta event"),
514 }
515 }
516
517 #[test]
518 fn test_handle_thinking_delta_event() {
519 let event = StreamingEvent::ContentBlockDelta {
520 index: 0,
521 delta: ContentDelta::ThinkingDelta {
522 thinking: "Analyzing the request...".to_string(),
523 },
524 };
525
526 let mut tool_call_state = None;
527 let mut thinking_state = None;
528 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
529
530 assert!(result.is_some());
531 let choice = result.unwrap().unwrap();
532
533 match choice {
534 RawStreamingChoice::Reasoning { id, reasoning, .. } => {
535 assert_eq!(id, None);
536 assert_eq!(reasoning, "Analyzing the request...");
537 }
538 _ => panic!("Expected Reasoning choice"),
539 }
540
541 assert!(thinking_state.is_some());
543 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
544 }
545
546 #[test]
547 fn test_handle_signature_delta_event() {
548 let event = StreamingEvent::ContentBlockDelta {
549 index: 0,
550 delta: ContentDelta::SignatureDelta {
551 signature: "test_signature".to_string(),
552 },
553 };
554
555 let mut tool_call_state = None;
556 let mut thinking_state = None;
557 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
558
559 assert!(result.is_none());
561
562 assert!(thinking_state.is_some());
564 assert_eq!(thinking_state.unwrap().signature, "test_signature");
565 }
566
567 #[test]
568 fn test_handle_text_delta_event() {
569 let event = StreamingEvent::ContentBlockDelta {
570 index: 0,
571 delta: ContentDelta::TextDelta {
572 text: "Hello, world!".to_string(),
573 },
574 };
575
576 let mut tool_call_state = None;
577 let mut thinking_state = None;
578 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
579
580 assert!(result.is_some());
581 let choice = result.unwrap().unwrap();
582
583 match choice {
584 RawStreamingChoice::Message(text) => {
585 assert_eq!(text, "Hello, world!");
586 }
587 _ => panic!("Expected Message choice"),
588 }
589 }
590
591 #[test]
592 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
593 let event = StreamingEvent::ContentBlockDelta {
595 index: 0,
596 delta: ContentDelta::ThinkingDelta {
597 thinking: "Thinking while tool is active...".to_string(),
598 },
599 };
600
601 let mut tool_call_state = Some(ToolCallState {
602 name: "test_tool".to_string(),
603 id: "tool_123".to_string(),
604 input_json: String::new(),
605 });
606 let mut thinking_state = None;
607
608 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
609
610 assert!(result.is_some());
611 let choice = result.unwrap().unwrap();
612
613 match choice {
614 RawStreamingChoice::Reasoning { reasoning, .. } => {
615 assert_eq!(reasoning, "Thinking while tool is active...");
616 }
617 _ => panic!("Expected Reasoning choice"),
618 }
619
620 assert!(tool_call_state.is_some());
622 }
623
624 #[test]
625 fn test_handle_input_json_delta_event() {
626 let event = StreamingEvent::ContentBlockDelta {
627 index: 0,
628 delta: ContentDelta::InputJsonDelta {
629 partial_json: "{\"arg\":\"value".to_string(),
630 },
631 };
632
633 let mut tool_call_state = Some(ToolCallState {
634 name: "test_tool".to_string(),
635 id: "tool_123".to_string(),
636 input_json: String::new(),
637 });
638 let mut thinking_state = None;
639
640 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
641
642 assert!(result.is_some());
644 let choice = result.unwrap().unwrap();
645
646 match choice {
647 RawStreamingChoice::ToolCallDelta { id, delta } => {
648 assert_eq!(id, "tool_123");
649 assert_eq!(delta, "{\"arg\":\"value");
650 }
651 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
652 }
653
654 assert!(tool_call_state.is_some());
656 let state = tool_call_state.unwrap();
657 assert_eq!(state.input_json, "{\"arg\":\"value");
658 }
659
660 #[test]
661 fn test_tool_call_accumulation_with_multiple_deltas() {
662 let mut tool_call_state = Some(ToolCallState {
663 name: "test_tool".to_string(),
664 id: "tool_123".to_string(),
665 input_json: String::new(),
666 });
667 let mut thinking_state = None;
668
669 let event1 = StreamingEvent::ContentBlockDelta {
671 index: 0,
672 delta: ContentDelta::InputJsonDelta {
673 partial_json: "{\"location\":".to_string(),
674 },
675 };
676 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
677 assert!(result1.is_some());
678
679 let event2 = StreamingEvent::ContentBlockDelta {
681 index: 0,
682 delta: ContentDelta::InputJsonDelta {
683 partial_json: "\"Paris\",".to_string(),
684 },
685 };
686 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
687 assert!(result2.is_some());
688
689 let event3 = StreamingEvent::ContentBlockDelta {
691 index: 0,
692 delta: ContentDelta::InputJsonDelta {
693 partial_json: "\"temp\":\"20C\"}".to_string(),
694 },
695 };
696 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
697 assert!(result3.is_some());
698
699 assert!(tool_call_state.is_some());
701 let state = tool_call_state.as_ref().unwrap();
702 assert_eq!(
703 state.input_json,
704 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
705 );
706
707 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
709 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
710 assert!(final_result.is_some());
711
712 match final_result.unwrap().unwrap() {
713 RawStreamingChoice::ToolCall {
714 id,
715 name,
716 arguments,
717 ..
718 } => {
719 assert_eq!(id, "tool_123");
720 assert_eq!(name, "test_tool");
721 assert_eq!(
722 arguments.get("location").unwrap().as_str().unwrap(),
723 "Paris"
724 );
725 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
726 }
727 other => panic!("Expected ToolCall, got {:?}", other),
728 }
729
730 assert!(tool_call_state.is_none());
732 }
733}