Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via either:
4//! - AWS SigV4 signing (standard AWS credentials from env/file/profile)
5//! - Bearer token auth (API Gateway / Vault-managed keys)
6//!
7//! Uses the native Bedrock Converse API format.
8//! Dynamically discovers available models via the Bedrock ListFoundationModels
9//! and ListInferenceProfiles APIs.
10//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
11
12use super::{
13    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
14    Role, StreamChunk, ToolDefinition, Usage,
15};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use hmac::{Hmac, Mac};
19use reqwest::Client;
20use serde::Deserialize;
21use serde_json::{Value, json};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24
25pub const DEFAULT_REGION: &str = "us-east-1";
26
27/// AWS credentials for SigV4 signing
28#[derive(Debug, Clone)]
29pub struct AwsCredentials {
30    pub access_key_id: String,
31    pub secret_access_key: String,
32    pub session_token: Option<String>,
33}
34
35impl AwsCredentials {
36    /// Load credentials from environment variables, then fall back to
37    /// ~/.aws/credentials file (default or named profile).
38    pub fn from_environment() -> Option<Self> {
39        // 1) Try env vars first
40        if let (Ok(key_id), Ok(secret)) = (
41            std::env::var("AWS_ACCESS_KEY_ID"),
42            std::env::var("AWS_SECRET_ACCESS_KEY"),
43        ) {
44            if !key_id.is_empty() && !secret.is_empty() {
45                return Some(Self {
46                    access_key_id: key_id,
47                    secret_access_key: secret,
48                    session_token: std::env::var("AWS_SESSION_TOKEN")
49                        .ok()
50                        .filter(|s| !s.is_empty()),
51                });
52            }
53        }
54
55        // 2) Fall back to ~/.aws/credentials file
56        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
57        Self::from_credentials_file(&profile)
58    }
59
60    /// Parse ~/.aws/credentials INI file for the given profile.
61    fn from_credentials_file(profile: &str) -> Option<Self> {
62        let home = std::env::var("HOME")
63            .or_else(|_| std::env::var("USERPROFILE"))
64            .ok()?;
65        let path = std::path::Path::new(&home).join(".aws").join("credentials");
66        let content = std::fs::read_to_string(&path).ok()?;
67
68        let section_header = format!("[{}]", profile);
69        let mut in_section = false;
70        let mut key_id = None;
71        let mut secret = None;
72        let mut token = None;
73
74        for line in content.lines() {
75            let trimmed = line.trim();
76            if trimmed.starts_with('[') {
77                in_section = trimmed == section_header;
78                continue;
79            }
80            if !in_section {
81                continue;
82            }
83            if let Some((k, v)) = trimmed.split_once('=') {
84                let k = k.trim();
85                let v = v.trim();
86                match k {
87                    "aws_access_key_id" => key_id = Some(v.to_string()),
88                    "aws_secret_access_key" => secret = Some(v.to_string()),
89                    "aws_session_token" => token = Some(v.to_string()),
90                    _ => {}
91                }
92            }
93        }
94
95        Some(Self {
96            access_key_id: key_id?,
97            secret_access_key: secret?,
98            session_token: token,
99        })
100    }
101
102    /// Detect region from AWS_REGION / AWS_DEFAULT_REGION env vars,
103    /// then from ~/.aws/config.
104    pub fn detect_region() -> Option<String> {
105        if let Ok(r) = std::env::var("AWS_REGION") {
106            if !r.is_empty() {
107                return Some(r);
108            }
109        }
110        if let Ok(r) = std::env::var("AWS_DEFAULT_REGION") {
111            if !r.is_empty() {
112                return Some(r);
113            }
114        }
115        // Try ~/.aws/config
116        let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
117        let home = std::env::var("HOME")
118            .or_else(|_| std::env::var("USERPROFILE"))
119            .ok()?;
120        let path = std::path::Path::new(&home).join(".aws").join("config");
121        let content = std::fs::read_to_string(&path).ok()?;
122
123        // In ~/.aws/config, default profile is [default], others are [profile foo]
124        let section_header = if profile == "default" {
125            "[default]".to_string()
126        } else {
127            format!("[profile {}]", profile)
128        };
129        let mut in_section = false;
130        for line in content.lines() {
131            let trimmed = line.trim();
132            if trimmed.starts_with('[') {
133                in_section = trimmed == section_header;
134                continue;
135            }
136            if !in_section {
137                continue;
138            }
139            if let Some((k, v)) = trimmed.split_once('=') {
140                if k.trim() == "region" {
141                    let v = v.trim();
142                    if !v.is_empty() {
143                        return Some(v.to_string());
144                    }
145                }
146            }
147        }
148        None
149    }
150}
151
152/// Authentication mode for the Bedrock provider.
153#[derive(Debug, Clone)]
154pub enum BedrockAuth {
155    /// Standard AWS SigV4 signing with IAM credentials
156    SigV4(AwsCredentials),
157    /// Bearer token (API Gateway or custom auth layer)
158    BearerToken(String),
159}
160
161pub struct BedrockProvider {
162    client: Client,
163    auth: BedrockAuth,
164    region: String,
165}
166
167impl std::fmt::Debug for BedrockProvider {
168    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169        f.debug_struct("BedrockProvider")
170            .field(
171                "auth",
172                &match &self.auth {
173                    BedrockAuth::SigV4(_) => "SigV4",
174                    BedrockAuth::BearerToken(_) => "BearerToken",
175                },
176            )
177            .field("region", &self.region)
178            .finish()
179    }
180}
181
182impl BedrockProvider {
183    /// Create from a bearer token (API Gateway / Vault key).
184    pub fn new(api_key: String) -> Result<Self> {
185        Self::with_region(api_key, DEFAULT_REGION.to_string())
186    }
187
188    /// Create from a bearer token with a specific region.
189    pub fn with_region(api_key: String, region: String) -> Result<Self> {
190        tracing::debug!(
191            provider = "bedrock",
192            region = %region,
193            auth = "bearer_token",
194            "Creating Bedrock provider"
195        );
196        Ok(Self {
197            client: Client::new(),
198            auth: BedrockAuth::BearerToken(api_key),
199            region,
200        })
201    }
202
203    /// Create from AWS IAM credentials (SigV4 signing).
204    pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
205        tracing::debug!(
206            provider = "bedrock",
207            region = %region,
208            auth = "sigv4",
209            "Creating Bedrock provider with AWS credentials"
210        );
211        Ok(Self {
212            client: Client::new(),
213            auth: BedrockAuth::SigV4(credentials),
214            region,
215        })
216    }
217
218    /// Public wrapper for sending a signed Converse API request.
219    /// Used by the thinker backend.
220    pub async fn send_converse_request(
221        &self,
222        url: &str,
223        body: &[u8],
224    ) -> Result<reqwest::Response> {
225        self.send_request("POST", url, Some(body), "bedrock-runtime")
226            .await
227    }
228
229    fn validate_auth(&self) -> Result<()> {
230        match &self.auth {
231            BedrockAuth::BearerToken(key) => {
232                if key.is_empty() {
233                    anyhow::bail!("Bedrock API key is empty");
234                }
235            }
236            BedrockAuth::SigV4(creds) => {
237                if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
238                    anyhow::bail!("AWS credentials are incomplete");
239                }
240            }
241        }
242        Ok(())
243    }
244
245    fn base_url(&self) -> String {
246        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
247    }
248
249    /// Management API URL (for listing models, not inference)
250    fn management_url(&self) -> String {
251        format!("https://bedrock.{}.amazonaws.com", self.region)
252    }
253
254    // ── AWS SigV4 signing helpers ──────────────────────────────────────
255
256    fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
257        let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
258        mac.update(data);
259        mac.finalize().into_bytes().to_vec()
260    }
261
262    fn sha256_hex(data: &[u8]) -> String {
263        let mut hasher = Sha256::new();
264        hasher.update(data);
265        hex::encode(hasher.finalize())
266    }
267
268    /// Build a SigV4-signed request and send it.
269    async fn send_signed_request(
270        &self,
271        method: &str,
272        url: &str,
273        body: &[u8],
274        service: &str,
275    ) -> Result<reqwest::Response> {
276        let creds = match &self.auth {
277            BedrockAuth::SigV4(c) => c,
278            BedrockAuth::BearerToken(_) => {
279                anyhow::bail!("send_signed_request called with bearer token auth");
280            }
281        };
282
283        let now = chrono::Utc::now();
284        let datestamp = now.format("%Y%m%d").to_string();
285        let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
286
287        // Parse URL components
288        let host_start = url.find("://").map(|i| i + 3).unwrap_or(0);
289        let after_host = url[host_start..]
290            .find('/')
291            .map(|i| host_start + i)
292            .unwrap_or(url.len());
293        let host = url[host_start..after_host].to_string();
294        let path_and_query = &url[after_host..];
295        let (canonical_uri, canonical_querystring) = match path_and_query.split_once('?') {
296            Some((p, q)) => (p.to_string(), q.to_string()),
297            None => (path_and_query.to_string(), String::new()),
298        };
299
300        let payload_hash = Self::sha256_hex(body);
301
302        // Build canonical headers (must be sorted)
303        let mut headers_map: Vec<(&str, String)> = vec![
304            ("content-type", "application/json".to_string()),
305            ("host", host.clone()),
306            ("x-amz-date", amz_date.clone()),
307        ];
308        if let Some(token) = &creds.session_token {
309            headers_map.push(("x-amz-security-token", token.clone()));
310        }
311        headers_map.sort_by_key(|(k, _)| *k);
312
313        let canonical_headers: String = headers_map
314            .iter()
315            .map(|(k, v)| format!("{}:{}", k, v))
316            .collect::<Vec<_>>()
317            .join("\n")
318            + "\n";
319
320        let signed_headers: String = headers_map
321            .iter()
322            .map(|(k, _)| *k)
323            .collect::<Vec<_>>()
324            .join(";");
325
326        let canonical_request = format!(
327            "{}\n{}\n{}\n{}\n{}\n{}",
328            method,
329            canonical_uri,
330            canonical_querystring,
331            canonical_headers,
332            signed_headers,
333            payload_hash
334        );
335
336        let credential_scope = format!("{}/{}/{}/aws4_request", datestamp, self.region, service);
337
338        let string_to_sign = format!(
339            "AWS4-HMAC-SHA256\n{}\n{}\n{}",
340            amz_date,
341            credential_scope,
342            Self::sha256_hex(canonical_request.as_bytes())
343        );
344
345        // Derive signing key
346        let k_date = Self::hmac_sha256(
347            format!("AWS4{}", creds.secret_access_key).as_bytes(),
348            datestamp.as_bytes(),
349        );
350        let k_region = Self::hmac_sha256(&k_date, self.region.as_bytes());
351        let k_service = Self::hmac_sha256(&k_region, service.as_bytes());
352        let k_signing = Self::hmac_sha256(&k_service, b"aws4_request");
353
354        let signature = hex::encode(Self::hmac_sha256(&k_signing, string_to_sign.as_bytes()));
355
356        let authorization = format!(
357            "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
358            creds.access_key_id, credential_scope, signed_headers, signature
359        );
360
361        let mut req = self
362            .client
363            .request(method.parse().unwrap_or(reqwest::Method::POST), url)
364            .header("content-type", "application/json")
365            .header("host", &host)
366            .header("x-amz-date", &amz_date)
367            .header("x-amz-content-sha256", &payload_hash)
368            .header("authorization", &authorization);
369
370        if let Some(token) = &creds.session_token {
371            req = req.header("x-amz-security-token", token);
372        }
373
374        if method == "POST" || method == "PUT" {
375            req = req.body(body.to_vec());
376        }
377
378        req.send()
379            .await
380            .context("Failed to send signed request to Bedrock")
381    }
382
383    /// Send an HTTP request using whichever auth mode is configured.
384    async fn send_request(
385        &self,
386        method: &str,
387        url: &str,
388        body: Option<&[u8]>,
389        service: &str,
390    ) -> Result<reqwest::Response> {
391        match &self.auth {
392            BedrockAuth::SigV4(_) => {
393                self.send_signed_request(method, url, body.unwrap_or(b""), service)
394                    .await
395            }
396            BedrockAuth::BearerToken(token) => {
397                let mut req = self
398                    .client
399                    .request(method.parse().unwrap_or(reqwest::Method::GET), url)
400                    .bearer_auth(token)
401                    .header("content-type", "application/json")
402                    .header("accept", "application/json");
403
404                if let Some(b) = body {
405                    req = req.body(b.to_vec());
406                }
407
408                req.send()
409                    .await
410                    .context("Failed to send request to Bedrock")
411            }
412        }
413    }
414
415    /// Resolve a short model alias to the full Bedrock model ID.
416    /// Allows users to specify e.g. "claude-sonnet-4" instead of
417    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
418    fn resolve_model_id(model: &str) -> &str {
419        match model {
420            // --- Anthropic Claude (verified via AWS CLI) ---
421            "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
422            "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
423            "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
424            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
425            "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
426                "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
427            }
428            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
429            "claude-haiku-4.5" | "claude-4.5-haiku" => {
430                "us.anthropic.claude-haiku-4-5-20251001-v1:0"
431            }
432            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
433                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
434            }
435            "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
436                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
437            }
438            "claude-3.5-haiku" | "claude-haiku-3.5" => {
439                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
440            }
441            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
442                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
443            }
444            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
445            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
446            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
447
448            // --- Amazon Nova ---
449            "nova-pro" => "amazon.nova-pro-v1:0",
450            "nova-lite" => "amazon.nova-lite-v1:0",
451            "nova-micro" => "amazon.nova-micro-v1:0",
452            "nova-premier" => "us.amazon.nova-premier-v1:0",
453
454            // --- Meta Llama ---
455            "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
456            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
457            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
458            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
459            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
460            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
461            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
462            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
463            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
464            "llama-3-70b" | "llama3-70b" => "meta.llama3-70b-instruct-v1:0",
465            "llama-3-8b" | "llama3-8b" => "meta.llama3-8b-instruct-v1:0",
466
467            // --- Mistral (mix of ON_DEMAND and INFERENCE_PROFILE) ---
468            "mistral-large-3" | "mistral-large" => "mistral.mistral-large-3-675b-instruct",
469            "mistral-large-2402" => "mistral.mistral-large-2402-v1:0",
470            "mistral-small" => "mistral.mistral-small-2402-v1:0",
471            "mixtral-8x7b" => "mistral.mixtral-8x7b-instruct-v0:1",
472            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
473            "magistral-small" => "mistral.magistral-small-2509",
474
475            // --- DeepSeek ---
476            "deepseek-r1" => "us.deepseek.r1-v1:0",
477            "deepseek-v3" | "deepseek-v3.2" => "deepseek.v3.2",
478
479            // --- Cohere (ON_DEMAND only, no us. prefix) ---
480            "command-r" => "cohere.command-r-v1:0",
481            "command-r-plus" => "cohere.command-r-plus-v1:0",
482
483            // --- Qwen (ON_DEMAND only, no us. prefix) ---
484            "qwen3-32b" => "qwen.qwen3-32b-v1:0",
485            "qwen3-coder" | "qwen3-coder-next" => "qwen.qwen3-coder-next",
486            "qwen3-coder-30b" => "qwen.qwen3-coder-30b-a3b-v1:0",
487
488            // --- Google Gemma (ON_DEMAND only, no us. prefix) ---
489            "gemma-3-27b" => "google.gemma-3-27b-it",
490            "gemma-3-12b" => "google.gemma-3-12b-it",
491            "gemma-3-4b" => "google.gemma-3-4b-it",
492
493            // --- Moonshot / Kimi (ON_DEMAND only, no us. prefix) ---
494            "kimi-k2" | "kimi-k2-thinking" => "moonshot.kimi-k2-thinking",
495            "kimi-k2.5" => "moonshotai.kimi-k2.5",
496
497            // --- AI21 Jamba (ON_DEMAND only, no us. prefix) ---
498            "jamba-1.5-large" => "ai21.jamba-1-5-large-v1:0",
499            "jamba-1.5-mini" => "ai21.jamba-1-5-mini-v1:0",
500
501            // --- MiniMax (ON_DEMAND only, no us. prefix) ---
502            "minimax-m2" => "minimax.minimax-m2",
503            "minimax-m2.1" => "minimax.minimax-m2.1",
504
505            // --- NVIDIA (ON_DEMAND only, no us. prefix) ---
506            "nemotron-nano-30b" => "nvidia.nemotron-nano-3-30b",
507            "nemotron-nano-12b" => "nvidia.nemotron-nano-12b-v2",
508            "nemotron-nano-9b" => "nvidia.nemotron-nano-9b-v2",
509
510            // --- Z.AI / GLM (ON_DEMAND only, no us. prefix) ---
511            "glm-5" => "zai.glm-5",
512            "glm-4.7" => "zai.glm-4.7",
513            "glm-4.7-flash" => "zai.glm-4.7-flash",
514
515            // Pass through full model IDs unchanged
516            other => other,
517        }
518    }
519
520    /// Dynamically discover available models from the Bedrock API.
521    /// Merges foundation models with cross-region inference profiles.
522    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
523        let mut models: HashMap<String, ModelInfo> = HashMap::new();
524
525        // 1) Fetch foundation models
526        let fm_url = format!("{}/foundation-models", self.management_url());
527        let fm_resp = self.send_request("GET", &fm_url, None, "bedrock").await;
528
529        if let Ok(resp) = fm_resp {
530            if resp.status().is_success() {
531                if let Ok(data) = resp.json::<Value>().await {
532                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
533                        for m in summaries {
534                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
535                            let model_name =
536                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
537                            let provider_name =
538                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
539
540                            let output_modalities: Vec<&str> = m
541                                .get("outputModalities")
542                                .and_then(|v| v.as_array())
543                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
544                                .unwrap_or_default();
545
546                            let input_modalities: Vec<&str> = m
547                                .get("inputModalities")
548                                .and_then(|v| v.as_array())
549                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
550                                .unwrap_or_default();
551
552                            let inference_types: Vec<&str> = m
553                                .get("inferenceTypesSupported")
554                                .and_then(|v| v.as_array())
555                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
556                                .unwrap_or_default();
557
558                            // Only include TEXT output models with ON_DEMAND or INFERENCE_PROFILE inference
559                            if !output_modalities.contains(&"TEXT")
560                                || (!inference_types.contains(&"ON_DEMAND")
561                                    && !inference_types.contains(&"INFERENCE_PROFILE"))
562                            {
563                                continue;
564                            }
565
566                            // Skip non-chat models
567                            let name_lower = model_name.to_lowercase();
568                            if name_lower.contains("rerank")
569                                || name_lower.contains("embed")
570                                || name_lower.contains("safeguard")
571                                || name_lower.contains("sonic")
572                                || name_lower.contains("pegasus")
573                            {
574                                continue;
575                            }
576
577                            let streaming = m
578                                .get("responseStreamingSupported")
579                                .and_then(|v| v.as_bool())
580                                .unwrap_or(false);
581                            let vision = input_modalities.contains(&"IMAGE");
582
583                            // Models with INFERENCE_PROFILE support use cross-region
584                            // us. prefix; ON_DEMAND-only models use bare model IDs.
585                            // Amazon models never get the prefix.
586                            let actual_id = if model_id.starts_with("amazon.") {
587                                model_id.to_string()
588                            } else if inference_types.contains(&"INFERENCE_PROFILE") {
589                                format!("us.{}", model_id)
590                            } else {
591                                model_id.to_string()
592                            };
593
594                            let display_name = format!("{} (Bedrock)", model_name);
595
596                            models.insert(
597                                actual_id.clone(),
598                                ModelInfo {
599                                    id: actual_id,
600                                    name: display_name,
601                                    provider: "bedrock".to_string(),
602                                    context_window: Self::estimate_context_window(
603                                        model_id,
604                                        provider_name,
605                                    ),
606                                    max_output_tokens: Some(Self::estimate_max_output(
607                                        model_id,
608                                        provider_name,
609                                    )),
610                                    supports_vision: vision,
611                                    supports_tools: true,
612                                    supports_streaming: streaming,
613                                    input_cost_per_million: None,
614                                    output_cost_per_million: None,
615                                },
616                            );
617                        }
618                    }
619                }
620            }
621        }
622
623        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
624        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
625        let ip_url = format!(
626            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
627            self.management_url()
628        );
629        let ip_resp = self.send_request("GET", &ip_url, None, "bedrock").await;
630
631        if let Ok(resp) = ip_resp {
632            if resp.status().is_success() {
633                if let Ok(data) = resp.json::<Value>().await {
634                    if let Some(profiles) = data
635                        .get("inferenceProfileSummaries")
636                        .and_then(|v| v.as_array())
637                    {
638                        for p in profiles {
639                            let pid = p
640                                .get("inferenceProfileId")
641                                .and_then(|v| v.as_str())
642                                .unwrap_or("");
643                            let pname = p
644                                .get("inferenceProfileName")
645                                .and_then(|v| v.as_str())
646                                .unwrap_or("");
647
648                            // Only US cross-region profiles
649                            if !pid.starts_with("us.") {
650                                continue;
651                            }
652
653                            // Skip already-discovered models
654                            if models.contains_key(pid) {
655                                continue;
656                            }
657
658                            // Skip non-text models
659                            let name_lower = pname.to_lowercase();
660                            if name_lower.contains("image")
661                                || name_lower.contains("stable ")
662                                || name_lower.contains("upscale")
663                                || name_lower.contains("embed")
664                                || name_lower.contains("marengo")
665                                || name_lower.contains("outpaint")
666                                || name_lower.contains("inpaint")
667                                || name_lower.contains("erase")
668                                || name_lower.contains("recolor")
669                                || name_lower.contains("replace")
670                                || name_lower.contains("style ")
671                                || name_lower.contains("background")
672                                || name_lower.contains("sketch")
673                                || name_lower.contains("control")
674                                || name_lower.contains("transfer")
675                                || name_lower.contains("sonic")
676                                || name_lower.contains("pegasus")
677                                || name_lower.contains("rerank")
678                            {
679                                continue;
680                            }
681
682                            // Guess vision from known model families
683                            let vision = pid.contains("llama3-2-11b")
684                                || pid.contains("llama3-2-90b")
685                                || pid.contains("pixtral")
686                                || pid.contains("claude-3")
687                                || pid.contains("claude-sonnet-4")
688                                || pid.contains("claude-opus-4")
689                                || pid.contains("claude-haiku-4");
690
691                            let display_name = pname.replace("US ", "");
692                            let display_name = format!("{} (Bedrock)", display_name.trim());
693
694                            // Extract provider hint from model ID
695                            let provider_hint = pid
696                                .strip_prefix("us.")
697                                .unwrap_or(pid)
698                                .split('.')
699                                .next()
700                                .unwrap_or("");
701
702                            models.insert(
703                                pid.to_string(),
704                                ModelInfo {
705                                    id: pid.to_string(),
706                                    name: display_name,
707                                    provider: "bedrock".to_string(),
708                                    context_window: Self::estimate_context_window(
709                                        pid,
710                                        provider_hint,
711                                    ),
712                                    max_output_tokens: Some(Self::estimate_max_output(
713                                        pid,
714                                        provider_hint,
715                                    )),
716                                    supports_vision: vision,
717                                    supports_tools: true,
718                                    supports_streaming: true,
719                                    input_cost_per_million: None,
720                                    output_cost_per_million: None,
721                                },
722                            );
723                        }
724                    }
725                }
726            }
727        }
728
729        let mut result: Vec<ModelInfo> = models.into_values().collect();
730        result.sort_by(|a, b| a.id.cmp(&b.id));
731
732        tracing::info!(
733            provider = "bedrock",
734            model_count = result.len(),
735            "Discovered Bedrock models dynamically"
736        );
737
738        Ok(result)
739    }
740
741    /// Estimate context window size based on model family
742    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
743        let id = model_id.to_lowercase();
744        if id.contains("anthropic") || id.contains("claude") {
745            200_000
746        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
747        {
748            300_000
749        } else if id.contains("nova-micro") || id.contains("nova-2") {
750            128_000
751        } else if id.contains("deepseek") {
752            128_000
753        } else if id.contains("llama4") {
754            256_000
755        } else if id.contains("llama3") {
756            128_000
757        } else if id.contains("mistral-large-3") || id.contains("magistral") {
758            128_000
759        } else if id.contains("mistral") {
760            32_000
761        } else if id.contains("qwen") {
762            128_000
763        } else if id.contains("kimi") {
764            128_000
765        } else if id.contains("jamba") {
766            256_000
767        } else if id.contains("glm") {
768            128_000
769        } else if id.contains("minimax") {
770            128_000
771        } else if id.contains("gemma") {
772            128_000
773        } else if id.contains("cohere") || id.contains("command") {
774            128_000
775        } else if id.contains("nemotron") {
776            128_000
777        } else if provider.to_lowercase().contains("amazon") {
778            128_000
779        } else {
780            32_000
781        }
782    }
783
784    /// Estimate max output tokens based on model family
785    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
786        let id = model_id.to_lowercase();
787        if id.contains("claude-opus-4-6") {
788            32_000
789        } else if id.contains("claude-opus-4-5") {
790            32_000
791        } else if id.contains("claude-opus-4-1") {
792            32_000
793        } else if id.contains("claude-sonnet-4-5")
794            || id.contains("claude-sonnet-4")
795            || id.contains("claude-3-7")
796        {
797            64_000
798        } else if id.contains("claude-haiku-4-5") {
799            16_384
800        } else if id.contains("claude-opus-4") {
801            32_000
802        } else if id.contains("claude") {
803            8_192
804        } else if id.contains("nova") {
805            5_000
806        } else if id.contains("deepseek") {
807            16_384
808        } else if id.contains("llama4") {
809            16_384
810        } else if id.contains("llama") {
811            4_096
812        } else if id.contains("mistral-large-3") {
813            16_384
814        } else if id.contains("mistral") || id.contains("mixtral") {
815            8_192
816        } else if id.contains("qwen") {
817            8_192
818        } else if id.contains("kimi") {
819            8_192
820        } else if id.contains("jamba") {
821            4_096
822        } else {
823            4_096
824        }
825    }
826
827    /// Convert our generic messages to Bedrock Converse API format.
828    ///
829    /// Bedrock Converse uses:
830    /// - system prompt as a top-level "system" array
831    /// - messages with "role" and "content" array
832    /// - tool_use blocks in assistant content
833    /// - toolResult blocks in user content
834    ///
835    /// IMPORTANT: Bedrock requires strict role alternation (user/assistant).
836    /// Consecutive Role::Tool messages must be merged into a single "user"
837    /// message so all toolResult blocks for a given assistant turn appear
838    /// together. Consecutive same-role messages are also merged to prevent
839    /// validation errors.
840    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
841        let mut system_parts: Vec<Value> = Vec::new();
842        let mut api_messages: Vec<Value> = Vec::new();
843
844        for msg in messages {
845            match msg.role {
846                Role::System => {
847                    let text: String = msg
848                        .content
849                        .iter()
850                        .filter_map(|p| match p {
851                            ContentPart::Text { text } => Some(text.clone()),
852                            _ => None,
853                        })
854                        .collect::<Vec<_>>()
855                        .join("\n");
856                    system_parts.push(json!({"text": text}));
857                }
858                Role::User => {
859                    let mut content_parts: Vec<Value> = Vec::new();
860                    for part in &msg.content {
861                        match part {
862                            ContentPart::Text { text } => {
863                                if !text.is_empty() {
864                                    content_parts.push(json!({"text": text}));
865                                }
866                            }
867                            _ => {}
868                        }
869                    }
870                    if !content_parts.is_empty() {
871                        // Merge into previous user message if the last message is also "user"
872                        if let Some(last) = api_messages.last_mut() {
873                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
874                                if let Some(arr) =
875                                    last.get_mut("content").and_then(|c| c.as_array_mut())
876                                {
877                                    arr.extend(content_parts);
878                                    continue;
879                                }
880                            }
881                        }
882                        api_messages.push(json!({
883                            "role": "user",
884                            "content": content_parts
885                        }));
886                    }
887                }
888                Role::Assistant => {
889                    let mut content_parts: Vec<Value> = Vec::new();
890                    for part in &msg.content {
891                        match part {
892                            ContentPart::Text { text } => {
893                                if !text.is_empty() {
894                                    content_parts.push(json!({"text": text}));
895                                }
896                            }
897                            ContentPart::ToolCall {
898                                id,
899                                name,
900                                arguments,
901                            } => {
902                                let input: Value = serde_json::from_str(arguments)
903                                    .unwrap_or_else(|_| json!({"raw": arguments}));
904                                content_parts.push(json!({
905                                    "toolUse": {
906                                        "toolUseId": id,
907                                        "name": name,
908                                        "input": input
909                                    }
910                                }));
911                            }
912                            _ => {}
913                        }
914                    }
915                    if content_parts.is_empty() {
916                        content_parts.push(json!({"text": " "}));
917                    }
918                    // Merge into previous assistant message if consecutive
919                    if let Some(last) = api_messages.last_mut() {
920                        if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
921                            if let Some(arr) =
922                                last.get_mut("content").and_then(|c| c.as_array_mut())
923                            {
924                                arr.extend(content_parts);
925                                continue;
926                            }
927                        }
928                    }
929                    api_messages.push(json!({
930                        "role": "assistant",
931                        "content": content_parts
932                    }));
933                }
934                Role::Tool => {
935                    // Tool results must be in a "user" message with toolResult blocks.
936                    // Merge into the previous user message if one exists (handles
937                    // consecutive Tool messages being collapsed into one user turn).
938                    let mut content_parts: Vec<Value> = Vec::new();
939                    for part in &msg.content {
940                        if let ContentPart::ToolResult {
941                            tool_call_id,
942                            content,
943                        } = part
944                        {
945                            content_parts.push(json!({
946                                "toolResult": {
947                                    "toolUseId": tool_call_id,
948                                    "content": [{"text": content}],
949                                    "status": "success"
950                                }
951                            }));
952                        }
953                    }
954                    if !content_parts.is_empty() {
955                        // Merge into previous user message (from earlier Tool messages)
956                        if let Some(last) = api_messages.last_mut() {
957                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
958                                if let Some(arr) =
959                                    last.get_mut("content").and_then(|c| c.as_array_mut())
960                                {
961                                    arr.extend(content_parts);
962                                    continue;
963                                }
964                            }
965                        }
966                        api_messages.push(json!({
967                            "role": "user",
968                            "content": content_parts
969                        }));
970                    }
971                }
972            }
973        }
974
975        (system_parts, api_messages)
976    }
977
978    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
979        tools
980            .iter()
981            .map(|t| {
982                json!({
983                    "toolSpec": {
984                        "name": t.name,
985                        "description": t.description,
986                        "inputSchema": {
987                            "json": t.parameters
988                        }
989                    }
990                })
991            })
992            .collect()
993    }
994}
995
996/// Bedrock Converse API response types
997
998#[derive(Debug, Deserialize)]
999#[serde(rename_all = "camelCase")]
1000struct ConverseResponse {
1001    output: ConverseOutput,
1002    #[serde(default)]
1003    stop_reason: Option<String>,
1004    #[serde(default)]
1005    usage: Option<ConverseUsage>,
1006}
1007
1008#[derive(Debug, Deserialize)]
1009struct ConverseOutput {
1010    message: ConverseMessage,
1011}
1012
1013#[derive(Debug, Deserialize)]
1014struct ConverseMessage {
1015    #[allow(dead_code)]
1016    role: String,
1017    content: Vec<ConverseContent>,
1018}
1019
1020#[derive(Debug, Deserialize)]
1021#[serde(untagged)]
1022enum ConverseContent {
1023    ReasoningContent {
1024        #[serde(rename = "reasoningContent")]
1025        reasoning_content: ReasoningContentBlock,
1026    },
1027    Text {
1028        text: String,
1029    },
1030    ToolUse {
1031        #[serde(rename = "toolUse")]
1032        tool_use: ConverseToolUse,
1033    },
1034}
1035
1036#[derive(Debug, Deserialize)]
1037#[serde(rename_all = "camelCase")]
1038struct ReasoningContentBlock {
1039    reasoning_text: ReasoningText,
1040}
1041
1042#[derive(Debug, Deserialize)]
1043struct ReasoningText {
1044    text: String,
1045}
1046
1047#[derive(Debug, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049struct ConverseToolUse {
1050    tool_use_id: String,
1051    name: String,
1052    input: Value,
1053}
1054
1055#[derive(Debug, Deserialize)]
1056#[serde(rename_all = "camelCase")]
1057struct ConverseUsage {
1058    #[serde(default)]
1059    input_tokens: usize,
1060    #[serde(default)]
1061    output_tokens: usize,
1062    #[serde(default)]
1063    total_tokens: usize,
1064}
1065
1066#[derive(Debug, Deserialize)]
1067struct BedrockError {
1068    message: String,
1069}
1070
1071#[async_trait]
1072impl Provider for BedrockProvider {
1073    fn name(&self) -> &str {
1074        "bedrock"
1075    }
1076
1077    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
1078        self.validate_auth()?;
1079        self.discover_models().await
1080    }
1081
1082    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
1083        let model_id = Self::resolve_model_id(&request.model);
1084
1085        tracing::debug!(
1086            provider = "bedrock",
1087            model = %model_id,
1088            original_model = %request.model,
1089            message_count = request.messages.len(),
1090            tool_count = request.tools.len(),
1091            "Starting Bedrock Converse request"
1092        );
1093
1094        self.validate_auth()?;
1095
1096        let (system_parts, messages) = Self::convert_messages(&request.messages);
1097        let tools = Self::convert_tools(&request.tools);
1098
1099        let mut body = json!({
1100            "messages": messages,
1101        });
1102
1103        if !system_parts.is_empty() {
1104            body["system"] = json!(system_parts);
1105        }
1106
1107        // inferenceConfig
1108        let mut inference_config = json!({});
1109        if let Some(max_tokens) = request.max_tokens {
1110            inference_config["maxTokens"] = json!(max_tokens);
1111        } else {
1112            inference_config["maxTokens"] = json!(8192);
1113        }
1114        if let Some(temp) = request.temperature {
1115            inference_config["temperature"] = json!(temp);
1116        }
1117        if let Some(top_p) = request.top_p {
1118            inference_config["topP"] = json!(top_p);
1119        }
1120        body["inferenceConfig"] = inference_config;
1121
1122        if !tools.is_empty() {
1123            body["toolConfig"] = json!({"tools": tools});
1124        }
1125
1126        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
1127        let encoded_model_id = model_id.replace(':', "%3A");
1128        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
1129        tracing::debug!("Bedrock request URL: {}", url);
1130
1131        let body_bytes = serde_json::to_vec(&body)?;
1132        let response = self
1133            .send_request("POST", &url, Some(&body_bytes), "bedrock-runtime")
1134            .await?;
1135
1136        let status = response.status();
1137        let text = response
1138            .text()
1139            .await
1140            .context("Failed to read Bedrock response")?;
1141
1142        if !status.is_success() {
1143            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
1144                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
1145            }
1146            anyhow::bail!(
1147                "Bedrock API error: {} {}",
1148                status,
1149                &text[..text.len().min(500)]
1150            );
1151        }
1152
1153        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
1154            "Failed to parse Bedrock response: {}",
1155            &text[..text.len().min(300)]
1156        ))?;
1157
1158        tracing::debug!(
1159            stop_reason = ?response.stop_reason,
1160            "Received Bedrock response"
1161        );
1162
1163        let mut content = Vec::new();
1164        let mut has_tool_calls = false;
1165
1166        for part in &response.output.message.content {
1167            match part {
1168                ConverseContent::ReasoningContent { reasoning_content } => {
1169                    if !reasoning_content.reasoning_text.text.is_empty() {
1170                        content.push(ContentPart::Thinking {
1171                            text: reasoning_content.reasoning_text.text.clone(),
1172                        });
1173                    }
1174                }
1175                ConverseContent::Text { text } => {
1176                    if !text.is_empty() {
1177                        content.push(ContentPart::Text { text: text.clone() });
1178                    }
1179                }
1180                ConverseContent::ToolUse { tool_use } => {
1181                    has_tool_calls = true;
1182                    content.push(ContentPart::ToolCall {
1183                        id: tool_use.tool_use_id.clone(),
1184                        name: tool_use.name.clone(),
1185                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1186                    });
1187                }
1188            }
1189        }
1190
1191        let finish_reason = if has_tool_calls {
1192            FinishReason::ToolCalls
1193        } else {
1194            match response.stop_reason.as_deref() {
1195                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
1196                Some("max_tokens") => FinishReason::Length,
1197                Some("tool_use") => FinishReason::ToolCalls,
1198                Some("content_filtered") => FinishReason::ContentFilter,
1199                _ => FinishReason::Stop,
1200            }
1201        };
1202
1203        let usage = response.usage.as_ref();
1204
1205        Ok(CompletionResponse {
1206            message: Message {
1207                role: Role::Assistant,
1208                content,
1209            },
1210            usage: Usage {
1211                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
1212                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
1213                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
1214                cache_read_tokens: None,
1215                cache_write_tokens: None,
1216            },
1217            finish_reason,
1218        })
1219    }
1220
1221    async fn complete_stream(
1222        &self,
1223        request: CompletionRequest,
1224    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
1225        // Fall back to non-streaming for now
1226        let response = self.complete(request).await?;
1227        let text = response
1228            .message
1229            .content
1230            .iter()
1231            .filter_map(|p| match p {
1232                ContentPart::Text { text } => Some(text.clone()),
1233                _ => None,
1234            })
1235            .collect::<Vec<_>>()
1236            .join("");
1237
1238        Ok(Box::pin(futures::stream::once(async move {
1239            StreamChunk::Text(text)
1240        })))
1241    }
1242}