codetether_agent/provider/bedrock/
stream.rs1use super::BedrockProvider;
17use super::eventstream::{EventMessage, FrameBuffer};
18use crate::provider::{StreamChunk, Usage};
19use anyhow::{Context, Result};
20use futures::StreamExt;
21use serde_json::Value;
22use std::collections::HashMap;
23
24impl BedrockProvider {
25 pub(super) async fn converse_stream(
34 &self,
35 model_id: &str,
36 body: Vec<u8>,
37 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
38 let url = format!("{}/model/{}/converse-stream", self.base_url(), model_id);
39 tracing::debug!("Bedrock stream URL: {}", url);
40
41 let response = self
42 .send_request("POST", &url, Some(&body), "bedrock")
43 .await?;
44
45 let status = response.status();
46 if !status.is_success() {
47 let text = response.text().await.context("failed to read error body")?;
48 anyhow::bail!(
49 "Bedrock stream error ({status}): {}",
50 crate::util::truncate_bytes_safe(&text, 500)
51 );
52 }
53
54 let mut byte_stream = response.bytes_stream();
55 let mut framer = FrameBuffer::new();
56 let mut state = StreamState::default();
57
58 let stream = async_stream::stream! {
59 while let Some(chunk) = byte_stream.next().await {
60 match chunk {
61 Ok(bytes) => framer.extend(&bytes),
62 Err(e) => {
63 yield StreamChunk::Error(format!("transport error: {e}"));
64 return;
65 }
66 }
67
68 loop {
69 match framer.next_frame() {
70 Ok(Some(msg)) => {
71 for out in handle_event(&mut state, msg) {
72 yield out;
73 }
74 }
75 Ok(None) => break,
76 Err(e) => {
77 yield StreamChunk::Error(format!("frame decode: {e}"));
78 return;
79 }
80 }
81 }
82 }
83
84 yield StreamChunk::Done { usage: state.usage };
85 };
86
87 Ok(Box::pin(stream))
88 }
89}
90
91#[derive(Debug, Default)]
95struct StreamState {
96 open_tool_blocks: HashMap<u64, String>,
98 usage: Option<Usage>,
99}
100
101fn handle_event(state: &mut StreamState, msg: EventMessage) -> Vec<StreamChunk> {
102 let message_type = msg.message_type().unwrap_or("event");
103 if matches!(message_type, "exception" | "error") {
104 let event_type = msg.event_type().unwrap_or("unknown");
105 let payload = String::from_utf8_lossy(&msg.payload);
106 return vec![StreamChunk::Error(format!("{event_type}: {payload}"))];
107 }
108
109 let event_type = msg.event_type().unwrap_or("");
110 let body: Value = match serde_json::from_slice(&msg.payload) {
111 Ok(v) => v,
112 Err(_) if msg.payload.is_empty() => Value::Null,
113 Err(e) => return vec![StreamChunk::Error(format!("bad {event_type} json: {e}"))],
114 };
115
116 match event_type {
117 "contentBlockStart" => handle_block_start(state, &body),
118 "contentBlockDelta" => handle_block_delta(state, &body),
119 "contentBlockStop" => handle_block_stop(state, &body),
120 "metadata" => {
121 state.usage = extract_usage(&body);
122 Vec::new()
123 }
124 _ => Vec::new(),
126 }
127}
128
129fn handle_block_start(state: &mut StreamState, body: &Value) -> Vec<StreamChunk> {
130 let Some(idx) = body.get("contentBlockIndex").and_then(|v| v.as_u64()) else {
131 return Vec::new();
132 };
133 let Some(tool_use) = body.get("start").and_then(|v| v.get("toolUse")) else {
134 return Vec::new();
135 };
136 let id = tool_use
137 .get("toolUseId")
138 .and_then(|v| v.as_str())
139 .unwrap_or("")
140 .to_string();
141 let name = tool_use
142 .get("name")
143 .and_then(|v| v.as_str())
144 .unwrap_or("")
145 .to_string();
146 if id.is_empty() {
147 return Vec::new();
148 }
149 state.open_tool_blocks.insert(idx, id.clone());
150 vec![StreamChunk::ToolCallStart { id, name }]
151}
152
153fn handle_block_delta(state: &StreamState, body: &Value) -> Vec<StreamChunk> {
154 let idx = body.get("contentBlockIndex").and_then(|v| v.as_u64());
155 let Some(delta) = body.get("delta") else {
156 return Vec::new();
157 };
158
159 if let Some(text) = delta.get("text").and_then(|v| v.as_str())
160 && !text.is_empty()
161 {
162 return vec![StreamChunk::Text(text.to_string())];
163 }
164
165 if let Some(tool) = delta.get("toolUse")
166 && let Some(partial) = tool.get("input").and_then(|v| v.as_str())
167 && let Some(idx) = idx
168 && let Some(id) = state.open_tool_blocks.get(&idx)
169 {
170 return vec![StreamChunk::ToolCallDelta {
171 id: id.clone(),
172 arguments_delta: partial.to_string(),
173 }];
174 }
175
176 Vec::new()
178}
179
180fn handle_block_stop(state: &mut StreamState, body: &Value) -> Vec<StreamChunk> {
181 let Some(idx) = body.get("contentBlockIndex").and_then(|v| v.as_u64()) else {
182 return Vec::new();
183 };
184 if let Some(id) = state.open_tool_blocks.remove(&idx) {
185 return vec![StreamChunk::ToolCallEnd { id }];
186 }
187 Vec::new()
188}
189
190fn extract_usage(body: &Value) -> Option<Usage> {
191 let u = body.get("usage")?;
192 Some(Usage {
193 prompt_tokens: u.get("inputTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
194 completion_tokens: u.get("outputTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
195 total_tokens: u.get("totalTokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
196 cache_read_tokens: u
197 .get("cacheReadInputTokens")
198 .and_then(|v| v.as_u64())
199 .map(|n| n as usize),
200 cache_write_tokens: u
201 .get("cacheWriteInputTokens")
202 .and_then(|v| v.as_u64())
203 .map(|n| n as usize),
204 })
205}