1use crate::types::completion_request::AwsCompletionRequest;
2use crate::{
3 completion::{CompletionModel, resolve_request_model},
4 types::errors::AwsSdkConverseStreamError,
5};
6use async_stream::stream;
7use aws_sdk_bedrockruntime::types as aws_bedrock;
8use rig::completion::GetTokenUsage;
9use rig::streaming::StreamingCompletionResponse;
10use rig::{
11 completion::CompletionError,
12 message::ReasoningContent,
13 streaming::{RawStreamingChoice, RawStreamingToolCall, ToolCallDeltaContent},
14};
15use serde::{Deserialize, Serialize};
16
17#[derive(Clone, Deserialize, Serialize)]
18pub struct BedrockStreamingResponse {
19 pub usage: Option<BedrockUsage>,
20}
21
22#[derive(Clone, Deserialize, Serialize)]
23pub struct BedrockUsage {
24 pub input_tokens: i32,
25 pub output_tokens: i32,
26 pub total_tokens: i32,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub cache_read_input_tokens: Option<i32>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub cache_write_input_tokens: Option<i32>,
31}
32
33impl GetTokenUsage for BedrockStreamingResponse {
34 fn token_usage(&self) -> Option<rig::completion::Usage> {
35 self.usage.as_ref().map(|u| rig::completion::Usage {
36 input_tokens: u.input_tokens as u64,
37 output_tokens: u.output_tokens as u64,
38 total_tokens: u.total_tokens as u64,
39 cached_input_tokens: u.cache_read_input_tokens.unwrap_or_default() as u64,
40 cache_creation_input_tokens: u.cache_write_input_tokens.unwrap_or_default() as u64,
41 })
42 }
43}
44
45#[derive(Default)]
46struct ToolCallState {
47 name: String,
48 id: String,
49 internal_call_id: String,
50 input_json: String,
51}
52
53#[derive(Default)]
54struct ReasoningState {
55 content: String,
56 signature: Option<String>,
57}
58
59impl CompletionModel {
60 pub(crate) async fn stream(
61 &self,
62 completion_request: rig::completion::CompletionRequest,
63 ) -> Result<StreamingCompletionResponse<BedrockStreamingResponse>, CompletionError> {
64 let request_model = resolve_request_model(&self.model, &completion_request);
65 let request = AwsCompletionRequest {
66 inner: completion_request,
67 prompt_caching: self.prompt_caching,
68 };
69
70 let mut converse_builder = self
71 .client
72 .get_inner()
73 .await
74 .converse_stream()
75 .model_id(request_model);
76
77 let tool_config = request.tools_config()?;
78 let prompt_with_history = request.messages()?;
79 converse_builder = converse_builder
80 .set_additional_model_request_fields(request.additional_params())
81 .set_inference_config(request.inference_config())
82 .set_tool_config(tool_config)
83 .set_system(request.system_prompt())
84 .set_messages(Some(prompt_with_history));
85
86 let response = converse_builder.send().await.map_err(|sdk_error| {
87 Into::<CompletionError>::into(AwsSdkConverseStreamError(sdk_error))
88 })?;
89
90 let stream = Box::pin(stream! {
91 let mut current_tool_call: Option<ToolCallState> = None;
92 let mut current_reasoning: Option<ReasoningState> = None;
93 let mut stream = response.stream;
94 while let Ok(Some(output)) = stream.recv().await {
95 match output {
96 aws_bedrock::ConverseStreamOutput::ContentBlockDelta(event) => {
97 let delta = event.delta.ok_or(CompletionError::ProviderError("The delta for a content block is missing".into()))?;
98 match delta {
99 aws_bedrock::ContentBlockDelta::Text(text) => {
100 if current_tool_call.is_none() {
101 yield Ok(RawStreamingChoice::Message(text))
102 }
103 },
104 aws_bedrock::ContentBlockDelta::ToolUse(tool) => {
105 if let Some(ref mut tool_call) = current_tool_call {
106 let delta = tool.input().to_string();
107 tool_call.input_json.push_str(&delta);
108
109 yield Ok(RawStreamingChoice::ToolCallDelta {
111 id: tool_call.id.clone(),
112 internal_call_id: tool_call.internal_call_id.clone(),
113 content: ToolCallDeltaContent::Delta(delta),
114 });
115 }
116 },
117 aws_bedrock::ContentBlockDelta::ReasoningContent(reasoning) => {
118 match reasoning {
119 aws_bedrock::ReasoningContentBlockDelta::Text(text) => {
120 if current_reasoning.is_none() {
121 current_reasoning = Some(ReasoningState::default());
122 }
123
124 if let Some(ref mut state) = current_reasoning {
125 state.content.push_str(text.as_str());
126 }
127
128 if !text.is_empty() {
129 yield Ok(RawStreamingChoice::ReasoningDelta {
130 reasoning: text.clone(),
131 id: None,
132 })
133 }
134 },
135 aws_bedrock::ReasoningContentBlockDelta::Signature(signature) => {
136 if current_reasoning.is_none() {
137 current_reasoning = Some(ReasoningState::default());
138 }
139
140 if let Some(ref mut state) = current_reasoning {
141 state.signature = Some(signature.clone());
142 }
143 },
144 _ => {}
145 }
146 },
147 _ => {}
148 }
149 },
150 aws_bedrock::ConverseStreamOutput::ContentBlockStart(event) => {
151 match event.start.ok_or(CompletionError::ProviderError("ContentBlockStart has no data".into()))? {
152 aws_bedrock::ContentBlockStart::ToolUse(tool_use) => {
153 let internal_call_id = nanoid::nanoid!();
154 current_tool_call = Some(ToolCallState {
155 name: tool_use.name.clone(),
156 id: tool_use.tool_use_id.clone(),
157 internal_call_id: internal_call_id.clone(),
158 input_json: String::new(),
159 });
160 yield Ok(RawStreamingChoice::ToolCallDelta {
161 id: tool_use.tool_use_id,
162 internal_call_id,
163 content: ToolCallDeltaContent::Name(tool_use.name),
164 });
165 },
166 _ => yield Err(CompletionError::ProviderError("Stream is empty".into()))
167 }
168 },
169 aws_bedrock::ConverseStreamOutput::ContentBlockStop(_event) => {
170 if let Some(reasoning_state) = current_reasoning.take()
171 && !reasoning_state.content.is_empty() {
172 yield Ok(RawStreamingChoice::Reasoning {
173 id: None,
174 content: ReasoningContent::Text {
175 text: reasoning_state.content,
176 signature: reasoning_state.signature,
177 },
178 })
179 }
180 },
181 aws_bedrock::ConverseStreamOutput::MessageStop(message_stop_event) => {
182 match message_stop_event.stop_reason {
183 aws_bedrock::StopReason::ToolUse => {
184 if let Some(tool_call) = current_tool_call.take() {
185 let tool_input = if tool_call.input_json.is_empty() {
187 serde_json::json!({})
188 } else {
189 serde_json::from_str(tool_call.input_json.as_str())?
190 };
191 yield Ok(RawStreamingChoice::ToolCall(
192 RawStreamingToolCall::new(tool_call.id, tool_call.name, tool_input)
193 .with_internal_call_id(tool_call.internal_call_id)
194 ));
195 } else {
196 yield Err(CompletionError::ProviderError("Failed to call tool".into()))
197 }
198 }
199 aws_bedrock::StopReason::MaxTokens => {
200 yield Err(CompletionError::ProviderError("Exceeded max tokens".into()))
201 }
202 _ => {}
203 }
204 },
205 aws_bedrock::ConverseStreamOutput::Metadata(metadata_event) => {
206 if let Some(usage) = metadata_event.usage {
208 yield Ok(RawStreamingChoice::FinalResponse(BedrockStreamingResponse {
209 usage: Some(BedrockUsage {
210 input_tokens: usage.input_tokens,
211 output_tokens: usage.output_tokens,
212 total_tokens: usage.total_tokens,
213 cache_read_input_tokens: usage.cache_read_input_tokens,
214 cache_write_input_tokens: usage.cache_write_input_tokens,
215 }),
216 }));
217 }
218 },
219 _ => {}
220 }
221 }
222 });
223
224 Ok(StreamingCompletionResponse::stream(stream))
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_bedrock_usage_creation() {
234 let usage = BedrockUsage {
235 input_tokens: 100,
236 output_tokens: 50,
237 total_tokens: 150,
238 cache_read_input_tokens: None,
239 cache_write_input_tokens: None,
240 };
241
242 assert_eq!(usage.input_tokens, 100);
243 assert_eq!(usage.output_tokens, 50);
244 assert_eq!(usage.total_tokens, 150);
245 }
246
247 #[test]
248 fn test_bedrock_streaming_response_with_usage() {
249 let response = BedrockStreamingResponse {
250 usage: Some(BedrockUsage {
251 input_tokens: 200,
252 output_tokens: 75,
253 total_tokens: 275,
254 cache_read_input_tokens: Some(40),
255 cache_write_input_tokens: Some(10),
256 }),
257 };
258
259 let rig_usage = response.token_usage();
260 assert!(rig_usage.is_some());
261
262 let usage = rig_usage.unwrap();
263 assert_eq!(usage.input_tokens, 200);
264 assert_eq!(usage.output_tokens, 75);
265 assert_eq!(usage.total_tokens, 275);
266 assert_eq!(usage.cached_input_tokens, 40);
267 assert_eq!(usage.cache_creation_input_tokens, 10);
268 }
269
270 #[test]
271 fn test_bedrock_streaming_response_without_usage() {
272 let response = BedrockStreamingResponse { usage: None };
273
274 let rig_usage = response.token_usage();
275 assert!(rig_usage.is_none());
276 }
277
278 #[test]
279 fn test_get_token_usage_trait() {
280 let response = BedrockStreamingResponse {
281 usage: Some(BedrockUsage {
282 input_tokens: 448,
283 output_tokens: 68,
284 total_tokens: 516,
285 cache_read_input_tokens: Some(80),
286 cache_write_input_tokens: Some(20),
287 }),
288 };
289
290 let usage = response.token_usage().expect("Usage should be present");
292 assert_eq!(usage.input_tokens, 448);
293 assert_eq!(usage.output_tokens, 68);
294 assert_eq!(usage.total_tokens, 516);
295 assert_eq!(usage.cached_input_tokens, 80);
296 assert_eq!(usage.cache_creation_input_tokens, 20);
297 }
298
299 #[test]
300 fn test_bedrock_usage_serde() {
301 let usage = BedrockUsage {
302 input_tokens: 100,
303 output_tokens: 50,
304 total_tokens: 150,
305 cache_read_input_tokens: Some(25),
306 cache_write_input_tokens: Some(5),
307 };
308
309 let json = serde_json::to_string(&usage).expect("Should serialize");
311 assert!(json.contains("\"input_tokens\":100"));
312 assert!(json.contains("\"output_tokens\":50"));
313 assert!(json.contains("\"total_tokens\":150"));
314
315 let deserialized: BedrockUsage = serde_json::from_str(&json).expect("Should deserialize");
317 assert_eq!(deserialized.input_tokens, usage.input_tokens);
318 assert_eq!(deserialized.output_tokens, usage.output_tokens);
319 assert_eq!(deserialized.total_tokens, usage.total_tokens);
320 assert_eq!(
321 deserialized.cache_read_input_tokens,
322 usage.cache_read_input_tokens
323 );
324 assert_eq!(
325 deserialized.cache_write_input_tokens,
326 usage.cache_write_input_tokens
327 );
328 }
329
330 #[test]
331 fn test_bedrock_streaming_response_serde() {
332 let response = BedrockStreamingResponse {
333 usage: Some(BedrockUsage {
334 input_tokens: 200,
335 output_tokens: 75,
336 total_tokens: 275,
337 cache_read_input_tokens: Some(30),
338 cache_write_input_tokens: Some(15),
339 }),
340 };
341
342 let json = serde_json::to_string(&response).expect("Should serialize");
344 assert!(json.contains("\"input_tokens\":200"));
345
346 let deserialized: BedrockStreamingResponse =
348 serde_json::from_str(&json).expect("Should deserialize");
349 assert!(deserialized.usage.is_some());
350 let usage = deserialized.usage.unwrap();
351 assert_eq!(usage.input_tokens, 200);
352 assert_eq!(usage.output_tokens, 75);
353 assert_eq!(usage.total_tokens, 275);
354 assert_eq!(usage.cache_read_input_tokens, Some(30));
355 assert_eq!(usage.cache_write_input_tokens, Some(15));
356 }
357
358 #[test]
359 fn test_reasoning_state_default() {
360 let state = ReasoningState::default();
362 assert_eq!(state.content, "");
363 assert_eq!(state.signature, None);
364 }
365
366 #[test]
367 fn test_reasoning_state_accumulate_content() {
368 let mut state = ReasoningState::default();
370 state.content.push_str("First chunk");
371 state.content.push_str(" Second chunk");
372 state.content.push_str(" Third chunk");
373
374 assert_eq!(state.content, "First chunk Second chunk Third chunk");
375 assert_eq!(state.signature, None);
376 }
377
378 #[test]
379 fn test_reasoning_state_with_signature() {
380 let mut state = ReasoningState::default();
382 state.content.push_str("Reasoning content");
383 state.signature = Some("test_signature_456".to_string());
384
385 assert_eq!(state.content, "Reasoning content");
386 assert_eq!(state.signature, Some("test_signature_456".to_string()));
387 }
388
389 #[test]
390 fn test_reasoning_state_empty_content() {
391 let state = ReasoningState {
393 signature: Some("signature_only".to_string()),
394 ..Default::default()
395 };
396
397 assert_eq!(state.content, "");
398 assert!(state.signature.is_some());
399 }
400
401 #[test]
402 fn test_tool_call_state_default() {
403 let state = ToolCallState::default();
405 assert_eq!(state.name, "");
406 assert_eq!(state.id, "");
407 assert_eq!(state.input_json, "");
408 }
409
410 #[test]
411 fn test_tool_call_state_accumulate_json() {
412 let mut state = ToolCallState {
414 name: "my_tool".to_string(),
415 id: "tool_123".to_string(),
416 internal_call_id: nanoid::nanoid!(),
417 input_json: String::new(),
418 };
419
420 state.input_json.push_str("{\"arg1\":");
421 state.input_json.push_str("\"value1\"");
422 state.input_json.push('}');
423
424 assert_eq!(state.name, "my_tool");
425 assert_eq!(state.id, "tool_123");
426 assert_eq!(state.input_json, "{\"arg1\":\"value1\"}");
427 }
428
429 #[test]
430 fn test_tool_call_state_empty_accumulation() {
431 let state = ToolCallState {
432 name: "test_tool".to_string(),
433 id: "tool_abc".to_string(),
434 internal_call_id: nanoid::nanoid!(),
435 input_json: String::new(),
436 };
437
438 assert_eq!(state.name, "test_tool");
439 assert_eq!(state.id, "tool_abc");
440 assert!(state.input_json.is_empty());
441 }
442
443 #[test]
444 fn test_tool_call_state_single_chunk() {
445 let mut state = ToolCallState {
446 name: "get_weather".to_string(),
447 id: "call_123".to_string(),
448 internal_call_id: nanoid::nanoid!(),
449 input_json: String::new(),
450 };
451
452 state.input_json.push_str("{\"location\":\"Paris\"}");
453
454 assert_eq!(state.input_json, "{\"location\":\"Paris\"}");
455 }
456
457 #[test]
458 fn test_tool_call_state_multiple_small_chunks() {
459 let mut state = ToolCallState {
460 name: "search".to_string(),
461 id: "call_xyz".to_string(),
462 internal_call_id: nanoid::nanoid!(),
463 input_json: String::new(),
464 };
465
466 let chunks = vec!["{", "\"q", "uery", "\":", "\"R", "ust", "\"}"];
468
469 for chunk in chunks {
470 state.input_json.push_str(chunk);
471 }
472
473 assert_eq!(state.input_json, "{\"query\":\"Rust\"}");
474 }
475
476 #[test]
477 fn test_tool_call_state_complex_json_accumulation() {
478 let mut state = ToolCallState {
479 name: "analyze_data".to_string(),
480 id: "call_456".to_string(),
481 internal_call_id: nanoid::nanoid!(),
482 input_json: String::new(),
483 };
484
485 state.input_json.push_str("{\"data\":{");
487 state.input_json.push_str("\"values\":[1,2,3],");
488 state
489 .input_json
490 .push_str("\"metadata\":{\"source\":\"api\"}");
491 state.input_json.push_str("}}");
492
493 assert_eq!(
494 state.input_json,
495 "{\"data\":{\"values\":[1,2,3],\"metadata\":{\"source\":\"api\"}}}"
496 );
497
498 let parsed: serde_json::Value =
500 serde_json::from_str(&state.input_json).expect("Should parse as valid JSON");
501 assert!(parsed.is_object());
502 }
503
504 #[test]
505 fn test_reasoning_state_accumulation() {
506 let mut state = ReasoningState::default();
507
508 state.content.push_str("First, ");
509 state.content.push_str("I need to ");
510 state.content.push_str("analyze the problem.");
511
512 assert_eq!(state.content, "First, I need to analyze the problem.");
513 assert!(state.signature.is_none());
514 }
515
516 #[test]
517 fn test_reasoning_state_with_signature_accumulation() {
518 let mut state = ReasoningState::default();
519
520 state.content.push_str("Reasoning content here");
521 state.signature = Some("sig_part1".to_string());
522
523 if let Some(ref mut sig) = state.signature {
525 sig.push_str("_part2");
526 }
527
528 assert_eq!(state.content, "Reasoning content here");
529 assert_eq!(state.signature, Some("sig_part1_part2".to_string()));
530 }
531}