1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use crate::OneOrMany;
10use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20 MessageStart {
21 message: MessageStart,
22 },
23 ContentBlockStart {
24 index: usize,
25 content_block: Content,
26 },
27 ContentBlockDelta {
28 index: usize,
29 delta: ContentDelta,
30 },
31 ContentBlockStop {
32 index: usize,
33 },
34 MessageDelta {
35 delta: MessageDelta,
36 usage: PartialUsage,
37 },
38 MessageStop,
39 Ping,
40 #[serde(other)]
41 Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46 pub id: String,
47 pub role: String,
48 pub content: Vec<Content>,
49 pub model: String,
50 pub stop_reason: Option<String>,
51 pub stop_sequence: Option<String>,
52 pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58 TextDelta { text: String },
59 InputJsonDelta { partial_json: String },
60 ThinkingDelta { thinking: String },
61 SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66 pub stop_reason: Option<String>,
67 pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize, Default)]
71pub struct PartialUsage {
72 pub output_tokens: usize,
73 #[serde(default)]
74 pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78 fn token_usage(&self) -> Option<crate::completion::Usage> {
79 let mut usage = crate::completion::Usage::new();
80
81 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82 usage.output_tokens = self.output_tokens as u64;
83 usage.total_tokens = usage.input_tokens + usage.output_tokens;
84 Some(usage)
85 }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90 name: String,
91 id: String,
92 input_json: String,
93}
94
95#[derive(Default)]
96struct ThinkingState {
97 thinking: String,
98 signature: String,
99}
100
101#[derive(Clone, Debug, Deserialize, Serialize)]
102pub struct StreamingCompletionResponse {
103 pub usage: PartialUsage,
104}
105
106impl GetTokenUsage for StreamingCompletionResponse {
107 fn token_usage(&self) -> Option<crate::completion::Usage> {
108 let mut usage = crate::completion::Usage::new();
109 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
110 usage.output_tokens = self.usage.output_tokens as u64;
111 usage.total_tokens =
112 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
113
114 Some(usage)
115 }
116}
117
118impl<T> CompletionModel<T>
119where
120 T: HttpClientExt + Clone + Default + 'static,
121{
122 pub(crate) async fn stream(
123 &self,
124 completion_request: CompletionRequest,
125 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
126 {
127 let span = if tracing::Span::current().is_disabled() {
128 info_span!(
129 target: "rig::completions",
130 "chat_streaming",
131 gen_ai.operation.name = "chat_streaming",
132 gen_ai.provider.name = "anthropic",
133 gen_ai.request.model = self.model,
134 gen_ai.system_instructions = &completion_request.preamble,
135 gen_ai.response.id = tracing::field::Empty,
136 gen_ai.response.model = self.model,
137 gen_ai.usage.output_tokens = tracing::field::Empty,
138 gen_ai.usage.input_tokens = tracing::field::Empty,
139 gen_ai.input.messages = tracing::field::Empty,
140 gen_ai.output.messages = tracing::field::Empty,
141 )
142 } else {
143 tracing::Span::current()
144 };
145 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
146 tokens
147 } else if let Some(tokens) = self.default_max_tokens {
148 tokens
149 } else {
150 return Err(CompletionError::RequestError(
151 "`max_tokens` must be set for Anthropic".into(),
152 ));
153 };
154
155 let mut full_history = vec![];
156 if let Some(docs) = completion_request.normalized_documents() {
157 full_history.push(docs);
158 }
159 full_history.extend(completion_request.chat_history);
160 span.record_model_input(&full_history);
161
162 let full_history = full_history
163 .into_iter()
164 .map(Message::try_from)
165 .collect::<Result<Vec<Message>, _>>()?;
166
167 let mut body = json!({
168 "model": self.model,
169 "messages": full_history,
170 "max_tokens": max_tokens,
171 "system": completion_request.preamble.unwrap_or("".to_string()),
172 "stream": true,
173 });
174
175 if let Some(temperature) = completion_request.temperature {
176 merge_inplace(&mut body, json!({ "temperature": temperature }));
177 }
178
179 if !completion_request.tools.is_empty() {
180 merge_inplace(
181 &mut body,
182 json!({
183 "tools": completion_request
184 .tools
185 .into_iter()
186 .map(|tool| ToolDefinition {
187 name: tool.name,
188 description: Some(tool.description),
189 input_schema: tool.parameters,
190 })
191 .collect::<Vec<_>>(),
192 "tool_choice": ToolChoice::Auto,
193 }),
194 );
195 }
196
197 if let Some(ref params) = completion_request.additional_params {
198 merge_inplace(&mut body, params.clone())
199 }
200
201 let body: Vec<u8> = serde_json::to_vec(&body)?;
202
203 let req = self
204 .client
205 .post("/v1/messages")?
206 .body(body)
207 .map_err(http_client::Error::Protocol)?;
208
209 let stream = GenericEventSource::new(self.client.http_client().clone(), req);
210
211 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
213 let mut current_tool_call: Option<ToolCallState> = None;
214 let mut current_thinking: Option<ThinkingState> = None;
215 let mut sse_stream = Box::pin(stream);
216 let mut input_tokens = 0;
217 let mut final_usage = None;
218
219 let mut text_content = String::new();
220
221 while let Some(sse_result) = sse_stream.next().await {
222 match sse_result {
223 Ok(Event::Open) => {}
224 Ok(Event::Message(sse)) => {
225 match serde_json::from_str::<StreamingEvent>(&sse.data) {
227 Ok(event) => {
228 match &event {
229 StreamingEvent::MessageStart { message } => {
230 input_tokens = message.usage.input_tokens;
231
232 let span = tracing::Span::current();
233 span.record("gen_ai.response.id", &message.id);
234 span.record("gen_ai.response.model_name", &message.model);
235 },
236 StreamingEvent::MessageDelta { delta, usage } => {
237 if delta.stop_reason.is_some() {
238 let usage = PartialUsage {
239 output_tokens: usage.output_tokens,
240 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
241 };
242
243 let span = tracing::Span::current();
244 span.record_token_usage(&usage);
245 span.record_model_output(&Message {
246 role: super::completion::Role::Assistant,
247 content: OneOrMany::one(Content::Text { text: text_content.clone() })}
248 );
249
250 final_usage = Some(usage);
251 break;
252 }
253 }
254 _ => {}
255 }
256
257 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
258 if let Ok(RawStreamingChoice::Message(ref text)) = result {
259 text_content += text;
260 }
261 yield result;
262 }
263 },
264 Err(e) => {
265 if !sse.data.trim().is_empty() {
266 yield Err(CompletionError::ResponseError(
267 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
268 ));
269 }
270 }
271 }
272 },
273 Err(e) => {
274 yield Err(CompletionError::ProviderError(format!("SSE Error: {e}")));
275 break;
276 }
277 }
278 }
279
280 sse_stream.close();
282
283 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
284 usage: final_usage.unwrap_or_default()
285 }))
286 }.instrument(span));
287
288 Ok(streaming::StreamingCompletionResponse::stream(stream))
289 }
290}
291
292fn handle_event(
293 event: &StreamingEvent,
294 current_tool_call: &mut Option<ToolCallState>,
295 current_thinking: &mut Option<ThinkingState>,
296) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
297 match event {
298 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
299 ContentDelta::TextDelta { text } => {
300 if current_tool_call.is_none() {
301 return Some(Ok(RawStreamingChoice::Message(text.clone())));
302 }
303 None
304 }
305 ContentDelta::InputJsonDelta { partial_json } => {
306 if let Some(tool_call) = current_tool_call {
307 tool_call.input_json.push_str(partial_json);
308 return Some(Ok(RawStreamingChoice::ToolCallDelta {
310 id: tool_call.id.clone(),
311 delta: partial_json.clone(),
312 }));
313 }
314 None
315 }
316 ContentDelta::ThinkingDelta { thinking } => {
317 if current_thinking.is_none() {
318 *current_thinking = Some(ThinkingState::default());
319 }
320
321 if let Some(state) = current_thinking {
322 state.thinking.push_str(thinking);
323 }
324
325 Some(Ok(RawStreamingChoice::Reasoning {
326 id: None,
327 reasoning: thinking.clone(),
328 signature: None,
329 }))
330 }
331 ContentDelta::SignatureDelta { signature } => {
332 if current_thinking.is_none() {
333 *current_thinking = Some(ThinkingState::default());
334 }
335
336 if let Some(state) = current_thinking {
337 state.signature.push_str(signature);
338 }
339
340 None
342 }
343 },
344 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
345 Content::ToolUse { id, name, .. } => {
346 *current_tool_call = Some(ToolCallState {
347 name: name.clone(),
348 id: id.clone(),
349 input_json: String::new(),
350 });
351 None
352 }
353 Content::Thinking { .. } => {
354 *current_thinking = Some(ThinkingState::default());
355 None
356 }
357 _ => None,
359 },
360 StreamingEvent::ContentBlockStop { .. } => {
361 if let Some(thinking_state) = Option::take(current_thinking)
362 && !thinking_state.thinking.is_empty()
363 {
364 let signature = if thinking_state.signature.is_empty() {
365 None
366 } else {
367 Some(thinking_state.signature)
368 };
369
370 return Some(Ok(RawStreamingChoice::Reasoning {
371 id: None,
372 reasoning: thinking_state.thinking,
373 signature,
374 }));
375 }
376
377 if let Some(tool_call) = Option::take(current_tool_call) {
378 let json_str = if tool_call.input_json.is_empty() {
379 "{}"
380 } else {
381 &tool_call.input_json
382 };
383 match serde_json::from_str(json_str) {
384 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
385 name: tool_call.name,
386 id: tool_call.id,
387 arguments: json_value,
388 call_id: None,
389 })),
390 Err(e) => Some(Err(CompletionError::from(e))),
391 }
392 } else {
393 None
394 }
395 }
396 StreamingEvent::MessageStart { .. }
398 | StreamingEvent::MessageDelta { .. }
399 | StreamingEvent::MessageStop
400 | StreamingEvent::Ping
401 | StreamingEvent::Unknown => None,
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_thinking_delta_deserialization() {
411 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
412 let delta: ContentDelta = serde_json::from_str(json).unwrap();
413
414 match delta {
415 ContentDelta::ThinkingDelta { thinking } => {
416 assert_eq!(thinking, "Let me think about this...");
417 }
418 _ => panic!("Expected ThinkingDelta variant"),
419 }
420 }
421
422 #[test]
423 fn test_signature_delta_deserialization() {
424 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
425 let delta: ContentDelta = serde_json::from_str(json).unwrap();
426
427 match delta {
428 ContentDelta::SignatureDelta { signature } => {
429 assert_eq!(signature, "abc123def456");
430 }
431 _ => panic!("Expected SignatureDelta variant"),
432 }
433 }
434
435 #[test]
436 fn test_thinking_delta_streaming_event_deserialization() {
437 let json = r#"{
438 "type": "content_block_delta",
439 "index": 0,
440 "delta": {
441 "type": "thinking_delta",
442 "thinking": "First, I need to understand the problem."
443 }
444 }"#;
445
446 let event: StreamingEvent = serde_json::from_str(json).unwrap();
447
448 match event {
449 StreamingEvent::ContentBlockDelta { index, delta } => {
450 assert_eq!(index, 0);
451 match delta {
452 ContentDelta::ThinkingDelta { thinking } => {
453 assert_eq!(thinking, "First, I need to understand the problem.");
454 }
455 _ => panic!("Expected ThinkingDelta"),
456 }
457 }
458 _ => panic!("Expected ContentBlockDelta event"),
459 }
460 }
461
462 #[test]
463 fn test_signature_delta_streaming_event_deserialization() {
464 let json = r#"{
465 "type": "content_block_delta",
466 "index": 0,
467 "delta": {
468 "type": "signature_delta",
469 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
470 }
471 }"#;
472
473 let event: StreamingEvent = serde_json::from_str(json).unwrap();
474
475 match event {
476 StreamingEvent::ContentBlockDelta { index, delta } => {
477 assert_eq!(index, 0);
478 match delta {
479 ContentDelta::SignatureDelta { signature } => {
480 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
481 }
482 _ => panic!("Expected SignatureDelta"),
483 }
484 }
485 _ => panic!("Expected ContentBlockDelta event"),
486 }
487 }
488
489 #[test]
490 fn test_handle_thinking_delta_event() {
491 let event = StreamingEvent::ContentBlockDelta {
492 index: 0,
493 delta: ContentDelta::ThinkingDelta {
494 thinking: "Analyzing the request...".to_string(),
495 },
496 };
497
498 let mut tool_call_state = None;
499 let mut thinking_state = None;
500 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
501
502 assert!(result.is_some());
503 let choice = result.unwrap().unwrap();
504
505 match choice {
506 RawStreamingChoice::Reasoning { id, reasoning, .. } => {
507 assert_eq!(id, None);
508 assert_eq!(reasoning, "Analyzing the request...");
509 }
510 _ => panic!("Expected Reasoning choice"),
511 }
512
513 assert!(thinking_state.is_some());
515 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
516 }
517
518 #[test]
519 fn test_handle_signature_delta_event() {
520 let event = StreamingEvent::ContentBlockDelta {
521 index: 0,
522 delta: ContentDelta::SignatureDelta {
523 signature: "test_signature".to_string(),
524 },
525 };
526
527 let mut tool_call_state = None;
528 let mut thinking_state = None;
529 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
530
531 assert!(result.is_none());
533
534 assert!(thinking_state.is_some());
536 assert_eq!(thinking_state.unwrap().signature, "test_signature");
537 }
538
539 #[test]
540 fn test_handle_text_delta_event() {
541 let event = StreamingEvent::ContentBlockDelta {
542 index: 0,
543 delta: ContentDelta::TextDelta {
544 text: "Hello, world!".to_string(),
545 },
546 };
547
548 let mut tool_call_state = None;
549 let mut thinking_state = None;
550 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
551
552 assert!(result.is_some());
553 let choice = result.unwrap().unwrap();
554
555 match choice {
556 RawStreamingChoice::Message(text) => {
557 assert_eq!(text, "Hello, world!");
558 }
559 _ => panic!("Expected Message choice"),
560 }
561 }
562
563 #[test]
564 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
565 let event = StreamingEvent::ContentBlockDelta {
567 index: 0,
568 delta: ContentDelta::ThinkingDelta {
569 thinking: "Thinking while tool is active...".to_string(),
570 },
571 };
572
573 let mut tool_call_state = Some(ToolCallState {
574 name: "test_tool".to_string(),
575 id: "tool_123".to_string(),
576 input_json: String::new(),
577 });
578 let mut thinking_state = None;
579
580 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
581
582 assert!(result.is_some());
583 let choice = result.unwrap().unwrap();
584
585 match choice {
586 RawStreamingChoice::Reasoning { reasoning, .. } => {
587 assert_eq!(reasoning, "Thinking while tool is active...");
588 }
589 _ => panic!("Expected Reasoning choice"),
590 }
591
592 assert!(tool_call_state.is_some());
594 }
595
596 #[test]
597 fn test_handle_input_json_delta_event() {
598 let event = StreamingEvent::ContentBlockDelta {
599 index: 0,
600 delta: ContentDelta::InputJsonDelta {
601 partial_json: "{\"arg\":\"value".to_string(),
602 },
603 };
604
605 let mut tool_call_state = Some(ToolCallState {
606 name: "test_tool".to_string(),
607 id: "tool_123".to_string(),
608 input_json: String::new(),
609 });
610 let mut thinking_state = None;
611
612 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
613
614 assert!(result.is_some());
616 let choice = result.unwrap().unwrap();
617
618 match choice {
619 RawStreamingChoice::ToolCallDelta { id, delta } => {
620 assert_eq!(id, "tool_123");
621 assert_eq!(delta, "{\"arg\":\"value");
622 }
623 _ => panic!("Expected ToolCallDelta choice, got {:?}", choice),
624 }
625
626 assert!(tool_call_state.is_some());
628 let state = tool_call_state.unwrap();
629 assert_eq!(state.input_json, "{\"arg\":\"value");
630 }
631
632 #[test]
633 fn test_tool_call_accumulation_with_multiple_deltas() {
634 let mut tool_call_state = Some(ToolCallState {
635 name: "test_tool".to_string(),
636 id: "tool_123".to_string(),
637 input_json: String::new(),
638 });
639 let mut thinking_state = None;
640
641 let event1 = StreamingEvent::ContentBlockDelta {
643 index: 0,
644 delta: ContentDelta::InputJsonDelta {
645 partial_json: "{\"location\":".to_string(),
646 },
647 };
648 let result1 = handle_event(&event1, &mut tool_call_state, &mut thinking_state);
649 assert!(result1.is_some());
650
651 let event2 = StreamingEvent::ContentBlockDelta {
653 index: 0,
654 delta: ContentDelta::InputJsonDelta {
655 partial_json: "\"Paris\",".to_string(),
656 },
657 };
658 let result2 = handle_event(&event2, &mut tool_call_state, &mut thinking_state);
659 assert!(result2.is_some());
660
661 let event3 = StreamingEvent::ContentBlockDelta {
663 index: 0,
664 delta: ContentDelta::InputJsonDelta {
665 partial_json: "\"temp\":\"20C\"}".to_string(),
666 },
667 };
668 let result3 = handle_event(&event3, &mut tool_call_state, &mut thinking_state);
669 assert!(result3.is_some());
670
671 assert!(tool_call_state.is_some());
673 let state = tool_call_state.as_ref().unwrap();
674 assert_eq!(
675 state.input_json,
676 "{\"location\":\"Paris\",\"temp\":\"20C\"}"
677 );
678
679 let stop_event = StreamingEvent::ContentBlockStop { index: 0 };
681 let final_result = handle_event(&stop_event, &mut tool_call_state, &mut thinking_state);
682 assert!(final_result.is_some());
683
684 match final_result.unwrap().unwrap() {
685 RawStreamingChoice::ToolCall {
686 id,
687 name,
688 arguments,
689 ..
690 } => {
691 assert_eq!(id, "tool_123");
692 assert_eq!(name, "test_tool");
693 assert_eq!(
694 arguments.get("location").unwrap().as_str().unwrap(),
695 "Paris"
696 );
697 assert_eq!(arguments.get("temp").unwrap().as_str().unwrap(), "20C");
698 }
699 other => panic!("Expected ToolCall, got {:?}", other),
700 }
701
702 assert!(tool_call_state.is_none());
704 }
705}