1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use crate::OneOrMany;
10use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20 MessageStart {
21 message: MessageStart,
22 },
23 ContentBlockStart {
24 index: usize,
25 content_block: Content,
26 },
27 ContentBlockDelta {
28 index: usize,
29 delta: ContentDelta,
30 },
31 ContentBlockStop {
32 index: usize,
33 },
34 MessageDelta {
35 delta: MessageDelta,
36 usage: PartialUsage,
37 },
38 MessageStop,
39 Ping,
40 #[serde(other)]
41 Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46 pub id: String,
47 pub role: String,
48 pub content: Vec<Content>,
49 pub model: String,
50 pub stop_reason: Option<String>,
51 pub stop_sequence: Option<String>,
52 pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58 TextDelta { text: String },
59 InputJsonDelta { partial_json: String },
60 ThinkingDelta { thinking: String },
61 SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66 pub stop_reason: Option<String>,
67 pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct PartialUsage {
72 pub output_tokens: usize,
73 #[serde(default)]
74 pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78 fn token_usage(&self) -> Option<crate::completion::Usage> {
79 let mut usage = crate::completion::Usage::new();
80
81 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82 usage.output_tokens = self.output_tokens as u64;
83 usage.total_tokens = usage.input_tokens + usage.output_tokens;
84 Some(usage)
85 }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90 name: String,
91 id: String,
92 input_json: String,
93}
94
95#[derive(Default)]
96struct ThinkingState {
97 thinking: String,
98 signature: String,
99}
100
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct StreamingCompletionResponse {
103 pub usage: PartialUsage,
104}
105
106impl GetTokenUsage for StreamingCompletionResponse {
107 fn token_usage(&self) -> Option<crate::completion::Usage> {
108 let mut usage = crate::completion::Usage::new();
109 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
110 usage.output_tokens = self.usage.output_tokens as u64;
111 usage.total_tokens =
112 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
113
114 Some(usage)
115 }
116}
117
118impl<T> CompletionModel<T>
119where
120 T: HttpClientExt + Clone + Default + 'static,
121{
122 pub(crate) async fn stream(
123 &self,
124 completion_request: CompletionRequest,
125 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
126 {
127 let span = if tracing::Span::current().is_disabled() {
128 info_span!(
129 target: "rig::completions",
130 "chat_streaming",
131 gen_ai.operation.name = "chat_streaming",
132 gen_ai.provider.name = "anthropic",
133 gen_ai.request.model = self.model,
134 gen_ai.system_instructions = &completion_request.preamble,
135 gen_ai.response.id = tracing::field::Empty,
136 gen_ai.response.model = self.model,
137 gen_ai.usage.output_tokens = tracing::field::Empty,
138 gen_ai.usage.input_tokens = tracing::field::Empty,
139 gen_ai.input.messages = tracing::field::Empty,
140 gen_ai.output.messages = tracing::field::Empty,
141 )
142 } else {
143 tracing::Span::current()
144 };
145 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
146 tokens
147 } else if let Some(tokens) = self.default_max_tokens {
148 tokens
149 } else {
150 return Err(CompletionError::RequestError(
151 "`max_tokens` must be set for Anthropic".into(),
152 ));
153 };
154
155 let mut full_history = vec![];
156 if let Some(docs) = completion_request.normalized_documents() {
157 full_history.push(docs);
158 }
159 full_history.extend(completion_request.chat_history);
160 span.record_model_input(&full_history);
161
162 let full_history = full_history
163 .into_iter()
164 .map(Message::try_from)
165 .collect::<Result<Vec<Message>, _>>()?;
166
167 let mut body = json!({
168 "model": self.model,
169 "messages": full_history,
170 "max_tokens": max_tokens,
171 "system": completion_request.preamble.unwrap_or("".to_string()),
172 "stream": true,
173 });
174
175 if let Some(temperature) = completion_request.temperature {
176 merge_inplace(&mut body, json!({ "temperature": temperature }));
177 }
178
179 if !completion_request.tools.is_empty() {
180 merge_inplace(
181 &mut body,
182 json!({
183 "tools": completion_request
184 .tools
185 .into_iter()
186 .map(|tool| ToolDefinition {
187 name: tool.name,
188 description: Some(tool.description),
189 input_schema: tool.parameters,
190 })
191 .collect::<Vec<_>>(),
192 "tool_choice": ToolChoice::Auto,
193 }),
194 );
195 }
196
197 if let Some(ref params) = completion_request.additional_params {
198 merge_inplace(&mut body, params.clone())
199 }
200
201 let body: Vec<u8> = serde_json::to_vec(&body)?;
202
203 let req = self
204 .client
205 .post("/v1/messages")
206 .header("Content-Type", "application/json")
207 .body(body)
208 .map_err(http_client::Error::Protocol)?;
209
210 let stream = GenericEventSource::new(self.client.http_client.clone(), req);
211
212 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
214 let mut current_tool_call: Option<ToolCallState> = None;
215 let mut current_thinking: Option<ThinkingState> = None;
216 let mut sse_stream = Box::pin(stream);
217 let mut input_tokens = 0;
218
219 let mut text_content = String::new();
220
221 while let Some(sse_result) = sse_stream.next().await {
222 match sse_result {
223 Ok(Event::Open) => {}
224 Ok(Event::Message(sse)) => {
225 match serde_json::from_str::<StreamingEvent>(&sse.data) {
227 Ok(event) => {
228 match &event {
229 StreamingEvent::MessageStart { message } => {
230 input_tokens = message.usage.input_tokens;
231
232 let span = tracing::Span::current();
233 span.record("gen_ai.response.id", &message.id);
234 span.record("gen_ai.response.model_name", &message.model);
235 },
236 StreamingEvent::MessageDelta { delta, usage } => {
237 if delta.stop_reason.is_some() {
238 let usage = PartialUsage {
239 output_tokens: usage.output_tokens,
240 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
241 };
242
243 let span = tracing::Span::current();
244 span.record_token_usage(&usage);
245 span.record_model_output(&Message {
246 role: super::completion::Role::Assistant,
247 content: OneOrMany::one(Content::Text { text: text_content.clone() })}
248 );
249
250 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
251 usage
252 }))
253 }
254 }
255 _ => {}
256 }
257
258 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
259 if let Ok(RawStreamingChoice::Message(ref text)) = result {
260 text_content += text;
261 }
262 yield result;
263 }
264 },
265 Err(e) => {
266 if !sse.data.trim().is_empty() {
267 yield Err(CompletionError::ResponseError(
268 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
269 ));
270 }
271 }
272 }
273 },
274 Err(e) => {
275 yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
276 break;
277 }
278 }
279 }
280 }.instrument(span));
281
282 Ok(streaming::StreamingCompletionResponse::stream(stream))
283 }
284}
285
286fn handle_event(
287 event: &StreamingEvent,
288 current_tool_call: &mut Option<ToolCallState>,
289 current_thinking: &mut Option<ThinkingState>,
290) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
291 match event {
292 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
293 ContentDelta::TextDelta { text } => {
294 if current_tool_call.is_none() {
295 return Some(Ok(RawStreamingChoice::Message(text.clone())));
296 }
297 None
298 }
299 ContentDelta::InputJsonDelta { partial_json } => {
300 if let Some(tool_call) = current_tool_call {
301 tool_call.input_json.push_str(partial_json);
302 return Some(Ok(RawStreamingChoice::ToolCallDelta {
304 id: tool_call.id.clone(),
305 delta: partial_json.clone(),
306 }));
307 }
308 None
309 }
310 ContentDelta::ThinkingDelta { thinking } => {
311 if current_thinking.is_none() {
312 *current_thinking = Some(ThinkingState::default());
313 }
314
315 if let Some(state) = current_thinking {
316 state.thinking.push_str(thinking);
317 }
318
319 Some(Ok(RawStreamingChoice::Reasoning {
320 id: None,
321 reasoning: thinking.clone(),
322 signature: None,
323 }))
324 }
325 ContentDelta::SignatureDelta { signature } => {
326 if current_thinking.is_none() {
327 *current_thinking = Some(ThinkingState::default());
328 }
329
330 if let Some(state) = current_thinking {
331 state.signature.push_str(signature);
332 }
333
334 None
336 }
337 },
338 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
339 Content::ToolUse { id, name, .. } => {
340 *current_tool_call = Some(ToolCallState {
341 name: name.clone(),
342 id: id.clone(),
343 input_json: String::new(),
344 });
345 None
346 }
347 Content::Thinking { .. } => {
348 *current_thinking = Some(ThinkingState::default());
349 None
350 }
351 _ => None,
353 },
354 StreamingEvent::ContentBlockStop { .. } => {
355 if let Some(thinking_state) = Option::take(current_thinking)
356 && !thinking_state.thinking.is_empty()
357 {
358 let signature = if thinking_state.signature.is_empty() {
359 None
360 } else {
361 Some(thinking_state.signature)
362 };
363
364 return Some(Ok(RawStreamingChoice::Reasoning {
365 id: None,
366 reasoning: thinking_state.thinking,
367 signature,
368 }));
369 }
370
371 if let Some(tool_call) = Option::take(current_tool_call) {
372 let json_str = if tool_call.input_json.is_empty() {
373 "{}"
374 } else {
375 &tool_call.input_json
376 };
377 match serde_json::from_str(json_str) {
378 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
379 name: tool_call.name,
380 id: tool_call.id,
381 arguments: json_value,
382 call_id: None,
383 })),
384 Err(e) => Some(Err(CompletionError::from(e))),
385 }
386 } else {
387 None
388 }
389 }
390 StreamingEvent::MessageStart { .. }
392 | StreamingEvent::MessageDelta { .. }
393 | StreamingEvent::MessageStop
394 | StreamingEvent::Ping
395 | StreamingEvent::Unknown => None,
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 #[test]
404 fn test_thinking_delta_deserialization() {
405 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
406 let delta: ContentDelta = serde_json::from_str(json).unwrap();
407
408 match delta {
409 ContentDelta::ThinkingDelta { thinking } => {
410 assert_eq!(thinking, "Let me think about this...");
411 }
412 _ => panic!("Expected ThinkingDelta variant"),
413 }
414 }
415
416 #[test]
417 fn test_signature_delta_deserialization() {
418 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
419 let delta: ContentDelta = serde_json::from_str(json).unwrap();
420
421 match delta {
422 ContentDelta::SignatureDelta { signature } => {
423 assert_eq!(signature, "abc123def456");
424 }
425 _ => panic!("Expected SignatureDelta variant"),
426 }
427 }
428
429 #[test]
430 fn test_thinking_delta_streaming_event_deserialization() {
431 let json = r#"{
432 "type": "content_block_delta",
433 "index": 0,
434 "delta": {
435 "type": "thinking_delta",
436 "thinking": "First, I need to understand the problem."
437 }
438 }"#;
439
440 let event: StreamingEvent = serde_json::from_str(json).unwrap();
441
442 match event {
443 StreamingEvent::ContentBlockDelta { index, delta } => {
444 assert_eq!(index, 0);
445 match delta {
446 ContentDelta::ThinkingDelta { thinking } => {
447 assert_eq!(thinking, "First, I need to understand the problem.");
448 }
449 _ => panic!("Expected ThinkingDelta"),
450 }
451 }
452 _ => panic!("Expected ContentBlockDelta event"),
453 }
454 }
455
456 #[test]
457 fn test_signature_delta_streaming_event_deserialization() {
458 let json = r#"{
459 "type": "content_block_delta",
460 "index": 0,
461 "delta": {
462 "type": "signature_delta",
463 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
464 }
465 }"#;
466
467 let event: StreamingEvent = serde_json::from_str(json).unwrap();
468
469 match event {
470 StreamingEvent::ContentBlockDelta { index, delta } => {
471 assert_eq!(index, 0);
472 match delta {
473 ContentDelta::SignatureDelta { signature } => {
474 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
475 }
476 _ => panic!("Expected SignatureDelta"),
477 }
478 }
479 _ => panic!("Expected ContentBlockDelta event"),
480 }
481 }
482
483 #[test]
484 fn test_handle_thinking_delta_event() {
485 let event = StreamingEvent::ContentBlockDelta {
486 index: 0,
487 delta: ContentDelta::ThinkingDelta {
488 thinking: "Analyzing the request...".to_string(),
489 },
490 };
491
492 let mut tool_call_state = None;
493 let mut thinking_state = None;
494 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
495
496 assert!(result.is_some());
497 let choice = result.unwrap().unwrap();
498
499 match choice {
500 RawStreamingChoice::Reasoning { id, reasoning, .. } => {
501 assert_eq!(id, None);
502 assert_eq!(reasoning, "Analyzing the request...");
503 }
504 _ => panic!("Expected Reasoning choice"),
505 }
506
507 assert!(thinking_state.is_some());
509 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
510 }
511
512 #[test]
513 fn test_handle_signature_delta_event() {
514 let event = StreamingEvent::ContentBlockDelta {
515 index: 0,
516 delta: ContentDelta::SignatureDelta {
517 signature: "test_signature".to_string(),
518 },
519 };
520
521 let mut tool_call_state = None;
522 let mut thinking_state = None;
523 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
524
525 assert!(result.is_none());
527
528 assert!(thinking_state.is_some());
530 assert_eq!(thinking_state.unwrap().signature, "test_signature");
531 }
532
533 #[test]
534 fn test_handle_text_delta_event() {
535 let event = StreamingEvent::ContentBlockDelta {
536 index: 0,
537 delta: ContentDelta::TextDelta {
538 text: "Hello, world!".to_string(),
539 },
540 };
541
542 let mut tool_call_state = None;
543 let mut thinking_state = None;
544 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
545
546 assert!(result.is_some());
547 let choice = result.unwrap().unwrap();
548
549 match choice {
550 RawStreamingChoice::Message(text) => {
551 assert_eq!(text, "Hello, world!");
552 }
553 _ => panic!("Expected Message choice"),
554 }
555 }
556
557 #[test]
558 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
559 let event = StreamingEvent::ContentBlockDelta {
561 index: 0,
562 delta: ContentDelta::ThinkingDelta {
563 thinking: "Thinking while tool is active...".to_string(),
564 },
565 };
566
567 let mut tool_call_state = Some(ToolCallState {
568 name: "test_tool".to_string(),
569 id: "tool_123".to_string(),
570 input_json: String::new(),
571 });
572 let mut thinking_state = None;
573
574 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
575
576 assert!(result.is_some());
577 let choice = result.unwrap().unwrap();
578
579 match choice {
580 RawStreamingChoice::Reasoning { reasoning, .. } => {
581 assert_eq!(reasoning, "Thinking while tool is active...");
582 }
583 _ => panic!("Expected Reasoning choice"),
584 }
585
586 assert!(tool_call_state.is_some());
588 }
589
590 #[test]
591 fn test_handle_input_json_delta_event() {
592 let event = StreamingEvent::ContentBlockDelta {
593 index: 0,
594 delta: ContentDelta::InputJsonDelta {
595 partial_json: "{\"arg\":\"value".to_string(),
596 },
597 };
598
599 let mut tool_call_state = Some(ToolCallState {
600 name: "test_tool".to_string(),
601 id: "tool_123".to_string(),
602 input_json: String::new(),
603 });
604 let mut thinking_state = None;
605
606 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
607
608 assert!(result.is_some());
610 let choice = result.unwrap().unwrap();
611
612 match choice {
613 RawStreamingChoice::ToolCallDelta { id, delta } => {
614 assert_eq!(id, "tool_123");
615 assert_eq!(delta, "{\"arg\":\"value");
616 }
617 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
618 }
619
620 assert!(tool_call_state.is_some());
622 let state = tool_call_state.unwrap();
623 assert_eq!(state.input_json, "{\"arg\":\"value");
624 }
625
626 #[test]
627 fn test_tool_call_accumulation_with_multiple_deltas() {
628 let mut tool_call_state = Some(ToolCallState {
629 name: "test_tool".to_string(),
630 id: "tool_123".to_string(),
631 input_json: String::new(),
632 });
633 let mut thinking_state = None;
634
635 let event1 = StreamingEvent::ContentBlockDelta {
637 index: 0,
638 delta: ContentDelta::InputJsonDelta {
639 partial_json: "{\"location\":".to_string(),
640 },
641 };
642 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
643 assert!(result1.is_some());
644
645 let event2 = StreamingEvent::ContentBlockDelta {
647 index: 0,
648 delta: ContentDelta::InputJsonDelta {
649 partial_json: "\"Paris\",".to_string(),
650 },
651 };
652 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
653 assert!(result2.is_some());
654
655 let event3 = StreamingEvent::ContentBlockDelta {
657 index: 0,
658 delta: ContentDelta::InputJsonDelta {
659 partial_json: "\"temp\":\"20C\"}".to_string(),
660 },
661 };
662 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
663 assert!(result3.is_some());
664
665 assert!(tool_call_state.is_some());
667 let state = tool_call_state.as_ref().unwrap();
668 assert_eq!(
669 state.input_json,
670 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
671 );
672
673 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
675 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
676 assert!(final_result.is_some());
677
678 match final_result.unwrap().unwrap() {
679 RawStreamingChoice::ToolCall {
680 id,
681 name,
682 arguments,
683 ..
684 } => {
685 assert_eq!(id, "tool_123");
686 assert_eq!(name, "test_tool");
687 assert_eq!(
688 arguments.get("location").unwrap().as_str().unwrap(),
689 "Paris"
690 );
691 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
692 }
693 other => panic!("Expected ToolCall, got {:?}", other),
694 }
695
696 assert!(tool_call_state.is_none());
698 }
699}