1use async_stream::stream;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4use serde_json::json;
5use tracing::info_span;
6use tracing_futures::Instrument;
7
8use super::completion::{CompletionModel, Content, Message, ToolChoice, ToolDefinition, Usage};
9use super::decoders::sse::from_response as sse_from_response;
10use crate::OneOrMany;
11use crate::completion::{CompletionError, CompletionRequest, GetTokenUsage};
12use crate::http_client::{self, HttpClientExt};
13use crate::json_utils::merge_inplace;
14use crate::streaming::{self, RawStreamingChoice, StreamingResult};
15use crate::telemetry::SpanCombinator;
16
17#[derive(Debug, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum StreamingEvent {
20 MessageStart {
21 message: MessageStart,
22 },
23 ContentBlockStart {
24 index: usize,
25 content_block: Content,
26 },
27 ContentBlockDelta {
28 index: usize,
29 delta: ContentDelta,
30 },
31 ContentBlockStop {
32 index: usize,
33 },
34 MessageDelta {
35 delta: MessageDelta,
36 usage: PartialUsage,
37 },
38 MessageStop,
39 Ping,
40 #[serde(other)]
41 Unknown,
42}
43
44#[derive(Debug, Deserialize)]
45pub struct MessageStart {
46 pub id: String,
47 pub role: String,
48 pub content: Vec<Content>,
49 pub model: String,
50 pub stop_reason: Option<String>,
51 pub stop_sequence: Option<String>,
52 pub usage: Usage,
53}
54
55#[derive(Debug, Deserialize)]
56#[serde(tag = "type", rename_all = "snake_case")]
57pub enum ContentDelta {
58 TextDelta { text: String },
59 InputJsonDelta { partial_json: String },
60 ThinkingDelta { thinking: String },
61 SignatureDelta { signature: String },
62}
63
64#[derive(Debug, Deserialize)]
65pub struct MessageDelta {
66 pub stop_reason: Option<String>,
67 pub stop_sequence: Option<String>,
68}
69
70#[derive(Debug, Deserialize, Clone, Serialize)]
71pub struct PartialUsage {
72 pub output_tokens: usize,
73 #[serde(default)]
74 pub input_tokens: Option<usize>,
75}
76
77impl GetTokenUsage for PartialUsage {
78 fn token_usage(&self) -> Option<crate::completion::Usage> {
79 let mut usage = crate::completion::Usage::new();
80
81 usage.input_tokens = self.input_tokens.unwrap_or_default() as u64;
82 usage.output_tokens = self.output_tokens as u64;
83 usage.total_tokens = usage.input_tokens + usage.output_tokens;
84 Some(usage)
85 }
86}
87
88#[derive(Default)]
89struct ToolCallState {
90 name: String,
91 id: String,
92 input_json: String,
93}
94
95#[derive(Clone, Deserialize, Serialize)]
96pub struct StreamingCompletionResponse {
97 pub usage: PartialUsage,
98}
99
100impl GetTokenUsage for StreamingCompletionResponse {
101 fn token_usage(&self) -> Option<crate::completion::Usage> {
102 let mut usage = crate::completion::Usage::new();
103 usage.input_tokens = self.usage.input_tokens.unwrap_or(0) as u64;
104 usage.output_tokens = self.usage.output_tokens as u64;
105 usage.total_tokens =
106 self.usage.input_tokens.unwrap_or(0) as u64 + self.usage.output_tokens as u64;
107
108 Some(usage)
109 }
110}
111
112impl<T> CompletionModel<T>
113where
114 T: HttpClientExt + Clone + Default,
115{
116 pub(crate) async fn stream(
117 &self,
118 completion_request: CompletionRequest,
119 ) -> Result<streaming::StreamingCompletionResponse<StreamingCompletionResponse>, CompletionError>
120 {
121 let span = if tracing::Span::current().is_disabled() {
122 info_span!(
123 target: "rig::completions",
124 "chat_streaming",
125 gen_ai.operation.name = "chat_streaming",
126 gen_ai.provider.name = "anthropic",
127 gen_ai.request.model = self.model,
128 gen_ai.system_instructions = &completion_request.preamble,
129 gen_ai.response.id = tracing::field::Empty,
130 gen_ai.response.model = self.model,
131 gen_ai.usage.output_tokens = tracing::field::Empty,
132 gen_ai.usage.input_tokens = tracing::field::Empty,
133 gen_ai.input.messages = tracing::field::Empty,
134 gen_ai.output.messages = tracing::field::Empty,
135 )
136 } else {
137 tracing::Span::current()
138 };
139 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
140 tokens
141 } else if let Some(tokens) = self.default_max_tokens {
142 tokens
143 } else {
144 return Err(CompletionError::RequestError(
145 "`max_tokens` must be set for Anthropic".into(),
146 ));
147 };
148
149 let mut full_history = vec![];
150 if let Some(docs) = completion_request.normalized_documents() {
151 full_history.push(docs);
152 }
153 full_history.extend(completion_request.chat_history);
154 span.record_model_input(&full_history);
155
156 let full_history = full_history
157 .into_iter()
158 .map(Message::try_from)
159 .collect::<Result<Vec<Message>, _>>()?;
160
161 let mut body = json!({
162 "model": self.model,
163 "messages": full_history,
164 "max_tokens": max_tokens,
165 "system": completion_request.preamble.unwrap_or("".to_string()),
166 "stream": true,
167 });
168
169 if let Some(temperature) = completion_request.temperature {
170 merge_inplace(&mut body, json!({ "temperature": temperature }));
171 }
172
173 if !completion_request.tools.is_empty() {
174 merge_inplace(
175 &mut body,
176 json!({
177 "tools": completion_request
178 .tools
179 .into_iter()
180 .map(|tool| ToolDefinition {
181 name: tool.name,
182 description: Some(tool.description),
183 input_schema: tool.parameters,
184 })
185 .collect::<Vec<_>>(),
186 "tool_choice": ToolChoice::Auto,
187 }),
188 );
189 }
190
191 if let Some(ref params) = completion_request.additional_params {
192 merge_inplace(&mut body, params.clone())
193 }
194
195 let body: Vec<u8> = serde_json::to_vec(&body)?;
196
197 let req = self
198 .client
199 .post("/v1/messages")
200 .header("Content-Type", "application/json")
201 .body(body)
202 .map_err(http_client::Error::Protocol)?;
203
204 let response = self.client.send_streaming(req).await?;
205
206 if !response.status().is_success() {
207 let mut stream = response.into_body();
208 let mut text = String::with_capacity(1024);
209 loop {
210 let Some(chunk) = stream.next().await else {
211 break;
212 };
213
214 let chunk: Vec<u8> = chunk?.into();
215
216 let str = String::from_utf8_lossy(&chunk);
217
218 text.push_str(&str)
219 }
220 return Err(CompletionError::ProviderError(text));
221 }
222
223 let stream = sse_from_response(response.into_body());
224
225 let stream: StreamingResult<StreamingCompletionResponse> = Box::pin(stream! {
227 let mut current_tool_call: Option<ToolCallState> = None;
228 let mut sse_stream = Box::pin(stream);
229 let mut input_tokens = 0;
230
231 let mut text_content = String::new();
232
233 while let Some(sse_result) = sse_stream.next().await {
234 match sse_result {
235 Ok(sse) => {
236 match serde_json::from_str::<StreamingEvent>(&sse.data) {
238 Ok(event) => {
239 match &event {
240 StreamingEvent::MessageStart { message } => {
241 input_tokens = message.usage.input_tokens;
242
243 let span = tracing::Span::current();
244 span.record("gen_ai.response.id", &message.id);
245 span.record("gen_ai.response.model_name", &message.model);
246 },
247 StreamingEvent::MessageDelta { delta, usage } => {
248 if delta.stop_reason.is_some() {
249 let usage = PartialUsage {
250 output_tokens: usage.output_tokens,
251 input_tokens: Some(input_tokens.try_into().expect("Failed to convert input_tokens to usize")),
252 };
253
254 let span = tracing::Span::current();
255 span.record_token_usage(&usage);
256 span.record_model_output(&Message {
257 role: super::completion::Role::Assistant,
258 content: OneOrMany::one(Content::Text { text: text_content.clone() })}
259 );
260
261 yield Ok(RawStreamingChoice::FinalResponse(StreamingCompletionResponse {
262 usage
263 }))
264 }
265 }
266 _ => {}
267 }
268
269 if let Some(result) = handle_event(&event, &mut current_tool_call) {
270 if let Ok(RawStreamingChoice::Message(ref text)) = result {
271 text_content += text;
272 }
273 yield result;
274 }
275 },
276 Err(e) => {
277 if !sse.data.trim().is_empty() {
278 yield Err(CompletionError::ResponseError(
279 format!("Failed to parse JSON: {} (Data: {})", e, sse.data)
280 ));
281 }
282 }
283 }
284 },
285 Err(e) => {
286 yield Err(CompletionError::ResponseError(format!("SSE Error: {e}")));
287 break;
288 }
289 }
290 }
291 }.instrument(span));
292
293 Ok(streaming::StreamingCompletionResponse::stream(stream))
294 }
295}
296
297fn handle_event(
298 event: &StreamingEvent,
299 current_tool_call: &mut Option<ToolCallState>,
300) -> Option<Result<RawStreamingChoice<StreamingCompletionResponse>, CompletionError>> {
301 match event {
302 StreamingEvent::ContentBlockDelta { delta, .. } => match delta {
303 ContentDelta::TextDelta { text } => {
304 if current_tool_call.is_none() {
305 return Some(Ok(RawStreamingChoice::Message(text.clone())));
306 }
307 None
308 }
309 ContentDelta::InputJsonDelta { partial_json } => {
310 if let Some(tool_call) = current_tool_call {
311 tool_call.input_json.push_str(partial_json);
312 }
313 None
314 }
315 ContentDelta::ThinkingDelta { thinking } => Some(Ok(RawStreamingChoice::Reasoning {
316 id: None,
317 reasoning: thinking.clone(),
318 })),
319 ContentDelta::SignatureDelta { .. } => {
320 None
322 }
323 },
324 StreamingEvent::ContentBlockStart { content_block, .. } => match content_block {
325 Content::ToolUse { id, name, .. } => {
326 *current_tool_call = Some(ToolCallState {
327 name: name.clone(),
328 id: id.clone(),
329 input_json: String::new(),
330 });
331 None
332 }
333 _ => None,
335 },
336 StreamingEvent::ContentBlockStop { .. } => {
337 if let Some(tool_call) = Option::take(current_tool_call) {
338 let json_str = if tool_call.input_json.is_empty() {
339 "{}"
340 } else {
341 &tool_call.input_json
342 };
343 match serde_json::from_str(json_str) {
344 Ok(json_value) => Some(Ok(RawStreamingChoice::ToolCall {
345 name: tool_call.name,
346 id: tool_call.id,
347 arguments: json_value,
348 call_id: None,
349 })),
350 Err(e) => Some(Err(CompletionError::from(e))),
351 }
352 } else {
353 None
354 }
355 }
356 StreamingEvent::MessageStart { .. }
358 | StreamingEvent::MessageDelta { .. }
359 | StreamingEvent::MessageStop
360 | StreamingEvent::Ping
361 | StreamingEvent::Unknown => None,
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_thinking_delta_deserialization() {
371 let json = r#"{"type": "thinking_delta", "thinking": "Let me think about this..."}"#;
372 let delta: ContentDelta = serde_json::from_str(json).unwrap();
373
374 match delta {
375 ContentDelta::ThinkingDelta { thinking } => {
376 assert_eq!(thinking, "Let me think about this...");
377 }
378 _ => panic!("Expected ThinkingDelta variant"),
379 }
380 }
381
382 #[test]
383 fn test_signature_delta_deserialization() {
384 let json = r#"{"type": "signature_delta", "signature": "abc123def456"}"#;
385 let delta: ContentDelta = serde_json::from_str(json).unwrap();
386
387 match delta {
388 ContentDelta::SignatureDelta { signature } => {
389 assert_eq!(signature, "abc123def456");
390 }
391 _ => panic!("Expected SignatureDelta variant"),
392 }
393 }
394
395 #[test]
396 fn test_thinking_delta_streaming_event_deserialization() {
397 let json = r#"{
398 "type": "content_block_delta",
399 "index": 0,
400 "delta": {
401 "type": "thinking_delta",
402 "thinking": "First, I need to understand the problem."
403 }
404 }"#;
405
406 let event: StreamingEvent = serde_json::from_str(json).unwrap();
407
408 match event {
409 StreamingEvent::ContentBlockDelta { index, delta } => {
410 assert_eq!(index, 0);
411 match delta {
412 ContentDelta::ThinkingDelta { thinking } => {
413 assert_eq!(thinking, "First, I need to understand the problem.");
414 }
415 _ => panic!("Expected ThinkingDelta"),
416 }
417 }
418 _ => panic!("Expected ContentBlockDelta event"),
419 }
420 }
421
422 #[test]
423 fn test_signature_delta_streaming_event_deserialization() {
424 let json = r#"{
425 "type": "content_block_delta",
426 "index": 0,
427 "delta": {
428 "type": "signature_delta",
429 "signature": "ErUBCkYICBgCIkCaGbqC85F4"
430 }
431 }"#;
432
433 let event: StreamingEvent = serde_json::from_str(json).unwrap();
434
435 match event {
436 StreamingEvent::ContentBlockDelta { index, delta } => {
437 assert_eq!(index, 0);
438 match delta {
439 ContentDelta::SignatureDelta { signature } => {
440 assert_eq!(signature, "ErUBCkYICBgCIkCaGbqC85F4");
441 }
442 _ => panic!("Expected SignatureDelta"),
443 }
444 }
445 _ => panic!("Expected ContentBlockDelta event"),
446 }
447 }
448
449 #[test]
450 fn test_handle_thinking_delta_event() {
451 let event = StreamingEvent::ContentBlockDelta {
452 index: 0,
453 delta: ContentDelta::ThinkingDelta {
454 thinking: "Analyzing the request...".to_string(),
455 },
456 };
457
458 let mut tool_call_state = None;
459 let result = handle_event(&event, &mut tool_call_state);
460
461 assert!(result.is_some());
462 let choice = result.unwrap().unwrap();
463
464 match choice {
465 RawStreamingChoice::Reasoning { id, reasoning } => {
466 assert_eq!(id, None);
467 assert_eq!(reasoning, "Analyzing the request...");
468 }
469 _ => panic!("Expected Reasoning choice"),
470 }
471 }
472
473 #[test]
474 fn test_handle_signature_delta_event() {
475 let event = StreamingEvent::ContentBlockDelta {
476 index: 0,
477 delta: ContentDelta::SignatureDelta {
478 signature: "test_signature".to_string(),
479 },
480 };
481
482 let mut tool_call_state = None;
483 let result = handle_event(&event, &mut tool_call_state);
484
485 assert!(result.is_none());
487 }
488
489 #[test]
490 fn test_handle_text_delta_event() {
491 let event = StreamingEvent::ContentBlockDelta {
492 index: 0,
493 delta: ContentDelta::TextDelta {
494 text: "Hello, world!".to_string(),
495 },
496 };
497
498 let mut tool_call_state = None;
499 let result = handle_event(&event, &mut tool_call_state);
500
501 assert!(result.is_some());
502 let choice = result.unwrap().unwrap();
503
504 match choice {
505 RawStreamingChoice::Message(text) => {
506 assert_eq!(text, "Hello, world!");
507 }
508 _ => panic!("Expected Message choice"),
509 }
510 }
511
512 #[test]
513 fn test_thinking_delta_does_not_interfere_with_tool_calls() {
514 let event = StreamingEvent::ContentBlockDelta {
516 index: 0,
517 delta: ContentDelta::ThinkingDelta {
518 thinking: "Thinking while tool is active...".to_string(),
519 },
520 };
521
522 let mut tool_call_state = Some(ToolCallState {
523 name: "test_tool".to_string(),
524 id: "tool_123".to_string(),
525 input_json: String::new(),
526 });
527
528 let result = handle_event(&event, &mut tool_call_state);
529
530 assert!(result.is_some());
531 let choice = result.unwrap().unwrap();
532
533 match choice {
534 RawStreamingChoice::Reasoning { reasoning, .. } => {
535 assert_eq!(reasoning, "Thinking while tool is active...");
536 }
537 _ => panic!("Expected Reasoning choice"),
538 }
539
540 assert!(tool_call_state.is_some());
542 }
543}