Skip to main content

llm/providers/bedrock/
streaming.rs

1use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
2use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
3use aws_sdk_bedrockruntime::types::{
4    ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, StopReason as BedrockStopReason,
5    TokenUsage as BedrockTokenUsage,
6};
7use futures::Stream;
8use std::collections::HashMap;
9use tracing::{debug, error, info, warn};
10
11use crate::{LlmError, LlmResponse, StopReason, TokenUsage, ToolCallRequest};
12
13impl From<&BedrockTokenUsage> for TokenUsage {
14    fn from(usage: &BedrockTokenUsage) -> Self {
15        TokenUsage {
16            input_tokens: u32::try_from(usage.input_tokens).unwrap_or(0),
17            output_tokens: u32::try_from(usage.output_tokens).unwrap_or(0),
18            cache_read_tokens: usage.cache_read_input_tokens().and_then(|v| u32::try_from(v).ok()),
19            cache_creation_tokens: usage.cache_write_input_tokens().and_then(|v| u32::try_from(v).ok()),
20            ..TokenUsage::default()
21        }
22    }
23}
24
25struct PendingToolCall {
26    id: String,
27    name: String,
28    args: String,
29}
30
31enum StreamEvent {
32    Emit(LlmResponse),
33    Stop(StopReason),
34    Skip,
35}
36
37pub fn process_bedrock_stream(
38    mut receiver: EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>,
39) -> impl Stream<Item = crate::Result<LlmResponse>> + Send {
40    async_stream::stream! {
41        let message_id = uuid::Uuid::new_v4().to_string();
42        yield Ok(LlmResponse::Start { message_id });
43
44        let mut active_tool_calls: HashMap<i32, PendingToolCall> = HashMap::new();
45        let mut last_stop_reason: Option<StopReason> = None;
46
47        loop {
48            match receiver.recv().await {
49                Ok(Some(event)) => {
50                    match process_stream_event(&event, &mut active_tool_calls) {
51                        StreamEvent::Emit(resp) => yield Ok(resp),
52                        StreamEvent::Stop(sr) => last_stop_reason = Some(sr),
53                        StreamEvent::Skip => {}
54                    }
55                }
56                Ok(None) => {
57                    debug!("Bedrock stream ended (recv returned None)");
58                    break;
59                }
60                Err(e) => {
61                    error!("Bedrock stream recv error: {e}");
62                    yield Err(LlmError::ApiError(format!("Bedrock stream error: {e}")));
63                    break;
64                }
65            }
66        }
67
68        // Emit any remaining tool calls that weren't completed via ContentBlockStop
69        for (_index, tc) in active_tool_calls {
70            let tool_call = ToolCallRequest {
71                id: tc.id,
72                name: tc.name,
73                arguments: tc.args,
74            };
75            yield Ok(LlmResponse::ToolRequestComplete { tool_call });
76        }
77
78        yield Ok(LlmResponse::Done {
79            stop_reason: last_stop_reason,
80        });
81    }
82}
83
84fn process_stream_event(
85    event: &ConverseStreamOutput,
86    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
87) -> StreamEvent {
88    match event {
89        ConverseStreamOutput::MessageStart(_) => {
90            info!("Bedrock message started");
91            StreamEvent::Skip
92        }
93        ConverseStreamOutput::ContentBlockStart(start_event) => {
94            handle_content_block_start(start_event, active_tool_calls)
95        }
96        ConverseStreamOutput::ContentBlockDelta(delta_event) => {
97            handle_content_block_delta(delta_event, active_tool_calls)
98        }
99        ConverseStreamOutput::ContentBlockStop(stop_event) => {
100            handle_content_block_stop(stop_event.content_block_index(), active_tool_calls)
101        }
102        ConverseStreamOutput::MessageStop(stop_event) => {
103            let stop_reason = map_bedrock_stop_reason(&stop_event.stop_reason);
104            info!("Bedrock message stopped: {stop_reason:?}");
105            StreamEvent::Stop(stop_reason)
106        }
107        ConverseStreamOutput::Metadata(metadata_event) => metadata_event
108            .usage()
109            .map_or(StreamEvent::Skip, |usage| StreamEvent::Emit(LlmResponse::Usage { tokens: usage.into() })),
110        other => {
111            warn!("Unhandled Bedrock stream event: {other:?}");
112            StreamEvent::Skip
113        }
114    }
115}
116
117fn handle_content_block_start(
118    event: &aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
119    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
120) -> StreamEvent {
121    let index = event.content_block_index();
122
123    if let Some(ContentBlockStart::ToolUse(tool_start)) = event.start() {
124        let id = tool_start.tool_use_id().to_string();
125        let name = tool_start.name().to_string();
126        debug!("Bedrock tool use started: {name} ({id})");
127        active_tool_calls.insert(index, PendingToolCall { id: id.clone(), name: name.clone(), args: String::new() });
128        StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name })
129    } else {
130        debug!("Content block started at index {index}");
131        StreamEvent::Skip
132    }
133}
134
135fn handle_content_block_delta(
136    event: &aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
137    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
138) -> StreamEvent {
139    let index = event.content_block_index();
140
141    let Some(delta) = event.delta() else {
142        return StreamEvent::Skip;
143    };
144
145    match delta {
146        ContentBlockDelta::Text(text) if !text.is_empty() => {
147            StreamEvent::Emit(LlmResponse::Text { chunk: text.clone() })
148        }
149        ContentBlockDelta::ToolUse(tool_delta) => {
150            let input = tool_delta.input();
151            if input.is_empty() {
152                return StreamEvent::Skip;
153            }
154
155            if let Some(tc) = active_tool_calls.get_mut(&index) {
156                tc.args.push_str(input);
157                StreamEvent::Emit(LlmResponse::ToolRequestArg { id: tc.id.clone(), chunk: input.to_string() })
158            } else {
159                warn!("Received tool input delta for unknown content block index: {index}");
160                StreamEvent::Skip
161            }
162        }
163        ContentBlockDelta::ReasoningContent(reasoning) => {
164            if let Ok(text) = reasoning.as_text()
165                && !text.is_empty()
166            {
167                return StreamEvent::Emit(LlmResponse::Reasoning { chunk: text.clone() });
168            }
169            StreamEvent::Skip
170        }
171        _ => {
172            debug!("Unhandled content block delta type");
173            StreamEvent::Skip
174        }
175    }
176}
177
178fn handle_content_block_stop(index: i32, active_tool_calls: &mut HashMap<i32, PendingToolCall>) -> StreamEvent {
179    if let Some(tc) = active_tool_calls.remove(&index) {
180        let tool_call = ToolCallRequest { id: tc.id, name: tc.name, arguments: tc.args };
181        StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
182    } else {
183        debug!("Content block stopped at index {index}");
184        StreamEvent::Skip
185    }
186}
187
188fn map_bedrock_stop_reason(reason: &BedrockStopReason) -> StopReason {
189    match reason {
190        BedrockStopReason::EndTurn | BedrockStopReason::StopSequence => StopReason::EndTurn,
191        BedrockStopReason::ToolUse => StopReason::ToolCalls,
192        BedrockStopReason::MaxTokens | BedrockStopReason::ModelContextWindowExceeded => StopReason::Length,
193        BedrockStopReason::ContentFiltered | BedrockStopReason::GuardrailIntervened => StopReason::ContentFilter,
194        other => StopReason::Unknown(format!("{other:?}")),
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_map_stop_reason_end_turn() {
204        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::EndTurn), StopReason::EndTurn);
205    }
206
207    #[test]
208    fn test_map_stop_reason_stop_sequence() {
209        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::StopSequence), StopReason::EndTurn);
210    }
211
212    #[test]
213    fn test_map_stop_reason_tool_use() {
214        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ToolUse), StopReason::ToolCalls);
215    }
216
217    #[test]
218    fn test_map_stop_reason_max_tokens() {
219        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::MaxTokens), StopReason::Length);
220    }
221
222    #[test]
223    fn test_map_stop_reason_context_window_exceeded() {
224        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ModelContextWindowExceeded), StopReason::Length);
225    }
226
227    #[test]
228    fn test_map_stop_reason_content_filtered() {
229        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ContentFiltered), StopReason::ContentFilter);
230    }
231
232    #[test]
233    fn test_map_stop_reason_guardrail() {
234        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::GuardrailIntervened), StopReason::ContentFilter);
235    }
236
237    #[test]
238    fn test_handle_content_block_start_tool_use() {
239        let mut active = HashMap::new();
240        let tool_start = aws_sdk_bedrockruntime::types::ToolUseBlockStart::builder()
241            .tool_use_id("tool_123")
242            .name("search")
243            .build()
244            .unwrap();
245
246        let event = aws_sdk_bedrockruntime::types::ContentBlockStartEvent::builder()
247            .content_block_index(0)
248            .start(ContentBlockStart::ToolUse(tool_start))
249            .build()
250            .unwrap();
251
252        let result = handle_content_block_start(&event, &mut active);
253        assert!(
254            matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name }) if id == "tool_123" && name == "search")
255        );
256        assert!(active.contains_key(&0));
257    }
258
259    #[test]
260    fn test_handle_content_block_delta_text() {
261        let mut active = HashMap::new();
262        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
263            .content_block_index(0)
264            .delta(ContentBlockDelta::Text("Hello".to_string()))
265            .build()
266            .unwrap();
267
268        let result = handle_content_block_delta(&delta, &mut active);
269        assert!(matches!(&result, StreamEvent::Emit(LlmResponse::Text { chunk }) if chunk == "Hello"));
270    }
271
272    #[test]
273    fn test_handle_content_block_delta_tool_input() {
274        let mut active = HashMap::new();
275        active
276            .insert(0, PendingToolCall { id: "tool_123".to_string(), name: "search".to_string(), args: String::new() });
277
278        let tool_delta =
279            aws_sdk_bedrockruntime::types::ToolUseBlockDelta::builder().input(r#"{"query":"test"}"#).build().unwrap();
280
281        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
282            .content_block_index(0)
283            .delta(ContentBlockDelta::ToolUse(tool_delta))
284            .build()
285            .unwrap();
286
287        let result = handle_content_block_delta(&delta, &mut active);
288        assert!(
289            matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestArg { id, chunk }) if id == "tool_123" && chunk == r#"{"query":"test"}"#)
290        );
291
292        // Verify accumulated args
293        assert_eq!(active.get(&0).unwrap().args, r#"{"query":"test"}"#);
294    }
295
296    #[test]
297    fn test_handle_content_block_stop_completes_tool() {
298        let mut active = HashMap::new();
299        active.insert(
300            0,
301            PendingToolCall {
302                id: "tool_123".to_string(),
303                name: "search".to_string(),
304                args: r#"{"query":"test"}"#.to_string(),
305            },
306        );
307
308        let result = handle_content_block_stop(0, &mut active);
309        assert!(matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
310            if tool_call.id == "tool_123"
311            && tool_call.name == "search"
312            && tool_call.arguments == r#"{"query":"test"}"#
313        ));
314        assert!(active.is_empty());
315    }
316
317    #[test]
318    fn test_handle_content_block_stop_no_tool() {
319        let mut active = HashMap::new();
320        let result = handle_content_block_stop(0, &mut active);
321        assert!(matches!(result, StreamEvent::Skip));
322    }
323
324    #[test]
325    fn test_metadata_event_emits_cache_read_and_creation() {
326        let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
327            .input_tokens(100)
328            .output_tokens(50)
329            .total_tokens(150)
330            .cache_read_input_tokens(40)
331            .cache_write_input_tokens(20)
332            .build()
333            .unwrap();
334
335        let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
336
337        let event = ConverseStreamOutput::Metadata(metadata);
338        let mut active = HashMap::new();
339        let result = process_stream_event(&event, &mut active);
340
341        match result {
342            StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
343                assert_eq!(sample.input_tokens, 100);
344                assert_eq!(sample.output_tokens, 50);
345                assert_eq!(sample.cache_read_tokens, Some(40));
346                assert_eq!(sample.cache_creation_tokens, Some(20));
347            }
348            _ => panic!("expected Emit(Usage{{..}})"),
349        }
350    }
351
352    #[test]
353    fn test_metadata_event_without_cache_fields() {
354        let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
355            .input_tokens(10)
356            .output_tokens(5)
357            .total_tokens(15)
358            .build()
359            .unwrap();
360
361        let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
362
363        let event = ConverseStreamOutput::Metadata(metadata);
364        let mut active = HashMap::new();
365        let result = process_stream_event(&event, &mut active);
366
367        match result {
368            StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
369                assert_eq!(sample.cache_read_tokens, None);
370                assert_eq!(sample.cache_creation_tokens, None);
371            }
372            _ => panic!("expected Emit(Usage{{..}})"),
373        }
374    }
375
376    #[test]
377    fn test_handle_content_block_delta_empty_text() {
378        let mut active = HashMap::new();
379        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
380            .content_block_index(0)
381            .delta(ContentBlockDelta::Text(String::new()))
382            .build()
383            .unwrap();
384
385        let result = handle_content_block_delta(&delta, &mut active);
386        assert!(matches!(result, StreamEvent::Skip));
387    }
388}