Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via API Key bearer token auth.
4//! Uses the native Bedrock Converse API format.
5//! Dynamically discovers available models via the Bedrock ListFoundationModels
6//! and ListInferenceProfiles APIs.
7//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
8
9use super::{
10    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11    Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23    client: Client,
24    api_key: String,
25    region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("BedrockProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("region", &self.region)
33            .finish()
34    }
35}
36
37impl BedrockProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        Self::with_region(api_key, DEFAULT_REGION.to_string())
40    }
41
42    pub fn with_region(api_key: String, region: String) -> Result<Self> {
43        tracing::debug!(
44            provider = "bedrock",
45            region = %region,
46            api_key_len = api_key.len(),
47            "Creating Bedrock provider"
48        );
49        Ok(Self {
50            client: Client::new(),
51            api_key,
52            region,
53        })
54    }
55
56    fn validate_api_key(&self) -> Result<()> {
57        if self.api_key.is_empty() {
58            anyhow::bail!("Bedrock API key is empty");
59        }
60        Ok(())
61    }
62
63    fn base_url(&self) -> String {
64        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65    }
66
67    /// Management API URL (for listing models, not inference)
68    fn management_url(&self) -> String {
69        format!("https://bedrock.{}.amazonaws.com", self.region)
70    }
71
72    /// Resolve a short model alias to the full Bedrock model ID.
73    /// Allows users to specify e.g. "claude-sonnet-4" instead of
74    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
75    fn resolve_model_id(model: &str) -> &str {
76        match model {
77            // --- Anthropic Claude ---
78            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
79            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
80            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
81                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
82            }
83            "claude-3.5-haiku" | "claude-haiku-3.5" => {
84                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
85            }
86            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
87                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
88            }
89            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
90            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
91            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
92
93            // --- Amazon Nova ---
94            "nova-pro" => "amazon.nova-pro-v1:0",
95            "nova-lite" => "amazon.nova-lite-v1:0",
96            "nova-micro" => "amazon.nova-micro-v1:0",
97            "nova-premier" => "us.amazon.nova-premier-v1:0",
98
99            // --- Meta Llama ---
100            "llama-4-maverick" | "llama4-maverick" => {
101                "us.meta.llama4-maverick-17b-instruct-v1:0"
102            }
103            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
104            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
105            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
106            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
107            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
108            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
109            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
110            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
111            "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
112            "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
113
114            // --- Mistral ---
115            "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
116            "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
117            "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
118            "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
119            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
120            "magistral-small" => "us.mistral.magistral-small-2509",
121
122            // --- DeepSeek ---
123            "deepseek-r1" => "us.deepseek.r1-v1:0",
124            "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
125
126            // --- Cohere ---
127            "command-r" => "us.cohere.command-r-v1:0",
128            "command-r-plus" => "us.cohere.command-r-plus-v1:0",
129
130            // --- Qwen ---
131            "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
132            "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
133            "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
134
135            // --- Google Gemma ---
136            "gemma-3-27b" => "us.google.gemma-3-27b-it",
137            "gemma-3-12b" => "us.google.gemma-3-12b-it",
138            "gemma-3-4b" => "us.google.gemma-3-4b-it",
139
140            // --- Moonshot / Kimi ---
141            "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
142            "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
143
144            // --- AI21 Jamba ---
145            "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
146            "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
147
148            // --- MiniMax ---
149            "minimax-m2" => "us.minimax.minimax-m2",
150            "minimax-m2.1" => "us.minimax.minimax-m2.1",
151
152            // --- NVIDIA ---
153            "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
154            "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
155            "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
156
157            // --- Z.AI / GLM ---
158            "glm-4.7" => "us.zai.glm-4.7",
159            "glm-4.7-flash" => "us.zai.glm-4.7-flash",
160
161            // Pass through full model IDs unchanged
162            other => other,
163        }
164    }
165
166    /// Dynamically discover available models from the Bedrock API.
167    /// Merges foundation models with cross-region inference profiles.
168    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
169        let mut models: HashMap<String, ModelInfo> = HashMap::new();
170
171        // 1) Fetch foundation models
172        let fm_url = format!("{}/foundation-models", self.management_url());
173        let fm_resp = self
174            .client
175            .get(&fm_url)
176            .bearer_auth(&self.api_key)
177            .send()
178            .await;
179
180        if let Ok(resp) = fm_resp {
181            if resp.status().is_success() {
182                if let Ok(data) = resp.json::<Value>().await {
183                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
184                        for m in summaries {
185                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
186                            let model_name =
187                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
188                            let provider_name =
189                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
190
191                            let output_modalities: Vec<&str> = m
192                                .get("outputModalities")
193                                .and_then(|v| v.as_array())
194                                .map(|a| {
195                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
196                                })
197                                .unwrap_or_default();
198
199                            let input_modalities: Vec<&str> = m
200                                .get("inputModalities")
201                                .and_then(|v| v.as_array())
202                                .map(|a| {
203                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
204                                })
205                                .unwrap_or_default();
206
207                            let inference_types: Vec<&str> = m
208                                .get("inferenceTypesSupported")
209                                .and_then(|v| v.as_array())
210                                .map(|a| {
211                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
212                                })
213                                .unwrap_or_default();
214
215                            // Only include TEXT output models with ON_DEMAND inference
216                            if !output_modalities.contains(&"TEXT")
217                                || !inference_types.contains(&"ON_DEMAND")
218                            {
219                                continue;
220                            }
221
222                            // Skip non-chat models
223                            let name_lower = model_name.to_lowercase();
224                            if name_lower.contains("rerank")
225                                || name_lower.contains("embed")
226                                || name_lower.contains("safeguard")
227                                || name_lower.contains("sonic")
228                                || name_lower.contains("pegasus")
229                            {
230                                continue;
231                            }
232
233                            let streaming = m
234                                .get("responseStreamingSupported")
235                                .and_then(|v| v.as_bool())
236                                .unwrap_or(false);
237                            let vision = input_modalities.contains(&"IMAGE");
238
239                            // Non-Amazon models need us. cross-region prefix
240                            let actual_id = if model_id.starts_with("amazon.") {
241                                model_id.to_string()
242                            } else {
243                                format!("us.{}", model_id)
244                            };
245
246                            let display_name = format!("{} (Bedrock)", model_name);
247
248                            models.insert(
249                                actual_id.clone(),
250                                ModelInfo {
251                                    id: actual_id,
252                                    name: display_name,
253                                    provider: "bedrock".to_string(),
254                                    context_window: Self::estimate_context_window(
255                                        model_id,
256                                        provider_name,
257                                    ),
258                                    max_output_tokens: Some(Self::estimate_max_output(
259                                        model_id,
260                                        provider_name,
261                                    )),
262                                    supports_vision: vision,
263                                    supports_tools: true,
264                                    supports_streaming: streaming,
265                                    input_cost_per_million: None,
266                                    output_cost_per_million: None,
267                                },
268                            );
269                        }
270                    }
271                }
272            }
273        }
274
275        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
276        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
277        let ip_url = format!(
278            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
279            self.management_url()
280        );
281        let ip_resp = self
282            .client
283            .get(&ip_url)
284            .bearer_auth(&self.api_key)
285            .send()
286            .await;
287
288        if let Ok(resp) = ip_resp {
289            if resp.status().is_success() {
290                if let Ok(data) = resp.json::<Value>().await {
291                    if let Some(profiles) = data
292                        .get("inferenceProfileSummaries")
293                        .and_then(|v| v.as_array())
294                    {
295                        for p in profiles {
296                            let pid = p
297                                .get("inferenceProfileId")
298                                .and_then(|v| v.as_str())
299                                .unwrap_or("");
300                            let pname = p
301                                .get("inferenceProfileName")
302                                .and_then(|v| v.as_str())
303                                .unwrap_or("");
304
305                            // Only US cross-region profiles
306                            if !pid.starts_with("us.") {
307                                continue;
308                            }
309
310                            // Skip already-discovered models
311                            if models.contains_key(pid) {
312                                continue;
313                            }
314
315                            // Skip non-text models
316                            let name_lower = pname.to_lowercase();
317                            if name_lower.contains("image")
318                                || name_lower.contains("stable ")
319                                || name_lower.contains("upscale")
320                                || name_lower.contains("embed")
321                                || name_lower.contains("marengo")
322                                || name_lower.contains("outpaint")
323                                || name_lower.contains("inpaint")
324                                || name_lower.contains("erase")
325                                || name_lower.contains("recolor")
326                                || name_lower.contains("replace")
327                                || name_lower.contains("style ")
328                                || name_lower.contains("background")
329                                || name_lower.contains("sketch")
330                                || name_lower.contains("control")
331                                || name_lower.contains("transfer")
332                                || name_lower.contains("sonic")
333                                || name_lower.contains("pegasus")
334                                || name_lower.contains("rerank")
335                            {
336                                continue;
337                            }
338
339                            // Guess vision from known model families
340                            let vision = pid.contains("llama3-2-11b")
341                                || pid.contains("llama3-2-90b")
342                                || pid.contains("pixtral")
343                                || pid.contains("claude-3")
344                                || pid.contains("claude-sonnet-4")
345                                || pid.contains("claude-opus-4");
346
347                            let display_name = pname.replace("US ", "");
348                            let display_name = format!("{} (Bedrock)", display_name.trim());
349
350                            // Extract provider hint from model ID
351                            let provider_hint = pid
352                                .strip_prefix("us.")
353                                .unwrap_or(pid)
354                                .split('.')
355                                .next()
356                                .unwrap_or("");
357
358                            models.insert(
359                                pid.to_string(),
360                                ModelInfo {
361                                    id: pid.to_string(),
362                                    name: display_name,
363                                    provider: "bedrock".to_string(),
364                                    context_window: Self::estimate_context_window(
365                                        pid,
366                                        provider_hint,
367                                    ),
368                                    max_output_tokens: Some(Self::estimate_max_output(
369                                        pid,
370                                        provider_hint,
371                                    )),
372                                    supports_vision: vision,
373                                    supports_tools: true,
374                                    supports_streaming: true,
375                                    input_cost_per_million: None,
376                                    output_cost_per_million: None,
377                                },
378                            );
379                        }
380                    }
381                }
382            }
383        }
384
385        let mut result: Vec<ModelInfo> = models.into_values().collect();
386        result.sort_by(|a, b| a.id.cmp(&b.id));
387
388        tracing::info!(
389            provider = "bedrock",
390            model_count = result.len(),
391            "Discovered Bedrock models dynamically"
392        );
393
394        Ok(result)
395    }
396
397    /// Estimate context window size based on model family
398    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
399        let id = model_id.to_lowercase();
400        if id.contains("anthropic") || id.contains("claude") {
401            200_000
402        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
403        {
404            300_000
405        } else if id.contains("nova-micro") || id.contains("nova-2") {
406            128_000
407        } else if id.contains("deepseek") {
408            128_000
409        } else if id.contains("llama4") {
410            256_000
411        } else if id.contains("llama3") {
412            128_000
413        } else if id.contains("mistral-large-3") || id.contains("magistral") {
414            128_000
415        } else if id.contains("mistral") {
416            32_000
417        } else if id.contains("qwen") {
418            128_000
419        } else if id.contains("kimi") {
420            128_000
421        } else if id.contains("jamba") {
422            256_000
423        } else if id.contains("glm") {
424            128_000
425        } else if id.contains("minimax") {
426            128_000
427        } else if id.contains("gemma") {
428            128_000
429        } else if id.contains("cohere") || id.contains("command") {
430            128_000
431        } else if id.contains("nemotron") {
432            128_000
433        } else if provider.to_lowercase().contains("amazon") {
434            128_000
435        } else {
436            32_000
437        }
438    }
439
440    /// Estimate max output tokens based on model family
441    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
442        let id = model_id.to_lowercase();
443        if id.contains("claude-sonnet-4") || id.contains("claude-3-7") {
444            64_000
445        } else if id.contains("claude-opus-4") {
446            32_000
447        } else if id.contains("claude") {
448            8_192
449        } else if id.contains("nova") {
450            5_000
451        } else if id.contains("deepseek") {
452            16_384
453        } else if id.contains("llama4") {
454            16_384
455        } else if id.contains("llama") {
456            4_096
457        } else if id.contains("mistral-large-3") {
458            16_384
459        } else if id.contains("mistral") || id.contains("mixtral") {
460            8_192
461        } else if id.contains("qwen") {
462            8_192
463        } else if id.contains("kimi") {
464            8_192
465        } else if id.contains("jamba") {
466            4_096
467        } else {
468            4_096
469        }
470    }
471
472    /// Convert our generic messages to Bedrock Converse API format.
473    ///
474    /// Bedrock Converse uses:
475    /// - system prompt as a top-level "system" array
476    /// - messages with "role" and "content" array
477    /// - tool_use blocks in assistant content
478    /// - toolResult blocks in user content
479    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
480        let mut system_parts: Vec<Value> = Vec::new();
481        let mut api_messages: Vec<Value> = Vec::new();
482
483        for msg in messages {
484            match msg.role {
485                Role::System => {
486                    let text: String = msg
487                        .content
488                        .iter()
489                        .filter_map(|p| match p {
490                            ContentPart::Text { text } => Some(text.clone()),
491                            _ => None,
492                        })
493                        .collect::<Vec<_>>()
494                        .join("\n");
495                    system_parts.push(json!({"text": text}));
496                }
497                Role::User => {
498                    let mut content_parts: Vec<Value> = Vec::new();
499                    for part in &msg.content {
500                        match part {
501                            ContentPart::Text { text } => {
502                                if !text.is_empty() {
503                                    content_parts.push(json!({"text": text}));
504                                }
505                            }
506                            _ => {}
507                        }
508                    }
509                    if !content_parts.is_empty() {
510                        api_messages.push(json!({
511                            "role": "user",
512                            "content": content_parts
513                        }));
514                    }
515                }
516                Role::Assistant => {
517                    let mut content_parts: Vec<Value> = Vec::new();
518                    for part in &msg.content {
519                        match part {
520                            ContentPart::Text { text } => {
521                                if !text.is_empty() {
522                                    content_parts.push(json!({"text": text}));
523                                }
524                            }
525                            ContentPart::ToolCall {
526                                id,
527                                name,
528                                arguments,
529                            } => {
530                                let input: Value = serde_json::from_str(arguments)
531                                    .unwrap_or_else(|_| json!({"raw": arguments}));
532                                content_parts.push(json!({
533                                    "toolUse": {
534                                        "toolUseId": id,
535                                        "name": name,
536                                        "input": input
537                                    }
538                                }));
539                            }
540                            _ => {}
541                        }
542                    }
543                    if content_parts.is_empty() {
544                        content_parts.push(json!({"text": ""}));
545                    }
546                    api_messages.push(json!({
547                        "role": "assistant",
548                        "content": content_parts
549                    }));
550                }
551                Role::Tool => {
552                    // Tool results go into a user message with toolResult blocks
553                    let mut content_parts: Vec<Value> = Vec::new();
554                    for part in &msg.content {
555                        if let ContentPart::ToolResult {
556                            tool_call_id,
557                            content,
558                        } = part
559                        {
560                            content_parts.push(json!({
561                                "toolResult": {
562                                    "toolUseId": tool_call_id,
563                                    "content": [{"text": content}],
564                                    "status": "success"
565                                }
566                            }));
567                        }
568                    }
569                    if !content_parts.is_empty() {
570                        api_messages.push(json!({
571                            "role": "user",
572                            "content": content_parts
573                        }));
574                    }
575                }
576            }
577        }
578
579        (system_parts, api_messages)
580    }
581
582    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
583        tools
584            .iter()
585            .map(|t| {
586                json!({
587                    "toolSpec": {
588                        "name": t.name,
589                        "description": t.description,
590                        "inputSchema": {
591                            "json": t.parameters
592                        }
593                    }
594                })
595            })
596            .collect()
597    }
598}
599
600/// Bedrock Converse API response types
601
602#[derive(Debug, Deserialize)]
603#[serde(rename_all = "camelCase")]
604struct ConverseResponse {
605    output: ConverseOutput,
606    #[serde(default)]
607    stop_reason: Option<String>,
608    #[serde(default)]
609    usage: Option<ConverseUsage>,
610}
611
612#[derive(Debug, Deserialize)]
613struct ConverseOutput {
614    message: ConverseMessage,
615}
616
617#[derive(Debug, Deserialize)]
618struct ConverseMessage {
619    #[allow(dead_code)]
620    role: String,
621    content: Vec<ConverseContent>,
622}
623
624#[derive(Debug, Deserialize)]
625#[serde(untagged)]
626enum ConverseContent {
627    Text {
628        text: String,
629    },
630    ToolUse {
631        #[serde(rename = "toolUse")]
632        tool_use: ConverseToolUse,
633    },
634}
635
636#[derive(Debug, Deserialize)]
637#[serde(rename_all = "camelCase")]
638struct ConverseToolUse {
639    tool_use_id: String,
640    name: String,
641    input: Value,
642}
643
644#[derive(Debug, Deserialize)]
645#[serde(rename_all = "camelCase")]
646struct ConverseUsage {
647    #[serde(default)]
648    input_tokens: usize,
649    #[serde(default)]
650    output_tokens: usize,
651    #[serde(default)]
652    total_tokens: usize,
653}
654
655#[derive(Debug, Deserialize)]
656struct BedrockError {
657    message: String,
658}
659
660#[async_trait]
661impl Provider for BedrockProvider {
662    fn name(&self) -> &str {
663        "bedrock"
664    }
665
666    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
667        self.validate_api_key()?;
668        self.discover_models().await
669    }
670
671    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
672        let model_id = Self::resolve_model_id(&request.model);
673
674        tracing::debug!(
675            provider = "bedrock",
676            model = %model_id,
677            original_model = %request.model,
678            message_count = request.messages.len(),
679            tool_count = request.tools.len(),
680            "Starting Bedrock Converse request"
681        );
682
683        self.validate_api_key()?;
684
685        let (system_parts, messages) = Self::convert_messages(&request.messages);
686        let tools = Self::convert_tools(&request.tools);
687
688        let mut body = json!({
689            "messages": messages,
690        });
691
692        if !system_parts.is_empty() {
693            body["system"] = json!(system_parts);
694        }
695
696        // inferenceConfig
697        let mut inference_config = json!({});
698        if let Some(max_tokens) = request.max_tokens {
699            inference_config["maxTokens"] = json!(max_tokens);
700        } else {
701            inference_config["maxTokens"] = json!(8192);
702        }
703        if let Some(temp) = request.temperature {
704            inference_config["temperature"] = json!(temp);
705        }
706        if let Some(top_p) = request.top_p {
707            inference_config["topP"] = json!(top_p);
708        }
709        body["inferenceConfig"] = inference_config;
710
711        if !tools.is_empty() {
712            body["toolConfig"] = json!({"tools": tools});
713        }
714
715        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
716        let encoded_model_id = model_id.replace(':', "%3A");
717        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
718        tracing::debug!("Bedrock request URL: {}", url);
719
720        let response = self
721            .client
722            .post(&url)
723            .bearer_auth(&self.api_key)
724            .header("content-type", "application/json")
725            .header("accept", "application/json")
726            .json(&body)
727            .send()
728            .await
729            .context("Failed to send request to Bedrock")?;
730
731        let status = response.status();
732        let text = response
733            .text()
734            .await
735            .context("Failed to read Bedrock response")?;
736
737        if !status.is_success() {
738            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
739                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
740            }
741            anyhow::bail!(
742                "Bedrock API error: {} {}",
743                status,
744                &text[..text.len().min(500)]
745            );
746        }
747
748        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
749            "Failed to parse Bedrock response: {}",
750            &text[..text.len().min(300)]
751        ))?;
752
753        tracing::debug!(
754            stop_reason = ?response.stop_reason,
755            "Received Bedrock response"
756        );
757
758        let mut content = Vec::new();
759        let mut has_tool_calls = false;
760
761        for part in &response.output.message.content {
762            match part {
763                ConverseContent::Text { text } => {
764                    if !text.is_empty() {
765                        content.push(ContentPart::Text { text: text.clone() });
766                    }
767                }
768                ConverseContent::ToolUse { tool_use } => {
769                    has_tool_calls = true;
770                    content.push(ContentPart::ToolCall {
771                        id: tool_use.tool_use_id.clone(),
772                        name: tool_use.name.clone(),
773                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
774                    });
775                }
776            }
777        }
778
779        let finish_reason = if has_tool_calls {
780            FinishReason::ToolCalls
781        } else {
782            match response.stop_reason.as_deref() {
783                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
784                Some("max_tokens") => FinishReason::Length,
785                Some("tool_use") => FinishReason::ToolCalls,
786                Some("content_filtered") => FinishReason::ContentFilter,
787                _ => FinishReason::Stop,
788            }
789        };
790
791        let usage = response.usage.as_ref();
792
793        Ok(CompletionResponse {
794            message: Message {
795                role: Role::Assistant,
796                content,
797            },
798            usage: Usage {
799                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
800                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
801                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
802                cache_read_tokens: None,
803                cache_write_tokens: None,
804            },
805            finish_reason,
806        })
807    }
808
809    async fn complete_stream(
810        &self,
811        request: CompletionRequest,
812    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
813        // Fall back to non-streaming for now
814        let response = self.complete(request).await?;
815        let text = response
816            .message
817            .content
818            .iter()
819            .filter_map(|p| match p {
820                ContentPart::Text { text } => Some(text.clone()),
821                _ => None,
822            })
823            .collect::<Vec<_>>()
824            .join("");
825
826        Ok(Box::pin(futures::stream::once(async move {
827            StreamChunk::Text(text)
828        })))
829    }
830}