1use super::{
13 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
14 Role, StreamChunk, ToolDefinition, Usage,
15};
16use anyhow::{Context, Result};
17use async_trait::async_trait;
18use hmac::{Hmac, Mac};
19use reqwest::Client;
20use serde::Deserialize;
21use serde_json::{Value, json};
22use sha2::{Digest, Sha256};
23use std::collections::HashMap;
24
25pub const DEFAULT_REGION: &str = "us-east-1";
26
27#[derive(Debug, Clone)]
29pub struct AwsCredentials {
30 pub access_key_id: String,
31 pub secret_access_key: String,
32 pub session_token: Option<String>,
33}
34
35impl AwsCredentials {
36 pub fn from_environment() -> Option<Self> {
39 if let (Ok(key_id), Ok(secret)) = (
41 std::env::var("AWS_ACCESS_KEY_ID"),
42 std::env::var("AWS_SECRET_ACCESS_KEY"),
43 ) {
44 if !key_id.is_empty() && !secret.is_empty() {
45 return Some(Self {
46 access_key_id: key_id,
47 secret_access_key: secret,
48 session_token: std::env::var("AWS_SESSION_TOKEN")
49 .ok()
50 .filter(|s| !s.is_empty()),
51 });
52 }
53 }
54
55 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
57 Self::from_credentials_file(&profile)
58 }
59
60 fn from_credentials_file(profile: &str) -> Option<Self> {
62 let home = std::env::var("HOME")
63 .or_else(|_| std::env::var("USERPROFILE"))
64 .ok()?;
65 let path = std::path::Path::new(&home).join(".aws").join("credentials");
66 let content = std::fs::read_to_string(&path).ok()?;
67
68 let section_header = format!("[{}]", profile);
69 let mut in_section = false;
70 let mut key_id = None;
71 let mut secret = None;
72 let mut token = None;
73
74 for line in content.lines() {
75 let trimmed = line.trim();
76 if trimmed.starts_with('[') {
77 in_section = trimmed == section_header;
78 continue;
79 }
80 if !in_section {
81 continue;
82 }
83 if let Some((k, v)) = trimmed.split_once('=') {
84 let k = k.trim();
85 let v = v.trim();
86 match k {
87 "aws_access_key_id" => key_id = Some(v.to_string()),
88 "aws_secret_access_key" => secret = Some(v.to_string()),
89 "aws_session_token" => token = Some(v.to_string()),
90 _ => {}
91 }
92 }
93 }
94
95 Some(Self {
96 access_key_id: key_id?,
97 secret_access_key: secret?,
98 session_token: token,
99 })
100 }
101
102 pub fn detect_region() -> Option<String> {
105 if let Ok(r) = std::env::var("AWS_REGION") {
106 if !r.is_empty() {
107 return Some(r);
108 }
109 }
110 if let Ok(r) = std::env::var("AWS_DEFAULT_REGION") {
111 if !r.is_empty() {
112 return Some(r);
113 }
114 }
115 let profile = std::env::var("AWS_PROFILE").unwrap_or_else(|_| "default".to_string());
117 let home = std::env::var("HOME")
118 .or_else(|_| std::env::var("USERPROFILE"))
119 .ok()?;
120 let path = std::path::Path::new(&home).join(".aws").join("config");
121 let content = std::fs::read_to_string(&path).ok()?;
122
123 let section_header = if profile == "default" {
125 "[default]".to_string()
126 } else {
127 format!("[profile {}]", profile)
128 };
129 let mut in_section = false;
130 for line in content.lines() {
131 let trimmed = line.trim();
132 if trimmed.starts_with('[') {
133 in_section = trimmed == section_header;
134 continue;
135 }
136 if !in_section {
137 continue;
138 }
139 if let Some((k, v)) = trimmed.split_once('=') {
140 if k.trim() == "region" {
141 let v = v.trim();
142 if !v.is_empty() {
143 return Some(v.to_string());
144 }
145 }
146 }
147 }
148 None
149 }
150}
151
152#[derive(Debug, Clone)]
154pub enum BedrockAuth {
155 SigV4(AwsCredentials),
157 BearerToken(String),
159}
160
161pub struct BedrockProvider {
162 client: Client,
163 auth: BedrockAuth,
164 region: String,
165}
166
167impl std::fmt::Debug for BedrockProvider {
168 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
169 f.debug_struct("BedrockProvider")
170 .field(
171 "auth",
172 &match &self.auth {
173 BedrockAuth::SigV4(_) => "SigV4",
174 BedrockAuth::BearerToken(_) => "BearerToken",
175 },
176 )
177 .field("region", &self.region)
178 .finish()
179 }
180}
181
182impl BedrockProvider {
183 pub fn new(api_key: String) -> Result<Self> {
185 Self::with_region(api_key, DEFAULT_REGION.to_string())
186 }
187
188 pub fn with_region(api_key: String, region: String) -> Result<Self> {
190 tracing::debug!(
191 provider = "bedrock",
192 region = %region,
193 auth = "bearer_token",
194 "Creating Bedrock provider"
195 );
196 Ok(Self {
197 client: Client::new(),
198 auth: BedrockAuth::BearerToken(api_key),
199 region,
200 })
201 }
202
203 pub fn with_credentials(credentials: AwsCredentials, region: String) -> Result<Self> {
205 tracing::debug!(
206 provider = "bedrock",
207 region = %region,
208 auth = "sigv4",
209 "Creating Bedrock provider with AWS credentials"
210 );
211 Ok(Self {
212 client: Client::new(),
213 auth: BedrockAuth::SigV4(credentials),
214 region,
215 })
216 }
217
218 pub async fn send_converse_request(&self, url: &str, body: &[u8]) -> Result<reqwest::Response> {
221 self.send_request("POST", url, Some(body), "bedrock-runtime")
222 .await
223 }
224
225 fn validate_auth(&self) -> Result<()> {
226 match &self.auth {
227 BedrockAuth::BearerToken(key) => {
228 if key.is_empty() {
229 anyhow::bail!("Bedrock API key is empty");
230 }
231 }
232 BedrockAuth::SigV4(creds) => {
233 if creds.access_key_id.is_empty() || creds.secret_access_key.is_empty() {
234 anyhow::bail!("AWS credentials are incomplete");
235 }
236 }
237 }
238 Ok(())
239 }
240
241 fn base_url(&self) -> String {
242 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
243 }
244
245 fn management_url(&self) -> String {
247 format!("https://bedrock.{}.amazonaws.com", self.region)
248 }
249
250 fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
253 let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
254 mac.update(data);
255 mac.finalize().into_bytes().to_vec()
256 }
257
258 fn sha256_hex(data: &[u8]) -> String {
259 let mut hasher = Sha256::new();
260 hasher.update(data);
261 hex::encode(hasher.finalize())
262 }
263
264 async fn send_signed_request(
266 &self,
267 method: &str,
268 url: &str,
269 body: &[u8],
270 service: &str,
271 ) -> Result<reqwest::Response> {
272 let creds = match &self.auth {
273 BedrockAuth::SigV4(c) => c,
274 BedrockAuth::BearerToken(_) => {
275 anyhow::bail!("send_signed_request called with bearer token auth");
276 }
277 };
278
279 let now = chrono::Utc::now();
280 let datestamp = now.format("%Y%m%d").to_string();
281 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
282
283 let host_start = url.find("://").map(|i| i + 3).unwrap_or(0);
285 let after_host = url[host_start..]
286 .find('/')
287 .map(|i| host_start + i)
288 .unwrap_or(url.len());
289 let host = url[host_start..after_host].to_string();
290 let path_and_query = &url[after_host..];
291 let (canonical_uri, canonical_querystring) = match path_and_query.split_once('?') {
292 Some((p, q)) => (p.to_string(), q.to_string()),
293 None => (path_and_query.to_string(), String::new()),
294 };
295
296 let payload_hash = Self::sha256_hex(body);
297
298 let mut headers_map: Vec<(&str, String)> = vec![
300 ("content-type", "application/json".to_string()),
301 ("host", host.clone()),
302 ("x-amz-date", amz_date.clone()),
303 ];
304 if let Some(token) = &creds.session_token {
305 headers_map.push(("x-amz-security-token", token.clone()));
306 }
307 headers_map.sort_by_key(|(k, _)| *k);
308
309 let canonical_headers: String = headers_map
310 .iter()
311 .map(|(k, v)| format!("{}:{}", k, v))
312 .collect::<Vec<_>>()
313 .join("\n")
314 + "\n";
315
316 let signed_headers: String = headers_map
317 .iter()
318 .map(|(k, _)| *k)
319 .collect::<Vec<_>>()
320 .join(";");
321
322 let canonical_request = format!(
323 "{}\n{}\n{}\n{}\n{}\n{}",
324 method,
325 canonical_uri,
326 canonical_querystring,
327 canonical_headers,
328 signed_headers,
329 payload_hash
330 );
331
332 let credential_scope = format!("{}/{}/{}/aws4_request", datestamp, self.region, service);
333
334 let string_to_sign = format!(
335 "AWS4-HMAC-SHA256\n{}\n{}\n{}",
336 amz_date,
337 credential_scope,
338 Self::sha256_hex(canonical_request.as_bytes())
339 );
340
341 let k_date = Self::hmac_sha256(
343 format!("AWS4{}", creds.secret_access_key).as_bytes(),
344 datestamp.as_bytes(),
345 );
346 let k_region = Self::hmac_sha256(&k_date, self.region.as_bytes());
347 let k_service = Self::hmac_sha256(&k_region, service.as_bytes());
348 let k_signing = Self::hmac_sha256(&k_service, b"aws4_request");
349
350 let signature = hex::encode(Self::hmac_sha256(&k_signing, string_to_sign.as_bytes()));
351
352 let authorization = format!(
353 "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
354 creds.access_key_id, credential_scope, signed_headers, signature
355 );
356
357 let mut req = self
358 .client
359 .request(method.parse().unwrap_or(reqwest::Method::POST), url)
360 .header("content-type", "application/json")
361 .header("host", &host)
362 .header("x-amz-date", &amz_date)
363 .header("x-amz-content-sha256", &payload_hash)
364 .header("authorization", &authorization);
365
366 if let Some(token) = &creds.session_token {
367 req = req.header("x-amz-security-token", token);
368 }
369
370 if method == "POST" || method == "PUT" {
371 req = req.body(body.to_vec());
372 }
373
374 req.send()
375 .await
376 .context("Failed to send signed request to Bedrock")
377 }
378
379 async fn send_request(
381 &self,
382 method: &str,
383 url: &str,
384 body: Option<&[u8]>,
385 service: &str,
386 ) -> Result<reqwest::Response> {
387 match &self.auth {
388 BedrockAuth::SigV4(_) => {
389 self.send_signed_request(method, url, body.unwrap_or(b""), service)
390 .await
391 }
392 BedrockAuth::BearerToken(token) => {
393 let mut req = self
394 .client
395 .request(method.parse().unwrap_or(reqwest::Method::GET), url)
396 .bearer_auth(token)
397 .header("content-type", "application/json")
398 .header("accept", "application/json");
399
400 if let Some(b) = body {
401 req = req.body(b.to_vec());
402 }
403
404 req.send()
405 .await
406 .context("Failed to send request to Bedrock")
407 }
408 }
409 }
410
411 fn resolve_model_id(model: &str) -> &str {
415 match model {
416 "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
418 "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
419 "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
420 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
421 "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
422 "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
423 }
424 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
425 "claude-haiku-4.5" | "claude-4.5-haiku" => {
426 "us.anthropic.claude-haiku-4-5-20251001-v1:0"
427 }
428 "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
429 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
430 }
431 "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
432 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
433 }
434 "claude-3.5-haiku" | "claude-haiku-3.5" => {
435 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
436 }
437 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
438 "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
439 }
440 "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
441 "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
442 "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
443
444 "nova-pro" => "amazon.nova-pro-v1:0",
446 "nova-lite" => "amazon.nova-lite-v1:0",
447 "nova-micro" => "amazon.nova-micro-v1:0",
448 "nova-premier" => "us.amazon.nova-premier-v1:0",
449
450 "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
452 "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
453 "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
454 "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
455 "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
456 "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
457 "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
458 "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
459 "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
460 "llama-3-70b" | "llama3-70b" => "meta.llama3-70b-instruct-v1:0",
461 "llama-3-8b" | "llama3-8b" => "meta.llama3-8b-instruct-v1:0",
462
463 "mistral-large-3" | "mistral-large" => "mistral.mistral-large-3-675b-instruct",
465 "mistral-large-2402" => "mistral.mistral-large-2402-v1:0",
466 "mistral-small" => "mistral.mistral-small-2402-v1:0",
467 "mixtral-8x7b" => "mistral.mixtral-8x7b-instruct-v0:1",
468 "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
469 "magistral-small" => "mistral.magistral-small-2509",
470
471 "deepseek-r1" => "us.deepseek.r1-v1:0",
473 "deepseek-v3" | "deepseek-v3.2" => "deepseek.v3.2",
474
475 "command-r" => "cohere.command-r-v1:0",
477 "command-r-plus" => "cohere.command-r-plus-v1:0",
478
479 "qwen3-32b" => "qwen.qwen3-32b-v1:0",
481 "qwen3-coder" | "qwen3-coder-next" => "qwen.qwen3-coder-next",
482 "qwen3-coder-30b" => "qwen.qwen3-coder-30b-a3b-v1:0",
483
484 "gemma-3-27b" => "google.gemma-3-27b-it",
486 "gemma-3-12b" => "google.gemma-3-12b-it",
487 "gemma-3-4b" => "google.gemma-3-4b-it",
488
489 "kimi-k2" | "kimi-k2-thinking" => "moonshot.kimi-k2-thinking",
491 "kimi-k2.5" => "moonshotai.kimi-k2.5",
492
493 "jamba-1.5-large" => "ai21.jamba-1-5-large-v1:0",
495 "jamba-1.5-mini" => "ai21.jamba-1-5-mini-v1:0",
496
497 "minimax-m2" => "minimax.minimax-m2",
499 "minimax-m2.1" => "minimax.minimax-m2.1",
500
501 "nemotron-nano-30b" => "nvidia.nemotron-nano-3-30b",
503 "nemotron-nano-12b" => "nvidia.nemotron-nano-12b-v2",
504 "nemotron-nano-9b" => "nvidia.nemotron-nano-9b-v2",
505
506 "glm-5" => "zai.glm-5",
508 "glm-4.7" => "zai.glm-4.7",
509 "glm-4.7-flash" => "zai.glm-4.7-flash",
510
511 other => other,
513 }
514 }
515
516 async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
519 let mut models: HashMap<String, ModelInfo> = HashMap::new();
520
521 let fm_url = format!("{}/foundation-models", self.management_url());
523 let fm_resp = self.send_request("GET", &fm_url, None, "bedrock").await;
524
525 if let Ok(resp) = fm_resp {
526 if resp.status().is_success() {
527 if let Ok(data) = resp.json::<Value>().await {
528 if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
529 for m in summaries {
530 let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
531 let model_name =
532 m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
533 let provider_name =
534 m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
535
536 let output_modalities: Vec<&str> = m
537 .get("outputModalities")
538 .and_then(|v| v.as_array())
539 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
540 .unwrap_or_default();
541
542 let input_modalities: Vec<&str> = m
543 .get("inputModalities")
544 .and_then(|v| v.as_array())
545 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
546 .unwrap_or_default();
547
548 let inference_types: Vec<&str> = m
549 .get("inferenceTypesSupported")
550 .and_then(|v| v.as_array())
551 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
552 .unwrap_or_default();
553
554 if !output_modalities.contains(&"TEXT")
556 || (!inference_types.contains(&"ON_DEMAND")
557 && !inference_types.contains(&"INFERENCE_PROFILE"))
558 {
559 continue;
560 }
561
562 let name_lower = model_name.to_lowercase();
564 if name_lower.contains("rerank")
565 || name_lower.contains("embed")
566 || name_lower.contains("safeguard")
567 || name_lower.contains("sonic")
568 || name_lower.contains("pegasus")
569 {
570 continue;
571 }
572
573 let streaming = m
574 .get("responseStreamingSupported")
575 .and_then(|v| v.as_bool())
576 .unwrap_or(false);
577 let vision = input_modalities.contains(&"IMAGE");
578
579 let actual_id = if model_id.starts_with("amazon.") {
583 model_id.to_string()
584 } else if inference_types.contains(&"INFERENCE_PROFILE") {
585 format!("us.{}", model_id)
586 } else {
587 model_id.to_string()
588 };
589
590 let display_name = format!("{} (Bedrock)", model_name);
591
592 models.insert(
593 actual_id.clone(),
594 ModelInfo {
595 id: actual_id,
596 name: display_name,
597 provider: "bedrock".to_string(),
598 context_window: Self::estimate_context_window(
599 model_id,
600 provider_name,
601 ),
602 max_output_tokens: Some(Self::estimate_max_output(
603 model_id,
604 provider_name,
605 )),
606 supports_vision: vision,
607 supports_tools: true,
608 supports_streaming: streaming,
609 input_cost_per_million: None,
610 output_cost_per_million: None,
611 },
612 );
613 }
614 }
615 }
616 }
617 }
618
619 let ip_url = format!(
622 "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
623 self.management_url()
624 );
625 let ip_resp = self.send_request("GET", &ip_url, None, "bedrock").await;
626
627 if let Ok(resp) = ip_resp {
628 if resp.status().is_success() {
629 if let Ok(data) = resp.json::<Value>().await {
630 if let Some(profiles) = data
631 .get("inferenceProfileSummaries")
632 .and_then(|v| v.as_array())
633 {
634 for p in profiles {
635 let pid = p
636 .get("inferenceProfileId")
637 .and_then(|v| v.as_str())
638 .unwrap_or("");
639 let pname = p
640 .get("inferenceProfileName")
641 .and_then(|v| v.as_str())
642 .unwrap_or("");
643
644 if !pid.starts_with("us.") {
646 continue;
647 }
648
649 if models.contains_key(pid) {
651 continue;
652 }
653
654 let name_lower = pname.to_lowercase();
656 if name_lower.contains("image")
657 || name_lower.contains("stable ")
658 || name_lower.contains("upscale")
659 || name_lower.contains("embed")
660 || name_lower.contains("marengo")
661 || name_lower.contains("outpaint")
662 || name_lower.contains("inpaint")
663 || name_lower.contains("erase")
664 || name_lower.contains("recolor")
665 || name_lower.contains("replace")
666 || name_lower.contains("style ")
667 || name_lower.contains("background")
668 || name_lower.contains("sketch")
669 || name_lower.contains("control")
670 || name_lower.contains("transfer")
671 || name_lower.contains("sonic")
672 || name_lower.contains("pegasus")
673 || name_lower.contains("rerank")
674 {
675 continue;
676 }
677
678 let vision = pid.contains("llama3-2-11b")
680 || pid.contains("llama3-2-90b")
681 || pid.contains("pixtral")
682 || pid.contains("claude-3")
683 || pid.contains("claude-sonnet-4")
684 || pid.contains("claude-opus-4")
685 || pid.contains("claude-haiku-4");
686
687 let display_name = pname.replace("US ", "");
688 let display_name = format!("{} (Bedrock)", display_name.trim());
689
690 let provider_hint = pid
692 .strip_prefix("us.")
693 .unwrap_or(pid)
694 .split('.')
695 .next()
696 .unwrap_or("");
697
698 models.insert(
699 pid.to_string(),
700 ModelInfo {
701 id: pid.to_string(),
702 name: display_name,
703 provider: "bedrock".to_string(),
704 context_window: Self::estimate_context_window(
705 pid,
706 provider_hint,
707 ),
708 max_output_tokens: Some(Self::estimate_max_output(
709 pid,
710 provider_hint,
711 )),
712 supports_vision: vision,
713 supports_tools: true,
714 supports_streaming: true,
715 input_cost_per_million: None,
716 output_cost_per_million: None,
717 },
718 );
719 }
720 }
721 }
722 }
723 }
724
725 let mut result: Vec<ModelInfo> = models.into_values().collect();
726 result.sort_by(|a, b| a.id.cmp(&b.id));
727
728 tracing::info!(
729 provider = "bedrock",
730 model_count = result.len(),
731 "Discovered Bedrock models dynamically"
732 );
733
734 Ok(result)
735 }
736
737 fn estimate_context_window(model_id: &str, provider: &str) -> usize {
739 let id = model_id.to_lowercase();
740 if id.contains("anthropic") || id.contains("claude") {
741 200_000
742 } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
743 {
744 300_000
745 } else if id.contains("nova-micro") || id.contains("nova-2") {
746 128_000
747 } else if id.contains("deepseek") {
748 128_000
749 } else if id.contains("llama4") {
750 256_000
751 } else if id.contains("llama3") {
752 128_000
753 } else if id.contains("mistral-large-3") || id.contains("magistral") {
754 128_000
755 } else if id.contains("mistral") {
756 32_000
757 } else if id.contains("qwen") {
758 128_000
759 } else if id.contains("kimi") {
760 128_000
761 } else if id.contains("jamba") {
762 256_000
763 } else if id.contains("glm") {
764 128_000
765 } else if id.contains("minimax") {
766 128_000
767 } else if id.contains("gemma") {
768 128_000
769 } else if id.contains("cohere") || id.contains("command") {
770 128_000
771 } else if id.contains("nemotron") {
772 128_000
773 } else if provider.to_lowercase().contains("amazon") {
774 128_000
775 } else {
776 32_000
777 }
778 }
779
780 fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
782 let id = model_id.to_lowercase();
783 if id.contains("claude-opus-4-6") {
784 32_000
785 } else if id.contains("claude-opus-4-5") {
786 32_000
787 } else if id.contains("claude-opus-4-1") {
788 32_000
789 } else if id.contains("claude-sonnet-4-5")
790 || id.contains("claude-sonnet-4")
791 || id.contains("claude-3-7")
792 {
793 64_000
794 } else if id.contains("claude-haiku-4-5") {
795 16_384
796 } else if id.contains("claude-opus-4") {
797 32_000
798 } else if id.contains("claude") {
799 8_192
800 } else if id.contains("nova") {
801 5_000
802 } else if id.contains("deepseek") {
803 16_384
804 } else if id.contains("llama4") {
805 16_384
806 } else if id.contains("llama") {
807 4_096
808 } else if id.contains("mistral-large-3") {
809 16_384
810 } else if id.contains("mistral") || id.contains("mixtral") {
811 8_192
812 } else if id.contains("qwen") {
813 8_192
814 } else if id.contains("kimi") {
815 8_192
816 } else if id.contains("jamba") {
817 4_096
818 } else {
819 4_096
820 }
821 }
822
823 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
837 let mut system_parts: Vec<Value> = Vec::new();
838 let mut api_messages: Vec<Value> = Vec::new();
839
840 for msg in messages {
841 match msg.role {
842 Role::System => {
843 let text: String = msg
844 .content
845 .iter()
846 .filter_map(|p| match p {
847 ContentPart::Text { text } => Some(text.clone()),
848 _ => None,
849 })
850 .collect::<Vec<_>>()
851 .join("\n");
852 system_parts.push(json!({"text": text}));
853 }
854 Role::User => {
855 let mut content_parts: Vec<Value> = Vec::new();
856 for part in &msg.content {
857 match part {
858 ContentPart::Text { text } => {
859 if !text.is_empty() {
860 content_parts.push(json!({"text": text}));
861 }
862 }
863 _ => {}
864 }
865 }
866 if !content_parts.is_empty() {
867 if let Some(last) = api_messages.last_mut() {
869 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
870 if let Some(arr) =
871 last.get_mut("content").and_then(|c| c.as_array_mut())
872 {
873 arr.extend(content_parts);
874 continue;
875 }
876 }
877 }
878 api_messages.push(json!({
879 "role": "user",
880 "content": content_parts
881 }));
882 }
883 }
884 Role::Assistant => {
885 let mut content_parts: Vec<Value> = Vec::new();
886 for part in &msg.content {
887 match part {
888 ContentPart::Text { text } => {
889 if !text.is_empty() {
890 content_parts.push(json!({"text": text}));
891 }
892 }
893 ContentPart::ToolCall {
894 id,
895 name,
896 arguments,
897 } => {
898 let input: Value = serde_json::from_str(arguments)
899 .unwrap_or_else(|_| json!({"raw": arguments}));
900 content_parts.push(json!({
901 "toolUse": {
902 "toolUseId": id,
903 "name": name,
904 "input": input
905 }
906 }));
907 }
908 _ => {}
909 }
910 }
911 if content_parts.is_empty() {
912 content_parts.push(json!({"text": " "}));
913 }
914 if let Some(last) = api_messages.last_mut() {
916 if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
917 if let Some(arr) =
918 last.get_mut("content").and_then(|c| c.as_array_mut())
919 {
920 arr.extend(content_parts);
921 continue;
922 }
923 }
924 }
925 api_messages.push(json!({
926 "role": "assistant",
927 "content": content_parts
928 }));
929 }
930 Role::Tool => {
931 let mut content_parts: Vec<Value> = Vec::new();
935 for part in &msg.content {
936 if let ContentPart::ToolResult {
937 tool_call_id,
938 content,
939 } = part
940 {
941 content_parts.push(json!({
942 "toolResult": {
943 "toolUseId": tool_call_id,
944 "content": [{"text": content}],
945 "status": "success"
946 }
947 }));
948 }
949 }
950 if !content_parts.is_empty() {
951 if let Some(last) = api_messages.last_mut() {
953 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
954 if let Some(arr) =
955 last.get_mut("content").and_then(|c| c.as_array_mut())
956 {
957 arr.extend(content_parts);
958 continue;
959 }
960 }
961 }
962 api_messages.push(json!({
963 "role": "user",
964 "content": content_parts
965 }));
966 }
967 }
968 }
969 }
970
971 (system_parts, api_messages)
972 }
973
974 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
975 tools
976 .iter()
977 .map(|t| {
978 json!({
979 "toolSpec": {
980 "name": t.name,
981 "description": t.description,
982 "inputSchema": {
983 "json": t.parameters
984 }
985 }
986 })
987 })
988 .collect()
989 }
990}
991
992#[derive(Debug, Deserialize)]
995#[serde(rename_all = "camelCase")]
996struct ConverseResponse {
997 output: ConverseOutput,
998 #[serde(default)]
999 stop_reason: Option<String>,
1000 #[serde(default)]
1001 usage: Option<ConverseUsage>,
1002}
1003
1004#[derive(Debug, Deserialize)]
1005struct ConverseOutput {
1006 message: ConverseMessage,
1007}
1008
1009#[derive(Debug, Deserialize)]
1010struct ConverseMessage {
1011 #[allow(dead_code)]
1012 role: String,
1013 content: Vec<ConverseContent>,
1014}
1015
1016#[derive(Debug, Deserialize)]
1017#[serde(untagged)]
1018enum ConverseContent {
1019 ReasoningContent {
1020 #[serde(rename = "reasoningContent")]
1021 reasoning_content: ReasoningContentBlock,
1022 },
1023 Text {
1024 text: String,
1025 },
1026 ToolUse {
1027 #[serde(rename = "toolUse")]
1028 tool_use: ConverseToolUse,
1029 },
1030}
1031
1032#[derive(Debug, Deserialize)]
1033#[serde(rename_all = "camelCase")]
1034struct ReasoningContentBlock {
1035 reasoning_text: ReasoningText,
1036}
1037
1038#[derive(Debug, Deserialize)]
1039struct ReasoningText {
1040 text: String,
1041}
1042
1043#[derive(Debug, Deserialize)]
1044#[serde(rename_all = "camelCase")]
1045struct ConverseToolUse {
1046 tool_use_id: String,
1047 name: String,
1048 input: Value,
1049}
1050
1051#[derive(Debug, Deserialize)]
1052#[serde(rename_all = "camelCase")]
1053struct ConverseUsage {
1054 #[serde(default)]
1055 input_tokens: usize,
1056 #[serde(default)]
1057 output_tokens: usize,
1058 #[serde(default)]
1059 total_tokens: usize,
1060}
1061
1062#[derive(Debug, Deserialize)]
1063struct BedrockError {
1064 message: String,
1065}
1066
1067#[async_trait]
1068impl Provider for BedrockProvider {
1069 fn name(&self) -> &str {
1070 "bedrock"
1071 }
1072
1073 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
1074 self.validate_auth()?;
1075 self.discover_models().await
1076 }
1077
1078 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
1079 let model_id = Self::resolve_model_id(&request.model);
1080
1081 tracing::debug!(
1082 provider = "bedrock",
1083 model = %model_id,
1084 original_model = %request.model,
1085 message_count = request.messages.len(),
1086 tool_count = request.tools.len(),
1087 "Starting Bedrock Converse request"
1088 );
1089
1090 self.validate_auth()?;
1091
1092 let (system_parts, messages) = Self::convert_messages(&request.messages);
1093 let tools = Self::convert_tools(&request.tools);
1094
1095 let mut body = json!({
1096 "messages": messages,
1097 });
1098
1099 if !system_parts.is_empty() {
1100 body["system"] = json!(system_parts);
1101 }
1102
1103 let mut inference_config = json!({});
1105 if let Some(max_tokens) = request.max_tokens {
1106 inference_config["maxTokens"] = json!(max_tokens);
1107 } else {
1108 inference_config["maxTokens"] = json!(8192);
1109 }
1110 if let Some(temp) = request.temperature {
1111 inference_config["temperature"] = json!(temp);
1112 }
1113 if let Some(top_p) = request.top_p {
1114 inference_config["topP"] = json!(top_p);
1115 }
1116 body["inferenceConfig"] = inference_config;
1117
1118 if !tools.is_empty() {
1119 body["toolConfig"] = json!({"tools": tools});
1120 }
1121
1122 let encoded_model_id = model_id.replace(':', "%3A");
1124 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
1125 tracing::debug!("Bedrock request URL: {}", url);
1126
1127 let body_bytes = serde_json::to_vec(&body)?;
1128 let response = self
1129 .send_request("POST", &url, Some(&body_bytes), "bedrock-runtime")
1130 .await?;
1131
1132 let status = response.status();
1133 let text = response
1134 .text()
1135 .await
1136 .context("Failed to read Bedrock response")?;
1137
1138 if !status.is_success() {
1139 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
1140 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
1141 }
1142 anyhow::bail!(
1143 "Bedrock API error: {} {}",
1144 status,
1145 &text[..text.len().min(500)]
1146 );
1147 }
1148
1149 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
1150 "Failed to parse Bedrock response: {}",
1151 &text[..text.len().min(300)]
1152 ))?;
1153
1154 tracing::debug!(
1155 stop_reason = ?response.stop_reason,
1156 "Received Bedrock response"
1157 );
1158
1159 let mut content = Vec::new();
1160 let mut has_tool_calls = false;
1161
1162 for part in &response.output.message.content {
1163 match part {
1164 ConverseContent::ReasoningContent { reasoning_content } => {
1165 if !reasoning_content.reasoning_text.text.is_empty() {
1166 content.push(ContentPart::Thinking {
1167 text: reasoning_content.reasoning_text.text.clone(),
1168 });
1169 }
1170 }
1171 ConverseContent::Text { text } => {
1172 if !text.is_empty() {
1173 content.push(ContentPart::Text { text: text.clone() });
1174 }
1175 }
1176 ConverseContent::ToolUse { tool_use } => {
1177 has_tool_calls = true;
1178 content.push(ContentPart::ToolCall {
1179 id: tool_use.tool_use_id.clone(),
1180 name: tool_use.name.clone(),
1181 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
1182 });
1183 }
1184 }
1185 }
1186
1187 let finish_reason = if has_tool_calls {
1188 FinishReason::ToolCalls
1189 } else {
1190 match response.stop_reason.as_deref() {
1191 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
1192 Some("max_tokens") => FinishReason::Length,
1193 Some("tool_use") => FinishReason::ToolCalls,
1194 Some("content_filtered") => FinishReason::ContentFilter,
1195 _ => FinishReason::Stop,
1196 }
1197 };
1198
1199 let usage = response.usage.as_ref();
1200
1201 Ok(CompletionResponse {
1202 message: Message {
1203 role: Role::Assistant,
1204 content,
1205 },
1206 usage: Usage {
1207 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
1208 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
1209 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
1210 cache_read_tokens: None,
1211 cache_write_tokens: None,
1212 },
1213 finish_reason,
1214 })
1215 }
1216
1217 async fn complete_stream(
1218 &self,
1219 request: CompletionRequest,
1220 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
1221 let response = self.complete(request).await?;
1223 let text = response
1224 .message
1225 .content
1226 .iter()
1227 .filter_map(|p| match p {
1228 ContentPart::Text { text } => Some(text.clone()),
1229 _ => None,
1230 })
1231 .collect::<Vec<_>>()
1232 .join("");
1233
1234 Ok(Box::pin(futures::stream::once(async move {
1235 StreamChunk::Text(text)
1236 })))
1237 }
1238}