use aws_sdk_bedrockruntime::primitives::event_stream::EventReceiver;
use aws_sdk_bedrockruntime::types::error::ConverseStreamOutputError;
use aws_sdk_bedrockruntime::types::{
ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, StopReason as BedrockStopReason,
TokenUsage as BedrockTokenUsage,
};
use futures::Stream;
use std::collections::HashMap;
use tracing::{debug, error, info, warn};
use crate::{LlmError, LlmResponse, StopReason, TokenUsage, ToolCallRequest};
impl From<&BedrockTokenUsage> for TokenUsage {
fn from(usage: &BedrockTokenUsage) -> Self {
TokenUsage {
input_tokens: u32::try_from(usage.input_tokens).unwrap_or(0),
output_tokens: u32::try_from(usage.output_tokens).unwrap_or(0),
cache_read_tokens: usage.cache_read_input_tokens().and_then(|v| u32::try_from(v).ok()),
cache_creation_tokens: usage.cache_write_input_tokens().and_then(|v| u32::try_from(v).ok()),
..TokenUsage::default()
}
}
}
struct PendingToolCall {
id: String,
name: String,
args: String,
}
enum StreamEvent {
Emit(LlmResponse),
Stop(StopReason),
Skip,
}
pub fn process_bedrock_stream(
mut receiver: EventReceiver<ConverseStreamOutput, ConverseStreamOutputError>,
) -> impl Stream<Item = crate::Result<LlmResponse>> + Send {
async_stream::stream! {
let message_id = uuid::Uuid::new_v4().to_string();
yield Ok(LlmResponse::Start { message_id });
let mut active_tool_calls: HashMap<i32, PendingToolCall> = HashMap::new();
let mut last_stop_reason: Option<StopReason> = None;
loop {
match receiver.recv().await {
Ok(Some(event)) => {
match process_stream_event(&event, &mut active_tool_calls) {
StreamEvent::Emit(resp) => yield Ok(resp),
StreamEvent::Stop(sr) => last_stop_reason = Some(sr),
StreamEvent::Skip => {}
}
}
Ok(None) => {
debug!("Bedrock stream ended (recv returned None)");
break;
}
Err(e) => {
error!("Bedrock stream recv error: {e}");
yield Err(LlmError::ApiError(format!("Bedrock stream error: {e}")));
break;
}
}
}
for (_index, tc) in active_tool_calls {
let tool_call = ToolCallRequest {
id: tc.id,
name: tc.name,
arguments: tc.args,
};
yield Ok(LlmResponse::ToolRequestComplete { tool_call });
}
yield Ok(LlmResponse::Done {
stop_reason: last_stop_reason,
});
}
}
fn process_stream_event(
event: &ConverseStreamOutput,
active_tool_calls: &mut HashMap<i32, PendingToolCall>,
) -> StreamEvent {
match event {
ConverseStreamOutput::MessageStart(_) => {
info!("Bedrock message started");
StreamEvent::Skip
}
ConverseStreamOutput::ContentBlockStart(start_event) => {
handle_content_block_start(start_event, active_tool_calls)
}
ConverseStreamOutput::ContentBlockDelta(delta_event) => {
handle_content_block_delta(delta_event, active_tool_calls)
}
ConverseStreamOutput::ContentBlockStop(stop_event) => {
handle_content_block_stop(stop_event.content_block_index(), active_tool_calls)
}
ConverseStreamOutput::MessageStop(stop_event) => {
let stop_reason = map_bedrock_stop_reason(&stop_event.stop_reason);
info!("Bedrock message stopped: {stop_reason:?}");
StreamEvent::Stop(stop_reason)
}
ConverseStreamOutput::Metadata(metadata_event) => metadata_event
.usage()
.map_or(StreamEvent::Skip, |usage| StreamEvent::Emit(LlmResponse::Usage { tokens: usage.into() })),
other => {
warn!("Unhandled Bedrock stream event: {other:?}");
StreamEvent::Skip
}
}
}
fn handle_content_block_start(
event: &aws_sdk_bedrockruntime::types::ContentBlockStartEvent,
active_tool_calls: &mut HashMap<i32, PendingToolCall>,
) -> StreamEvent {
let index = event.content_block_index();
if let Some(ContentBlockStart::ToolUse(tool_start)) = event.start() {
let id = tool_start.tool_use_id().to_string();
let name = tool_start.name().to_string();
debug!("Bedrock tool use started: {name} ({id})");
active_tool_calls.insert(index, PendingToolCall { id: id.clone(), name: name.clone(), args: String::new() });
StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name })
} else {
debug!("Content block started at index {index}");
StreamEvent::Skip
}
}
fn handle_content_block_delta(
event: &aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent,
active_tool_calls: &mut HashMap<i32, PendingToolCall>,
) -> StreamEvent {
let index = event.content_block_index();
let Some(delta) = event.delta() else {
return StreamEvent::Skip;
};
match delta {
ContentBlockDelta::Text(text) if !text.is_empty() => {
StreamEvent::Emit(LlmResponse::Text { chunk: text.clone() })
}
ContentBlockDelta::ToolUse(tool_delta) => {
let input = tool_delta.input();
if input.is_empty() {
return StreamEvent::Skip;
}
if let Some(tc) = active_tool_calls.get_mut(&index) {
tc.args.push_str(input);
StreamEvent::Emit(LlmResponse::ToolRequestArg { id: tc.id.clone(), chunk: input.to_string() })
} else {
warn!("Received tool input delta for unknown content block index: {index}");
StreamEvent::Skip
}
}
ContentBlockDelta::ReasoningContent(reasoning) => {
if let Ok(text) = reasoning.as_text()
&& !text.is_empty()
{
return StreamEvent::Emit(LlmResponse::Reasoning { chunk: text.clone() });
}
StreamEvent::Skip
}
_ => {
debug!("Unhandled content block delta type");
StreamEvent::Skip
}
}
}
fn handle_content_block_stop(index: i32, active_tool_calls: &mut HashMap<i32, PendingToolCall>) -> StreamEvent {
if let Some(tc) = active_tool_calls.remove(&index) {
let tool_call = ToolCallRequest { id: tc.id, name: tc.name, arguments: tc.args };
StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
} else {
debug!("Content block stopped at index {index}");
StreamEvent::Skip
}
}
fn map_bedrock_stop_reason(reason: &BedrockStopReason) -> StopReason {
match reason {
BedrockStopReason::EndTurn | BedrockStopReason::StopSequence => StopReason::EndTurn,
BedrockStopReason::ToolUse => StopReason::ToolCalls,
BedrockStopReason::MaxTokens | BedrockStopReason::ModelContextWindowExceeded => StopReason::Length,
BedrockStopReason::ContentFiltered | BedrockStopReason::GuardrailIntervened => StopReason::ContentFilter,
other => StopReason::Unknown(format!("{other:?}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_map_stop_reason_end_turn() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::EndTurn), StopReason::EndTurn);
}
#[test]
fn test_map_stop_reason_stop_sequence() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::StopSequence), StopReason::EndTurn);
}
#[test]
fn test_map_stop_reason_tool_use() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ToolUse), StopReason::ToolCalls);
}
#[test]
fn test_map_stop_reason_max_tokens() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::MaxTokens), StopReason::Length);
}
#[test]
fn test_map_stop_reason_context_window_exceeded() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ModelContextWindowExceeded), StopReason::Length);
}
#[test]
fn test_map_stop_reason_content_filtered() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::ContentFiltered), StopReason::ContentFilter);
}
#[test]
fn test_map_stop_reason_guardrail() {
assert_eq!(map_bedrock_stop_reason(&BedrockStopReason::GuardrailIntervened), StopReason::ContentFilter);
}
#[test]
fn test_handle_content_block_start_tool_use() {
let mut active = HashMap::new();
let tool_start = aws_sdk_bedrockruntime::types::ToolUseBlockStart::builder()
.tool_use_id("tool_123")
.name("search")
.build()
.unwrap();
let event = aws_sdk_bedrockruntime::types::ContentBlockStartEvent::builder()
.content_block_index(0)
.start(ContentBlockStart::ToolUse(tool_start))
.build()
.unwrap();
let result = handle_content_block_start(&event, &mut active);
assert!(
matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestStart { id, name }) if id == "tool_123" && name == "search")
);
assert!(active.contains_key(&0));
}
#[test]
fn test_handle_content_block_delta_text() {
let mut active = HashMap::new();
let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
.content_block_index(0)
.delta(ContentBlockDelta::Text("Hello".to_string()))
.build()
.unwrap();
let result = handle_content_block_delta(&delta, &mut active);
assert!(matches!(&result, StreamEvent::Emit(LlmResponse::Text { chunk }) if chunk == "Hello"));
}
#[test]
fn test_handle_content_block_delta_tool_input() {
let mut active = HashMap::new();
active
.insert(0, PendingToolCall { id: "tool_123".to_string(), name: "search".to_string(), args: String::new() });
let tool_delta =
aws_sdk_bedrockruntime::types::ToolUseBlockDelta::builder().input(r#"{"query":"test"}"#).build().unwrap();
let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
.content_block_index(0)
.delta(ContentBlockDelta::ToolUse(tool_delta))
.build()
.unwrap();
let result = handle_content_block_delta(&delta, &mut active);
assert!(
matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestArg { id, chunk }) if id == "tool_123" && chunk == r#"{"query":"test"}"#)
);
assert_eq!(active.get(&0).unwrap().args, r#"{"query":"test"}"#);
}
#[test]
fn test_handle_content_block_stop_completes_tool() {
let mut active = HashMap::new();
active.insert(
0,
PendingToolCall {
id: "tool_123".to_string(),
name: "search".to_string(),
args: r#"{"query":"test"}"#.to_string(),
},
);
let result = handle_content_block_stop(0, &mut active);
assert!(matches!(&result, StreamEvent::Emit(LlmResponse::ToolRequestComplete { tool_call })
if tool_call.id == "tool_123"
&& tool_call.name == "search"
&& tool_call.arguments == r#"{"query":"test"}"#
));
assert!(active.is_empty());
}
#[test]
fn test_handle_content_block_stop_no_tool() {
let mut active = HashMap::new();
let result = handle_content_block_stop(0, &mut active);
assert!(matches!(result, StreamEvent::Skip));
}
#[test]
fn test_metadata_event_emits_cache_read_and_creation() {
let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
.input_tokens(100)
.output_tokens(50)
.total_tokens(150)
.cache_read_input_tokens(40)
.cache_write_input_tokens(20)
.build()
.unwrap();
let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
let event = ConverseStreamOutput::Metadata(metadata);
let mut active = HashMap::new();
let result = process_stream_event(&event, &mut active);
match result {
StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
assert_eq!(sample.input_tokens, 100);
assert_eq!(sample.output_tokens, 50);
assert_eq!(sample.cache_read_tokens, Some(40));
assert_eq!(sample.cache_creation_tokens, Some(20));
}
_ => panic!("expected Emit(Usage{{..}})"),
}
}
#[test]
fn test_metadata_event_without_cache_fields() {
let usage = aws_sdk_bedrockruntime::types::TokenUsage::builder()
.input_tokens(10)
.output_tokens(5)
.total_tokens(15)
.build()
.unwrap();
let metadata = aws_sdk_bedrockruntime::types::ConverseStreamMetadataEvent::builder().usage(usage).build();
let event = ConverseStreamOutput::Metadata(metadata);
let mut active = HashMap::new();
let result = process_stream_event(&event, &mut active);
match result {
StreamEvent::Emit(LlmResponse::Usage { tokens: sample }) => {
assert_eq!(sample.cache_read_tokens, None);
assert_eq!(sample.cache_creation_tokens, None);
}
_ => panic!("expected Emit(Usage{{..}})"),
}
}
#[test]
fn test_handle_content_block_delta_empty_text() {
let mut active = HashMap::new();
let delta = aws_sdk_bedrockruntime::types::ContentBlockDeltaEvent::builder()
.content_block_index(0)
.delta(ContentBlockDelta::Text(String::new()))
.build()
.unwrap();
let result = handle_content_block_delta(&delta, &mut active);
assert!(matches!(result, StreamEvent::Skip));
}
}