Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via either:
4//! - AWS SigV4 signing (standard AWS credentials from env/file/profile)
5//! - Bearer token auth (API Gateway / Vault-managed keys)
6//!
7//! Uses the native Bedrock Converse API format.
8//! Dynamically discovers available models via the Bedrock ListFoundationModels
9//! and ListInferenceProfiles APIs.
10//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
11
12use super::{
13    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
14    Role, StreamChunk, ToolDefinition, Usage,
15};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use hmac::{Hmac, Mac};
19use reqwest::Client;
20use serde::Deserialize;
21use serde_json::{Value, json};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24
25pub const DEFAULT_REGION: &str = "us-east-1";
26
27/// AWS credentials for SigV4 signing
28#[derive(Debug, Clone)]
29pub struct AwsCredentials {
30    pub access_key_id: String,
31    pub secret_access_key: String,
32    pub session_token: Option<String>,
33}
34
35impl AwsCredentials {
36    /// Load credentials from environment variables, then fall back to
37    /// ~/.aws/credentials file (default or named profile).
38    pub fn from_environment() -> Option<Self> {
39        // 1) Try env vars first
40        if let (Ok(key_id), Ok(secret)) = (
41            std::env::var("AWS_ACCESS_KEY_ID"),
42            std::env::var("AWS_SECRET_ACCESS_KEY"),
43        ) {
44            if !key_id.is_empty() && !secret.is_empty() {
45                return Some(Self {
46                    access_key_id: key_id,
47                    secret_access_key: secret,
48                    session_token: std::env::var("AWS_SESSION_TOKEN")
49                        .ok()
50                        .filter(|s| !s.is_empty()),
51                });
52            }
53        }
54
55        // 2) Fall back to ~/.aws/credentials file
56        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
57        Self::from_credentials_file(&profile)
58    }
59
60    /// Parse ~/.aws/credentials INI file for the given profile.
61    fn from_credentials_file(profile: &str) -> Option<Self> {
62        let home = std::env::var("HOME")
63            .or_else(|_| std::env::var("USERPROFILE"))
64            .ok()?;
65        let path = std::path::Path::new(&home).join(".aws").join("credentials");
66        let content = std::fs::read_to_string(&path).ok()?;
67
68        let section_header = format!("[{}]", profile);
69        let mut in_section = false;
70        let mut key_id = None;
71        let mut secret = None;
72        let mut token = None;
73
74        for line in content.lines() {
75            let trimmed = line.trim();
76            if trimmed.starts_with('[') {
77                in_section = trimmed == section_header;
78                continue;
79            }
80            if !in_section {
81                continue;
82            }
83            if let Some((k, v)) = trimmed.split_once('=') {
84                let k = k.trim();
85                let v = v.trim();
86                match k {
87                    "aws_access_key_id" => key_id = Some(v.to_string()),
88                    "aws_secret_access_key" => secret = Some(v.to_string()),
89                    "aws_session_token" => token = Some(v.to_string()),
90                    _ => {}
91                }
92            }
93        }
94
95        Some(Self {
96            access_key_id: key_id?,
97            secret_access_key: secret?,
98            session_token: token,
99        })
100    }
101
102    /// Detect region from AWS_REGION / AWS_DEFAULT_REGION env vars,
103    /// then from ~/.aws/config.
104    pub fn detect_region() -> Option<String> {
105        if let Ok(r) = std::env::var("AWS_REGION") {
106            if !r.is_empty() {
107                return Some(r);
108            }
109        }
110        if let Ok(r) = std::env::var("AWS_DEFAULT_REGION") {
111            if !r.is_empty() {
112                return Some(r);
113            }
114        }
115        // Try ~/.aws/config
116        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
117        let home = std::env::var("HOME")
118            .or_else(|_| std::env::var("USERPROFILE"))
119            .ok()?;
120        let path = std::path::Path::new(&home).join(".aws").join("config");
121        let content = std::fs::read_to_string(&path).ok()?;
122
123        // In ~/.aws/config, default profile is [default], others are [profile foo]
124        let section_header = if profile == "default" {
125            "[default]".to_string()
126        } else {
127            format!("[profile {}]", profile)
128        };
129        let mut in_section = false;
130        for line in content.lines() {
131            let trimmed = line.trim();
132            if trimmed.starts_with('[') {
133                in_section = trimmed == section_header;
134                continue;
135            }
136            if !in_section {
137                continue;
138            }
139            if let Some((k, v)) = trimmed.split_once('=') {
140                if k.trim() == "region" {
141                    let v = v.trim();
142                    if !v.is_empty() {
143                        return Some(v.to_string());
144                    }
145                }
146            }
147        }
148        None
149    }
150}
151
152/// Authentication mode for the Bedrock provider.
153#[derive(Debug, Clone)]
154pub enum BedrockAuth {
155    /// Standard AWS SigV4 signing with IAM credentials
156    SigV4(AwsCredentials),
157    /// Bearer token (API Gateway or custom auth layer)
158    BearerToken(String),
159}
160
161pub struct BedrockProvider {
162    client: Client,
163    auth: BedrockAuth,
164    region: String,
165}
166
167impl std::fmt::Debug for BedrockProvider {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("BedrockProvider")
170            .field(
171                "auth",
172                &match &self.auth {
173                    BedrockAuth::SigV4(_) => "SigV4",
174                    BedrockAuth::BearerToken(_) => "BearerToken",
175                },
176            )
177            .field("region", &self.region)
178            .finish()
179    }
180}
181
182impl BedrockProvider {
183    /// Create from a bearer token (API Gateway / Vault key).
184    pub fn new(api_key: String) -> Result<Self> {
185        Self::with_region(api_key, DEFAULT_REGION.to_string())
186    }
187
188    /// Create from a bearer token with a specific region.
189    pub fn with_region(api_key: String, region: String) -> Result<Self> {
190        tracing::debug!(
191            provider = "bedrock",
192            region = %region,
193            auth = "bearer_token",
194            "Creating Bedrock provider"
195        );
196        Ok(Self {
197            client: Client::new(),
198            auth: BedrockAuth::BearerToken(api_key),
199            region,
200        })
201    }
202
203    /// Create from AWS IAM credentials (SigV4 signing).
204    pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
205        tracing::debug!(
206            provider = "bedrock",
207            region = %region,
208            auth = "sigv4",
209            "Creating Bedrock provider with AWS credentials"
210        );
211        Ok(Self {
212            client: Client::new(),
213            auth: BedrockAuth::SigV4(credentials),
214            region,
215        })
216    }
217
218    /// Public wrapper for sending a signed Converse API request.
219    /// Used by the thinker backend.
220    pub async fn send_converse_request(&self, url: &str, body: &[u8]) -> Result<reqwest::Response> {
221        self.send_request("POST", url, Some(body), "bedrock-runtime")
222            .await
223    }
224
225    fn validate_auth(&self) -> Result<()> {
226        match &self.auth {
227            BedrockAuth::BearerToken(key) => {
228                if key.is_empty() {
229                    anyhow::bail!("Bedrock API key is empty");
230                }
231            }
232            BedrockAuth::SigV4(creds) => {
233                if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
234                    anyhow::bail!("AWS credentials are incomplete");
235                }
236            }
237        }
238        Ok(())
239    }
240
241    fn base_url(&self) -> String {
242        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
243    }
244
245    /// Management API URL (for listing models, not inference)
246    fn management_url(&self) -> String {
247        format!("https://bedrock.{}.amazonaws.com", self.region)
248    }
249
250    // ── AWS SigV4 signing helpers ──────────────────────────────────────
251
252    fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
253        let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
254        mac.update(data);
255        mac.finalize().into_bytes().to_vec()
256    }
257
258    fn sha256_hex(data: &[u8]) -> String {
259        let mut hasher = Sha256::new();
260        hasher.update(data);
261        hex::encode(hasher.finalize())
262    }
263
264    /// Build a SigV4-signed request and send it.
265    async fn send_signed_request(
266        &self,
267        method: &str,
268        url: &str,
269        body: &[u8],
270        service: &str,
271    ) -> Result<reqwest::Response> {
272        let creds = match &self.auth {
273            BedrockAuth::SigV4(c) => c,
274            BedrockAuth::BearerToken(_) => {
275                anyhow::bail!("send_signed_request called with bearer token auth");
276            }
277        };
278
279        let now = chrono::Utc::now();
280        let datestamp = now.format("%Y%m%d").to_string();
281        let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
282
283        // Parse URL components
284        let host_start = url.find("://").map(|i| i + 3).unwrap_or(0);
285        let after_host = url[host_start..]
286            .find('/')
287            .map(|i| host_start + i)
288            .unwrap_or(url.len());
289        let host = url[host_start..after_host].to_string();
290        let path_and_query = &url[after_host..];
291        let (canonical_uri, canonical_querystring) = match path_and_query.split_once('?') {
292            Some((p, q)) => (p.to_string(), q.to_string()),
293            None => (path_and_query.to_string(), String::new()),
294        };
295
296        let payload_hash = Self::sha256_hex(body);
297
298        // Build canonical headers (must be sorted)
299        let mut headers_map: Vec<(&str, String)> = vec![
300            ("content-type", "application/json".to_string()),
301            ("host", host.clone()),
302            ("x-amz-date", amz_date.clone()),
303        ];
304        if let Some(token) = &creds.session_token {
305            headers_map.push(("x-amz-security-token", token.clone()));
306        }
307        headers_map.sort_by_key(|(k, _)| *k);
308
309        let canonical_headers: String = headers_map
310            .iter()
311            .map(|(k, v)| format!("{}:{}", k, v))
312            .collect::<Vec<_>>()
313            .join("\n")
314            + "\n";
315
316        let signed_headers: String = headers_map
317            .iter()
318            .map(|(k, _)| *k)
319            .collect::<Vec<_>>()
320            .join(";");
321
322        let canonical_request = format!(
323            "{}\n{}\n{}\n{}\n{}\n{}",
324            method,
325            canonical_uri,
326            canonical_querystring,
327            canonical_headers,
328            signed_headers,
329            payload_hash
330        );
331
332        let credential_scope = format!("{}/{}/{}/aws4_request", datestamp, self.region, service);
333
334        let string_to_sign = format!(
335            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
336            amz_date,
337            credential_scope,
338            Self::sha256_hex(canonical_request.as_bytes())
339        );
340
341        // Derive signing key
342        let k_date = Self::hmac_sha256(
343            format!("AWS4{}", creds.secret_access_key).as_bytes(),
344            datestamp.as_bytes(),
345        );
346        let k_region = Self::hmac_sha256(&k_date, self.region.as_bytes());
347        let k_service = Self::hmac_sha256(&k_region, service.as_bytes());
348        let k_signing = Self::hmac_sha256(&k_service, b"aws4_request");
349
350        let signature = hex::encode(Self::hmac_sha256(&k_signing, string_to_sign.as_bytes()));
351
352        let authorization = format!(
353            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
354            creds.access_key_id, credential_scope, signed_headers, signature
355        );
356
357        let mut req = self
358            .client
359            .request(method.parse().unwrap_or(reqwest::Method::POST), url)
360            .header("content-type", "application/json")
361            .header("host", &host)
362            .header("x-amz-date", &amz_date)
363            .header("x-amz-content-sha256", &payload_hash)
364            .header("authorization", &authorization);
365
366        if let Some(token) = &creds.session_token {
367            req = req.header("x-amz-security-token", token);
368        }
369
370        if method == "POST" || method == "PUT" {
371            req = req.body(body.to_vec());
372        }
373
374        req.send()
375            .await
376            .context("Failed to send signed request to Bedrock")
377    }
378
379    /// Send an HTTP request using whichever auth mode is configured.
380    async fn send_request(
381        &self,
382        method: &str,
383        url: &str,
384        body: Option<&[u8]>,
385        service: &str,
386    ) -> Result<reqwest::Response> {
387        match &self.auth {
388            BedrockAuth::SigV4(_) => {
389                self.send_signed_request(method, url, body.unwrap_or(b""), service)
390                    .await
391            }
392            BedrockAuth::BearerToken(token) => {
393                let mut req = self
394                    .client
395                    .request(method.parse().unwrap_or(reqwest::Method::GET), url)
396                    .bearer_auth(token)
397                    .header("content-type", "application/json")
398                    .header("accept", "application/json");
399
400                if let Some(b) = body {
401                    req = req.body(b.to_vec());
402                }
403
404                req.send()
405                    .await
406                    .context("Failed to send request to Bedrock")
407            }
408        }
409    }
410
411    /// Resolve a short model alias to the full Bedrock model ID.
412    /// Allows users to specify e.g. "claude-sonnet-4" instead of
413    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
414    fn resolve_model_id(model: &str) -> &str {
415        match model {
416            // --- Anthropic Claude (verified via AWS CLI) ---
417            "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
418            "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
419            "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
420            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
421            "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
422                "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
423            }
424            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
425            "claude-haiku-4.5" | "claude-4.5-haiku" => {
426                "us.anthropic.claude-haiku-4-5-20251001-v1:0"
427            }
428            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
429                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
430            }
431            "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
432                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
433            }
434            "claude-3.5-haiku" | "claude-haiku-3.5" => {
435                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
436            }
437            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
438                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
439            }
440            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
441            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
442            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
443
444            // --- Amazon Nova ---
445            "nova-pro" => "amazon.nova-pro-v1:0",
446            "nova-lite" => "amazon.nova-lite-v1:0",
447            "nova-micro" => "amazon.nova-micro-v1:0",
448            "nova-premier" => "us.amazon.nova-premier-v1:0",
449
450            // --- Meta Llama ---
451            "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
452            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
453            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
454            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
455            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
456            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
457            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
458            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
459            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
460            "llama-3-70b" | "llama3-70b" => "meta.llama3-70b-instruct-v1:0",
461            "llama-3-8b" | "llama3-8b" => "meta.llama3-8b-instruct-v1:0",
462
463            // --- Mistral (mix of ON_DEMAND and INFERENCE_PROFILE) ---
464            "mistral-large-3" | "mistral-large" => "mistral.mistral-large-3-675b-instruct",
465            "mistral-large-2402" => "mistral.mistral-large-2402-v1:0",
466            "mistral-small" => "mistral.mistral-small-2402-v1:0",
467            "mixtral-8x7b" => "mistral.mixtral-8x7b-instruct-v0:1",
468            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
469            "magistral-small" => "mistral.magistral-small-2509",
470
471            // --- DeepSeek ---
472            "deepseek-r1" => "us.deepseek.r1-v1:0",
473            "deepseek-v3" | "deepseek-v3.2" => "deepseek.v3.2",
474
475            // --- Cohere (ON_DEMAND only, no us. prefix) ---
476            "command-r" => "cohere.command-r-v1:0",
477            "command-r-plus" => "cohere.command-r-plus-v1:0",
478
479            // --- Qwen (ON_DEMAND only, no us. prefix) ---
480            "qwen3-32b" => "qwen.qwen3-32b-v1:0",
481            "qwen3-coder" | "qwen3-coder-next" => "qwen.qwen3-coder-next",
482            "qwen3-coder-30b" => "qwen.qwen3-coder-30b-a3b-v1:0",
483
484            // --- Google Gemma (ON_DEMAND only, no us. prefix) ---
485            "gemma-3-27b" => "google.gemma-3-27b-it",
486            "gemma-3-12b" => "google.gemma-3-12b-it",
487            "gemma-3-4b" => "google.gemma-3-4b-it",
488
489            // --- Moonshot / Kimi (ON_DEMAND only, no us. prefix) ---
490            "kimi-k2" | "kimi-k2-thinking" => "moonshot.kimi-k2-thinking",
491            "kimi-k2.5" => "moonshotai.kimi-k2.5",
492
493            // --- AI21 Jamba (ON_DEMAND only, no us. prefix) ---
494            "jamba-1.5-large" => "ai21.jamba-1-5-large-v1:0",
495            "jamba-1.5-mini" => "ai21.jamba-1-5-mini-v1:0",
496
497            // --- MiniMax (ON_DEMAND only, no us. prefix) ---
498            "minimax-m2" => "minimax.minimax-m2",
499            "minimax-m2.1" => "minimax.minimax-m2.1",
500
501            // --- NVIDIA (ON_DEMAND only, no us. prefix) ---
502            "nemotron-nano-30b" => "nvidia.nemotron-nano-3-30b",
503            "nemotron-nano-12b" => "nvidia.nemotron-nano-12b-v2",
504            "nemotron-nano-9b" => "nvidia.nemotron-nano-9b-v2",
505
506            // --- Z.AI / GLM (ON_DEMAND only, no us. prefix) ---
507            "glm-5" => "zai.glm-5",
508            "glm-4.7" => "zai.glm-4.7",
509            "glm-4.7-flash" => "zai.glm-4.7-flash",
510
511            // Pass through full model IDs unchanged
512            other => other,
513        }
514    }
515
516    /// Dynamically discover available models from the Bedrock API.
517    /// Merges foundation models with cross-region inference profiles.
518    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
519        let mut models: HashMap<String, ModelInfo> = HashMap::new();
520
521        // 1) Fetch foundation models
522        let fm_url = format!("{}/foundation-models", self.management_url());
523        let fm_resp = self.send_request("GET", &fm_url, None, "bedrock").await;
524
525        if let Ok(resp) = fm_resp {
526            if resp.status().is_success() {
527                if let Ok(data) = resp.json::<Value>().await {
528                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
529                        for m in summaries {
530                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
531                            let model_name =
532                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
533                            let provider_name =
534                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
535
536                            let output_modalities: Vec<&str> = m
537                                .get("outputModalities")
538                                .and_then(|v| v.as_array())
539                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
540                                .unwrap_or_default();
541
542                            let input_modalities: Vec<&str> = m
543                                .get("inputModalities")
544                                .and_then(|v| v.as_array())
545                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
546                                .unwrap_or_default();
547
548                            let inference_types: Vec<&str> = m
549                                .get("inferenceTypesSupported")
550                                .and_then(|v| v.as_array())
551                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
552                                .unwrap_or_default();
553
554                            // Only include TEXT output models with ON_DEMAND or INFERENCE_PROFILE inference
555                            if !output_modalities.contains(&"TEXT")
556                                || (!inference_types.contains(&"ON_DEMAND")
557                                    && !inference_types.contains(&"INFERENCE_PROFILE"))
558                            {
559                                continue;
560                            }
561
562                            // Skip non-chat models
563                            let name_lower = model_name.to_lowercase();
564                            if name_lower.contains("rerank")
565                                || name_lower.contains("embed")
566                                || name_lower.contains("safeguard")
567                                || name_lower.contains("sonic")
568                                || name_lower.contains("pegasus")
569                            {
570                                continue;
571                            }
572
573                            let streaming = m
574                                .get("responseStreamingSupported")
575                                .and_then(|v| v.as_bool())
576                                .unwrap_or(false);
577                            let vision = input_modalities.contains(&"IMAGE");
578
579                            // Models with INFERENCE_PROFILE support use cross-region
580                            // us. prefix; ON_DEMAND-only models use bare model IDs.
581                            // Amazon models never get the prefix.
582                            let actual_id = if model_id.starts_with("amazon.") {
583                                model_id.to_string()
584                            } else if inference_types.contains(&"INFERENCE_PROFILE") {
585                                format!("us.{}", model_id)
586                            } else {
587                                model_id.to_string()
588                            };
589
590                            let display_name = format!("{} (Bedrock)", model_name);
591
592                            models.insert(
593                                actual_id.clone(),
594                                ModelInfo {
595                                    id: actual_id,
596                                    name: display_name,
597                                    provider: "bedrock".to_string(),
598                                    context_window: Self::estimate_context_window(
599                                        model_id,
600                                        provider_name,
601                                    ),
602                                    max_output_tokens: Some(Self::estimate_max_output(
603                                        model_id,
604                                        provider_name,
605                                    )),
606                                    supports_vision: vision,
607                                    supports_tools: true,
608                                    supports_streaming: streaming,
609                                    input_cost_per_million: None,
610                                    output_cost_per_million: None,
611                                },
612                            );
613                        }
614                    }
615                }
616            }
617        }
618
619        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
620        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
621        let ip_url = format!(
622            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
623            self.management_url()
624        );
625        let ip_resp = self.send_request("GET", &ip_url, None, "bedrock").await;
626
627        if let Ok(resp) = ip_resp {
628            if resp.status().is_success() {
629                if let Ok(data) = resp.json::<Value>().await {
630                    if let Some(profiles) = data
631                        .get("inferenceProfileSummaries")
632                        .and_then(|v| v.as_array())
633                    {
634                        for p in profiles {
635                            let pid = p
636                                .get("inferenceProfileId")
637                                .and_then(|v| v.as_str())
638                                .unwrap_or("");
639                            let pname = p
640                                .get("inferenceProfileName")
641                                .and_then(|v| v.as_str())
642                                .unwrap_or("");
643
644                            // Only US cross-region profiles
645                            if !pid.starts_with("us.") {
646                                continue;
647                            }
648
649                            // Skip already-discovered models
650                            if models.contains_key(pid) {
651                                continue;
652                            }
653
654                            // Skip non-text models
655                            let name_lower = pname.to_lowercase();
656                            if name_lower.contains("image")
657                                || name_lower.contains("stable ")
658                                || name_lower.contains("upscale")
659                                || name_lower.contains("embed")
660                                || name_lower.contains("marengo")
661                                || name_lower.contains("outpaint")
662                                || name_lower.contains("inpaint")
663                                || name_lower.contains("erase")
664                                || name_lower.contains("recolor")
665                                || name_lower.contains("replace")
666                                || name_lower.contains("style ")
667                                || name_lower.contains("background")
668                                || name_lower.contains("sketch")
669                                || name_lower.contains("control")
670                                || name_lower.contains("transfer")
671                                || name_lower.contains("sonic")
672                                || name_lower.contains("pegasus")
673                                || name_lower.contains("rerank")
674                            {
675                                continue;
676                            }
677
678                            // Guess vision from known model families
679                            let vision = pid.contains("llama3-2-11b")
680                                || pid.contains("llama3-2-90b")
681                                || pid.contains("pixtral")
682                                || pid.contains("claude-3")
683                                || pid.contains("claude-sonnet-4")
684                                || pid.contains("claude-opus-4")
685                                || pid.contains("claude-haiku-4");
686
687                            let display_name = pname.replace("US ", "");
688                            let display_name = format!("{} (Bedrock)", display_name.trim());
689
690                            // Extract provider hint from model ID
691                            let provider_hint = pid
692                                .strip_prefix("us.")
693                                .unwrap_or(pid)
694                                .split('.')
695                                .next()
696                                .unwrap_or("");
697
698                            models.insert(
699                                pid.to_string(),
700                                ModelInfo {
701                                    id: pid.to_string(),
702                                    name: display_name,
703                                    provider: "bedrock".to_string(),
704                                    context_window: Self::estimate_context_window(
705                                        pid,
706                                        provider_hint,
707                                    ),
708                                    max_output_tokens: Some(Self::estimate_max_output(
709                                        pid,
710                                        provider_hint,
711                                    )),
712                                    supports_vision: vision,
713                                    supports_tools: true,
714                                    supports_streaming: true,
715                                    input_cost_per_million: None,
716                                    output_cost_per_million: None,
717                                },
718                            );
719                        }
720                    }
721                }
722            }
723        }
724
725        let mut result: Vec<ModelInfo> = models.into_values().collect();
726        result.sort_by(|a, b| a.id.cmp(&b.id));
727
728        tracing::info!(
729            provider = "bedrock",
730            model_count = result.len(),
731            "Discovered Bedrock models dynamically"
732        );
733
734        Ok(result)
735    }
736
737    /// Estimate context window size based on model family
738    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
739        let id = model_id.to_lowercase();
740        if id.contains("anthropic") || id.contains("claude") {
741            200_000
742        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
743        {
744            300_000
745        } else if id.contains("nova-micro") || id.contains("nova-2") {
746            128_000
747        } else if id.contains("deepseek") {
748            128_000
749        } else if id.contains("llama4") {
750            256_000
751        } else if id.contains("llama3") {
752            128_000
753        } else if id.contains("mistral-large-3") || id.contains("magistral") {
754            128_000
755        } else if id.contains("mistral") {
756            32_000
757        } else if id.contains("qwen") {
758            128_000
759        } else if id.contains("kimi") {
760            128_000
761        } else if id.contains("jamba") {
762            256_000
763        } else if id.contains("glm") {
764            128_000
765        } else if id.contains("minimax") {
766            128_000
767        } else if id.contains("gemma") {
768            128_000
769        } else if id.contains("cohere") || id.contains("command") {
770            128_000
771        } else if id.contains("nemotron") {
772            128_000
773        } else if provider.to_lowercase().contains("amazon") {
774            128_000
775        } else {
776            32_000
777        }
778    }
779
780    /// Estimate max output tokens based on model family
781    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
782        let id = model_id.to_lowercase();
783        if id.contains("claude-opus-4-6") {
784            32_000
785        } else if id.contains("claude-opus-4-5") {
786            32_000
787        } else if id.contains("claude-opus-4-1") {
788            32_000
789        } else if id.contains("claude-sonnet-4-5")
790            || id.contains("claude-sonnet-4")
791            || id.contains("claude-3-7")
792        {
793            64_000
794        } else if id.contains("claude-haiku-4-5") {
795            16_384
796        } else if id.contains("claude-opus-4") {
797            32_000
798        } else if id.contains("claude") {
799            8_192
800        } else if id.contains("nova") {
801            5_000
802        } else if id.contains("deepseek") {
803            16_384
804        } else if id.contains("llama4") {
805            16_384
806        } else if id.contains("llama") {
807            4_096
808        } else if id.contains("mistral-large-3") {
809            16_384
810        } else if id.contains("mistral") || id.contains("mixtral") {
811            8_192
812        } else if id.contains("qwen") {
813            8_192
814        } else if id.contains("kimi") {
815            8_192
816        } else if id.contains("jamba") {
817            4_096
818        } else {
819            4_096
820        }
821    }
822
823    /// Convert our generic messages to Bedrock Converse API format.
824    ///
825    /// Bedrock Converse uses:
826    /// - system prompt as a top-level "system" array
827    /// - messages with "role" and "content" array
828    /// - tool_use blocks in assistant content
829    /// - toolResult blocks in user content
830    ///
831    /// IMPORTANT: Bedrock requires strict role alternation (user/assistant).
832    /// Consecutive Role::Tool messages must be merged into a single "user"
833    /// message so all toolResult blocks for a given assistant turn appear
834    /// together. Consecutive same-role messages are also merged to prevent
835    /// validation errors.
836    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
837        let mut system_parts: Vec<Value> = Vec::new();
838        let mut api_messages: Vec<Value> = Vec::new();
839
840        for msg in messages {
841            match msg.role {
842                Role::System => {
843                    let text: String = msg
844                        .content
845                        .iter()
846                        .filter_map(|p| match p {
847                            ContentPart::Text { text } => Some(text.clone()),
848                            _ => None,
849                        })
850                        .collect::<Vec<_>>()
851                        .join("\n");
852                    system_parts.push(json!({"text": text}));
853                }
854                Role::User => {
855                    let mut content_parts: Vec<Value> = Vec::new();
856                    for part in &msg.content {
857                        match part {
858                            ContentPart::Text { text } => {
859                                if !text.is_empty() {
860                                    content_parts.push(json!({"text": text}));
861                                }
862                            }
863                            _ => {}
864                        }
865                    }
866                    if !content_parts.is_empty() {
867                        // Merge into previous user message if the last message is also "user"
868                        if let Some(last) = api_messages.last_mut() {
869                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
870                                if let Some(arr) =
871                                    last.get_mut("content").and_then(|c| c.as_array_mut())
872                                {
873                                    arr.extend(content_parts);
874                                    continue;
875                                }
876                            }
877                        }
878                        api_messages.push(json!({
879                            "role": "user",
880                            "content": content_parts
881                        }));
882                    }
883                }
884                Role::Assistant => {
885                    let mut content_parts: Vec<Value> = Vec::new();
886                    for part in &msg.content {
887                        match part {
888                            ContentPart::Text { text } => {
889                                if !text.is_empty() {
890                                    content_parts.push(json!({"text": text}));
891                                }
892                            }
893                            ContentPart::ToolCall {
894                                id,
895                                name,
896                                arguments,
897                            } => {
898                                let input: Value = serde_json::from_str(arguments)
899                                    .unwrap_or_else(|_| json!({"raw": arguments}));
900                                content_parts.push(json!({
901                                    "toolUse": {
902                                        "toolUseId": id,
903                                        "name": name,
904                                        "input": input
905                                    }
906                                }));
907                            }
908                            _ => {}
909                        }
910                    }
911                    if content_parts.is_empty() {
912                        content_parts.push(json!({"text": " "}));
913                    }
914                    // Merge into previous assistant message if consecutive
915                    if let Some(last) = api_messages.last_mut() {
916                        if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
917                            if let Some(arr) =
918                                last.get_mut("content").and_then(|c| c.as_array_mut())
919                            {
920                                arr.extend(content_parts);
921                                continue;
922                            }
923                        }
924                    }
925                    api_messages.push(json!({
926                        "role": "assistant",
927                        "content": content_parts
928                    }));
929                }
930                Role::Tool => {
931                    // Tool results must be in a "user" message with toolResult blocks.
932                    // Merge into the previous user message if one exists (handles
933                    // consecutive Tool messages being collapsed into one user turn).
934                    let mut content_parts: Vec<Value> = Vec::new();
935                    for part in &msg.content {
936                        if let ContentPart::ToolResult {
937                            tool_call_id,
938                            content,
939                        } = part
940                        {
941                            content_parts.push(json!({
942                                "toolResult": {
943                                    "toolUseId": tool_call_id,
944                                    "content": [{"text": content}],
945                                    "status": "success"
946                                }
947                            }));
948                        }
949                    }
950                    if !content_parts.is_empty() {
951                        // Merge into previous user message (from earlier Tool messages)
952                        if let Some(last) = api_messages.last_mut() {
953                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
954                                if let Some(arr) =
955                                    last.get_mut("content").and_then(|c| c.as_array_mut())
956                                {
957                                    arr.extend(content_parts);
958                                    continue;
959                                }
960                            }
961                        }
962                        api_messages.push(json!({
963                            "role": "user",
964                            "content": content_parts
965                        }));
966                    }
967                }
968            }
969        }
970
971        (system_parts, api_messages)
972    }
973
974    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
975        tools
976            .iter()
977            .map(|t| {
978                json!({
979                    "toolSpec": {
980                        "name": t.name,
981                        "description": t.description,
982                        "inputSchema": {
983                            "json": t.parameters
984                        }
985                    }
986                })
987            })
988            .collect()
989    }
990}
991
992/// Bedrock Converse API response types
993
994#[derive(Debug, Deserialize)]
995#[serde(rename_all = "camelCase")]
996struct ConverseResponse {
997    output: ConverseOutput,
998    #[serde(default)]
999    stop_reason: Option<String>,
1000    #[serde(default)]
1001    usage: Option<ConverseUsage>,
1002}
1003
1004#[derive(Debug, Deserialize)]
1005struct ConverseOutput {
1006    message: ConverseMessage,
1007}
1008
1009#[derive(Debug, Deserialize)]
1010struct ConverseMessage {
1011    #[allow(dead_code)]
1012    role: String,
1013    content: Vec<ConverseContent>,
1014}
1015
1016#[derive(Debug, Deserialize)]
1017#[serde(untagged)]
1018enum ConverseContent {
1019    ReasoningContent {
1020        #[serde(rename = "reasoningContent")]
1021        reasoning_content: ReasoningContentBlock,
1022    },
1023    Text {
1024        text: String,
1025    },
1026    ToolUse {
1027        #[serde(rename = "toolUse")]
1028        tool_use: ConverseToolUse,
1029    },
1030}
1031
1032#[derive(Debug, Deserialize)]
1033#[serde(rename_all = "camelCase")]
1034struct ReasoningContentBlock {
1035    reasoning_text: ReasoningText,
1036}
1037
1038#[derive(Debug, Deserialize)]
1039struct ReasoningText {
1040    text: String,
1041}
1042
1043#[derive(Debug, Deserialize)]
1044#[serde(rename_all = "camelCase")]
1045struct ConverseToolUse {
1046    tool_use_id: String,
1047    name: String,
1048    input: Value,
1049}
1050
1051#[derive(Debug, Deserialize)]
1052#[serde(rename_all = "camelCase")]
1053struct ConverseUsage {
1054    #[serde(default)]
1055    input_tokens: usize,
1056    #[serde(default)]
1057    output_tokens: usize,
1058    #[serde(default)]
1059    total_tokens: usize,
1060}
1061
1062#[derive(Debug, Deserialize)]
1063struct BedrockError {
1064    message: String,
1065}
1066
1067#[async_trait]
1068impl Provider for BedrockProvider {
1069    fn name(&self) -> &str {
1070        "bedrock"
1071    }
1072
1073    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
1074        self.validate_auth()?;
1075        self.discover_models().await
1076    }
1077
1078    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
1079        let model_id = Self::resolve_model_id(&request.model);
1080
1081        tracing::debug!(
1082            provider = "bedrock",
1083            model = %model_id,
1084            original_model = %request.model,
1085            message_count = request.messages.len(),
1086            tool_count = request.tools.len(),
1087            "Starting Bedrock Converse request"
1088        );
1089
1090        self.validate_auth()?;
1091
1092        let (system_parts, messages) = Self::convert_messages(&request.messages);
1093        let tools = Self::convert_tools(&request.tools);
1094
1095        let mut body = json!({
1096            "messages": messages,
1097        });
1098
1099        if !system_parts.is_empty() {
1100            body["system"] = json!(system_parts);
1101        }
1102
1103        // inferenceConfig
1104        let mut inference_config = json!({});
1105        if let Some(max_tokens) = request.max_tokens {
1106            inference_config["maxTokens"] = json!(max_tokens);
1107        } else {
1108            inference_config["maxTokens"] = json!(8192);
1109        }
1110        if let Some(temp) = request.temperature {
1111            inference_config["temperature"] = json!(temp);
1112        }
1113        if let Some(top_p) = request.top_p {
1114            inference_config["topP"] = json!(top_p);
1115        }
1116        body["inferenceConfig"] = inference_config;
1117
1118        if !tools.is_empty() {
1119            body["toolConfig"] = json!({"tools": tools});
1120        }
1121
1122        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
1123        let encoded_model_id = model_id.replace(':', "%3A");
1124        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
1125        tracing::debug!("Bedrock request URL: {}", url);
1126
1127        let body_bytes = serde_json::to_vec(&body)?;
1128        let response = self
1129            .send_request("POST", &url, Some(&body_bytes), "bedrock-runtime")
1130            .await?;
1131
1132        let status = response.status();
1133        let text = response
1134            .text()
1135            .await
1136            .context("Failed to read Bedrock response")?;
1137
1138        if !status.is_success() {
1139            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
1140                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
1141            }
1142            anyhow::bail!(
1143                "Bedrock API error: {} {}",
1144                status,
1145                &text[..text.len().min(500)]
1146            );
1147        }
1148
1149        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
1150            "Failed to parse Bedrock response: {}",
1151            &text[..text.len().min(300)]
1152        ))?;
1153
1154        tracing::debug!(
1155            stop_reason = ?response.stop_reason,
1156            "Received Bedrock response"
1157        );
1158
1159        let mut content = Vec::new();
1160        let mut has_tool_calls = false;
1161
1162        for part in &response.output.message.content {
1163            match part {
1164                ConverseContent::ReasoningContent { reasoning_content } => {
1165                    if !reasoning_content.reasoning_text.text.is_empty() {
1166                        content.push(ContentPart::Thinking {
1167                            text: reasoning_content.reasoning_text.text.clone(),
1168                        });
1169                    }
1170                }
1171                ConverseContent::Text { text } => {
1172                    if !text.is_empty() {
1173                        content.push(ContentPart::Text { text: text.clone() });
1174                    }
1175                }
1176                ConverseContent::ToolUse { tool_use } => {
1177                    has_tool_calls = true;
1178                    content.push(ContentPart::ToolCall {
1179                        id: tool_use.tool_use_id.clone(),
1180                        name: tool_use.name.clone(),
1181                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1182                    });
1183                }
1184            }
1185        }
1186
1187        let finish_reason = if has_tool_calls {
1188            FinishReason::ToolCalls
1189        } else {
1190            match response.stop_reason.as_deref() {
1191                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
1192                Some("max_tokens") => FinishReason::Length,
1193                Some("tool_use") => FinishReason::ToolCalls,
1194                Some("content_filtered") => FinishReason::ContentFilter,
1195                _ => FinishReason::Stop,
1196            }
1197        };
1198
1199        let usage = response.usage.as_ref();
1200
1201        Ok(CompletionResponse {
1202            message: Message {
1203                role: Role::Assistant,
1204                content,
1205            },
1206            usage: Usage {
1207                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
1208                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
1209                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
1210                cache_read_tokens: None,
1211                cache_write_tokens: None,
1212            },
1213            finish_reason,
1214        })
1215    }
1216
1217    async fn complete_stream(
1218        &self,
1219        request: CompletionRequest,
1220    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
1221        // Fall back to non-streaming for now
1222        let response = self.complete(request).await?;
1223        let text = response
1224            .message
1225            .content
1226            .iter()
1227            .filter_map(|p| match p {
1228                ContentPart::Text { text } => Some(text.clone()),
1229                _ => None,
1230            })
1231            .collect::<Vec<_>>()
1232            .join("");
1233
1234        Ok(Box::pin(futures::stream::once(async move {
1235            StreamChunk::Text(text)
1236        })))
1237    }
1238}