Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via API Key bearer token auth.
4//! Uses the native Bedrock Converse API format.
5//! Dynamically discovers available models via the Bedrock ListFoundationModels
6//! and ListInferenceProfiles APIs.
7//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
8
9use super::{
10    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11    Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23    client: Client,
24    api_key: String,
25    region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("BedrockProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("region", &self.region)
33            .finish()
34    }
35}
36
37impl BedrockProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        Self::with_region(api_key, DEFAULT_REGION.to_string())
40    }
41
42    pub fn with_region(api_key: String, region: String) -> Result<Self> {
43        tracing::debug!(
44            provider = "bedrock",
45            region = %region,
46            api_key_len = api_key.len(),
47            "Creating Bedrock provider"
48        );
49        Ok(Self {
50            client: Client::new(),
51            api_key,
52            region,
53        })
54    }
55
56    fn validate_api_key(&self) -> Result<()> {
57        if self.api_key.is_empty() {
58            anyhow::bail!("Bedrock API key is empty");
59        }
60        Ok(())
61    }
62
63    fn base_url(&self) -> String {
64        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65    }
66
67    /// Management API URL (for listing models, not inference)
68    fn management_url(&self) -> String {
69        format!("https://bedrock.{}.amazonaws.com", self.region)
70    }
71
72    /// Resolve a short model alias to the full Bedrock model ID.
73    /// Allows users to specify e.g. "claude-sonnet-4" instead of
74    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
75    fn resolve_model_id(model: &str) -> &str {
76        match model {
77            // --- Anthropic Claude (verified via AWS CLI) ---
78            "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
79            "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
80            "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
81            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
82            "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
83                "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
84            }
85            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
86            "claude-haiku-4.5" | "claude-4.5-haiku" => {
87                "us.anthropic.claude-haiku-4-5-20251001-v1:0"
88            }
89            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
90                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
91            }
92            "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
93                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
94            }
95            "claude-3.5-haiku" | "claude-haiku-3.5" => {
96                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
97            }
98            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
99                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
100            }
101            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
102            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
103            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
104
105            // --- Amazon Nova ---
106            "nova-pro" => "amazon.nova-pro-v1:0",
107            "nova-lite" => "amazon.nova-lite-v1:0",
108            "nova-micro" => "amazon.nova-micro-v1:0",
109            "nova-premier" => "us.amazon.nova-premier-v1:0",
110
111            // --- Meta Llama ---
112            "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
113            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
114            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
115            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
116            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
117            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
118            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
119            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
120            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
121            "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
122            "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
123
124            // --- Mistral ---
125            "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
126            "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
127            "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
128            "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
129            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
130            "magistral-small" => "us.mistral.magistral-small-2509",
131
132            // --- DeepSeek ---
133            "deepseek-r1" => "us.deepseek.r1-v1:0",
134            "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
135
136            // --- Cohere ---
137            "command-r" => "us.cohere.command-r-v1:0",
138            "command-r-plus" => "us.cohere.command-r-plus-v1:0",
139
140            // --- Qwen ---
141            "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
142            "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
143            "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
144
145            // --- Google Gemma ---
146            "gemma-3-27b" => "us.google.gemma-3-27b-it",
147            "gemma-3-12b" => "us.google.gemma-3-12b-it",
148            "gemma-3-4b" => "us.google.gemma-3-4b-it",
149
150            // --- Moonshot / Kimi ---
151            "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
152            "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
153
154            // --- AI21 Jamba ---
155            "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
156            "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
157
158            // --- MiniMax ---
159            "minimax-m2" => "us.minimax.minimax-m2",
160            "minimax-m2.1" => "us.minimax.minimax-m2.1",
161
162            // --- NVIDIA ---
163            "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
164            "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
165            "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
166
167            // --- Z.AI / GLM ---
168            "glm-4.7" => "us.zai.glm-4.7",
169            "glm-4.7-flash" => "us.zai.glm-4.7-flash",
170
171            // Pass through full model IDs unchanged
172            other => other,
173        }
174    }
175
176    /// Dynamically discover available models from the Bedrock API.
177    /// Merges foundation models with cross-region inference profiles.
178    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
179        let mut models: HashMap<String, ModelInfo> = HashMap::new();
180
181        // 1) Fetch foundation models
182        let fm_url = format!("{}/foundation-models", self.management_url());
183        let fm_resp = self
184            .client
185            .get(&fm_url)
186            .bearer_auth(&self.api_key)
187            .send()
188            .await;
189
190        if let Ok(resp) = fm_resp {
191            if resp.status().is_success() {
192                if let Ok(data) = resp.json::<Value>().await {
193                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
194                        for m in summaries {
195                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
196                            let model_name =
197                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
198                            let provider_name =
199                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
200
201                            let output_modalities: Vec<&str> = m
202                                .get("outputModalities")
203                                .and_then(|v| v.as_array())
204                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
205                                .unwrap_or_default();
206
207                            let input_modalities: Vec<&str> = m
208                                .get("inputModalities")
209                                .and_then(|v| v.as_array())
210                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
211                                .unwrap_or_default();
212
213                            let inference_types: Vec<&str> = m
214                                .get("inferenceTypesSupported")
215                                .and_then(|v| v.as_array())
216                                .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
217                                .unwrap_or_default();
218
219                            // Only include TEXT output models with ON_DEMAND or INFERENCE_PROFILE inference
220                            if !output_modalities.contains(&"TEXT")
221                                || (!inference_types.contains(&"ON_DEMAND")
222                                    && !inference_types.contains(&"INFERENCE_PROFILE"))
223                            {
224                                continue;
225                            }
226
227                            // Skip non-chat models
228                            let name_lower = model_name.to_lowercase();
229                            if name_lower.contains("rerank")
230                                || name_lower.contains("embed")
231                                || name_lower.contains("safeguard")
232                                || name_lower.contains("sonic")
233                                || name_lower.contains("pegasus")
234                            {
235                                continue;
236                            }
237
238                            let streaming = m
239                                .get("responseStreamingSupported")
240                                .and_then(|v| v.as_bool())
241                                .unwrap_or(false);
242                            let vision = input_modalities.contains(&"IMAGE");
243
244                            // Non-Amazon models need us. cross-region prefix
245                            let actual_id = if model_id.starts_with("amazon.") {
246                                model_id.to_string()
247                            } else {
248                                format!("us.{}", model_id)
249                            };
250
251                            let display_name = format!("{} (Bedrock)", model_name);
252
253                            models.insert(
254                                actual_id.clone(),
255                                ModelInfo {
256                                    id: actual_id,
257                                    name: display_name,
258                                    provider: "bedrock".to_string(),
259                                    context_window: Self::estimate_context_window(
260                                        model_id,
261                                        provider_name,
262                                    ),
263                                    max_output_tokens: Some(Self::estimate_max_output(
264                                        model_id,
265                                        provider_name,
266                                    )),
267                                    supports_vision: vision,
268                                    supports_tools: true,
269                                    supports_streaming: streaming,
270                                    input_cost_per_million: None,
271                                    output_cost_per_million: None,
272                                },
273                            );
274                        }
275                    }
276                }
277            }
278        }
279
280        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
281        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
282        let ip_url = format!(
283            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
284            self.management_url()
285        );
286        let ip_resp = self
287            .client
288            .get(&ip_url)
289            .bearer_auth(&self.api_key)
290            .send()
291            .await;
292
293        if let Ok(resp) = ip_resp {
294            if resp.status().is_success() {
295                if let Ok(data) = resp.json::<Value>().await {
296                    if let Some(profiles) = data
297                        .get("inferenceProfileSummaries")
298                        .and_then(|v| v.as_array())
299                    {
300                        for p in profiles {
301                            let pid = p
302                                .get("inferenceProfileId")
303                                .and_then(|v| v.as_str())
304                                .unwrap_or("");
305                            let pname = p
306                                .get("inferenceProfileName")
307                                .and_then(|v| v.as_str())
308                                .unwrap_or("");
309
310                            // Only US cross-region profiles
311                            if !pid.starts_with("us.") {
312                                continue;
313                            }
314
315                            // Skip already-discovered models
316                            if models.contains_key(pid) {
317                                continue;
318                            }
319
320                            // Skip non-text models
321                            let name_lower = pname.to_lowercase();
322                            if name_lower.contains("image")
323                                || name_lower.contains("stable ")
324                                || name_lower.contains("upscale")
325                                || name_lower.contains("embed")
326                                || name_lower.contains("marengo")
327                                || name_lower.contains("outpaint")
328                                || name_lower.contains("inpaint")
329                                || name_lower.contains("erase")
330                                || name_lower.contains("recolor")
331                                || name_lower.contains("replace")
332                                || name_lower.contains("style ")
333                                || name_lower.contains("background")
334                                || name_lower.contains("sketch")
335                                || name_lower.contains("control")
336                                || name_lower.contains("transfer")
337                                || name_lower.contains("sonic")
338                                || name_lower.contains("pegasus")
339                                || name_lower.contains("rerank")
340                            {
341                                continue;
342                            }
343
344                            // Guess vision from known model families
345                            let vision = pid.contains("llama3-2-11b")
346                                || pid.contains("llama3-2-90b")
347                                || pid.contains("pixtral")
348                                || pid.contains("claude-3")
349                                || pid.contains("claude-sonnet-4")
350                                || pid.contains("claude-opus-4")
351                                || pid.contains("claude-haiku-4");
352
353                            let display_name = pname.replace("US ", "");
354                            let display_name = format!("{} (Bedrock)", display_name.trim());
355
356                            // Extract provider hint from model ID
357                            let provider_hint = pid
358                                .strip_prefix("us.")
359                                .unwrap_or(pid)
360                                .split('.')
361                                .next()
362                                .unwrap_or("");
363
364                            models.insert(
365                                pid.to_string(),
366                                ModelInfo {
367                                    id: pid.to_string(),
368                                    name: display_name,
369                                    provider: "bedrock".to_string(),
370                                    context_window: Self::estimate_context_window(
371                                        pid,
372                                        provider_hint,
373                                    ),
374                                    max_output_tokens: Some(Self::estimate_max_output(
375                                        pid,
376                                        provider_hint,
377                                    )),
378                                    supports_vision: vision,
379                                    supports_tools: true,
380                                    supports_streaming: true,
381                                    input_cost_per_million: None,
382                                    output_cost_per_million: None,
383                                },
384                            );
385                        }
386                    }
387                }
388            }
389        }
390
391        let mut result: Vec<ModelInfo> = models.into_values().collect();
392        result.sort_by(|a, b| a.id.cmp(&b.id));
393
394        tracing::info!(
395            provider = "bedrock",
396            model_count = result.len(),
397            "Discovered Bedrock models dynamically"
398        );
399
400        Ok(result)
401    }
402
403    /// Estimate context window size based on model family
404    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
405        let id = model_id.to_lowercase();
406        if id.contains("anthropic") || id.contains("claude") {
407            200_000
408        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
409        {
410            300_000
411        } else if id.contains("nova-micro") || id.contains("nova-2") {
412            128_000
413        } else if id.contains("deepseek") {
414            128_000
415        } else if id.contains("llama4") {
416            256_000
417        } else if id.contains("llama3") {
418            128_000
419        } else if id.contains("mistral-large-3") || id.contains("magistral") {
420            128_000
421        } else if id.contains("mistral") {
422            32_000
423        } else if id.contains("qwen") {
424            128_000
425        } else if id.contains("kimi") {
426            128_000
427        } else if id.contains("jamba") {
428            256_000
429        } else if id.contains("glm") {
430            128_000
431        } else if id.contains("minimax") {
432            128_000
433        } else if id.contains("gemma") {
434            128_000
435        } else if id.contains("cohere") || id.contains("command") {
436            128_000
437        } else if id.contains("nemotron") {
438            128_000
439        } else if provider.to_lowercase().contains("amazon") {
440            128_000
441        } else {
442            32_000
443        }
444    }
445
446    /// Estimate max output tokens based on model family
447    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
448        let id = model_id.to_lowercase();
449        if id.contains("claude-opus-4-6") {
450            32_000
451        } else if id.contains("claude-opus-4-5") {
452            32_000
453        } else if id.contains("claude-opus-4-1") {
454            32_000
455        } else if id.contains("claude-sonnet-4-5")
456            || id.contains("claude-sonnet-4")
457            || id.contains("claude-3-7")
458        {
459            64_000
460        } else if id.contains("claude-haiku-4-5") {
461            16_384
462        } else if id.contains("claude-opus-4") {
463            32_000
464        } else if id.contains("claude") {
465            8_192
466        } else if id.contains("nova") {
467            5_000
468        } else if id.contains("deepseek") {
469            16_384
470        } else if id.contains("llama4") {
471            16_384
472        } else if id.contains("llama") {
473            4_096
474        } else if id.contains("mistral-large-3") {
475            16_384
476        } else if id.contains("mistral") || id.contains("mixtral") {
477            8_192
478        } else if id.contains("qwen") {
479            8_192
480        } else if id.contains("kimi") {
481            8_192
482        } else if id.contains("jamba") {
483            4_096
484        } else {
485            4_096
486        }
487    }
488
489    /// Convert our generic messages to Bedrock Converse API format.
490    ///
491    /// Bedrock Converse uses:
492    /// - system prompt as a top-level "system" array
493    /// - messages with "role" and "content" array
494    /// - tool_use blocks in assistant content
495    /// - toolResult blocks in user content
496    ///
497    /// IMPORTANT: Bedrock requires strict role alternation (user/assistant).
498    /// Consecutive Role::Tool messages must be merged into a single "user"
499    /// message so all toolResult blocks for a given assistant turn appear
500    /// together. Consecutive same-role messages are also merged to prevent
501    /// validation errors.
502    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
503        let mut system_parts: Vec<Value> = Vec::new();
504        let mut api_messages: Vec<Value> = Vec::new();
505
506        for msg in messages {
507            match msg.role {
508                Role::System => {
509                    let text: String = msg
510                        .content
511                        .iter()
512                        .filter_map(|p| match p {
513                            ContentPart::Text { text } => Some(text.clone()),
514                            _ => None,
515                        })
516                        .collect::<Vec<_>>()
517                        .join("\n");
518                    system_parts.push(json!({"text": text}));
519                }
520                Role::User => {
521                    let mut content_parts: Vec<Value> = Vec::new();
522                    for part in &msg.content {
523                        match part {
524                            ContentPart::Text { text } => {
525                                if !text.is_empty() {
526                                    content_parts.push(json!({"text": text}));
527                                }
528                            }
529                            _ => {}
530                        }
531                    }
532                    if !content_parts.is_empty() {
533                        // Merge into previous user message if the last message is also "user"
534                        if let Some(last) = api_messages.last_mut() {
535                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
536                                if let Some(arr) =
537                                    last.get_mut("content").and_then(|c| c.as_array_mut())
538                                {
539                                    arr.extend(content_parts);
540                                    continue;
541                                }
542                            }
543                        }
544                        api_messages.push(json!({
545                            "role": "user",
546                            "content": content_parts
547                        }));
548                    }
549                }
550                Role::Assistant => {
551                    let mut content_parts: Vec<Value> = Vec::new();
552                    for part in &msg.content {
553                        match part {
554                            ContentPart::Text { text } => {
555                                if !text.is_empty() {
556                                    content_parts.push(json!({"text": text}));
557                                }
558                            }
559                            ContentPart::ToolCall {
560                                id,
561                                name,
562                                arguments,
563                            } => {
564                                let input: Value = serde_json::from_str(arguments)
565                                    .unwrap_or_else(|_| json!({"raw": arguments}));
566                                content_parts.push(json!({
567                                    "toolUse": {
568                                        "toolUseId": id,
569                                        "name": name,
570                                        "input": input
571                                    }
572                                }));
573                            }
574                            _ => {}
575                        }
576                    }
577                    if content_parts.is_empty() {
578                        content_parts.push(json!({"text": ""}));
579                    }
580                    // Merge into previous assistant message if consecutive
581                    if let Some(last) = api_messages.last_mut() {
582                        if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
583                            if let Some(arr) =
584                                last.get_mut("content").and_then(|c| c.as_array_mut())
585                            {
586                                arr.extend(content_parts);
587                                continue;
588                            }
589                        }
590                    }
591                    api_messages.push(json!({
592                        "role": "assistant",
593                        "content": content_parts
594                    }));
595                }
596                Role::Tool => {
597                    // Tool results must be in a "user" message with toolResult blocks.
598                    // Merge into the previous user message if one exists (handles
599                    // consecutive Tool messages being collapsed into one user turn).
600                    let mut content_parts: Vec<Value> = Vec::new();
601                    for part in &msg.content {
602                        if let ContentPart::ToolResult {
603                            tool_call_id,
604                            content,
605                        } = part
606                        {
607                            content_parts.push(json!({
608                                "toolResult": {
609                                    "toolUseId": tool_call_id,
610                                    "content": [{"text": content}],
611                                    "status": "success"
612                                }
613                            }));
614                        }
615                    }
616                    if !content_parts.is_empty() {
617                        // Merge into previous user message (from earlier Tool messages)
618                        if let Some(last) = api_messages.last_mut() {
619                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
620                                if let Some(arr) =
621                                    last.get_mut("content").and_then(|c| c.as_array_mut())
622                                {
623                                    arr.extend(content_parts);
624                                    continue;
625                                }
626                            }
627                        }
628                        api_messages.push(json!({
629                            "role": "user",
630                            "content": content_parts
631                        }));
632                    }
633                }
634            }
635        }
636
637        (system_parts, api_messages)
638    }
639
640    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
641        tools
642            .iter()
643            .map(|t| {
644                json!({
645                    "toolSpec": {
646                        "name": t.name,
647                        "description": t.description,
648                        "inputSchema": {
649                            "json": t.parameters
650                        }
651                    }
652                })
653            })
654            .collect()
655    }
656}
657
658/// Bedrock Converse API response types
659
660#[derive(Debug, Deserialize)]
661#[serde(rename_all = "camelCase")]
662struct ConverseResponse {
663    output: ConverseOutput,
664    #[serde(default)]
665    stop_reason: Option<String>,
666    #[serde(default)]
667    usage: Option<ConverseUsage>,
668}
669
670#[derive(Debug, Deserialize)]
671struct ConverseOutput {
672    message: ConverseMessage,
673}
674
675#[derive(Debug, Deserialize)]
676struct ConverseMessage {
677    #[allow(dead_code)]
678    role: String,
679    content: Vec<ConverseContent>,
680}
681
682#[derive(Debug, Deserialize)]
683#[serde(untagged)]
684enum ConverseContent {
685    Text {
686        text: String,
687    },
688    ToolUse {
689        #[serde(rename = "toolUse")]
690        tool_use: ConverseToolUse,
691    },
692}
693
694#[derive(Debug, Deserialize)]
695#[serde(rename_all = "camelCase")]
696struct ConverseToolUse {
697    tool_use_id: String,
698    name: String,
699    input: Value,
700}
701
702#[derive(Debug, Deserialize)]
703#[serde(rename_all = "camelCase")]
704struct ConverseUsage {
705    #[serde(default)]
706    input_tokens: usize,
707    #[serde(default)]
708    output_tokens: usize,
709    #[serde(default)]
710    total_tokens: usize,
711}
712
713#[derive(Debug, Deserialize)]
714struct BedrockError {
715    message: String,
716}
717
718#[async_trait]
719impl Provider for BedrockProvider {
720    fn name(&self) -> &str {
721        "bedrock"
722    }
723
724    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
725        self.validate_api_key()?;
726        self.discover_models().await
727    }
728
729    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
730        let model_id = Self::resolve_model_id(&request.model);
731
732        tracing::debug!(
733            provider = "bedrock",
734            model = %model_id,
735            original_model = %request.model,
736            message_count = request.messages.len(),
737            tool_count = request.tools.len(),
738            "Starting Bedrock Converse request"
739        );
740
741        self.validate_api_key()?;
742
743        let (system_parts, messages) = Self::convert_messages(&request.messages);
744        let tools = Self::convert_tools(&request.tools);
745
746        let mut body = json!({
747            "messages": messages,
748        });
749
750        if !system_parts.is_empty() {
751            body["system"] = json!(system_parts);
752        }
753
754        // inferenceConfig
755        let mut inference_config = json!({});
756        if let Some(max_tokens) = request.max_tokens {
757            inference_config["maxTokens"] = json!(max_tokens);
758        } else {
759            inference_config["maxTokens"] = json!(8192);
760        }
761        if let Some(temp) = request.temperature {
762            inference_config["temperature"] = json!(temp);
763        }
764        if let Some(top_p) = request.top_p {
765            inference_config["topP"] = json!(top_p);
766        }
767        body["inferenceConfig"] = inference_config;
768
769        if !tools.is_empty() {
770            body["toolConfig"] = json!({"tools": tools});
771        }
772
773        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
774        let encoded_model_id = model_id.replace(':', "%3A");
775        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
776        tracing::debug!("Bedrock request URL: {}", url);
777
778        let response = self
779            .client
780            .post(&url)
781            .bearer_auth(&self.api_key)
782            .header("content-type", "application/json")
783            .header("accept", "application/json")
784            .json(&body)
785            .send()
786            .await
787            .context("Failed to send request to Bedrock")?;
788
789        let status = response.status();
790        let text = response
791            .text()
792            .await
793            .context("Failed to read Bedrock response")?;
794
795        if !status.is_success() {
796            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
797                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
798            }
799            anyhow::bail!(
800                "Bedrock API error: {} {}",
801                status,
802                &text[..text.len().min(500)]
803            );
804        }
805
806        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
807            "Failed to parse Bedrock response: {}",
808            &text[..text.len().min(300)]
809        ))?;
810
811        tracing::debug!(
812            stop_reason = ?response.stop_reason,
813            "Received Bedrock response"
814        );
815
816        let mut content = Vec::new();
817        let mut has_tool_calls = false;
818
819        for part in &response.output.message.content {
820            match part {
821                ConverseContent::Text { text } => {
822                    if !text.is_empty() {
823                        content.push(ContentPart::Text { text: text.clone() });
824                    }
825                }
826                ConverseContent::ToolUse { tool_use } => {
827                    has_tool_calls = true;
828                    content.push(ContentPart::ToolCall {
829                        id: tool_use.tool_use_id.clone(),
830                        name: tool_use.name.clone(),
831                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
832                    });
833                }
834            }
835        }
836
837        let finish_reason = if has_tool_calls {
838            FinishReason::ToolCalls
839        } else {
840            match response.stop_reason.as_deref() {
841                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
842                Some("max_tokens") => FinishReason::Length,
843                Some("tool_use") => FinishReason::ToolCalls,
844                Some("content_filtered") => FinishReason::ContentFilter,
845                _ => FinishReason::Stop,
846            }
847        };
848
849        let usage = response.usage.as_ref();
850
851        Ok(CompletionResponse {
852            message: Message {
853                role: Role::Assistant,
854                content,
855            },
856            usage: Usage {
857                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
858                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
859                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
860                cache_read_tokens: None,
861                cache_write_tokens: None,
862            },
863            finish_reason,
864        })
865    }
866
867    async fn complete_stream(
868        &self,
869        request: CompletionRequest,
870    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
871        // Fall back to non-streaming for now
872        let response = self.complete(request).await?;
873        let text = response
874            .message
875            .content
876            .iter()
877            .filter_map(|p| match p {
878                ContentPart::Text { text } => Some(text.clone()),
879                _ => None,
880            })
881            .collect::<Vec<_>>()
882            .join("");
883
884        Ok(Box::pin(futures::stream::once(async move {
885            StreamChunk::Text(text)
886        })))
887    }
888}