1use aws_sdk_bedrockruntime::error::SdkError;
2use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
3use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
4use aws_sdk_bedrockruntime::types::{
5 ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, StopReason as BedrockStopReason,
6 TokenUsage as BedrockTokenUsage,
7};
8use aws_smithy_types::event_stream::RawMessage;
9use futures::Stream;
10use std::collections::HashMap;
11use tracing::{debug, error, info, warn};
12
13use crate::{LlmError, LlmResponse, StopReason, TokenUsage, ToolCallRequest};
14
15impl From<&BedrockTokenUsage> for TokenUsage {
16 fn from(usage: &BedrockTokenUsage) -> Self {
17 TokenUsage {
18 input_tokens: u32::try_from(usage.input_tokens).unwrap_or(0),
19 output_tokens: u32::try_from(usage.output_tokens).unwrap_or(0),
20 cache_read_tokens: usage.cache_read_input_tokens().and_then(|v| u32::try_from(v).ok()),
21 cache_creation_tokens: usage.cache_write_input_tokens().and_then(|v| u32::try_from(v).ok()),
22 ..TokenUsage::default()
23 }
24 }
25}
26
27struct PendingToolCall {
28 id: String,
29 name: String,
30 args: String,
31}
32
33enum StreamEvent {
34 Emit(LlmResponse),
35 Stop(StopReason),
36 Skip,
37}
38
39pub fn process_bedrock_stream(
40 mut receiver: EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>,
41) -> impl Stream<Item = crate::Result<LlmResponse>> + Send {
42 async_stream::stream! {
43 let message_id = uuid::Uuid::new_v4().to_string();
44 yield Ok(LlmResponse::Start { message_id });
45
46 let mut active_tool_calls: HashMap<i32, PendingToolCall> = HashMap::new();
47 let mut last_stop_reason: Option<StopReason> = None;
48
49 loop {
50 match receiver.recv().await {
51 Ok(Some(event)) => {
52 match process_stream_event(&event, &mut active_tool_calls) {
53 StreamEvent::Emit(resp) => yield Ok(resp),
54 StreamEvent::Stop(sr) => last_stop_reason = Some(sr),
55 StreamEvent::Skip => {}
56 }
57 }
58 Ok(None) => {
59 debug!("Bedrock stream ended (recv returned None)");
60 break;
61 }
62 Err(e) => {
63 error!("Bedrock stream recv error: {e}");
64 yield Err(LlmError::from(e));
65 break;
66 }
67 }
68 }
69
70 for (_index, tc) in active_tool_calls {
72 let tool_call = ToolCallRequest {
73 id: tc.id,
74 name: tc.name,
75 arguments: tc.args,
76 };
77 yield Ok(LlmResponse::ToolRequestComplete { tool_call });
78 }
79
80 yield Ok(LlmResponse::Done {
81 stop_reason: last_stop_reason,
82 });
83 }
84}
85
86fn process_stream_event(
87 event: &ConverseStreamOutput,
88 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
89) -> StreamEvent {
90 match event {
91 ConverseStreamOutput::MessageStart(_) => {
92 info!("Bedrock message started");
93 StreamEvent::Skip
94 }
95 ConverseStreamOutput::ContentBlockStart(start_event) => {
96 handle_content_block_start(start_event, active_tool_calls)
97 }
98 ConverseStreamOutput::ContentBlockDelta(delta_event) => {
99 handle_content_block_delta(delta_event, active_tool_calls)
100 }
101 ConverseStreamOutput::ContentBlockStop(stop_event) => {
102 handle_content_block_stop(stop_event.content_block_index(), active_tool_calls)
103 }
104 ConverseStreamOutput::MessageStop(stop_event) => {
105 let stop_reason = map_bedrock_stop_reason(&stop_event.stop_reason);
106 info!("Bedrock message stopped: {stop_reason:?}");
107 StreamEvent::Stop(stop_reason)
108 }
109 ConverseStreamOutput::Metadata(metadata_event) => metadata_event
110 .usage()
111 .map_or(StreamEvent::Skip, |usage| StreamEvent::Emit(LlmResponse::Usage { tokens: usage.into() })),
112 other => {
113 warn!("Unhandled Bedrock stream event: {other:?}");
114 StreamEvent::Skip
115 }
116 }
117}
118
119fn handle_content_block_start(
120 event: &aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
121 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
122) -> StreamEvent {
123 let index = event.content_block_index();
124
125 if let Some(ContentBlockStart::ToolUse(tool_start)) = event.start() {
126 let id = tool_start.tool_use_id().to_string();
127 let name = tool_start.name().to_string();
128 debug!("Bedrock tool use started: {name} ({id})");
129 active_tool_calls.insert(index, PendingToolCall { id: id.clone(), name: name.clone(), args: String::new() });
130 StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name })
131 } else {
132 debug!("Content block started at index {index}");
133 StreamEvent::Skip
134 }
135}
136
137fn handle_content_block_delta(
138 event: &aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
139 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
140) -> StreamEvent {
141 let index = event.content_block_index();
142
143 let Some(delta) = event.delta() else {
144 return StreamEvent::Skip;
145 };
146
147 match delta {
148 ContentBlockDelta::Text(text) if !text.is_empty() => {
149 StreamEvent::Emit(LlmResponse::Text { chunk: text.clone() })
150 }
151 ContentBlockDelta::ToolUse(tool_delta) => {
152 let input = tool_delta.input();
153 if input.is_empty() {
154 return StreamEvent::Skip;
155 }
156
157 if let Some(tc) = active_tool_calls.get_mut(&index) {
158 tc.args.push_str(input);
159 StreamEvent::Emit(LlmResponse::ToolRequestArg { id: tc.id.clone(), chunk: input.to_string() })
160 } else {
161 warn!("Received tool input delta for unknown content block index: {index}");
162 StreamEvent::Skip
163 }
164 }
165 ContentBlockDelta::ReasoningContent(reasoning) => {
166 if let Ok(text) = reasoning.as_text()
167 && !text.is_empty()
168 {
169 return StreamEvent::Emit(LlmResponse::Reasoning { chunk: text.clone() });
170 }
171 StreamEvent::Skip
172 }
173 _ => {
174 debug!("Unhandled content block delta type");
175 StreamEvent::Skip
176 }
177 }
178}
179
180fn handle_content_block_stop(index: i32, active_tool_calls: &mut HashMap<i32, PendingToolCall>) -> StreamEvent {
181 if let Some(tc) = active_tool_calls.remove(&index) {
182 let tool_call = ToolCallRequest { id: tc.id, name: tc.name, arguments: tc.args };
183 StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
184 } else {
185 debug!("Content block stopped at index {index}");
186 StreamEvent::Skip
187 }
188}
189
190impl From<SdkError<ConverseStreamOutputError, RawMessage>> for LlmError {
191 fn from(e: SdkError<ConverseStreamOutputError, RawMessage>) -> Self {
192 let message = format!("Bedrock stream error: {e}");
193 match e {
194 SdkError::ServiceError(svc) => {
195 let inner = svc.err();
196 if inner.is_throttling_exception() {
197 LlmError::RateLimited(message)
198 } else if inner.is_service_unavailable_exception()
199 || inner.is_internal_server_exception()
200 || inner.is_model_stream_error_exception()
201 {
202 LlmError::StreamInterrupted(message)
203 } else {
204 LlmError::ApiError(message)
205 }
206 }
207 _ => LlmError::StreamInterrupted(message),
208 }
209 }
210}
211
212fn map_bedrock_stop_reason(reason: &BedrockStopReason) -> StopReason {
213 match reason {
214 BedrockStopReason::EndTurn | BedrockStopReason::StopSequence => StopReason::EndTurn,
215 BedrockStopReason::ToolUse => StopReason::ToolCalls,
216 BedrockStopReason::MaxTokens | BedrockStopReason::ModelContextWindowExceeded => StopReason::Length,
217 BedrockStopReason::ContentFiltered | BedrockStopReason::GuardrailIntervened => StopReason::ContentFilter,
218 other => StopReason::Unknown(format!("{other:?}")),
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_map_stop_reason_end_turn() {
228 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::EndTurn), StopReason::EndTurn);
229 }
230
231 #[test]
232 fn test_map_stop_reason_stop_sequence() {
233 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::StopSequence), StopReason::EndTurn);
234 }
235
236 #[test]
237 fn test_map_stop_reason_tool_use() {
238 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ToolUse), StopReason::ToolCalls);
239 }
240
241 #[test]
242 fn test_map_stop_reason_max_tokens() {
243 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::MaxTokens), StopReason::Length);
244 }
245
246 #[test]
247 fn test_map_stop_reason_context_window_exceeded() {
248 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ModelContextWindowExceeded), StopReason::Length);
249 }
250
251 #[test]
252 fn test_map_stop_reason_content_filtered() {
253 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ContentFiltered), StopReason::ContentFilter);
254 }
255
256 #[test]
257 fn test_map_stop_reason_guardrail() {
258 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::GuardrailIntervened), StopReason::ContentFilter);
259 }
260
261 #[test]
262 fn test_handle_content_block_start_tool_use() {
263 let mut active = HashMap::new();
264 let tool_start = aws_sdk_bedrockruntime::types::ToolUseBlockStart::builder()
265 .tool_use_id("tool_123")
266 .name("search")
267 .build()
268 .unwrap();
269
270 let event = aws_sdk_bedrockruntime::types::ContentBlockStartEvent::builder()
271 .content_block_index(0)
272 .start(ContentBlockStart::ToolUse(tool_start))
273 .build()
274 .unwrap();
275
276 let result = handle_content_block_start(&event, &mut active);
277 assert!(
278 matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name }) if id == "tool_123" && name == "search")
279 );
280 assert!(active.contains_key(&0));
281 }
282
283 #[test]
284 fn test_handle_content_block_delta_text() {
285 let mut active = HashMap::new();
286 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
287 .content_block_index(0)
288 .delta(ContentBlockDelta::Text("Hello".to_string()))
289 .build()
290 .unwrap();
291
292 let result = handle_content_block_delta(&delta, &mut active);
293 assert!(matches!(&result, StreamEvent::Emit(LlmResponse::Text { chunk }) if chunk == "Hello"));
294 }
295
296 #[test]
297 fn test_handle_content_block_delta_tool_input() {
298 let mut active = HashMap::new();
299 active
300 .insert(0, PendingToolCall { id: "tool_123".to_string(), name: "search".to_string(), args: String::new() });
301
302 let tool_delta =
303 aws_sdk_bedrockruntime::types::ToolUseBlockDelta::builder().input(r#"{"query":"test"}"#).build().unwrap();
304
305 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
306 .content_block_index(0)
307 .delta(ContentBlockDelta::ToolUse(tool_delta))
308 .build()
309 .unwrap();
310
311 let result = handle_content_block_delta(&delta, &mut active);
312 assert!(
313 matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestArg { id, chunk }) if id == "tool_123" && chunk == r#"{"query":"test"}"#)
314 );
315
316 assert_eq!(active.get(&0).unwrap().args, r#"{"query":"test"}"#);
318 }
319
320 #[test]
321 fn test_handle_content_block_stop_completes_tool() {
322 let mut active = HashMap::new();
323 active.insert(
324 0,
325 PendingToolCall {
326 id: "tool_123".to_string(),
327 name: "search".to_string(),
328 args: r#"{"query":"test"}"#.to_string(),
329 },
330 );
331
332 let result = handle_content_block_stop(0, &mut active);
333 assert!(matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
334 if tool_call.id == "tool_123"
335 && tool_call.name == "search"
336 && tool_call.arguments == r#"{"query":"test"}"#
337 ));
338 assert!(active.is_empty());
339 }
340
341 #[test]
342 fn test_handle_content_block_stop_no_tool() {
343 let mut active = HashMap::new();
344 let result = handle_content_block_stop(0, &mut active);
345 assert!(matches!(result, StreamEvent::Skip));
346 }
347
348 #[test]
349 fn test_metadata_event_emits_cache_read_and_creation() {
350 let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
351 .input_tokens(100)
352 .output_tokens(50)
353 .total_tokens(150)
354 .cache_read_input_tokens(40)
355 .cache_write_input_tokens(20)
356 .build()
357 .unwrap();
358
359 let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
360
361 let event = ConverseStreamOutput::Metadata(metadata);
362 let mut active = HashMap::new();
363 let result = process_stream_event(&event, &mut active);
364
365 match result {
366 StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
367 assert_eq!(sample.input_tokens, 100);
368 assert_eq!(sample.output_tokens, 50);
369 assert_eq!(sample.cache_read_tokens, Some(40));
370 assert_eq!(sample.cache_creation_tokens, Some(20));
371 }
372 _ => panic!("expected Emit(Usage{{..}})"),
373 }
374 }
375
376 #[test]
377 fn test_metadata_event_without_cache_fields() {
378 let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
379 .input_tokens(10)
380 .output_tokens(5)
381 .total_tokens(15)
382 .build()
383 .unwrap();
384
385 let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
386
387 let event = ConverseStreamOutput::Metadata(metadata);
388 let mut active = HashMap::new();
389 let result = process_stream_event(&event, &mut active);
390
391 match result {
392 StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
393 assert_eq!(sample.cache_read_tokens, None);
394 assert_eq!(sample.cache_creation_tokens, None);
395 }
396 _ => panic!("expected Emit(Usage{{..}})"),
397 }
398 }
399
400 #[test]
401 fn test_handle_content_block_delta_empty_text() {
402 let mut active = HashMap::new();
403 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
404 .content_block_index(0)
405 .delta(ContentBlockDelta::Text(String::new()))
406 .build()
407 .unwrap();
408
409 let result = handle_content_block_delta(&delta, &mut active);
410 assert!(matches!(result, StreamEvent::Skip));
411 }
412}