1use super::traits::*;
49use crate::types::*;
50use async_trait::async_trait;
51use futures::StreamExt;
52use serde::Deserialize;
53use tokio::sync::mpsc;
54use tracing::{debug, warn};
55
56pub struct BedrockProvider;
58
59#[async_trait]
60impl StreamProvider for BedrockProvider {
61 fn provider_id(&self) -> &str {
62 "bedrock"
63 }
64
65 async fn stream(
66 &self,
67 config: StreamConfig, tx: mpsc::UnboundedSender<StreamEvent>, cancel: tokio_util::sync::CancellationToken, ) -> Result<Message, ProviderError> {
71 let model_config = &config.model_config;
72 let api_key = model_config.resolve_api_key().await?;
75
76 if !matches!(config.response_format, ResponseFormat::Text)
82 && !model_config.id.contains("anthropic")
83 {
84 return Err(ProviderError::SchemaMismatch {
85 reason: format!(
86 "Bedrock model `{}` does not support structured output via the \
87 phi-core Converse adapter (only `anthropic.*` foundation models do). \
88 Either switch to a Bedrock Anthropic model or set response_format to Text.",
89 model_config.id
90 ),
91 });
92 }
93
94 let base_url = &model_config.base_url;
95 let url = format!(
96 "{}/model/{}/converse-stream",
97 base_url, config.model_config.id
98 );
99
100 let body = build_bedrock_body(&config);
101 debug!(
102 "Bedrock request: model={} url={}",
103 config.model_config.id, url
104 );
105
106 let parts: Vec<&str> = api_key.splitn(3, ':').collect();
116 if parts.len() < 2 {
117 return Err(ProviderError::Auth(
118 "Bedrock api_key must be 'access_key:secret_key[:session_token]'".into(),
119 ));
120 }
121
122 let client = reqwest::Client::new();
123 let mut request = client.post(&url).header("content-type", "application/json");
124
125 for (k, v) in &model_config.headers {
129 request = request.header(k, v);
130 }
131
132 if !model_config.headers.contains_key("authorization") {
135 request = request.header("authorization", format!("Bearer {}", api_key));
136 }
137
138 let response = request
139 .json(&body)
140 .send()
141 .await
142 .map_err(|e| ProviderError::Network(e.to_string()))?;
143
144 if !response.status().is_success() {
145 let status = response.status();
146 let body = response.text().await.unwrap_or_default();
147 return Err(ProviderError::classify(
148 status.as_u16(),
149 &format!("Bedrock error {}: {}", status, body),
150 ));
151 }
152
153 let mut content: Vec<Content> = Vec::new();
154 let mut usage = Usage::default();
155 let mut stop_reason = StopReason::Stop;
156
157 let _ = tx.send(StreamEvent::Start);
158
159 let mut stream = response.bytes_stream();
162 let mut buffer = String::new();
163
164 loop {
165 tokio::select! {
166 _ = cancel.cancelled() => {
167 return Err(ProviderError::Cancelled);
168 }
169 chunk = stream.next() => {
170 match chunk {
171 None => break,
172 Some(Err(e)) => {
173 warn!("Bedrock stream error: {}", e);
174 break;
175 }
176 Some(Ok(bytes)) => {
177 buffer.push_str(&String::from_utf8_lossy(&bytes));
178
179 while let Some(pos) = buffer.find('\n') {
181 let line = buffer[..pos].trim().to_string();
182 buffer = buffer[pos + 1..].to_string();
183
184 if line.is_empty() {
185 continue;
186 }
187
188 let event: BedrockEvent = match serde_json::from_str(&line) {
189 Ok(e) => e,
190 Err(_) => continue,
191 };
192
193 match event {
194 BedrockEvent::ContentBlockDelta { delta, .. } => {
195 if let Some(text) = delta.text {
196 let text_idx = content.iter().position(|c| matches!(c, Content::Text { .. }));
197 let idx = match text_idx {
198 Some(i) => i,
199 None => {
200 content.push(Content::Text { text: String::new() });
201 content.len() - 1
202 }
203 };
204 if let Some(Content::Text { text: t }) = content.get_mut(idx) {
205 t.push_str(&text);
206 }
207 let _ = tx.send(StreamEvent::TextDelta {
208 content_index: idx,
209 delta: text,
210 });
211 }
212 if let Some(tool_use) = delta.tool_use {
213 let _ = tx.send(StreamEvent::ToolCallDelta {
214 content_index: content.len(),
215 delta: tool_use.input,
216 });
217 }
218 }
219 BedrockEvent::ContentBlockStart { start, .. } => {
220 if let Some(tool_use) = start.tool_use {
221 let idx = content.len();
222 content.push(Content::ToolCall {
223 id: tool_use.tool_use_id.clone(),
224 name: tool_use.name.clone(),
225 arguments: serde_json::Value::Object(Default::default()),
226 });
227 let _ = tx.send(StreamEvent::ToolCallStart {
228 content_index: idx,
229 id: tool_use.tool_use_id,
230 name: tool_use.name,
231 });
232 }
233 }
234 BedrockEvent::ContentBlockStop { .. } => {
235 if content.iter().any(|c| matches!(c, Content::ToolCall { .. })) {
236 let _ = tx.send(StreamEvent::ToolCallEnd {
237 content_index: content.len() - 1,
238 });
239 }
240 }
241 BedrockEvent::MessageStop { stop_reason: sr } => {
242 stop_reason = match sr.as_deref() {
243 Some("end_turn") => StopReason::Stop,
244 Some("max_tokens") => StopReason::Length,
245 Some("tool_use") => StopReason::ToolUse,
246 _ => StopReason::Stop,
247 };
248 }
249 BedrockEvent::Metadata { usage: u } => {
250 if let Some(u) = u {
251 usage.input = u.input_tokens;
252 usage.output = u.output_tokens;
253 usage.total_tokens = u.input_tokens + u.output_tokens;
254 }
255 }
256 BedrockEvent::Unknown => {}
257 }
258 }
259 }
260 }
261 }
262 }
263 }
264
265 let message = Message::Assistant {
266 content,
267 stop_reason,
268 model: config.model_config.id.clone(),
269 provider: model_config.provider.clone(),
270 usage,
271 timestamp: now_ms(),
272 error_message: None,
273 };
274
275 let _ = tx.send(StreamEvent::Done {
276 message: message.clone(),
277 });
278 Ok(message)
279 }
280}
281
282fn build_bedrock_body(config: &StreamConfig) -> serde_json::Value {
283 let mut messages: Vec<serde_json::Value> = Vec::new();
284
285 for msg in &config.messages {
286 match msg {
287 Message::User { content, .. } => {
288 let blocks = content_to_bedrock(content);
289 messages.push(serde_json::json!({"role": "user", "content": blocks}));
290 }
291 Message::Assistant { content, .. } => {
292 let blocks = content_to_bedrock(content);
293 messages.push(serde_json::json!({"role": "assistant", "content": blocks}));
294 }
295 Message::ToolResult {
296 tool_call_id,
297 content,
298 is_error,
299 ..
300 } => {
301 let tool_content: Vec<serde_json::Value> = content
303 .iter()
304 .filter_map(|c| match c {
305 Content::Text { text } => Some(serde_json::json!({"text": text})),
306 Content::Image { data, mime_type } => Some(serde_json::json!({
307 "image": {
308 "format": mime_type.split('/').nth(1).unwrap_or("png"),
309 "source": {"bytes": data},
310 }
311 })),
312 _ => None,
313 })
314 .collect();
315
316 let tool_content = if tool_content.is_empty() {
317 vec![serde_json::json!({"text": ""})]
318 } else {
319 tool_content
320 };
321
322 messages.push(serde_json::json!({
323 "role": "user",
324 "content": [{
325 "toolResult": {
326 "toolUseId": tool_call_id,
327 "content": tool_content,
328 "status": if *is_error { "error" } else { "success" },
329 }
330 }],
331 }));
332 }
333 }
334 }
335
336 let mut body = serde_json::json!({"messages": messages});
337
338 if !config.system_prompt.is_empty() {
339 body["system"] = serde_json::json!([{"text": config.system_prompt}]);
340 }
341
342 let mut inference_config = serde_json::json!({});
343 if let Some(max) = config.max_tokens {
344 inference_config["maxTokens"] = serde_json::json!(max);
345 }
346 if let Some(temp) = config.temperature {
347 inference_config["temperature"] = serde_json::json!(temp);
348 }
349 if inference_config != serde_json::json!({}) {
350 body["inferenceConfig"] = inference_config;
351 }
352
353 if !config.tools.is_empty() {
354 let tools: Vec<serde_json::Value> = config
355 .tools
356 .iter()
357 .map(|t| {
358 serde_json::json!({
359 "toolSpec": {
360 "name": t.name,
361 "description": t.description,
362 "inputSchema": {"json": t.parameters},
363 }
364 })
365 })
366 .collect();
367 body["toolConfig"] = serde_json::json!({"tools": tools});
368 }
369
370 match &config.response_format {
375 ResponseFormat::Text => {}
376 ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. } => {
377 let (schema, description) = match &config.response_format {
378 ResponseFormat::JsonSchema { schema, name, .. } => (
379 schema.clone(),
380 format!("Return the response as a JSON object matching `{}`.", name),
381 ),
382 _ => (
383 serde_json::json!({"type": "object", "additionalProperties": true}),
384 "Return the response as a JSON object.".to_string(),
385 ),
386 };
387 let synthetic = serde_json::json!({
388 "toolSpec": {
389 "name": "respond_json",
390 "description": description,
391 "inputSchema": {"json": schema},
392 }
393 });
394 if let Some(tools_arr) = body
396 .get_mut("toolConfig")
397 .and_then(|tc| tc.get_mut("tools"))
398 .and_then(|t| t.as_array_mut())
399 {
400 tools_arr.push(synthetic);
401 } else {
402 body["toolConfig"] = serde_json::json!({"tools": [synthetic]});
403 }
404 body["toolConfig"]["toolChoice"] =
406 serde_json::json!({"tool": {"name": "respond_json"}});
407 }
408 }
409
410 body
411}
412
413fn content_to_bedrock(content: &[Content]) -> Vec<serde_json::Value> {
414 content
415 .iter()
416 .filter_map(|c| match c {
417 Content::Text { text } => Some(serde_json::json!({"text": text})),
418 Content::Image { data, mime_type } => Some(serde_json::json!({
419 "image": {
420 "format": mime_type.split('/').nth(1).unwrap_or("png"),
421 "source": {"bytes": data},
422 }
423 })),
424 Content::ToolCall {
425 id,
426 name,
427 arguments,
428 } => Some(serde_json::json!({
429 "toolUse": {"toolUseId": id, "name": name, "input": arguments},
430 })),
431 Content::Thinking { .. } => None,
432 })
433 .collect()
434}
435
436#[derive(Deserialize)]
438#[serde(untagged)]
439enum BedrockEvent {
440 ContentBlockDelta {
441 #[serde(rename = "contentBlockDelta")]
442 delta: BedrockDelta,
443 },
444 ContentBlockStart {
445 #[serde(rename = "contentBlockStart")]
446 start: BedrockBlockStart,
447 },
448 ContentBlockStop {
449 #[serde(rename = "contentBlockStop")]
450 #[allow(dead_code)]
451 stop: serde_json::Value,
452 },
453 MessageStop {
454 #[serde(rename = "messageStop")]
455 stop_reason: Option<String>,
456 },
457 Metadata {
458 #[serde(rename = "metadata")]
459 usage: Option<BedrockUsage>,
460 },
461 Unknown,
462}
463
464#[derive(Deserialize)]
465struct BedrockDelta {
466 #[serde(default)]
467 text: Option<String>,
468 #[serde(default, rename = "toolUse")]
469 tool_use: Option<BedrockToolUseDelta>,
470}
471
472#[derive(Deserialize)]
473struct BedrockToolUseDelta {
474 input: String,
475}
476
477#[derive(Deserialize)]
478struct BedrockBlockStart {
479 #[serde(default, rename = "toolUse")]
480 tool_use: Option<BedrockToolUseStart>,
481}
482
483#[derive(Deserialize)]
484struct BedrockToolUseStart {
485 #[serde(rename = "toolUseId")]
486 tool_use_id: String,
487 name: String,
488}
489
490#[derive(Deserialize)]
491struct BedrockUsage {
492 #[serde(default, rename = "inputTokens")]
493 input_tokens: u64,
494 #[serde(default, rename = "outputTokens")]
495 output_tokens: u64,
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 #[test]
503 fn test_build_bedrock_body() {
504 let config = StreamConfig {
505 model_config: crate::provider::ModelConfig::anthropic(
506 "anthropic.claude-3-sonnet-20240229-v1:0",
507 "Claude Sonnet",
508 "key:secret",
509 ),
510 system_prompt: "Be helpful".into(),
511 messages: vec![Message::user("Hello")],
512 tools: vec![],
513 thinking_level: ThinkingLevel::Off,
514 max_tokens: Some(1024),
515 temperature: None,
516 cache_config: CacheConfig::default(),
517 response_format: ResponseFormat::Text,
518 };
519
520 let body = build_bedrock_body(&config);
521 assert!(body["messages"].is_array());
522 assert_eq!(body["messages"][0]["role"], "user");
523 assert!(body["system"].is_array());
524 assert_eq!(body["inferenceConfig"]["maxTokens"], 1024);
525 }
526
527 #[test]
528 fn test_content_to_bedrock() {
529 let content = vec![
530 Content::Text {
531 text: "hello".into(),
532 },
533 Content::ToolCall {
534 id: "tc-1".into(),
535 name: "bash".into(),
536 arguments: serde_json::json!({"command": "ls"}),
537 },
538 ];
539 let blocks = content_to_bedrock(&content);
540 assert_eq!(blocks.len(), 2);
541 assert_eq!(blocks[0]["text"], "hello");
542 assert_eq!(blocks[1]["toolUse"]["name"], "bash");
543 }
544}