Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via either:
4//! - AWS SigV4 signing (standard AWS credentials from env/file/profile)
5//! - Bearer token auth (API Gateway / Vault-managed keys)
6//!
7//! Uses the native Bedrock Converse API format.
8//! Dynamically discovers available models via the Bedrock ListFoundationModels
9//! and ListInferenceProfiles APIs.
10//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
11
12use super::{
13    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
14    Role, StreamChunk, ToolDefinition, Usage,
15};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use hmac::{Hmac, Mac};
19use reqwest::Client;
20use serde::Deserialize;
21use serde_json::{Value, json};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24
25pub const DEFAULT_REGION: &str = "us-east-1";
26
27/// AWS credentials for SigV4 signing
28#[derive(Debug, Clone)]
29pub struct AwsCredentials {
30    pub access_key_id: String,
31    pub secret_access_key: String,
32    pub session_token: Option<String>,
33}
34
35impl AwsCredentials {
36    /// Load credentials from environment variables, then fall back to
37    /// ~/.aws/credentials file (default or named profile).
38    pub fn from_environment() -> Option<Self> {
39        // 1) Try env vars first
40        if let (Ok(key_id), Ok(secret)) = (
41            std::env::var("AWS_ACCESS_KEY_ID"),
42            std::env::var("AWS_SECRET_ACCESS_KEY"),
43        ) {
44            if !key_id.is_empty() && !secret.is_empty() {
45                return Some(Self {
46                    access_key_id: key_id,
47                    secret_access_key: secret,
48                    session_token: std::env::var("AWS_SESSION_TOKEN")
49                        .ok()
50                        .filter(|s| !s.is_empty()),
51                });
52            }
53        }
54
55        // 2) Fall back to ~/.aws/credentials file
56        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
57        Self::from_credentials_file(&profile)
58    }
59
60    /// Parse ~/.aws/credentials INI file for the given profile.
61    fn from_credentials_file(profile: &str) -> Option<Self> {
62        let home = std::env::var("HOME")
63            .or_else(|_| std::env::var("USERPROFILE"))
64            .ok()?;
65        let path = std::path::Path::new(&home).join(".aws").join("credentials");
66        let content = std::fs::read_to_string(&path).ok()?;
67
68        let section_header = format!("[{}]", profile);
69        let mut in_section = false;
70        let mut key_id = None;
71        let mut secret = None;
72        let mut token = None;
73
74        for line in content.lines() {
75            let trimmed = line.trim();
76            if trimmed.starts_with('[') {
77                in_section = trimmed == section_header;
78                continue;
79            }
80            if !in_section {
81                continue;
82            }
83            if let Some((k, v)) = trimmed.split_once('=') {
84                let k = k.trim();
85                let v = v.trim();
86                match k {
87                    "aws_access_key_id" => key_id = Some(v.to_string()),
88                    "aws_secret_access_key" => secret = Some(v.to_string()),
89                    "aws_session_token" => token = Some(v.to_string()),
90                    _ => {}
91                }
92            }
93        }
94
95        Some(Self {
96            access_key_id: key_id?,
97            secret_access_key: secret?,
98            session_token: token,
99        })
100    }
101
102    /// Detect region from AWS_REGION / AWS_DEFAULT_REGION env vars,
103    /// then from ~/.aws/config.
104    pub fn detect_region() -> Option<String> {
105        if let Ok(r) = std::env::var("AWS_REGION") {
106            if !r.is_empty() {
107                return Some(r);
108            }
109        }
110        if let Ok(r) = std::env::var("AWS_DEFAULT_REGION") {
111            if !r.is_empty() {
112                return Some(r);
113            }
114        }
115        // Try ~/.aws/config
116        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
117        let home = std::env::var("HOME")
118            .or_else(|_| std::env::var("USERPROFILE"))
119            .ok()?;
120        let path = std::path::Path::new(&home).join(".aws").join("config");
121        let content = std::fs::read_to_string(&path).ok()?;
122
123        // In ~/.aws/config, default profile is [default], others are [profile foo]
124        let section_header = if profile == "default" {
125            "[default]".to_string()
126        } else {
127            format!("[profile {}]", profile)
128        };
129        let mut in_section = false;
130        for line in content.lines() {
131            let trimmed = line.trim();
132            if trimmed.starts_with('[') {
133                in_section = trimmed == section_header;
134                continue;
135            }
136            if !in_section {
137                continue;
138            }
139            if let Some((k, v)) = trimmed.split_once('=') {
140                if k.trim() == "region" {
141                    let v = v.trim();
142                    if !v.is_empty() {
143                        return Some(v.to_string());
144                    }
145                }
146            }
147        }
148        None
149    }
150}
151
152/// Authentication mode for the Bedrock provider.
153#[derive(Debug, Clone)]
154pub enum BedrockAuth {
155    /// Standard AWS SigV4 signing with IAM credentials
156    SigV4(AwsCredentials),
157    /// Bearer token (API Gateway or custom auth layer)
158    BearerToken(String),
159}
160
161pub struct BedrockProvider {
162    client: Client,
163    auth: BedrockAuth,
164    region: String,
165}
166
167impl std::fmt::Debug for BedrockProvider {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("BedrockProvider")
170            .field(
171                "auth",
172                &match &self.auth {
173                    BedrockAuth::SigV4(_) => "SigV4",
174                    BedrockAuth::BearerToken(_) => "BearerToken",
175                },
176            )
177            .field("region", &self.region)
178            .finish()
179    }
180}
181
182impl BedrockProvider {
183    /// Create from a bearer token (API Gateway / Vault key).
184    pub fn new(api_key: String) -> Result<Self> {
185        Self::with_region(api_key, DEFAULT_REGION.to_string())
186    }
187
188    /// Create from a bearer token with a specific region.
189    pub fn with_region(api_key: String, region: String) -> Result<Self> {
190        tracing::debug!(
191            provider = "bedrock",
192            region = %region,
193            auth = "bearer_token",
194            "Creating Bedrock provider"
195        );
196        Ok(Self {
197            client: Client::new(),
198            auth: BedrockAuth::BearerToken(api_key),
199            region,
200        })
201    }
202
203    /// Create from AWS IAM credentials (SigV4 signing).
204    pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
205        tracing::debug!(
206            provider = "bedrock",
207            region = %region,
208            auth = "sigv4",
209            "Creating Bedrock provider with AWS credentials"
210        );
211        Ok(Self {
212            client: Client::new(),
213            auth: BedrockAuth::SigV4(credentials),
214            region,
215        })
216    }
217
218    /// Public wrapper for sending a signed Converse API request.
219    /// Used by the thinker backend.
220    pub async fn send_converse_request(&self, url: &str, body: &[u8]) -> Result<reqwest::Response> {
221        self.send_request("POST", url, Some(body), "bedrock-runtime")
222            .await
223    }
224
225    fn validate_auth(&self) -> Result<()> {
226        match &self.auth {
227            BedrockAuth::BearerToken(key) => {
228                if key.is_empty() {
229                    anyhow::bail!("Bedrock API key is empty");
230                }
231            }
232            BedrockAuth::SigV4(creds) => {
233                if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
234                    anyhow::bail!("AWS credentials are incomplete");
235                }
236            }
237        }
238        Ok(())
239    }
240
241    fn base_url(&self) -> String {
242        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
243    }
244
245    /// Management API URL (for listing models, not inference)
246    fn management_url(&self) -> String {
247        format!("https://bedrock.{}.amazonaws.com", self.region)
248    }
249
250    // ── AWS SigV4 signing helpers ──────────────────────────────────────
251
252    fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
253        let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
254        mac.update(data);
255        mac.finalize().into_bytes().to_vec()
256    }
257
258    fn sha256_hex(data: &[u8]) -> String {
259        let mut hasher = Sha256::new();
260        hasher.update(data);
261        hex::encode(hasher.finalize())
262    }
263
264    /// Build a SigV4-signed request and send it.
265    async fn send_signed_request(
266        &self,
267        method: &str,
268        url: &str,
269        body: &[u8],
270        service: &str,
271    ) -> Result<reqwest::Response> {
272        let creds = match &self.auth {
273            BedrockAuth::SigV4(c) => c,
274            BedrockAuth::BearerToken(_) => {
275                anyhow::bail!("send_signed_request called with bearer token auth");
276            }
277        };
278
279        let now = chrono::Utc::now();
280        let datestamp = now.format("%Y%m%d").to_string();
281        let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
282
283        // Parse URL components
284        let host_start = url.find("://").map(|i| i + 3).unwrap_or(0);
285        let after_host = url[host_start..]
286            .find('/')
287            .map(|i| host_start + i)
288            .unwrap_or(url.len());
289        let host = url[host_start..after_host].to_string();
290        let path_and_query = &url[after_host..];
291        let (canonical_uri, canonical_querystring) = match path_and_query.split_once('?') {
292            Some((p, q)) => (p.to_string(), q.to_string()),
293            None => (path_and_query.to_string(), String::new()),
294        };
295
296        let payload_hash = Self::sha256_hex(body);
297
298        // Build canonical headers (must be sorted)
299        let mut headers_map: Vec<(&str, String)> = vec![
300            ("content-type", "application/json".to_string()),
301            ("host", host.clone()),
302            ("x-amz-date", amz_date.clone()),
303        ];
304        if let Some(token) = &creds.session_token {
305            headers_map.push(("x-amz-security-token", token.clone()));
306        }
307        headers_map.sort_by_key(|(k, _)| *k);
308
309        let canonical_headers: String = headers_map
310            .iter()
311            .map(|(k, v)| format!("{}:{}", k, v))
312            .collect::<Vec<_>>()
313            .join("\n")
314            + "\n";
315
316        let signed_headers: String = headers_map
317            .iter()
318            .map(|(k, _)| *k)
319            .collect::<Vec<_>>()
320            .join(";");
321
322        let canonical_request = format!(
323            "{}\n{}\n{}\n{}\n{}\n{}",
324            method,
325            canonical_uri,
326            canonical_querystring,
327            canonical_headers,
328            signed_headers,
329            payload_hash
330        );
331
332        let credential_scope = format!("{}/{}/{}/aws4_request", datestamp, self.region, service);
333
334        let string_to_sign = format!(
335            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
336            amz_date,
337            credential_scope,
338            Self::sha256_hex(canonical_request.as_bytes())
339        );
340
341        // Derive signing key
342        let k_date = Self::hmac_sha256(
343            format!("AWS4{}", creds.secret_access_key).as_bytes(),
344            datestamp.as_bytes(),
345        );
346        let k_region = Self::hmac_sha256(&k_date, self.region.as_bytes());
347        let k_service = Self::hmac_sha256(&k_region, service.as_bytes());
348        let k_signing = Self::hmac_sha256(&k_service, b"aws4_request");
349
350        let signature = hex::encode(Self::hmac_sha256(&k_signing, string_to_sign.as_bytes()));
351
352        let authorization = format!(
353            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
354            creds.access_key_id, credential_scope, signed_headers, signature
355        );
356
357        let mut req = self
358            .client
359            .request(method.parse().unwrap_or(reqwest::Method::POST), url)
360            .header("content-type", "application/json")
361            .header("host", &host)
362            .header("x-amz-date", &amz_date)
363            .header("x-amz-content-sha256", &payload_hash)
364            .header("authorization", &authorization);
365
366        if let Some(token) = &creds.session_token {
367            req = req.header("x-amz-security-token", token);
368        }
369
370        if method == "POST" || method == "PUT" {
371            req = req.body(body.to_vec());
372        }
373
374        req.send()
375            .await
376            .context("Failed to send signed request to Bedrock")
377    }
378
379    /// Send an HTTP request using whichever auth mode is configured.
380    async fn send_request(
381        &self,
382        method: &str,
383        url: &str,
384        body: Option<&[u8]>,
385        service: &str,
386    ) -> Result<reqwest::Response> {
387        match &self.auth {
388            BedrockAuth::SigV4(_) => {
389                self.send_signed_request(method, url, body.unwrap_or(b""), service)
390                    .await
391            }
392            BedrockAuth::BearerToken(token) => {
393                let mut req = self
394                    .client
395                    .request(method.parse().unwrap_or(reqwest::Method::GET), url)
396                    .bearer_auth(token)
397                    .header("content-type", "application/json")
398                    .header("accept", "application/json");
399
400                if let Some(b) = body {
401                    req = req.body(b.to_vec());
402                }
403
404                req.send()
405                    .await
406                    .context("Failed to send request to Bedrock")
407            }
408        }
409    }
410
411    /// Resolve a short model alias to the full Bedrock model ID.
412    /// Allows users to specify e.g. "claude-sonnet-4" instead of
413    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
414    fn resolve_model_id(model: &str) -> &str {
415        match model {
416            // --- Anthropic Claude (verified via AWS CLI) ---
417            "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
418            "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
419            "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
420            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
421            "claude-sonnet-4.6" | "claude-4.6-sonnet" | "claude-sonnet-4-6" => {
422                "us.anthropic.claude-sonnet-4-6-v1:0"
423            }
424            "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
425                "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
426            }
427            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
428            "claude-haiku-4.5" | "claude-4.5-haiku" => {
429                "us.anthropic.claude-haiku-4-5-20251001-v1:0"
430            }
431            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
432                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
433            }
434            "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
435                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
436            }
437            "claude-3.5-haiku" | "claude-haiku-3.5" => {
438                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
439            }
440            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
441                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
442            }
443            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
444            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
445            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
446
447            // --- Amazon Nova ---
448            "nova-pro" => "amazon.nova-pro-v1:0",
449            "nova-lite" => "amazon.nova-lite-v1:0",
450            "nova-micro" => "amazon.nova-micro-v1:0",
451            "nova-premier" => "us.amazon.nova-premier-v1:0",
452
453            // --- Meta Llama ---
454            "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
455            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
456            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
457            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
458            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
459            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
460            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
461            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
462            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
463            "llama-3-70b" | "llama3-70b" => "meta.llama3-70b-instruct-v1:0",
464            "llama-3-8b" | "llama3-8b" => "meta.llama3-8b-instruct-v1:0",
465
466            // --- Mistral (mix of ON_DEMAND and INFERENCE_PROFILE) ---
467            "mistral-large-3" | "mistral-large" => "mistral.mistral-large-3-675b-instruct",
468            "mistral-large-2402" => "mistral.mistral-large-2402-v1:0",
469            "mistral-small" => "mistral.mistral-small-2402-v1:0",
470            "mixtral-8x7b" => "mistral.mixtral-8x7b-instruct-v0:1",
471            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
472            "magistral-small" => "mistral.magistral-small-2509",
473
474            // --- DeepSeek ---
475            "deepseek-r1" => "us.deepseek.r1-v1:0",
476            "deepseek-v3" | "deepseek-v3.2" => "deepseek.v3.2",
477
478            // --- Cohere (ON_DEMAND only, no us. prefix) ---
479            "command-r" => "cohere.command-r-v1:0",
480            "command-r-plus" => "cohere.command-r-plus-v1:0",
481
482            // --- Qwen (ON_DEMAND only, no us. prefix) ---
483            "qwen3-32b" => "qwen.qwen3-32b-v1:0",
484            "qwen3-coder" | "qwen3-coder-next" => "qwen.qwen3-coder-next",
485            "qwen3-coder-30b" => "qwen.qwen3-coder-30b-a3b-v1:0",
486
487            // --- Google Gemma (ON_DEMAND only, no us. prefix) ---
488            "gemma-3-27b" => "google.gemma-3-27b-it",
489            "gemma-3-12b" => "google.gemma-3-12b-it",
490            "gemma-3-4b" => "google.gemma-3-4b-it",
491
492            // --- Moonshot / Kimi (ON_DEMAND only, no us. prefix) ---
493            "kimi-k2" | "kimi-k2-thinking" => "moonshot.kimi-k2-thinking",
494            "kimi-k2.5" => "moonshotai.kimi-k2.5",
495
496            // --- AI21 Jamba (ON_DEMAND only, no us. prefix) ---
497            "jamba-1.5-large" => "ai21.jamba-1-5-large-v1:0",
498            "jamba-1.5-mini" => "ai21.jamba-1-5-mini-v1:0",
499
500            // --- MiniMax (ON_DEMAND only, no us. prefix) ---
501            "minimax-m2" => "minimax.minimax-m2",
502            "minimax-m2.1" => "minimax.minimax-m2.1",
503
504            // --- NVIDIA (ON_DEMAND only, no us. prefix) ---
505            "nemotron-nano-30b" => "nvidia.nemotron-nano-3-30b",
506            "nemotron-nano-12b" => "nvidia.nemotron-nano-12b-v2",
507            "nemotron-nano-9b" => "nvidia.nemotron-nano-9b-v2",
508
509            // --- Z.AI / GLM (ON_DEMAND only, no us. prefix) ---
510            "glm-5" => "zai.glm-5",
511            "glm-4.7" => "zai.glm-4.7",
512            "glm-4.7-flash" => "zai.glm-4.7-flash",
513
514            // Pass through full model IDs unchanged
515            other => other,
516        }
517    }
518
519    /// Dynamically discover available models from the Bedrock API.
520    /// Merges foundation models with cross-region inference profiles.
521    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
522        let mut models: HashMap<String, ModelInfo> = HashMap::new();
523
524        // 1) Fetch foundation models
525        let fm_url = format!("{}/foundation-models", self.management_url());
526        let fm_resp = self.send_request("GET", &fm_url, None, "bedrock").await;
527
528        if let Ok(resp) = fm_resp {
529            if resp.status().is_success() {
530                if let Ok(data) = resp.json::<Value>().await {
531                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
532                        for m in summaries {
533                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
534                            let model_name =
535                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
536                            let provider_name =
537                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
538
539                            let output_modalities: Vec<&str> = m
540                                .get("outputModalities")
541                                .and_then(|v| v.as_array())
542                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
543                                .unwrap_or_default();
544
545                            let input_modalities: Vec<&str> = m
546                                .get("inputModalities")
547                                .and_then(|v| v.as_array())
548                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
549                                .unwrap_or_default();
550
551                            let inference_types: Vec<&str> = m
552                                .get("inferenceTypesSupported")
553                                .and_then(|v| v.as_array())
554                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
555                                .unwrap_or_default();
556
557                            // Only include TEXT output models with ON_DEMAND or INFERENCE_PROFILE inference
558                            if !output_modalities.contains(&"TEXT")
559                                || (!inference_types.contains(&"ON_DEMAND")
560                                    && !inference_types.contains(&"INFERENCE_PROFILE"))
561                            {
562                                continue;
563                            }
564
565                            // Skip non-chat models
566                            let name_lower = model_name.to_lowercase();
567                            if name_lower.contains("rerank")
568                                || name_lower.contains("embed")
569                                || name_lower.contains("safeguard")
570                                || name_lower.contains("sonic")
571                                || name_lower.contains("pegasus")
572                            {
573                                continue;
574                            }
575
576                            let streaming = m
577                                .get("responseStreamingSupported")
578                                .and_then(|v| v.as_bool())
579                                .unwrap_or(false);
580                            let vision = input_modalities.contains(&"IMAGE");
581
582                            // Models with INFERENCE_PROFILE support use cross-region
583                            // us. prefix; ON_DEMAND-only models use bare model IDs.
584                            // Amazon models never get the prefix.
585                            let actual_id = if model_id.starts_with("amazon.") {
586                                model_id.to_string()
587                            } else if inference_types.contains(&"INFERENCE_PROFILE") {
588                                format!("us.{}", model_id)
589                            } else {
590                                model_id.to_string()
591                            };
592
593                            let display_name = format!("{} (Bedrock)", model_name);
594
595                            models.insert(
596                                actual_id.clone(),
597                                ModelInfo {
598                                    id: actual_id,
599                                    name: display_name,
600                                    provider: "bedrock".to_string(),
601                                    context_window: Self::estimate_context_window(
602                                        model_id,
603                                        provider_name,
604                                    ),
605                                    max_output_tokens: Some(Self::estimate_max_output(
606                                        model_id,
607                                        provider_name,
608                                    )),
609                                    supports_vision: vision,
610                                    supports_tools: true,
611                                    supports_streaming: streaming,
612                                    input_cost_per_million: None,
613                                    output_cost_per_million: None,
614                                },
615                            );
616                        }
617                    }
618                }
619            }
620        }
621
622        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
623        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
624        let ip_url = format!(
625            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
626            self.management_url()
627        );
628        let ip_resp = self.send_request("GET", &ip_url, None, "bedrock").await;
629
630        if let Ok(resp) = ip_resp {
631            if resp.status().is_success() {
632                if let Ok(data) = resp.json::<Value>().await {
633                    if let Some(profiles) = data
634                        .get("inferenceProfileSummaries")
635                        .and_then(|v| v.as_array())
636                    {
637                        for p in profiles {
638                            let pid = p
639                                .get("inferenceProfileId")
640                                .and_then(|v| v.as_str())
641                                .unwrap_or("");
642                            let pname = p
643                                .get("inferenceProfileName")
644                                .and_then(|v| v.as_str())
645                                .unwrap_or("");
646
647                            // Only US cross-region profiles
648                            if !pid.starts_with("us.") {
649                                continue;
650                            }
651
652                            // Skip already-discovered models
653                            if models.contains_key(pid) {
654                                continue;
655                            }
656
657                            // Skip non-text models
658                            let name_lower = pname.to_lowercase();
659                            if name_lower.contains("image")
660                                || name_lower.contains("stable ")
661                                || name_lower.contains("upscale")
662                                || name_lower.contains("embed")
663                                || name_lower.contains("marengo")
664                                || name_lower.contains("outpaint")
665                                || name_lower.contains("inpaint")
666                                || name_lower.contains("erase")
667                                || name_lower.contains("recolor")
668                                || name_lower.contains("replace")
669                                || name_lower.contains("style ")
670                                || name_lower.contains("background")
671                                || name_lower.contains("sketch")
672                                || name_lower.contains("control")
673                                || name_lower.contains("transfer")
674                                || name_lower.contains("sonic")
675                                || name_lower.contains("pegasus")
676                                || name_lower.contains("rerank")
677                            {
678                                continue;
679                            }
680
681                            // Guess vision from known model families
682                            let vision = pid.contains("llama3-2-11b")
683                                || pid.contains("llama3-2-90b")
684                                || pid.contains("pixtral")
685                                || pid.contains("claude-3")
686                                || pid.contains("claude-sonnet-4")
687                                || pid.contains("claude-opus-4")
688                                || pid.contains("claude-haiku-4");
689
690                            let display_name = pname.replace("US ", "");
691                            let display_name = format!("{} (Bedrock)", display_name.trim());
692
693                            // Extract provider hint from model ID
694                            let provider_hint = pid
695                                .strip_prefix("us.")
696                                .unwrap_or(pid)
697                                .split('.')
698                                .next()
699                                .unwrap_or("");
700
701                            models.insert(
702                                pid.to_string(),
703                                ModelInfo {
704                                    id: pid.to_string(),
705                                    name: display_name,
706                                    provider: "bedrock".to_string(),
707                                    context_window: Self::estimate_context_window(
708                                        pid,
709                                        provider_hint,
710                                    ),
711                                    max_output_tokens: Some(Self::estimate_max_output(
712                                        pid,
713                                        provider_hint,
714                                    )),
715                                    supports_vision: vision,
716                                    supports_tools: true,
717                                    supports_streaming: true,
718                                    input_cost_per_million: None,
719                                    output_cost_per_million: None,
720                                },
721                            );
722                        }
723                    }
724                }
725            }
726        }
727
728        let mut result: Vec<ModelInfo> = models.into_values().collect();
729        result.sort_by(|a, b| a.id.cmp(&b.id));
730
731        tracing::info!(
732            provider = "bedrock",
733            model_count = result.len(),
734            "Discovered Bedrock models dynamically"
735        );
736
737        Ok(result)
738    }
739
740    /// Estimate context window size based on model family
741    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
742        let id = model_id.to_lowercase();
743        if id.contains("anthropic") || id.contains("claude") {
744            200_000
745        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
746        {
747            300_000
748        } else if id.contains("nova-micro") || id.contains("nova-2") {
749            128_000
750        } else if id.contains("deepseek") {
751            128_000
752        } else if id.contains("llama4") {
753            256_000
754        } else if id.contains("llama3") {
755            128_000
756        } else if id.contains("mistral-large-3") || id.contains("magistral") {
757            128_000
758        } else if id.contains("mistral") {
759            32_000
760        } else if id.contains("qwen") {
761            128_000
762        } else if id.contains("kimi") {
763            128_000
764        } else if id.contains("jamba") {
765            256_000
766        } else if id.contains("glm") {
767            128_000
768        } else if id.contains("minimax") {
769            128_000
770        } else if id.contains("gemma") {
771            128_000
772        } else if id.contains("cohere") || id.contains("command") {
773            128_000
774        } else if id.contains("nemotron") {
775            128_000
776        } else if provider.to_lowercase().contains("amazon") {
777            128_000
778        } else {
779            32_000
780        }
781    }
782
783    /// Estimate max output tokens based on model family
784    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
785        let id = model_id.to_lowercase();
786        if id.contains("claude-opus-4-6") {
787            32_000
788        } else if id.contains("claude-opus-4-5") {
789            32_000
790        } else if id.contains("claude-opus-4-1") {
791            32_000
792        } else if id.contains("claude-sonnet-4-6") {
793            128_000
794        } else if id.contains("claude-sonnet-4-5")
795            || id.contains("claude-sonnet-4")
796            || id.contains("claude-3-7")
797        {
798            64_000
799        } else if id.contains("claude-haiku-4-5") {
800            16_384
801        } else if id.contains("claude-opus-4") {
802            32_000
803        } else if id.contains("claude") {
804            8_192
805        } else if id.contains("nova") {
806            5_000
807        } else if id.contains("deepseek") {
808            16_384
809        } else if id.contains("llama4") {
810            16_384
811        } else if id.contains("llama") {
812            4_096
813        } else if id.contains("mistral-large-3") {
814            16_384
815        } else if id.contains("mistral") || id.contains("mixtral") {
816            8_192
817        } else if id.contains("qwen") {
818            8_192
819        } else if id.contains("kimi") {
820            8_192
821        } else if id.contains("jamba") {
822            4_096
823        } else {
824            4_096
825        }
826    }
827
828    /// Convert our generic messages to Bedrock Converse API format.
829    ///
830    /// Bedrock Converse uses:
831    /// - system prompt as a top-level "system" array
832    /// - messages with "role" and "content" array
833    /// - tool_use blocks in assistant content
834    /// - toolResult blocks in user content
835    ///
836    /// IMPORTANT: Bedrock requires strict role alternation (user/assistant).
837    /// Consecutive Role::Tool messages must be merged into a single "user"
838    /// message so all toolResult blocks for a given assistant turn appear
839    /// together. Consecutive same-role messages are also merged to prevent
840    /// validation errors.
841    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
842        let mut system_parts: Vec<Value> = Vec::new();
843        let mut api_messages: Vec<Value> = Vec::new();
844
845        for msg in messages {
846            match msg.role {
847                Role::System => {
848                    let text: String = msg
849                        .content
850                        .iter()
851                        .filter_map(|p| match p {
852                            ContentPart::Text { text } => Some(text.clone()),
853                            _ => None,
854                        })
855                        .collect::<Vec<_>>()
856                        .join("\n");
857                    if !text.trim().is_empty() {
858                        system_parts.push(json!({"text": text}));
859                    }
860                }
861                Role::User => {
862                    let mut content_parts: Vec<Value> = Vec::new();
863                    for part in &msg.content {
864                        match part {
865                            ContentPart::Text { text } => {
866                                if !text.trim().is_empty() {
867                                    content_parts.push(json!({"text": text}));
868                                }
869                            }
870                            _ => {}
871                        }
872                    }
873                    if !content_parts.is_empty() {
874                        // Merge into previous user message if the last message is also "user"
875                        if let Some(last) = api_messages.last_mut() {
876                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
877                                if let Some(arr) =
878                                    last.get_mut("content").and_then(|c| c.as_array_mut())
879                                {
880                                    arr.extend(content_parts);
881                                    continue;
882                                }
883                            }
884                        }
885                        api_messages.push(json!({
886                            "role": "user",
887                            "content": content_parts
888                        }));
889                    }
890                }
891                Role::Assistant => {
892                    let mut content_parts: Vec<Value> = Vec::new();
893                    for part in &msg.content {
894                        match part {
895                            ContentPart::Text { text } => {
896                                if !text.trim().is_empty() {
897                                    content_parts.push(json!({"text": text}));
898                                }
899                            }
900                            ContentPart::ToolCall {
901                                id,
902                                name,
903                                arguments,
904                                ..
905                            } => {
906                                let input: Value = serde_json::from_str(arguments)
907                                    .unwrap_or_else(|_| json!({"raw": arguments}));
908                                content_parts.push(json!({
909                                    "toolUse": {
910                                        "toolUseId": id,
911                                        "name": name,
912                                        "input": input
913                                    }
914                                }));
915                            }
916                            _ => {}
917                        }
918                    }
919                    // Bedrock rejects whitespace-only text blocks; if the assistant message has
920                    // no usable content (e.g. thinking-only), drop it from the request.
921                    if content_parts.is_empty() {
922                        continue;
923                    }
924                    // Merge into previous assistant message if consecutive
925                    if let Some(last) = api_messages.last_mut() {
926                        if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
927                            if let Some(arr) =
928                                last.get_mut("content").and_then(|c| c.as_array_mut())
929                            {
930                                arr.extend(content_parts);
931                                continue;
932                            }
933                        }
934                    }
935                    api_messages.push(json!({
936                        "role": "assistant",
937                        "content": content_parts
938                    }));
939                }
940                Role::Tool => {
941                    // Tool results must be in a "user" message with toolResult blocks.
942                    // Merge into the previous user message if one exists (handles
943                    // consecutive Tool messages being collapsed into one user turn).
944                    let mut content_parts: Vec<Value> = Vec::new();
945                    for part in &msg.content {
946                        if let ContentPart::ToolResult {
947                            tool_call_id,
948                            content,
949                        } = part
950                        {
951                            let content = if content.trim().is_empty() {
952                                "(empty tool result)".to_string()
953                            } else {
954                                content.clone()
955                            };
956                            content_parts.push(json!({
957                                "toolResult": {
958                                    "toolUseId": tool_call_id,
959                                    "content": [{"text": content}],
960                                    "status": "success"
961                                }
962                            }));
963                        }
964                    }
965                    if !content_parts.is_empty() {
966                        // Merge into previous user message (from earlier Tool messages)
967                        if let Some(last) = api_messages.last_mut() {
968                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
969                                if let Some(arr) =
970                                    last.get_mut("content").and_then(|c| c.as_array_mut())
971                                {
972                                    arr.extend(content_parts);
973                                    continue;
974                                }
975                            }
976                        }
977                        api_messages.push(json!({
978                            "role": "user",
979                            "content": content_parts
980                        }));
981                    }
982                }
983            }
984        }
985
986        (system_parts, api_messages)
987    }
988
989    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
990        tools
991            .iter()
992            .map(|t| {
993                json!({
994                    "toolSpec": {
995                        "name": t.name,
996                        "description": t.description,
997                        "inputSchema": {
998                            "json": t.parameters
999                        }
1000                    }
1001                })
1002            })
1003            .collect()
1004    }
1005}
1006
1007/// Bedrock Converse API response types
1008
1009#[derive(Debug, Deserialize)]
1010#[serde(rename_all = "camelCase")]
1011struct ConverseResponse {
1012    output: ConverseOutput,
1013    #[serde(default)]
1014    stop_reason: Option<String>,
1015    #[serde(default)]
1016    usage: Option<ConverseUsage>,
1017}
1018
1019#[derive(Debug, Deserialize)]
1020struct ConverseOutput {
1021    message: ConverseMessage,
1022}
1023
1024#[derive(Debug, Deserialize)]
1025struct ConverseMessage {
1026    #[allow(dead_code)]
1027    role: String,
1028    content: Vec<ConverseContent>,
1029}
1030
1031#[derive(Debug, Deserialize)]
1032#[serde(untagged)]
1033enum ConverseContent {
1034    ReasoningContent {
1035        #[serde(rename = "reasoningContent")]
1036        reasoning_content: ReasoningContentBlock,
1037    },
1038    Text {
1039        text: String,
1040    },
1041    ToolUse {
1042        #[serde(rename = "toolUse")]
1043        tool_use: ConverseToolUse,
1044    },
1045}
1046
1047#[derive(Debug, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049struct ReasoningContentBlock {
1050    reasoning_text: ReasoningText,
1051}
1052
1053#[derive(Debug, Deserialize)]
1054struct ReasoningText {
1055    text: String,
1056}
1057
1058#[derive(Debug, Deserialize)]
1059#[serde(rename_all = "camelCase")]
1060struct ConverseToolUse {
1061    tool_use_id: String,
1062    name: String,
1063    input: Value,
1064}
1065
1066#[derive(Debug, Deserialize)]
1067#[serde(rename_all = "camelCase")]
1068struct ConverseUsage {
1069    #[serde(default)]
1070    input_tokens: usize,
1071    #[serde(default)]
1072    output_tokens: usize,
1073    #[serde(default)]
1074    total_tokens: usize,
1075}
1076
1077#[derive(Debug, Deserialize)]
1078struct BedrockError {
1079    message: String,
1080}
1081
1082#[async_trait]
1083impl Provider for BedrockProvider {
1084    fn name(&self) -> &str {
1085        "bedrock"
1086    }
1087
1088    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
1089        self.validate_auth()?;
1090        self.discover_models().await
1091    }
1092
1093    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
1094        let model_id = Self::resolve_model_id(&request.model);
1095
1096        tracing::debug!(
1097            provider = "bedrock",
1098            model = %model_id,
1099            original_model = %request.model,
1100            message_count = request.messages.len(),
1101            tool_count = request.tools.len(),
1102            "Starting Bedrock Converse request"
1103        );
1104
1105        self.validate_auth()?;
1106
1107        let (system_parts, messages) = Self::convert_messages(&request.messages);
1108        let tools = Self::convert_tools(&request.tools);
1109
1110        let mut body = json!({
1111            "messages": messages,
1112        });
1113
1114        if !system_parts.is_empty() {
1115            body["system"] = json!(system_parts);
1116        }
1117
1118        // inferenceConfig
1119        let mut inference_config = json!({});
1120        if let Some(max_tokens) = request.max_tokens {
1121            inference_config["maxTokens"] = json!(max_tokens);
1122        } else {
1123            inference_config["maxTokens"] = json!(8192);
1124        }
1125        if let Some(temp) = request.temperature {
1126            inference_config["temperature"] = json!(temp);
1127        }
1128        if let Some(top_p) = request.top_p {
1129            inference_config["topP"] = json!(top_p);
1130        }
1131        body["inferenceConfig"] = inference_config;
1132
1133        if !tools.is_empty() {
1134            body["toolConfig"] = json!({"tools": tools});
1135        }
1136
1137        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
1138        let encoded_model_id = model_id.replace(':', "%3A");
1139        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
1140        tracing::debug!("Bedrock request URL: {}", url);
1141
1142        let body_bytes = serde_json::to_vec(&body)?;
1143        let response = self
1144            .send_request("POST", &url, Some(&body_bytes), "bedrock-runtime")
1145            .await?;
1146
1147        let status = response.status();
1148        let text = response
1149            .text()
1150            .await
1151            .context("Failed to read Bedrock response")?;
1152
1153        if !status.is_success() {
1154            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
1155                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
1156            }
1157            anyhow::bail!(
1158                "Bedrock API error: {} {}",
1159                status,
1160                &text[..text.len().min(500)]
1161            );
1162        }
1163
1164        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
1165            "Failed to parse Bedrock response: {}",
1166            &text[..text.len().min(300)]
1167        ))?;
1168
1169        tracing::debug!(
1170            stop_reason = ?response.stop_reason,
1171            "Received Bedrock response"
1172        );
1173
1174        let mut content = Vec::new();
1175        let mut has_tool_calls = false;
1176
1177        for part in &response.output.message.content {
1178            match part {
1179                ConverseContent::ReasoningContent { reasoning_content } => {
1180                    if !reasoning_content.reasoning_text.text.is_empty() {
1181                        content.push(ContentPart::Thinking {
1182                            text: reasoning_content.reasoning_text.text.clone(),
1183                        });
1184                    }
1185                }
1186                ConverseContent::Text { text } => {
1187                    if !text.is_empty() {
1188                        content.push(ContentPart::Text { text: text.clone() });
1189                    }
1190                }
1191                ConverseContent::ToolUse { tool_use } => {
1192                    has_tool_calls = true;
1193                    content.push(ContentPart::ToolCall {
1194                        id: tool_use.tool_use_id.clone(),
1195                        name: tool_use.name.clone(),
1196                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1197                        thought_signature: None,
1198                    });
1199                }
1200            }
1201        }
1202
1203        let finish_reason = if has_tool_calls {
1204            FinishReason::ToolCalls
1205        } else {
1206            match response.stop_reason.as_deref() {
1207                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
1208                Some("max_tokens") => FinishReason::Length,
1209                Some("tool_use") => FinishReason::ToolCalls,
1210                Some("content_filtered") => FinishReason::ContentFilter,
1211                _ => FinishReason::Stop,
1212            }
1213        };
1214
1215        let usage = response.usage.as_ref();
1216
1217        Ok(CompletionResponse {
1218            message: Message {
1219                role: Role::Assistant,
1220                content,
1221            },
1222            usage: Usage {
1223                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
1224                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
1225                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
1226                cache_read_tokens: None,
1227                cache_write_tokens: None,
1228            },
1229            finish_reason,
1230        })
1231    }
1232
1233    async fn complete_stream(
1234        &self,
1235        request: CompletionRequest,
1236    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
1237        // Fall back to non-streaming for now
1238        let response = self.complete(request).await?;
1239        let text = response
1240            .message
1241            .content
1242            .iter()
1243            .filter_map(|p| match p {
1244                ContentPart::Text { text } => Some(text.clone()),
1245                _ => None,
1246            })
1247            .collect::<Vec<_>>()
1248            .join("");
1249
1250        Ok(Box::pin(futures::stream::once(async move {
1251            StreamChunk::Text(text)
1252        })))
1253    }
1254}