Skip to main content

llm/providers/bedrock/
streaming.rs

1use aws_sdk_bedrockruntime::error::SdkError;
2use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
3use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
4use aws_sdk_bedrockruntime::types::{
5    ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, StopReason as BedrockStopReason,
6    TokenUsage as BedrockTokenUsage,
7};
8use aws_smithy_types::event_stream::RawMessage;
9use futures::Stream;
10use std::collections::HashMap;
11use tracing::{debug, error, info, warn};
12
13use crate::{LlmError, LlmResponse, StopReason, TokenUsage, ToolCallRequest};
14
15impl From<&BedrockTokenUsage> for TokenUsage {
16    fn from(usage: &BedrockTokenUsage) -> Self {
17        TokenUsage {
18            input_tokens: u32::try_from(usage.input_tokens).unwrap_or(0),
19            output_tokens: u32::try_from(usage.output_tokens).unwrap_or(0),
20            cache_read_tokens: usage.cache_read_input_tokens().and_then(|v| u32::try_from(v).ok()),
21            cache_creation_tokens: usage.cache_write_input_tokens().and_then(|v| u32::try_from(v).ok()),
22            ..TokenUsage::default()
23        }
24    }
25}
26
27struct PendingToolCall {
28    id: String,
29    name: String,
30    args: String,
31}
32
33enum StreamEvent {
34    Emit(LlmResponse),
35    Stop(StopReason),
36    Skip,
37}
38
39pub fn process_bedrock_stream(
40    mut receiver: EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>,
41) -> impl Stream<Item = crate::Result<LlmResponse>> + Send {
42    async_stream::stream! {
43        let message_id = uuid::Uuid::new_v4().to_string();
44        yield Ok(LlmResponse::Start { message_id });
45
46        let mut active_tool_calls: HashMap<i32, PendingToolCall> = HashMap::new();
47        let mut last_stop_reason: Option<StopReason> = None;
48
49        loop {
50            match receiver.recv().await {
51                Ok(Some(event)) => {
52                    match process_stream_event(&event, &mut active_tool_calls) {
53                        StreamEvent::Emit(resp) => yield Ok(resp),
54                        StreamEvent::Stop(sr) => last_stop_reason = Some(sr),
55                        StreamEvent::Skip => {}
56                    }
57                }
58                Ok(None) => {
59                    debug!("Bedrock stream ended (recv returned None)");
60                    break;
61                }
62                Err(e) => {
63                    error!("Bedrock stream recv error: {e}");
64                    yield Err(LlmError::from(e));
65                    break;
66                }
67            }
68        }
69
70        // Emit any remaining tool calls that weren't completed via ContentBlockStop
71        for (_index, tc) in active_tool_calls {
72            let tool_call = ToolCallRequest {
73                id: tc.id,
74                name: tc.name,
75                arguments: tc.args,
76            };
77            yield Ok(LlmResponse::ToolRequestComplete { tool_call });
78        }
79
80        yield Ok(LlmResponse::Done {
81            stop_reason: last_stop_reason,
82        });
83    }
84}
85
86fn process_stream_event(
87    event: &ConverseStreamOutput,
88    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
89) -> StreamEvent {
90    match event {
91        ConverseStreamOutput::MessageStart(_) => {
92            info!("Bedrock message started");
93            StreamEvent::Skip
94        }
95        ConverseStreamOutput::ContentBlockStart(start_event) => {
96            handle_content_block_start(start_event, active_tool_calls)
97        }
98        ConverseStreamOutput::ContentBlockDelta(delta_event) => {
99            handle_content_block_delta(delta_event, active_tool_calls)
100        }
101        ConverseStreamOutput::ContentBlockStop(stop_event) => {
102            handle_content_block_stop(stop_event.content_block_index(), active_tool_calls)
103        }
104        ConverseStreamOutput::MessageStop(stop_event) => {
105            let stop_reason = map_bedrock_stop_reason(&stop_event.stop_reason);
106            info!("Bedrock message stopped: {stop_reason:?}");
107            StreamEvent::Stop(stop_reason)
108        }
109        ConverseStreamOutput::Metadata(metadata_event) => metadata_event
110            .usage()
111            .map_or(StreamEvent::Skip, |usage| StreamEvent::Emit(LlmResponse::Usage { tokens: usage.into() })),
112        other => {
113            warn!("Unhandled Bedrock stream event: {other:?}");
114            StreamEvent::Skip
115        }
116    }
117}
118
119fn handle_content_block_start(
120    event: &aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
121    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
122) -> StreamEvent {
123    let index = event.content_block_index();
124
125    if let Some(ContentBlockStart::ToolUse(tool_start)) = event.start() {
126        let id = tool_start.tool_use_id().to_string();
127        let name = tool_start.name().to_string();
128        debug!("Bedrock tool use started: {name} ({id})");
129        active_tool_calls.insert(index, PendingToolCall { id: id.clone(), name: name.clone(), args: String::new() });
130        StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name })
131    } else {
132        debug!("Content block started at index {index}");
133        StreamEvent::Skip
134    }
135}
136
137fn handle_content_block_delta(
138    event: &aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
139    active_tool_calls: &mut HashMap<i32, PendingToolCall>,
140) -> StreamEvent {
141    let index = event.content_block_index();
142
143    let Some(delta) = event.delta() else {
144        return StreamEvent::Skip;
145    };
146
147    match delta {
148        ContentBlockDelta::Text(text) if !text.is_empty() => {
149            StreamEvent::Emit(LlmResponse::Text { chunk: text.clone() })
150        }
151        ContentBlockDelta::ToolUse(tool_delta) => {
152            let input = tool_delta.input();
153            if input.is_empty() {
154                return StreamEvent::Skip;
155            }
156
157            if let Some(tc) = active_tool_calls.get_mut(&index) {
158                tc.args.push_str(input);
159                StreamEvent::Emit(LlmResponse::ToolRequestArg { id: tc.id.clone(), chunk: input.to_string() })
160            } else {
161                warn!("Received tool input delta for unknown content block index: {index}");
162                StreamEvent::Skip
163            }
164        }
165        ContentBlockDelta::ReasoningContent(reasoning) => {
166            if let Ok(text) = reasoning.as_text()
167                && !text.is_empty()
168            {
169                return StreamEvent::Emit(LlmResponse::Reasoning { chunk: text.clone() });
170            }
171            StreamEvent::Skip
172        }
173        _ => {
174            debug!("Unhandled content block delta type");
175            StreamEvent::Skip
176        }
177    }
178}
179
180fn handle_content_block_stop(index: i32, active_tool_calls: &mut HashMap<i32, PendingToolCall>) -> StreamEvent {
181    if let Some(tc) = active_tool_calls.remove(&index) {
182        let tool_call = ToolCallRequest { id: tc.id, name: tc.name, arguments: tc.args };
183        StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
184    } else {
185        debug!("Content block stopped at index {index}");
186        StreamEvent::Skip
187    }
188}
189
190impl From<SdkError<ConverseStreamOutputError, RawMessage>> for LlmError {
191    fn from(e: SdkError<ConverseStreamOutputError, RawMessage>) -> Self {
192        let message = format!("Bedrock stream error: {e}");
193        match e {
194            SdkError::ServiceError(svc) => {
195                let inner = svc.err();
196                if inner.is_throttling_exception() {
197                    LlmError::RateLimited(message)
198                } else if inner.is_service_unavailable_exception()
199                    || inner.is_internal_server_exception()
200                    || inner.is_model_stream_error_exception()
201                {
202                    LlmError::StreamInterrupted(message)
203                } else {
204                    LlmError::ApiError(message)
205                }
206            }
207            _ => LlmError::StreamInterrupted(message),
208        }
209    }
210}
211
212fn map_bedrock_stop_reason(reason: &BedrockStopReason) -> StopReason {
213    match reason {
214        BedrockStopReason::EndTurn | BedrockStopReason::StopSequence => StopReason::EndTurn,
215        BedrockStopReason::ToolUse => StopReason::ToolCalls,
216        BedrockStopReason::MaxTokens | BedrockStopReason::ModelContextWindowExceeded => StopReason::Length,
217        BedrockStopReason::ContentFiltered | BedrockStopReason::GuardrailIntervened => StopReason::ContentFilter,
218        other => StopReason::Unknown(format!("{other:?}")),
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_map_stop_reason_end_turn() {
228        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::EndTurn), StopReason::EndTurn);
229    }
230
231    #[test]
232    fn test_map_stop_reason_stop_sequence() {
233        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::StopSequence), StopReason::EndTurn);
234    }
235
236    #[test]
237    fn test_map_stop_reason_tool_use() {
238        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ToolUse), StopReason::ToolCalls);
239    }
240
241    #[test]
242    fn test_map_stop_reason_max_tokens() {
243        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::MaxTokens), StopReason::Length);
244    }
245
246    #[test]
247    fn test_map_stop_reason_context_window_exceeded() {
248        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ModelContextWindowExceeded), StopReason::Length);
249    }
250
251    #[test]
252    fn test_map_stop_reason_content_filtered() {
253        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ContentFiltered), StopReason::ContentFilter);
254    }
255
256    #[test]
257    fn test_map_stop_reason_guardrail() {
258        assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::GuardrailIntervened), StopReason::ContentFilter);
259    }
260
261    #[test]
262    fn test_handle_content_block_start_tool_use() {
263        let mut active = HashMap::new();
264        let tool_start = aws_sdk_bedrockruntime::types::ToolUseBlockStart::builder()
265            .tool_use_id("tool_123")
266            .name("search")
267            .build()
268            .unwrap();
269
270        let event = aws_sdk_bedrockruntime::types::ContentBlockStartEvent::builder()
271            .content_block_index(0)
272            .start(ContentBlockStart::ToolUse(tool_start))
273            .build()
274            .unwrap();
275
276        let result = handle_content_block_start(&event, &mut active);
277        assert!(
278            matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name }) if id == "tool_123" && name == "search")
279        );
280        assert!(active.contains_key(&0));
281    }
282
283    #[test]
284    fn test_handle_content_block_delta_text() {
285        let mut active = HashMap::new();
286        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
287            .content_block_index(0)
288            .delta(ContentBlockDelta::Text("Hello".to_string()))
289            .build()
290            .unwrap();
291
292        let result = handle_content_block_delta(&delta, &mut active);
293        assert!(matches!(&result, StreamEvent::Emit(LlmResponse::Text { chunk }) if chunk == "Hello"));
294    }
295
296    #[test]
297    fn test_handle_content_block_delta_tool_input() {
298        let mut active = HashMap::new();
299        active
300            .insert(0, PendingToolCall { id: "tool_123".to_string(), name: "search".to_string(), args: String::new() });
301
302        let tool_delta =
303            aws_sdk_bedrockruntime::types::ToolUseBlockDelta::builder().input(r#"{"query":"test"}"#).build().unwrap();
304
305        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
306            .content_block_index(0)
307            .delta(ContentBlockDelta::ToolUse(tool_delta))
308            .build()
309            .unwrap();
310
311        let result = handle_content_block_delta(&delta, &mut active);
312        assert!(
313            matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestArg { id, chunk }) if id == "tool_123" && chunk == r#"{"query":"test"}"#)
314        );
315
316        // Verify accumulated args
317        assert_eq!(active.get(&0).unwrap().args, r#"{"query":"test"}"#);
318    }
319
320    #[test]
321    fn test_handle_content_block_stop_completes_tool() {
322        let mut active = HashMap::new();
323        active.insert(
324            0,
325            PendingToolCall {
326                id: "tool_123".to_string(),
327                name: "search".to_string(),
328                args: r#"{"query":"test"}"#.to_string(),
329            },
330        );
331
332        let result = handle_content_block_stop(0, &mut active);
333        assert!(matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
334            if tool_call.id == "tool_123"
335            && tool_call.name == "search"
336            && tool_call.arguments == r#"{"query":"test"}"#
337        ));
338        assert!(active.is_empty());
339    }
340
341    #[test]
342    fn test_handle_content_block_stop_no_tool() {
343        let mut active = HashMap::new();
344        let result = handle_content_block_stop(0, &mut active);
345        assert!(matches!(result, StreamEvent::Skip));
346    }
347
348    #[test]
349    fn test_metadata_event_emits_cache_read_and_creation() {
350        let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
351            .input_tokens(100)
352            .output_tokens(50)
353            .total_tokens(150)
354            .cache_read_input_tokens(40)
355            .cache_write_input_tokens(20)
356            .build()
357            .unwrap();
358
359        let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
360
361        let event = ConverseStreamOutput::Metadata(metadata);
362        let mut active = HashMap::new();
363        let result = process_stream_event(&event, &mut active);
364
365        match result {
366            StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
367                assert_eq!(sample.input_tokens, 100);
368                assert_eq!(sample.output_tokens, 50);
369                assert_eq!(sample.cache_read_tokens, Some(40));
370                assert_eq!(sample.cache_creation_tokens, Some(20));
371            }
372            _ => panic!("expected Emit(Usage{{..}})"),
373        }
374    }
375
376    #[test]
377    fn test_metadata_event_without_cache_fields() {
378        let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
379            .input_tokens(10)
380            .output_tokens(5)
381            .total_tokens(15)
382            .build()
383            .unwrap();
384
385        let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
386
387        let event = ConverseStreamOutput::Metadata(metadata);
388        let mut active = HashMap::new();
389        let result = process_stream_event(&event, &mut active);
390
391        match result {
392            StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
393                assert_eq!(sample.cache_read_tokens, None);
394                assert_eq!(sample.cache_creation_tokens, None);
395            }
396            _ => panic!("expected Emit(Usage{{..}})"),
397        }
398    }
399
400    #[test]
401    fn test_handle_content_block_delta_empty_text() {
402        let mut active = HashMap::new();
403        let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
404            .content_block_index(0)
405            .delta(ContentBlockDelta::Text(String::new()))
406            .build()
407            .unwrap();
408
409        let result = handle_content_block_delta(&delta, &mut active);
410        assert!(matches!(result, StreamEvent::Skip));
411    }
412}