1use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
2use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
3use aws_sdk_bedrockruntime::types::{
4 ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, StopReason as BedrockStopReason,
5 TokenUsage as BedrockTokenUsage,
6};
7use futures::Stream;
8use std::collections::HashMap;
9use tracing::{debug, error, info, warn};
10
11use crate::{LlmError, LlmResponse, StopReason, TokenUsage, ToolCallRequest};
12
13impl From<&BedrockTokenUsage> for TokenUsage {
14 fn from(usage: &BedrockTokenUsage) -> Self {
15 TokenUsage {
16 input_tokens: u32::try_from(usage.input_tokens).unwrap_or(0),
17 output_tokens: u32::try_from(usage.output_tokens).unwrap_or(0),
18 cache_read_tokens: usage.cache_read_input_tokens().and_then(|v| u32::try_from(v).ok()),
19 cache_creation_tokens: usage.cache_write_input_tokens().and_then(|v| u32::try_from(v).ok()),
20 ..TokenUsage::default()
21 }
22 }
23}
24
25struct PendingToolCall {
26 id: String,
27 name: String,
28 args: String,
29}
30
31enum StreamEvent {
32 Emit(LlmResponse),
33 Stop(StopReason),
34 Skip,
35}
36
37pub fn process_bedrock_stream(
38 mut receiver: EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>,
39) -> impl Stream<Item = crate::Result<LlmResponse>> + Send {
40 async_stream::stream! {
41 let message_id = uuid::Uuid::new_v4().to_string();
42 yield Ok(LlmResponse::Start { message_id });
43
44 let mut active_tool_calls: HashMap<i32, PendingToolCall> = HashMap::new();
45 let mut last_stop_reason: Option<StopReason> = None;
46
47 loop {
48 match receiver.recv().await {
49 Ok(Some(event)) => {
50 match process_stream_event(&event, &mut active_tool_calls) {
51 StreamEvent::Emit(resp) => yield Ok(resp),
52 StreamEvent::Stop(sr) => last_stop_reason = Some(sr),
53 StreamEvent::Skip => {}
54 }
55 }
56 Ok(None) => {
57 debug!("Bedrock stream ended (recv returned None)");
58 break;
59 }
60 Err(e) => {
61 error!("Bedrock stream recv error: {e}");
62 yield Err(LlmError::ApiError(format!("Bedrock stream error: {e}")));
63 break;
64 }
65 }
66 }
67
68 for (_index, tc) in active_tool_calls {
70 let tool_call = ToolCallRequest {
71 id: tc.id,
72 name: tc.name,
73 arguments: tc.args,
74 };
75 yield Ok(LlmResponse::ToolRequestComplete { tool_call });
76 }
77
78 yield Ok(LlmResponse::Done {
79 stop_reason: last_stop_reason,
80 });
81 }
82}
83
84fn process_stream_event(
85 event: &ConverseStreamOutput,
86 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
87) -> StreamEvent {
88 match event {
89 ConverseStreamOutput::MessageStart(_) => {
90 info!("Bedrock message started");
91 StreamEvent::Skip
92 }
93 ConverseStreamOutput::ContentBlockStart(start_event) => {
94 handle_content_block_start(start_event, active_tool_calls)
95 }
96 ConverseStreamOutput::ContentBlockDelta(delta_event) => {
97 handle_content_block_delta(delta_event, active_tool_calls)
98 }
99 ConverseStreamOutput::ContentBlockStop(stop_event) => {
100 handle_content_block_stop(stop_event.content_block_index(), active_tool_calls)
101 }
102 ConverseStreamOutput::MessageStop(stop_event) => {
103 let stop_reason = map_bedrock_stop_reason(&stop_event.stop_reason);
104 info!("Bedrock message stopped: {stop_reason:?}");
105 StreamEvent::Stop(stop_reason)
106 }
107 ConverseStreamOutput::Metadata(metadata_event) => metadata_event
108 .usage()
109 .map_or(StreamEvent::Skip, |usage| StreamEvent::Emit(LlmResponse::Usage { tokens: usage.into() })),
110 other => {
111 warn!("Unhandled Bedrock stream event: {other:?}");
112 StreamEvent::Skip
113 }
114 }
115}
116
117fn handle_content_block_start(
118 event: &aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
119 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
120) -> StreamEvent {
121 let index = event.content_block_index();
122
123 if let Some(ContentBlockStart::ToolUse(tool_start)) = event.start() {
124 let id = tool_start.tool_use_id().to_string();
125 let name = tool_start.name().to_string();
126 debug!("Bedrock tool use started: {name} ({id})");
127 active_tool_calls.insert(index, PendingToolCall { id: id.clone(), name: name.clone(), args: String::new() });
128 StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name })
129 } else {
130 debug!("Content block started at index {index}");
131 StreamEvent::Skip
132 }
133}
134
135fn handle_content_block_delta(
136 event: &aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
137 active_tool_calls: &mut HashMap<i32, PendingToolCall>,
138) -> StreamEvent {
139 let index = event.content_block_index();
140
141 let Some(delta) = event.delta() else {
142 return StreamEvent::Skip;
143 };
144
145 match delta {
146 ContentBlockDelta::Text(text) if !text.is_empty() => {
147 StreamEvent::Emit(LlmResponse::Text { chunk: text.clone() })
148 }
149 ContentBlockDelta::ToolUse(tool_delta) => {
150 let input = tool_delta.input();
151 if input.is_empty() {
152 return StreamEvent::Skip;
153 }
154
155 if let Some(tc) = active_tool_calls.get_mut(&index) {
156 tc.args.push_str(input);
157 StreamEvent::Emit(LlmResponse::ToolRequestArg { id: tc.id.clone(), chunk: input.to_string() })
158 } else {
159 warn!("Received tool input delta for unknown content block index: {index}");
160 StreamEvent::Skip
161 }
162 }
163 ContentBlockDelta::ReasoningContent(reasoning) => {
164 if let Ok(text) = reasoning.as_text()
165 && !text.is_empty()
166 {
167 return StreamEvent::Emit(LlmResponse::Reasoning { chunk: text.clone() });
168 }
169 StreamEvent::Skip
170 }
171 _ => {
172 debug!("Unhandled content block delta type");
173 StreamEvent::Skip
174 }
175 }
176}
177
178fn handle_content_block_stop(index: i32, active_tool_calls: &mut HashMap<i32, PendingToolCall>) -> StreamEvent {
179 if let Some(tc) = active_tool_calls.remove(&index) {
180 let tool_call = ToolCallRequest { id: tc.id, name: tc.name, arguments: tc.args };
181 StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
182 } else {
183 debug!("Content block stopped at index {index}");
184 StreamEvent::Skip
185 }
186}
187
188fn map_bedrock_stop_reason(reason: &BedrockStopReason) -> StopReason {
189 match reason {
190 BedrockStopReason::EndTurn | BedrockStopReason::StopSequence => StopReason::EndTurn,
191 BedrockStopReason::ToolUse => StopReason::ToolCalls,
192 BedrockStopReason::MaxTokens | BedrockStopReason::ModelContextWindowExceeded => StopReason::Length,
193 BedrockStopReason::ContentFiltered | BedrockStopReason::GuardrailIntervened => StopReason::ContentFilter,
194 other => StopReason::Unknown(format!("{other:?}")),
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_map_stop_reason_end_turn() {
204 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::EndTurn), StopReason::EndTurn);
205 }
206
207 #[test]
208 fn test_map_stop_reason_stop_sequence() {
209 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::StopSequence), StopReason::EndTurn);
210 }
211
212 #[test]
213 fn test_map_stop_reason_tool_use() {
214 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ToolUse), StopReason::ToolCalls);
215 }
216
217 #[test]
218 fn test_map_stop_reason_max_tokens() {
219 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::MaxTokens), StopReason::Length);
220 }
221
222 #[test]
223 fn test_map_stop_reason_context_window_exceeded() {
224 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ModelContextWindowExceeded), StopReason::Length);
225 }
226
227 #[test]
228 fn test_map_stop_reason_content_filtered() {
229 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ContentFiltered), StopReason::ContentFilter);
230 }
231
232 #[test]
233 fn test_map_stop_reason_guardrail() {
234 assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::GuardrailIntervened), StopReason::ContentFilter);
235 }
236
237 #[test]
238 fn test_handle_content_block_start_tool_use() {
239 let mut active = HashMap::new();
240 let tool_start = aws_sdk_bedrockruntime::types::ToolUseBlockStart::builder()
241 .tool_use_id("tool_123")
242 .name("search")
243 .build()
244 .unwrap();
245
246 let event = aws_sdk_bedrockruntime::types::ContentBlockStartEvent::builder()
247 .content_block_index(0)
248 .start(ContentBlockStart::ToolUse(tool_start))
249 .build()
250 .unwrap();
251
252 let result = handle_content_block_start(&event, &mut active);
253 assert!(
254 matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name }) if id == "tool_123" && name == "search")
255 );
256 assert!(active.contains_key(&0));
257 }
258
259 #[test]
260 fn test_handle_content_block_delta_text() {
261 let mut active = HashMap::new();
262 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
263 .content_block_index(0)
264 .delta(ContentBlockDelta::Text("Hello".to_string()))
265 .build()
266 .unwrap();
267
268 let result = handle_content_block_delta(&delta, &mut active);
269 assert!(matches!(&result, StreamEvent::Emit(LlmResponse::Text { chunk }) if chunk == "Hello"));
270 }
271
272 #[test]
273 fn test_handle_content_block_delta_tool_input() {
274 let mut active = HashMap::new();
275 active
276 .insert(0, PendingToolCall { id: "tool_123".to_string(), name: "search".to_string(), args: String::new() });
277
278 let tool_delta =
279 aws_sdk_bedrockruntime::types::ToolUseBlockDelta::builder().input(r#"{"query":"test"}"#).build().unwrap();
280
281 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
282 .content_block_index(0)
283 .delta(ContentBlockDelta::ToolUse(tool_delta))
284 .build()
285 .unwrap();
286
287 let result = handle_content_block_delta(&delta, &mut active);
288 assert!(
289 matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestArg { id, chunk }) if id == "tool_123" && chunk == r#"{"query":"test"}"#)
290 );
291
292 assert_eq!(active.get(&0).unwrap().args, r#"{"query":"test"}"#);
294 }
295
296 #[test]
297 fn test_handle_content_block_stop_completes_tool() {
298 let mut active = HashMap::new();
299 active.insert(
300 0,
301 PendingToolCall {
302 id: "tool_123".to_string(),
303 name: "search".to_string(),
304 args: r#"{"query":"test"}"#.to_string(),
305 },
306 );
307
308 let result = handle_content_block_stop(0, &mut active);
309 assert!(matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
310 if tool_call.id == "tool_123"
311 && tool_call.name == "search"
312 && tool_call.arguments == r#"{"query":"test"}"#
313 ));
314 assert!(active.is_empty());
315 }
316
317 #[test]
318 fn test_handle_content_block_stop_no_tool() {
319 let mut active = HashMap::new();
320 let result = handle_content_block_stop(0, &mut active);
321 assert!(matches!(result, StreamEvent::Skip));
322 }
323
324 #[test]
325 fn test_metadata_event_emits_cache_read_and_creation() {
326 let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
327 .input_tokens(100)
328 .output_tokens(50)
329 .total_tokens(150)
330 .cache_read_input_tokens(40)
331 .cache_write_input_tokens(20)
332 .build()
333 .unwrap();
334
335 let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
336
337 let event = ConverseStreamOutput::Metadata(metadata);
338 let mut active = HashMap::new();
339 let result = process_stream_event(&event, &mut active);
340
341 match result {
342 StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
343 assert_eq!(sample.input_tokens, 100);
344 assert_eq!(sample.output_tokens, 50);
345 assert_eq!(sample.cache_read_tokens, Some(40));
346 assert_eq!(sample.cache_creation_tokens, Some(20));
347 }
348 _ => panic!("expected Emit(Usage{{..}})"),
349 }
350 }
351
352 #[test]
353 fn test_metadata_event_without_cache_fields() {
354 let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
355 .input_tokens(10)
356 .output_tokens(5)
357 .total_tokens(15)
358 .build()
359 .unwrap();
360
361 let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
362
363 let event = ConverseStreamOutput::Metadata(metadata);
364 let mut active = HashMap::new();
365 let result = process_stream_event(&event, &mut active);
366
367 match result {
368 StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
369 assert_eq!(sample.cache_read_tokens, None);
370 assert_eq!(sample.cache_creation_tokens, None);
371 }
372 _ => panic!("expected Emit(Usage{{..}})"),
373 }
374 }
375
376 #[test]
377 fn test_handle_content_block_delta_empty_text() {
378 let mut active = HashMap::new();
379 let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
380 .content_block_index(0)
381 .delta(ContentBlockDelta::Text(String::new()))
382 .build()
383 .unwrap();
384
385 let result = handle_content_block_delta(&delta, &mut active);
386 assert!(matches!(result, StreamEvent::Skip));
387 }
388}