1use super::{
13 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
14 Role, StreamChunk, ToolDefinition, Usage,
15};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use hmac::{Hmac, Mac};
19use reqwest::Client;
20use serde::Deserialize;
21use serde_json::{Value, json};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24
25pub const DEFAULT_REGION: &str = "us-east-1";
26
27#[derive(Debug, Clone)]
29pub struct AwsCredentials {
30 pub access_key_id: String,
31 pub secret_access_key: String,
32 pub session_token: Option<String>,
33}
34
35impl AwsCredentials {
36 pub fn from_environment() -> Option<Self> {
39 if let (Ok(key_id), Ok(secret)) = (
41 std::env::var("AWS_ACCESS_KEY_ID"),
42 std::env::var("AWS_SECRET_ACCESS_KEY"),
43 ) {
44 if !key_id.is_empty() && !secret.is_empty() {
45 return Some(Self {
46 access_key_id: key_id,
47 secret_access_key: secret,
48 session_token: std::env::var("AWS_SESSION_TOKEN")
49 .ok()
50 .filter(|s| !s.is_empty()),
51 });
52 }
53 }
54
55 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
57 Self::from_credentials_file(&profile)
58 }
59
60 fn from_credentials_file(profile: &str) -> Option<Self> {
62 let home = std::env::var("HOME")
63 .or_else(|_| std::env::var("USERPROFILE"))
64 .ok()?;
65 let path = std::path::Path::new(&home).join(".aws").join("credentials");
66 let content = std::fs::read_to_string(&path).ok()?;
67
68 let section_header = format!("[{}]", profile);
69 let mut in_section = false;
70 let mut key_id = None;
71 let mut secret = None;
72 let mut token = None;
73
74 for line in content.lines() {
75 let trimmed = line.trim();
76 if trimmed.starts_with('[') {
77 in_section = trimmed == section_header;
78 continue;
79 }
80 if !in_section {
81 continue;
82 }
83 if let Some((k, v)) = trimmed.split_once('=') {
84 let k = k.trim();
85 let v = v.trim();
86 match k {
87 "aws_access_key_id" => key_id = Some(v.to_string()),
88 "aws_secret_access_key" => secret = Some(v.to_string()),
89 "aws_session_token" => token = Some(v.to_string()),
90 _ => {}
91 }
92 }
93 }
94
95 Some(Self {
96 access_key_id: key_id?,
97 secret_access_key: secret?,
98 session_token: token,
99 })
100 }
101
102 pub fn detect_region() -> Option<String> {
105 if let Ok(r) = std::env::var("AWS_REGION") {
106 if !r.is_empty() {
107 return Some(r);
108 }
109 }
110 if let Ok(r) = std::env::var("AWS_DEFAULT_REGION") {
111 if !r.is_empty() {
112 return Some(r);
113 }
114 }
115 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
117 let home = std::env::var("HOME")
118 .or_else(|_| std::env::var("USERPROFILE"))
119 .ok()?;
120 let path = std::path::Path::new(&home).join(".aws").join("config");
121 let content = std::fs::read_to_string(&path).ok()?;
122
123 let section_header = if profile == "default" {
125 "[default]".to_string()
126 } else {
127 format!("[profile {}]", profile)
128 };
129 let mut in_section = false;
130 for line in content.lines() {
131 let trimmed = line.trim();
132 if trimmed.starts_with('[') {
133 in_section = trimmed == section_header;
134 continue;
135 }
136 if !in_section {
137 continue;
138 }
139 if let Some((k, v)) = trimmed.split_once('=') {
140 if k.trim() == "region" {
141 let v = v.trim();
142 if !v.is_empty() {
143 return Some(v.to_string());
144 }
145 }
146 }
147 }
148 None
149 }
150}
151
152#[derive(Debug, Clone)]
154pub enum BedrockAuth {
155 SigV4(AwsCredentials),
157 BearerToken(String),
159}
160
161pub struct BedrockProvider {
162 client: Client,
163 auth: BedrockAuth,
164 region: String,
165}
166
167impl std::fmt::Debug for BedrockProvider {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("BedrockProvider")
170 .field(
171 "auth",
172 &match &self.auth {
173 BedrockAuth::SigV4(_) => "SigV4",
174 BedrockAuth::BearerToken(_) => "BearerToken",
175 },
176 )
177 .field("region", &self.region)
178 .finish()
179 }
180}
181
182impl BedrockProvider {
183 pub fn new(api_key: String) -> Result<Self> {
185 Self::with_region(api_key, DEFAULT_REGION.to_string())
186 }
187
188 pub fn with_region(api_key: String, region: String) -> Result<Self> {
190 tracing::debug!(
191 provider = "bedrock",
192 region = %region,
193 auth = "bearer_token",
194 "Creating Bedrock provider"
195 );
196 Ok(Self {
197 client: Client::new(),
198 auth: BedrockAuth::BearerToken(api_key),
199 region,
200 })
201 }
202
203 pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
205 tracing::debug!(
206 provider = "bedrock",
207 region = %region,
208 auth = "sigv4",
209 "Creating Bedrock provider with AWS credentials"
210 );
211 Ok(Self {
212 client: Client::new(),
213 auth: BedrockAuth::SigV4(credentials),
214 region,
215 })
216 }
217
218 pub async fn send_converse_request(
221 &self,
222 url: &str,
223 body: &[u8],
224 ) -> Result<reqwest::Response> {
225 self.send_request("POST", url, Some(body), "bedrock-runtime")
226 .await
227 }
228
229 fn validate_auth(&self) -> Result<()> {
230 match &self.auth {
231 BedrockAuth::BearerToken(key) => {
232 if key.is_empty() {
233 anyhow::bail!("Bedrock API key is empty");
234 }
235 }
236 BedrockAuth::SigV4(creds) => {
237 if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
238 anyhow::bail!("AWS credentials are incomplete");
239 }
240 }
241 }
242 Ok(())
243 }
244
245 fn base_url(&self) -> String {
246 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
247 }
248
249 fn management_url(&self) -> String {
251 format!("https://bedrock.{}.amazonaws.com", self.region)
252 }
253
254 fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
257 let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
258 mac.update(data);
259 mac.finalize().into_bytes().to_vec()
260 }
261
262 fn sha256_hex(data: &[u8]) -> String {
263 let mut hasher = Sha256::new();
264 hasher.update(data);
265 hex::encode(hasher.finalize())
266 }
267
268 async fn send_signed_request(
270 &self,
271 method: &str,
272 url: &str,
273 body: &[u8],
274 service: &str,
275 ) -> Result<reqwest::Response> {
276 let creds = match &self.auth {
277 BedrockAuth::SigV4(c) => c,
278 BedrockAuth::BearerToken(_) => {
279 anyhow::bail!("send_signed_request called with bearer token auth");
280 }
281 };
282
283 let now = chrono::Utc::now();
284 let datestamp = now.format("%Y%m%d").to_string();
285 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
286
287 let host_start = url.find("://").map(|i| i + 3).unwrap_or(0);
289 let after_host = url[host_start..]
290 .find('/')
291 .map(|i| host_start + i)
292 .unwrap_or(url.len());
293 let host = url[host_start..after_host].to_string();
294 let path_and_query = &url[after_host..];
295 let (canonical_uri, canonical_querystring) = match path_and_query.split_once('?') {
296 Some((p, q)) => (p.to_string(), q.to_string()),
297 None => (path_and_query.to_string(), String::new()),
298 };
299
300 let payload_hash = Self::sha256_hex(body);
301
302 let mut headers_map: Vec<(&str, String)> = vec![
304 ("content-type", "application/json".to_string()),
305 ("host", host.clone()),
306 ("x-amz-date", amz_date.clone()),
307 ];
308 if let Some(token) = &creds.session_token {
309 headers_map.push(("x-amz-security-token", token.clone()));
310 }
311 headers_map.sort_by_key(|(k, _)| *k);
312
313 let canonical_headers: String = headers_map
314 .iter()
315 .map(|(k, v)| format!("{}:{}", k, v))
316 .collect::<Vec<_>>()
317 .join("\n")
318 + "\n";
319
320 let signed_headers: String = headers_map
321 .iter()
322 .map(|(k, _)| *k)
323 .collect::<Vec<_>>()
324 .join(";");
325
326 let canonical_request = format!(
327 "{}\n{}\n{}\n{}\n{}\n{}",
328 method,
329 canonical_uri,
330 canonical_querystring,
331 canonical_headers,
332 signed_headers,
333 payload_hash
334 );
335
336 let credential_scope = format!("{}/{}/{}/aws4_request", datestamp, self.region, service);
337
338 let string_to_sign = format!(
339 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
340 amz_date,
341 credential_scope,
342 Self::sha256_hex(canonical_request.as_bytes())
343 );
344
345 let k_date = Self::hmac_sha256(
347 format!("AWS4{}", creds.secret_access_key).as_bytes(),
348 datestamp.as_bytes(),
349 );
350 let k_region = Self::hmac_sha256(&k_date, self.region.as_bytes());
351 let k_service = Self::hmac_sha256(&k_region, service.as_bytes());
352 let k_signing = Self::hmac_sha256(&k_service, b"aws4_request");
353
354 let signature = hex::encode(Self::hmac_sha256(&k_signing, string_to_sign.as_bytes()));
355
356 let authorization = format!(
357 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
358 creds.access_key_id, credential_scope, signed_headers, signature
359 );
360
361 let mut req = self
362 .client
363 .request(method.parse().unwrap_or(reqwest::Method::POST), url)
364 .header("content-type", "application/json")
365 .header("host", &host)
366 .header("x-amz-date", &amz_date)
367 .header("x-amz-content-sha256", &payload_hash)
368 .header("authorization", &authorization);
369
370 if let Some(token) = &creds.session_token {
371 req = req.header("x-amz-security-token", token);
372 }
373
374 if method == "POST" || method == "PUT" {
375 req = req.body(body.to_vec());
376 }
377
378 req.send()
379 .await
380 .context("Failed to send signed request to Bedrock")
381 }
382
383 async fn send_request(
385 &self,
386 method: &str,
387 url: &str,
388 body: Option<&[u8]>,
389 service: &str,
390 ) -> Result<reqwest::Response> {
391 match &self.auth {
392 BedrockAuth::SigV4(_) => {
393 self.send_signed_request(method, url, body.unwrap_or(b""), service)
394 .await
395 }
396 BedrockAuth::BearerToken(token) => {
397 let mut req = self
398 .client
399 .request(method.parse().unwrap_or(reqwest::Method::GET), url)
400 .bearer_auth(token)
401 .header("content-type", "application/json")
402 .header("accept", "application/json");
403
404 if let Some(b) = body {
405 req = req.body(b.to_vec());
406 }
407
408 req.send()
409 .await
410 .context("Failed to send request to Bedrock")
411 }
412 }
413 }
414
415 fn resolve_model_id(model: &str) -> &str {
419 match model {
420 "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
422 "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
423 "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
424 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
425 "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
426 "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
427 }
428 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
429 "claude-haiku-4.5" | "claude-4.5-haiku" => {
430 "us.anthropic.claude-haiku-4-5-20251001-v1:0"
431 }
432 "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
433 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
434 }
435 "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
436 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
437 }
438 "claude-3.5-haiku" | "claude-haiku-3.5" => {
439 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
440 }
441 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
442 "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
443 }
444 "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
445 "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
446 "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
447
448 "nova-pro" => "amazon.nova-pro-v1:0",
450 "nova-lite" => "amazon.nova-lite-v1:0",
451 "nova-micro" => "amazon.nova-micro-v1:0",
452 "nova-premier" => "us.amazon.nova-premier-v1:0",
453
454 "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
456 "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
457 "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
458 "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
459 "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
460 "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
461 "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
462 "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
463 "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
464 "llama-3-70b" | "llama3-70b" => "meta.llama3-70b-instruct-v1:0",
465 "llama-3-8b" | "llama3-8b" => "meta.llama3-8b-instruct-v1:0",
466
467 "mistral-large-3" | "mistral-large" => "mistral.mistral-large-3-675b-instruct",
469 "mistral-large-2402" => "mistral.mistral-large-2402-v1:0",
470 "mistral-small" => "mistral.mistral-small-2402-v1:0",
471 "mixtral-8x7b" => "mistral.mixtral-8x7b-instruct-v0:1",
472 "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
473 "magistral-small" => "mistral.magistral-small-2509",
474
475 "deepseek-r1" => "us.deepseek.r1-v1:0",
477 "deepseek-v3" | "deepseek-v3.2" => "deepseek.v3.2",
478
479 "command-r" => "cohere.command-r-v1:0",
481 "command-r-plus" => "cohere.command-r-plus-v1:0",
482
483 "qwen3-32b" => "qwen.qwen3-32b-v1:0",
485 "qwen3-coder" | "qwen3-coder-next" => "qwen.qwen3-coder-next",
486 "qwen3-coder-30b" => "qwen.qwen3-coder-30b-a3b-v1:0",
487
488 "gemma-3-27b" => "google.gemma-3-27b-it",
490 "gemma-3-12b" => "google.gemma-3-12b-it",
491 "gemma-3-4b" => "google.gemma-3-4b-it",
492
493 "kimi-k2" | "kimi-k2-thinking" => "moonshot.kimi-k2-thinking",
495 "kimi-k2.5" => "moonshotai.kimi-k2.5",
496
497 "jamba-1.5-large" => "ai21.jamba-1-5-large-v1:0",
499 "jamba-1.5-mini" => "ai21.jamba-1-5-mini-v1:0",
500
501 "minimax-m2" => "minimax.minimax-m2",
503 "minimax-m2.1" => "minimax.minimax-m2.1",
504
505 "nemotron-nano-30b" => "nvidia.nemotron-nano-3-30b",
507 "nemotron-nano-12b" => "nvidia.nemotron-nano-12b-v2",
508 "nemotron-nano-9b" => "nvidia.nemotron-nano-9b-v2",
509
510 "glm-5" => "zai.glm-5",
512 "glm-4.7" => "zai.glm-4.7",
513 "glm-4.7-flash" => "zai.glm-4.7-flash",
514
515 other => other,
517 }
518 }
519
520 async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
523 let mut models: HashMap<String, ModelInfo> = HashMap::new();
524
525 let fm_url = format!("{}/foundation-models", self.management_url());
527 let fm_resp = self.send_request("GET", &fm_url, None, "bedrock").await;
528
529 if let Ok(resp) = fm_resp {
530 if resp.status().is_success() {
531 if let Ok(data) = resp.json::<Value>().await {
532 if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
533 for m in summaries {
534 let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
535 let model_name =
536 m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
537 let provider_name =
538 m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
539
540 let output_modalities: Vec<&str> = m
541 .get("outputModalities")
542 .and_then(|v| v.as_array())
543 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
544 .unwrap_or_default();
545
546 let input_modalities: Vec<&str> = m
547 .get("inputModalities")
548 .and_then(|v| v.as_array())
549 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
550 .unwrap_or_default();
551
552 let inference_types: Vec<&str> = m
553 .get("inferenceTypesSupported")
554 .and_then(|v| v.as_array())
555 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
556 .unwrap_or_default();
557
558 if !output_modalities.contains(&"TEXT")
560 || (!inference_types.contains(&"ON_DEMAND")
561 && !inference_types.contains(&"INFERENCE_PROFILE"))
562 {
563 continue;
564 }
565
566 let name_lower = model_name.to_lowercase();
568 if name_lower.contains("rerank")
569 || name_lower.contains("embed")
570 || name_lower.contains("safeguard")
571 || name_lower.contains("sonic")
572 || name_lower.contains("pegasus")
573 {
574 continue;
575 }
576
577 let streaming = m
578 .get("responseStreamingSupported")
579 .and_then(|v| v.as_bool())
580 .unwrap_or(false);
581 let vision = input_modalities.contains(&"IMAGE");
582
583 let actual_id = if model_id.starts_with("amazon.") {
587 model_id.to_string()
588 } else if inference_types.contains(&"INFERENCE_PROFILE") {
589 format!("us.{}", model_id)
590 } else {
591 model_id.to_string()
592 };
593
594 let display_name = format!("{} (Bedrock)", model_name);
595
596 models.insert(
597 actual_id.clone(),
598 ModelInfo {
599 id: actual_id,
600 name: display_name,
601 provider: "bedrock".to_string(),
602 context_window: Self::estimate_context_window(
603 model_id,
604 provider_name,
605 ),
606 max_output_tokens: Some(Self::estimate_max_output(
607 model_id,
608 provider_name,
609 )),
610 supports_vision: vision,
611 supports_tools: true,
612 supports_streaming: streaming,
613 input_cost_per_million: None,
614 output_cost_per_million: None,
615 },
616 );
617 }
618 }
619 }
620 }
621 }
622
623 let ip_url = format!(
626 "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
627 self.management_url()
628 );
629 let ip_resp = self.send_request("GET", &ip_url, None, "bedrock").await;
630
631 if let Ok(resp) = ip_resp {
632 if resp.status().is_success() {
633 if let Ok(data) = resp.json::<Value>().await {
634 if let Some(profiles) = data
635 .get("inferenceProfileSummaries")
636 .and_then(|v| v.as_array())
637 {
638 for p in profiles {
639 let pid = p
640 .get("inferenceProfileId")
641 .and_then(|v| v.as_str())
642 .unwrap_or("");
643 let pname = p
644 .get("inferenceProfileName")
645 .and_then(|v| v.as_str())
646 .unwrap_or("");
647
648 if !pid.starts_with("us.") {
650 continue;
651 }
652
653 if models.contains_key(pid) {
655 continue;
656 }
657
658 let name_lower = pname.to_lowercase();
660 if name_lower.contains("image")
661 || name_lower.contains("stable ")
662 || name_lower.contains("upscale")
663 || name_lower.contains("embed")
664 || name_lower.contains("marengo")
665 || name_lower.contains("outpaint")
666 || name_lower.contains("inpaint")
667 || name_lower.contains("erase")
668 || name_lower.contains("recolor")
669 || name_lower.contains("replace")
670 || name_lower.contains("style ")
671 || name_lower.contains("background")
672 || name_lower.contains("sketch")
673 || name_lower.contains("control")
674 || name_lower.contains("transfer")
675 || name_lower.contains("sonic")
676 || name_lower.contains("pegasus")
677 || name_lower.contains("rerank")
678 {
679 continue;
680 }
681
682 let vision = pid.contains("llama3-2-11b")
684 || pid.contains("llama3-2-90b")
685 || pid.contains("pixtral")
686 || pid.contains("claude-3")
687 || pid.contains("claude-sonnet-4")
688 || pid.contains("claude-opus-4")
689 || pid.contains("claude-haiku-4");
690
691 let display_name = pname.replace("US ", "");
692 let display_name = format!("{} (Bedrock)", display_name.trim());
693
694 let provider_hint = pid
696 .strip_prefix("us.")
697 .unwrap_or(pid)
698 .split('.')
699 .next()
700 .unwrap_or("");
701
702 models.insert(
703 pid.to_string(),
704 ModelInfo {
705 id: pid.to_string(),
706 name: display_name,
707 provider: "bedrock".to_string(),
708 context_window: Self::estimate_context_window(
709 pid,
710 provider_hint,
711 ),
712 max_output_tokens: Some(Self::estimate_max_output(
713 pid,
714 provider_hint,
715 )),
716 supports_vision: vision,
717 supports_tools: true,
718 supports_streaming: true,
719 input_cost_per_million: None,
720 output_cost_per_million: None,
721 },
722 );
723 }
724 }
725 }
726 }
727 }
728
729 let mut result: Vec<ModelInfo> = models.into_values().collect();
730 result.sort_by(|a, b| a.id.cmp(&b.id));
731
732 tracing::info!(
733 provider = "bedrock",
734 model_count = result.len(),
735 "Discovered Bedrock models dynamically"
736 );
737
738 Ok(result)
739 }
740
741 fn estimate_context_window(model_id: &str, provider: &str) -> usize {
743 let id = model_id.to_lowercase();
744 if id.contains("anthropic") || id.contains("claude") {
745 200_000
746 } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
747 {
748 300_000
749 } else if id.contains("nova-micro") || id.contains("nova-2") {
750 128_000
751 } else if id.contains("deepseek") {
752 128_000
753 } else if id.contains("llama4") {
754 256_000
755 } else if id.contains("llama3") {
756 128_000
757 } else if id.contains("mistral-large-3") || id.contains("magistral") {
758 128_000
759 } else if id.contains("mistral") {
760 32_000
761 } else if id.contains("qwen") {
762 128_000
763 } else if id.contains("kimi") {
764 128_000
765 } else if id.contains("jamba") {
766 256_000
767 } else if id.contains("glm") {
768 128_000
769 } else if id.contains("minimax") {
770 128_000
771 } else if id.contains("gemma") {
772 128_000
773 } else if id.contains("cohere") || id.contains("command") {
774 128_000
775 } else if id.contains("nemotron") {
776 128_000
777 } else if provider.to_lowercase().contains("amazon") {
778 128_000
779 } else {
780 32_000
781 }
782 }
783
784 fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
786 let id = model_id.to_lowercase();
787 if id.contains("claude-opus-4-6") {
788 32_000
789 } else if id.contains("claude-opus-4-5") {
790 32_000
791 } else if id.contains("claude-opus-4-1") {
792 32_000
793 } else if id.contains("claude-sonnet-4-5")
794 || id.contains("claude-sonnet-4")
795 || id.contains("claude-3-7")
796 {
797 64_000
798 } else if id.contains("claude-haiku-4-5") {
799 16_384
800 } else if id.contains("claude-opus-4") {
801 32_000
802 } else if id.contains("claude") {
803 8_192
804 } else if id.contains("nova") {
805 5_000
806 } else if id.contains("deepseek") {
807 16_384
808 } else if id.contains("llama4") {
809 16_384
810 } else if id.contains("llama") {
811 4_096
812 } else if id.contains("mistral-large-3") {
813 16_384
814 } else if id.contains("mistral") || id.contains("mixtral") {
815 8_192
816 } else if id.contains("qwen") {
817 8_192
818 } else if id.contains("kimi") {
819 8_192
820 } else if id.contains("jamba") {
821 4_096
822 } else {
823 4_096
824 }
825 }
826
827 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
841 let mut system_parts: Vec<Value> = Vec::new();
842 let mut api_messages: Vec<Value> = Vec::new();
843
844 for msg in messages {
845 match msg.role {
846 Role::System => {
847 let text: String = msg
848 .content
849 .iter()
850 .filter_map(|p| match p {
851 ContentPart::Text { text } => Some(text.clone()),
852 _ => None,
853 })
854 .collect::<Vec<_>>()
855 .join("\n");
856 system_parts.push(json!({"text": text}));
857 }
858 Role::User => {
859 let mut content_parts: Vec<Value> = Vec::new();
860 for part in &msg.content {
861 match part {
862 ContentPart::Text { text } => {
863 if !text.is_empty() {
864 content_parts.push(json!({"text": text}));
865 }
866 }
867 _ => {}
868 }
869 }
870 if !content_parts.is_empty() {
871 if let Some(last) = api_messages.last_mut() {
873 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
874 if let Some(arr) =
875 last.get_mut("content").and_then(|c| c.as_array_mut())
876 {
877 arr.extend(content_parts);
878 continue;
879 }
880 }
881 }
882 api_messages.push(json!({
883 "role": "user",
884 "content": content_parts
885 }));
886 }
887 }
888 Role::Assistant => {
889 let mut content_parts: Vec<Value> = Vec::new();
890 for part in &msg.content {
891 match part {
892 ContentPart::Text { text } => {
893 if !text.is_empty() {
894 content_parts.push(json!({"text": text}));
895 }
896 }
897 ContentPart::ToolCall {
898 id,
899 name,
900 arguments,
901 } => {
902 let input: Value = serde_json::from_str(arguments)
903 .unwrap_or_else(|_| json!({"raw": arguments}));
904 content_parts.push(json!({
905 "toolUse": {
906 "toolUseId": id,
907 "name": name,
908 "input": input
909 }
910 }));
911 }
912 _ => {}
913 }
914 }
915 if content_parts.is_empty() {
916 content_parts.push(json!({"text": " "}));
917 }
918 if let Some(last) = api_messages.last_mut() {
920 if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
921 if let Some(arr) =
922 last.get_mut("content").and_then(|c| c.as_array_mut())
923 {
924 arr.extend(content_parts);
925 continue;
926 }
927 }
928 }
929 api_messages.push(json!({
930 "role": "assistant",
931 "content": content_parts
932 }));
933 }
934 Role::Tool => {
935 let mut content_parts: Vec<Value> = Vec::new();
939 for part in &msg.content {
940 if let ContentPart::ToolResult {
941 tool_call_id,
942 content,
943 } = part
944 {
945 content_parts.push(json!({
946 "toolResult": {
947 "toolUseId": tool_call_id,
948 "content": [{"text": content}],
949 "status": "success"
950 }
951 }));
952 }
953 }
954 if !content_parts.is_empty() {
955 if let Some(last) = api_messages.last_mut() {
957 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
958 if let Some(arr) =
959 last.get_mut("content").and_then(|c| c.as_array_mut())
960 {
961 arr.extend(content_parts);
962 continue;
963 }
964 }
965 }
966 api_messages.push(json!({
967 "role": "user",
968 "content": content_parts
969 }));
970 }
971 }
972 }
973 }
974
975 (system_parts, api_messages)
976 }
977
978 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
979 tools
980 .iter()
981 .map(|t| {
982 json!({
983 "toolSpec": {
984 "name": t.name,
985 "description": t.description,
986 "inputSchema": {
987 "json": t.parameters
988 }
989 }
990 })
991 })
992 .collect()
993 }
994}
995
996#[derive(Debug, Deserialize)]
999#[serde(rename_all = "camelCase")]
1000struct ConverseResponse {
1001 output: ConverseOutput,
1002 #[serde(default)]
1003 stop_reason: Option<String>,
1004 #[serde(default)]
1005 usage: Option<ConverseUsage>,
1006}
1007
1008#[derive(Debug, Deserialize)]
1009struct ConverseOutput {
1010 message: ConverseMessage,
1011}
1012
1013#[derive(Debug, Deserialize)]
1014struct ConverseMessage {
1015 #[allow(dead_code)]
1016 role: String,
1017 content: Vec<ConverseContent>,
1018}
1019
1020#[derive(Debug, Deserialize)]
1021#[serde(untagged)]
1022enum ConverseContent {
1023 ReasoningContent {
1024 #[serde(rename = "reasoningContent")]
1025 reasoning_content: ReasoningContentBlock,
1026 },
1027 Text {
1028 text: String,
1029 },
1030 ToolUse {
1031 #[serde(rename = "toolUse")]
1032 tool_use: ConverseToolUse,
1033 },
1034}
1035
1036#[derive(Debug, Deserialize)]
1037#[serde(rename_all = "camelCase")]
1038struct ReasoningContentBlock {
1039 reasoning_text: ReasoningText,
1040}
1041
1042#[derive(Debug, Deserialize)]
1043struct ReasoningText {
1044 text: String,
1045}
1046
1047#[derive(Debug, Deserialize)]
1048#[serde(rename_all = "camelCase")]
1049struct ConverseToolUse {
1050 tool_use_id: String,
1051 name: String,
1052 input: Value,
1053}
1054
1055#[derive(Debug, Deserialize)]
1056#[serde(rename_all = "camelCase")]
1057struct ConverseUsage {
1058 #[serde(default)]
1059 input_tokens: usize,
1060 #[serde(default)]
1061 output_tokens: usize,
1062 #[serde(default)]
1063 total_tokens: usize,
1064}
1065
1066#[derive(Debug, Deserialize)]
1067struct BedrockError {
1068 message: String,
1069}
1070
1071#[async_trait]
1072impl Provider for BedrockProvider {
1073 fn name(&self) -> &str {
1074 "bedrock"
1075 }
1076
1077 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
1078 self.validate_auth()?;
1079 self.discover_models().await
1080 }
1081
1082 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
1083 let model_id = Self::resolve_model_id(&request.model);
1084
1085 tracing::debug!(
1086 provider = "bedrock",
1087 model = %model_id,
1088 original_model = %request.model,
1089 message_count = request.messages.len(),
1090 tool_count = request.tools.len(),
1091 "Starting Bedrock Converse request"
1092 );
1093
1094 self.validate_auth()?;
1095
1096 let (system_parts, messages) = Self::convert_messages(&request.messages);
1097 let tools = Self::convert_tools(&request.tools);
1098
1099 let mut body = json!({
1100 "messages": messages,
1101 });
1102
1103 if !system_parts.is_empty() {
1104 body["system"] = json!(system_parts);
1105 }
1106
1107 let mut inference_config = json!({});
1109 if let Some(max_tokens) = request.max_tokens {
1110 inference_config["maxTokens"] = json!(max_tokens);
1111 } else {
1112 inference_config["maxTokens"] = json!(8192);
1113 }
1114 if let Some(temp) = request.temperature {
1115 inference_config["temperature"] = json!(temp);
1116 }
1117 if let Some(top_p) = request.top_p {
1118 inference_config["topP"] = json!(top_p);
1119 }
1120 body["inferenceConfig"] = inference_config;
1121
1122 if !tools.is_empty() {
1123 body["toolConfig"] = json!({"tools": tools});
1124 }
1125
1126 let encoded_model_id = model_id.replace(':', "%3A");
1128 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
1129 tracing::debug!("Bedrock request URL: {}", url);
1130
1131 let body_bytes = serde_json::to_vec(&body)?;
1132 let response = self
1133 .send_request("POST", &url, Some(&body_bytes), "bedrock-runtime")
1134 .await?;
1135
1136 let status = response.status();
1137 let text = response
1138 .text()
1139 .await
1140 .context("Failed to read Bedrock response")?;
1141
1142 if !status.is_success() {
1143 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
1144 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
1145 }
1146 anyhow::bail!(
1147 "Bedrock API error: {} {}",
1148 status,
1149 &text[..text.len().min(500)]
1150 );
1151 }
1152
1153 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
1154 "Failed to parse Bedrock response: {}",
1155 &text[..text.len().min(300)]
1156 ))?;
1157
1158 tracing::debug!(
1159 stop_reason = ?response.stop_reason,
1160 "Received Bedrock response"
1161 );
1162
1163 let mut content = Vec::new();
1164 let mut has_tool_calls = false;
1165
1166 for part in &response.output.message.content {
1167 match part {
1168 ConverseContent::ReasoningContent { reasoning_content } => {
1169 if !reasoning_content.reasoning_text.text.is_empty() {
1170 content.push(ContentPart::Thinking {
1171 text: reasoning_content.reasoning_text.text.clone(),
1172 });
1173 }
1174 }
1175 ConverseContent::Text { text } => {
1176 if !text.is_empty() {
1177 content.push(ContentPart::Text { text: text.clone() });
1178 }
1179 }
1180 ConverseContent::ToolUse { tool_use } => {
1181 has_tool_calls = true;
1182 content.push(ContentPart::ToolCall {
1183 id: tool_use.tool_use_id.clone(),
1184 name: tool_use.name.clone(),
1185 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1186 });
1187 }
1188 }
1189 }
1190
1191 let finish_reason = if has_tool_calls {
1192 FinishReason::ToolCalls
1193 } else {
1194 match response.stop_reason.as_deref() {
1195 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
1196 Some("max_tokens") => FinishReason::Length,
1197 Some("tool_use") => FinishReason::ToolCalls,
1198 Some("content_filtered") => FinishReason::ContentFilter,
1199 _ => FinishReason::Stop,
1200 }
1201 };
1202
1203 let usage = response.usage.as_ref();
1204
1205 Ok(CompletionResponse {
1206 message: Message {
1207 role: Role::Assistant,
1208 content,
1209 },
1210 usage: Usage {
1211 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
1212 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
1213 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
1214 cache_read_tokens: None,
1215 cache_write_tokens: None,
1216 },
1217 finish_reason,
1218 })
1219 }
1220
1221 async fn complete_stream(
1222 &self,
1223 request: CompletionRequest,
1224 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
1225 let response = self.complete(request).await?;
1227 let text = response
1228 .message
1229 .content
1230 .iter()
1231 .filter_map(|p| match p {
1232 ContentPart::Text { text } => Some(text.clone()),
1233 _ => None,
1234 })
1235 .collect::<Vec<_>>()
1236 .join("");
1237
1238 Ok(Box::pin(futures::stream::once(async move {
1239 StreamChunk::Text(text)
1240 })))
1241 }
1242}