1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use crate::OneOrMany;
10use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
11use crate::http_client::sse::{Event, GenericEventSource};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20 MessageStart {
21 message: MessageStart,
22 },
23 ContentBlockStart {
24 index: usize,
25 content_block: Content,
26 },
27 ContentBlockDelta {
28 index: usize,
29 delta: ContentDelta,
30 },
31 ContentBlockStop {
32 index: usize,
33 },
34 MessageDelta {
35 delta: MessageDelta,
36 usage: PartialUsage,
37 },
38 MessageStop,
39 Ping,
40 #[serde(other)]
41 Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46 pub id: String,
47 pub role: String,
48 pub content: Vec<Content>,
49 pub model: String,
50 pub stop_reason: Option<String>,
51 pub stop_sequence: Option<String>,
52 pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58 TextDelta { text: String },
59 InputJsonDelta { partial_json: String },
60 ThinkingDelta { thinking: String },
61 SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66 pub stop_reason: Option<String>,
67 pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct PartialUsage {
72 pub output_tokens: usize,
73 #[serde(default)]
74 pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78 fn token_usage(&self) -> Option<crate::completion::Usage> {
79 let mut usage = crate::completion::Usage::new();
80
81 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82 usage.output_tokens = self.output_tokens as u64;
83 usage.total_tokens = usage.input_tokens + usage.output_tokens;
84 Some(usage)
85 }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90 name: String,
91 id: String,
92 input_json: String,
93}
94
95#[derive(Default)]
96struct ThinkingState {
97 thinking: String,
98 signature: String,
99}
100
101#[derive(Clone, Deserialize, Serialize)]
102pub struct StreamingCompletionResponse {
103 pub usage: PartialUsage,
104}
105
106impl GetTokenUsage for StreamingCompletionResponse {
107 fn token_usage(&self) -> Option<crate::completion::Usage> {
108 let mut usage = crate::completion::Usage::new();
109 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
110 usage.output_tokens = self.usage.output_tokens as u64;
111 usage.total_tokens =
112 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
113
114 Some(usage)
115 }
116}
117
118impl<T> CompletionModel<T>
119where
120 T: HttpClientExt + Clone + Default + 'static,
121{
122 pub(crate) async fn stream(
123 &self,
124 completion_request: CompletionRequest,
125 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
126 {
127 let span = if tracing::Span::current().is_disabled() {
128 info_span!(
129 target: "rig::completions",
130 "chat_streaming",
131 gen_ai.operation.name = "chat_streaming",
132 gen_ai.provider.name = "anthropic",
133 gen_ai.request.model = self.model,
134 gen_ai.system_instructions = &completion_request.preamble,
135 gen_ai.response.id = tracing::field::Empty,
136 gen_ai.response.model = self.model,
137 gen_ai.usage.output_tokens = tracing::field::Empty,
138 gen_ai.usage.input_tokens = tracing::field::Empty,
139 gen_ai.input.messages = tracing::field::Empty,
140 gen_ai.output.messages = tracing::field::Empty,
141 )
142 } else {
143 tracing::Span::current()
144 };
145 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
146 tokens
147 } else if let Some(tokens) = self.default_max_tokens {
148 tokens
149 } else {
150 return Err(CompletionError::RequestError(
151 "`max_tokens` must be set for Anthropic".into(),
152 ));
153 };
154
155 let mut full_history = vec![];
156 if let Some(docs) = completion_request.normalized_documents() {
157 full_history.push(docs);
158 }
159 full_history.extend(completion_request.chat_history);
160 span.record_model_input(&full_history);
161
162 let full_history = full_history
163 .into_iter()
164 .map(Message::try_from)
165 .collect::<Result<Vec<Message>, _>>()?;
166
167 let mut body = json!({
168 "model": self.model,
169 "messages": full_history,
170 "max_tokens": max_tokens,
171 "system": completion_request.preamble.unwrap_or("".to_string()),
172 "stream": true,
173 });
174
175 if let Some(temperature) = completion_request.temperature {
176 merge_inplace(&mut body, json!({ "temperature": temperature }));
177 }
178
179 if !completion_request.tools.is_empty() {
180 merge_inplace(
181 &mut body,
182 json!({
183 "tools": completion_request
184 .tools
185 .into_iter()
186 .map(|tool| ToolDefinition {
187 name: tool.name,
188 description: Some(tool.description),
189 input_schema: tool.parameters,
190 })
191 .collect::<Vec<_>>(),
192 "tool_choice": ToolChoice::Auto,
193 }),
194 );
195 }
196
197 if let Some(ref params) = completion_request.additional_params {
198 merge_inplace(&mut body, params.clone())
199 }
200
201 let body: Vec<u8> = serde_json::to_vec(&body)?;
202
203 let req = self
204 .client
205 .post("/v1/messages")
206 .header("Content-Type", "application/json")
207 .body(body)
208 .map_err(http_client::Error::Protocol)?;
209
210 let stream = GenericEventSource::new(self.client.http_client.clone(), req);
211
212 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
214 let mut current_tool_call: Option<ToolCallState> = None;
215 let mut current_thinking: Option<ThinkingState> = None;
216 let mut sse_stream = Box::pin(stream);
217 let mut input_tokens = 0;
218
219 let mut text_content = String::new();
220
221 while let Some(sse_result) = sse_stream.next().await {
222 match sse_result {
223 Ok(Event::Open) => {}
224 Ok(Event::Message(sse)) => {
225 match serde_json::from_str::<StreamingEvent>(&sse.data) {
227 Ok(event) => {
228 match &event {
229 StreamingEvent::MessageStart { message } => {
230 input_tokens = message.usage.input_tokens;
231
232 let span = tracing::Span::current();
233 span.record("gen_ai.response.id", &message.id);
234 span.record("gen_ai.response.model_name", &message.model);
235 },
236 StreamingEvent::MessageDelta { delta, usage } => {
237 if delta.stop_reason.is_some() {
238 let usage = PartialUsage {
239 output_tokens: usage.output_tokens,
240 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
241 };
242
243 let span = tracing::Span::current();
244 span.record_token_usage(&usage);
245 span.record_model_output(&Message {
246 role: super::completion::Role::Assistant,
247 content: OneOrMany::one(Content::Text { text: text_content.clone() })}
248 );
249
250 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
251 usage
252 }))
253 }
254 }
255 _ => {}
256 }
257
258 if let Some(result) = handle_event(&event, &mut current_tool_call, &mut current_thinking) {
259 if let Ok(RawStreamingChoice::Message(ref text)) = result {
260 text_content += text;
261 }
262 yield result;
263 }
264 },
265 Err(e) => {
266 if !sse.data.trim().is_empty() {
267 yield Err(CompletionError::ResponseError(
268 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
269 ));
270 }
271 }
272 }
273 },
274 Err(e) => {
275 yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
276 break;
277 }
278 }
279 }
280 }.instrument(span));
281
282 Ok(streaming::StreamingCompletionResponse::stream(stream))
283 }
284}
285
286fn handle_event(
287 event: &StreamingEvent,
288 current_tool_call: &mut Option<ToolCallState>,
289 current_thinking: &mut Option<ThinkingState>,
290) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
291 match event {
292 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
293 ContentDelta::TextDelta { text } => {
294 if current_tool_call.is_none() {
295 return Some(Ok(RawStreamingChoice::Message(text.clone())));
296 }
297 None
298 }
299 ContentDelta::InputJsonDelta { partial_json } => {
300 if let Some(tool_call) = current_tool_call {
301 tool_call.input_json.push_str(partial_json);
302 }
303 None
304 }
305 ContentDelta::ThinkingDelta { thinking } => {
306 if current_thinking.is_none() {
307 *current_thinking = Some(ThinkingState::default());
308 }
309
310 if let Some(state) = current_thinking {
311 state.thinking.push_str(thinking);
312 }
313
314 Some(Ok(RawStreamingChoice::Reasoning {
315 id: None,
316 reasoning: thinking.clone(),
317 signature: None,
318 }))
319 }
320 ContentDelta::SignatureDelta { signature } => {
321 if current_thinking.is_none() {
322 *current_thinking = Some(ThinkingState::default());
323 }
324
325 if let Some(state) = current_thinking {
326 state.signature.push_str(signature);
327 }
328
329 None
331 }
332 },
333 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
334 Content::ToolUse { id, name, .. } => {
335 *current_tool_call = Some(ToolCallState {
336 name: name.clone(),
337 id: id.clone(),
338 input_json: String::new(),
339 });
340 None
341 }
342 Content::Thinking { .. } => {
343 *current_thinking = Some(ThinkingState::default());
344 None
345 }
346 _ => None,
348 },
349 StreamingEvent::ContentBlockStop { .. } => {
350 if let Some(thinking_state) = Option::take(current_thinking)
351 && !thinking_state.thinking.is_empty()
352 {
353 let signature = if thinking_state.signature.is_empty() {
354 None
355 } else {
356 Some(thinking_state.signature)
357 };
358
359 return Some(Ok(RawStreamingChoice::Reasoning {
360 id: None,
361 reasoning: thinking_state.thinking,
362 signature,
363 }));
364 }
365
366 if let Some(tool_call) = Option::take(current_tool_call) {
367 let json_str = if tool_call.input_json.is_empty() {
368 "{}"
369 } else {
370 &tool_call.input_json
371 };
372 match serde_json::from_str(json_str) {
373 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
374 name: tool_call.name,
375 id: tool_call.id,
376 arguments: json_value,
377 call_id: None,
378 })),
379 Err(e) => Some(Err(CompletionError::from(e))),
380 }
381 } else {
382 None
383 }
384 }
385 StreamingEvent::MessageStart { .. }
387 | StreamingEvent::MessageDelta { .. }
388 | StreamingEvent::MessageStop
389 | StreamingEvent::Ping
390 | StreamingEvent::Unknown => None,
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn test_thinking_delta_deserialization() {
400 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
401 let delta: ContentDelta = serde_json::from_str(json).unwrap();
402
403 match delta {
404 ContentDelta::ThinkingDelta { thinking } => {
405 assert_eq!(thinking, "Let me think about this...");
406 }
407 _ => panic!("Expected ThinkingDelta variant"),
408 }
409 }
410
411 #[test]
412 fn test_signature_delta_deserialization() {
413 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
414 let delta: ContentDelta = serde_json::from_str(json).unwrap();
415
416 match delta {
417 ContentDelta::SignatureDelta { signature } => {
418 assert_eq!(signature, "abc123def456");
419 }
420 _ => panic!("Expected SignatureDelta variant"),
421 }
422 }
423
424 #[test]
425 fn test_thinking_delta_streaming_event_deserialization() {
426 let json = r#"{
427 "type": "content_block_delta",
428 "index": 0,
429 "delta": {
430 "type": "thinking_delta",
431 "thinking": "First, I need to understand the problem."
432 }
433 }"#;
434
435 let event: StreamingEvent = serde_json::from_str(json).unwrap();
436
437 match event {
438 StreamingEvent::ContentBlockDelta { index, delta } => {
439 assert_eq!(index, 0);
440 match delta {
441 ContentDelta::ThinkingDelta { thinking } => {
442 assert_eq!(thinking, "First, I need to understand the problem.");
443 }
444 _ => panic!("Expected ThinkingDelta"),
445 }
446 }
447 _ => panic!("Expected ContentBlockDelta event"),
448 }
449 }
450
451 #[test]
452 fn test_signature_delta_streaming_event_deserialization() {
453 let json = r#"{
454 "type": "content_block_delta",
455 "index": 0,
456 "delta": {
457 "type": "signature_delta",
458 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
459 }
460 }"#;
461
462 let event: StreamingEvent = serde_json::from_str(json).unwrap();
463
464 match event {
465 StreamingEvent::ContentBlockDelta { index, delta } => {
466 assert_eq!(index, 0);
467 match delta {
468 ContentDelta::SignatureDelta { signature } => {
469 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
470 }
471 _ => panic!("Expected SignatureDelta"),
472 }
473 }
474 _ => panic!("Expected ContentBlockDelta event"),
475 }
476 }
477
478 #[test]
479 fn test_handle_thinking_delta_event() {
480 let event = StreamingEvent::ContentBlockDelta {
481 index: 0,
482 delta: ContentDelta::ThinkingDelta {
483 thinking: "Analyzing the request...".to_string(),
484 },
485 };
486
487 let mut tool_call_state = None;
488 let mut thinking_state = None;
489 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
490
491 assert!(result.is_some());
492 let choice = result.unwrap().unwrap();
493
494 match choice {
495 RawStreamingChoice::Reasoning { id, reasoning, .. } => {
496 assert_eq!(id, None);
497 assert_eq!(reasoning, "Analyzing the request...");
498 }
499 _ => panic!("Expected Reasoning choice"),
500 }
501
502 assert!(thinking_state.is_some());
504 assert_eq!(thinking_state.unwrap().thinking, "Analyzing the request...");
505 }
506
507 #[test]
508 fn test_handle_signature_delta_event() {
509 let event = StreamingEvent::ContentBlockDelta {
510 index: 0,
511 delta: ContentDelta::SignatureDelta {
512 signature: "test_signature".to_string(),
513 },
514 };
515
516 let mut tool_call_state = None;
517 let mut thinking_state = None;
518 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
519
520 assert!(result.is_none());
522
523 assert!(thinking_state.is_some());
525 assert_eq!(thinking_state.unwrap().signature, "test_signature");
526 }
527
528 #[test]
529 fn test_handle_text_delta_event() {
530 let event = StreamingEvent::ContentBlockDelta {
531 index: 0,
532 delta: ContentDelta::TextDelta {
533 text: "Hello, world!".to_string(),
534 },
535 };
536
537 let mut tool_call_state = None;
538 let mut thinking_state = None;
539 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
540
541 assert!(result.is_some());
542 let choice = result.unwrap().unwrap();
543
544 match choice {
545 RawStreamingChoice::Message(text) => {
546 assert_eq!(text, "Hello, world!");
547 }
548 _ => panic!("Expected Message choice"),
549 }
550 }
551
552 #[test]
553 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
554 let event = StreamingEvent::ContentBlockDelta {
556 index: 0,
557 delta: ContentDelta::ThinkingDelta {
558 thinking: "Thinking while tool is active...".to_string(),
559 },
560 };
561
562 let mut tool_call_state = Some(ToolCallState {
563 name: "test_tool".to_string(),
564 id: "tool_123".to_string(),
565 input_json: String::new(),
566 });
567 let mut thinking_state = None;
568
569 let result = handle_event(&event, &mut tool_call_state, &mut thinking_state);
570
571 assert!(result.is_some());
572 let choice = result.unwrap().unwrap();
573
574 match choice {
575 RawStreamingChoice::Reasoning { reasoning, .. } => {
576 assert_eq!(reasoning, "Thinking while tool is active...");
577 }
578 _ => panic!("Expected Reasoning choice"),
579 }
580
581 assert!(tool_call_state.is_some());
583 }
584}