Skip to main content

codetether_agent/provider/bedrock/
stream.rs

1//! Real streaming adapter for Bedrock's `/converse-stream` endpoint.
2//!
3//! Parses AWS eventstream frames into [`StreamChunk`]s so callers see text
4//! deltas and tool-call argument deltas as they arrive, rather than waiting
5//! for the full response.
6//!
7//! Bedrock event types handled:
8//! - `messageStart`         → ignored (role is always assistant)
9//! - `contentBlockStart`    → `StreamChunk::ToolCallStart` (only for tool blocks)
10//! - `contentBlockDelta`    → `StreamChunk::Text` or `StreamChunk::ToolCallDelta`
11//! - `contentBlockStop`     → `StreamChunk::ToolCallEnd` (only for open tool blocks)
12//! - `messageStop`          → ignored (finish reason is carried by `Done`)
13//! - `metadata`             → captured into usage, emitted at `Done`
14//! - `exception` / `error`  → `StreamChunk::Error`
15
16use super::BedrockProvider;
17use super::eventstream::{EventMessage, FrameBuffer};
18use crate::provider::{StreamChunk, Usage};
19use anyhow::{Context, Result};
20use futures::StreamExt;
21use serde_json::Value;
22use std::collections::HashMap;
23
24impl BedrockProvider {
25    /// POST to `/model/{id}/converse-stream` and yield `StreamChunk`s as
26    /// eventstream frames arrive.
27    ///
28    /// # Errors
29    ///
30    /// Returns [`anyhow::Error`] if the initial HTTP request fails or the
31    /// server responds non-200. Per-frame decode errors are emitted as
32    /// [`StreamChunk::Error`] but do not abort the stream.
33    pub(super) async fn converse_stream(
34        &self,
35        model_id: &str,
36        body: Vec<u8>,
37    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
38        let url = format!("{}/model/{}/converse-stream", self.base_url(), model_id);
39        tracing::debug!("Bedrock stream URL: {}", url);
40
41        let response = self
42            .send_request("POST", &url, Some(&body), "bedrock")
43            .await?;
44
45        let status = response.status();
46        if !status.is_success() {
47            let text = response.text().await.context("failed to read error body")?;
48            anyhow::bail!(
49                "Bedrock stream error ({status}): {}",
50                crate::util::truncate_bytes_safe(&text, 500)
51            );
52        }
53
54        let mut byte_stream = response.bytes_stream();
55        let mut framer = FrameBuffer::new();
56        let mut state = StreamState::default();
57
58        let stream = async_stream::stream! {
59            while let Some(chunk) = byte_stream.next().await {
60                match chunk {
61                    Ok(bytes) => framer.extend(&bytes),
62                    Err(e) => {
63                        yield StreamChunk::Error(format!("transport error: {e}"));
64                        return;
65                    }
66                }
67
68                loop {
69                    match framer.next_frame() {
70                        Ok(Some(msg)) => {
71                            for out in handle_event(&mut state, msg) {
72                                yield out;
73                            }
74                        }
75                        Ok(None) => break,
76                        Err(e) => {
77                            yield StreamChunk::Error(format!("frame decode: {e}"));
78                            return;
79                        }
80                    }
81                }
82            }
83
84            yield StreamChunk::Done { usage: state.usage };
85        };
86
87        Ok(Box::pin(stream))
88    }
89}
90
91/// Per-stream state: tracks which content-block indexes are open tool-use
92/// blocks so we can emit matching `ToolCallEnd` events, and accumulates usage
93/// for the final `Done` chunk.
94#[derive(Debug, Default)]
95struct StreamState {
96    /// Map of contentBlockIndex → tool_use_id for active tool-use blocks.
97    open_tool_blocks: HashMap<u64, String>,
98    usage: Option<Usage>,
99}
100
101fn handle_event(state: &mut StreamState, msg: EventMessage) -> Vec<StreamChunk> {
102    let message_type = msg.message_type().unwrap_or("event");
103    if matches!(message_type, "exception" | "error") {
104        let event_type = msg.event_type().unwrap_or("unknown");
105        let payload = String::from_utf8_lossy(&msg.payload);
106        return vec![StreamChunk::Error(format!("{event_type}: {payload}"))];
107    }
108
109    let event_type = msg.event_type().unwrap_or("");
110    let body: Value = match serde_json::from_slice(&msg.payload) {
111        Ok(v) => v,
112        Err(_) if msg.payload.is_empty() => Value::Null,
113        Err(e) => return vec![StreamChunk::Error(format!("bad {event_type} json: {e}"))],
114    };
115
116    match event_type {
117        "contentBlockStart" => handle_block_start(state, &body),
118        "contentBlockDelta" => handle_block_delta(state, &body),
119        "contentBlockStop" => handle_block_stop(state, &body),
120        "metadata" => {
121            state.usage = extract_usage(&body);
122            Vec::new()
123        }
124        // messageStart / messageStop — no session-observable effect
125        _ => Vec::new(),
126    }
127}
128
129fn handle_block_start(state: &mut StreamState, body: &Value) -> Vec<StreamChunk> {
130    let Some(idx) = body.get("contentBlockIndex").and_then(|v| v.as_u64()) else {
131        return Vec::new();
132    };
133    let Some(tool_use) = body.get("start").and_then(|v| v.get("toolUse")) else {
134        return Vec::new();
135    };
136    let id = tool_use
137        .get("toolUseId")
138        .and_then(|v| v.as_str())
139        .unwrap_or("")
140        .to_string();
141    let name = tool_use
142        .get("name")
143        .and_then(|v| v.as_str())
144        .unwrap_or("")
145        .to_string();
146    if id.is_empty() {
147        return Vec::new();
148    }
149    state.open_tool_blocks.insert(idx, id.clone());
150    vec![StreamChunk::ToolCallStart { id, name }]
151}
152
153fn handle_block_delta(state: &StreamState, body: &Value) -> Vec<StreamChunk> {
154    let idx = body.get("contentBlockIndex").and_then(|v| v.as_u64());
155    let Some(delta) = body.get("delta") else {
156        return Vec::new();
157    };
158
159    if let Some(text) = delta.get("text").and_then(|v| v.as_str())
160        && !text.is_empty()
161    {
162        return vec![StreamChunk::Text(text.to_string())];
163    }
164
165    if let Some(tool) = delta.get("toolUse")
166        && let Some(partial) = tool.get("input").and_then(|v| v.as_str())
167        && let Some(idx) = idx
168        && let Some(id) = state.open_tool_blocks.get(&idx)
169    {
170        return vec![StreamChunk::ToolCallDelta {
171            id: id.clone(),
172            arguments_delta: partial.to_string(),
173        }];
174    }
175
176    // Reasoning/thinking deltas — no matching StreamChunk variant; drop silently.
177    Vec::new()
178}
179
180fn handle_block_stop(state: &mut StreamState, body: &Value) -> Vec<StreamChunk> {
181    let Some(idx) = body.get("contentBlockIndex").and_then(|v| v.as_u64()) else {
182        return Vec::new();
183    };
184    if let Some(id) = state.open_tool_blocks.remove(&idx) {
185        return vec![StreamChunk::ToolCallEnd { id }];
186    }
187    Vec::new()
188}
189
190fn extract_usage(body: &Value) -> Option<Usage> {
191    let u = body.get("usage")?;
192    Some(Usage {
193        prompt_tokens: u.get("inputTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
194        completion_tokens: u.get("outputTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
195        total_tokens: u.get("totalTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
196        cache_read_tokens: u
197            .get("cacheReadInputTokens")
198            .and_then(|v| v.as_u64())
199            .map(|n| n as usize),
200        cache_write_tokens: u
201            .get("cacheWriteInputTokens")
202            .and_then(|v| v.as_u64())
203            .map(|n| n as usize),
204    })
205}