Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via API Key bearer token auth.
4//! Uses the native Bedrock Converse API format.
5//! Dynamically discovers available models via the Bedrock ListFoundationModels
6//! and ListInferenceProfiles APIs.
7//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
8
9use super::{
10    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11    Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23    client: Client,
24    api_key: String,
25    region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30        f.debug_struct("BedrockProvider")
31            .field("api_key", &"<REDACTED>")
32            .field("region", &self.region)
33            .finish()
34    }
35}
36
37impl BedrockProvider {
38    pub fn new(api_key: String) -> Result<Self> {
39        Self::with_region(api_key, DEFAULT_REGION.to_string())
40    }
41
42    pub fn with_region(api_key: String, region: String) -> Result<Self> {
43        tracing::debug!(
44            provider = "bedrock",
45            region = %region,
46            api_key_len = api_key.len(),
47            "Creating Bedrock provider"
48        );
49        Ok(Self {
50            client: Client::new(),
51            api_key,
52            region,
53        })
54    }
55
56    fn validate_api_key(&self) -> Result<()> {
57        if self.api_key.is_empty() {
58            anyhow::bail!("Bedrock API key is empty");
59        }
60        Ok(())
61    }
62
63    fn base_url(&self) -> String {
64        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65    }
66
67    /// Management API URL (for listing models, not inference)
68    fn management_url(&self) -> String {
69        format!("https://bedrock.{}.amazonaws.com", self.region)
70    }
71
72    /// Resolve a short model alias to the full Bedrock model ID.
73    /// Allows users to specify e.g. "claude-sonnet-4" instead of
74    /// "us.anthropic.claude-sonnet-4-20250514-v1:0".
75    fn resolve_model_id(model: &str) -> &str {
76        match model {
77            // --- Anthropic Claude (verified via AWS CLI) ---
78            "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
79            "claude-opus-4.5" | "claude-4.5-opus" => {
80                "us.anthropic.claude-opus-4-5-20251101-v1:0"
81            }
82            "claude-opus-4.1" | "claude-4.1-opus" => {
83                "us.anthropic.claude-opus-4-1-20250805-v1:0"
84            }
85            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
86            "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
87                "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
88            }
89            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
90            "claude-haiku-4.5" | "claude-4.5-haiku" => {
91                "us.anthropic.claude-haiku-4-5-20251001-v1:0"
92            }
93            "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
94                "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
95            }
96            "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
97                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
98            }
99            "claude-3.5-haiku" | "claude-haiku-3.5" => {
100                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
101            }
102            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
103                "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
104            }
105            "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
106            "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
107            "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
108
109            // --- Amazon Nova ---
110            "nova-pro" => "amazon.nova-pro-v1:0",
111            "nova-lite" => "amazon.nova-lite-v1:0",
112            "nova-micro" => "amazon.nova-micro-v1:0",
113            "nova-premier" => "us.amazon.nova-premier-v1:0",
114
115            // --- Meta Llama ---
116            "llama-4-maverick" | "llama4-maverick" => {
117                "us.meta.llama4-maverick-17b-instruct-v1:0"
118            }
119            "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
120            "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
121            "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
122            "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
123            "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
124            "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
125            "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
126            "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
127            "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
128            "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
129
130            // --- Mistral ---
131            "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
132            "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
133            "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
134            "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
135            "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
136            "magistral-small" => "us.mistral.magistral-small-2509",
137
138            // --- DeepSeek ---
139            "deepseek-r1" => "us.deepseek.r1-v1:0",
140            "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
141
142            // --- Cohere ---
143            "command-r" => "us.cohere.command-r-v1:0",
144            "command-r-plus" => "us.cohere.command-r-plus-v1:0",
145
146            // --- Qwen ---
147            "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
148            "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
149            "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
150
151            // --- Google Gemma ---
152            "gemma-3-27b" => "us.google.gemma-3-27b-it",
153            "gemma-3-12b" => "us.google.gemma-3-12b-it",
154            "gemma-3-4b" => "us.google.gemma-3-4b-it",
155
156            // --- Moonshot / Kimi ---
157            "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
158            "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
159
160            // --- AI21 Jamba ---
161            "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
162            "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
163
164            // --- MiniMax ---
165            "minimax-m2" => "us.minimax.minimax-m2",
166            "minimax-m2.1" => "us.minimax.minimax-m2.1",
167
168            // --- NVIDIA ---
169            "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
170            "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
171            "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
172
173            // --- Z.AI / GLM ---
174            "glm-4.7" => "us.zai.glm-4.7",
175            "glm-4.7-flash" => "us.zai.glm-4.7-flash",
176
177            // Pass through full model IDs unchanged
178            other => other,
179        }
180    }
181
182    /// Dynamically discover available models from the Bedrock API.
183    /// Merges foundation models with cross-region inference profiles.
184    async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
185        let mut models: HashMap<String, ModelInfo> = HashMap::new();
186
187        // 1) Fetch foundation models
188        let fm_url = format!("{}/foundation-models", self.management_url());
189        let fm_resp = self
190            .client
191            .get(&fm_url)
192            .bearer_auth(&self.api_key)
193            .send()
194            .await;
195
196        if let Ok(resp) = fm_resp {
197            if resp.status().is_success() {
198                if let Ok(data) = resp.json::<Value>().await {
199                    if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
200                        for m in summaries {
201                            let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
202                            let model_name =
203                                m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
204                            let provider_name =
205                                m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
206
207                            let output_modalities: Vec<&str> = m
208                                .get("outputModalities")
209                                .and_then(|v| v.as_array())
210                                .map(|a| {
211                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
212                                })
213                                .unwrap_or_default();
214
215                            let input_modalities: Vec<&str> = m
216                                .get("inputModalities")
217                                .and_then(|v| v.as_array())
218                                .map(|a| {
219                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
220                                })
221                                .unwrap_or_default();
222
223                            let inference_types: Vec<&str> = m
224                                .get("inferenceTypesSupported")
225                                .and_then(|v| v.as_array())
226                                .map(|a| {
227                                    a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
228                                })
229                                .unwrap_or_default();
230
231                            // Only include TEXT output models with ON_DEMAND or INFERENCE_PROFILE inference
232                            if !output_modalities.contains(&"TEXT")
233                                || (!inference_types.contains(&"ON_DEMAND")
234                                    && !inference_types.contains(&"INFERENCE_PROFILE"))
235                            {
236                                continue;
237                            }
238
239                            // Skip non-chat models
240                            let name_lower = model_name.to_lowercase();
241                            if name_lower.contains("rerank")
242                                || name_lower.contains("embed")
243                                || name_lower.contains("safeguard")
244                                || name_lower.contains("sonic")
245                                || name_lower.contains("pegasus")
246                            {
247                                continue;
248                            }
249
250                            let streaming = m
251                                .get("responseStreamingSupported")
252                                .and_then(|v| v.as_bool())
253                                .unwrap_or(false);
254                            let vision = input_modalities.contains(&"IMAGE");
255
256                            // Non-Amazon models need us. cross-region prefix
257                            let actual_id = if model_id.starts_with("amazon.") {
258                                model_id.to_string()
259                            } else {
260                                format!("us.{}", model_id)
261                            };
262
263                            let display_name = format!("{} (Bedrock)", model_name);
264
265                            models.insert(
266                                actual_id.clone(),
267                                ModelInfo {
268                                    id: actual_id,
269                                    name: display_name,
270                                    provider: "bedrock".to_string(),
271                                    context_window: Self::estimate_context_window(
272                                        model_id,
273                                        provider_name,
274                                    ),
275                                    max_output_tokens: Some(Self::estimate_max_output(
276                                        model_id,
277                                        provider_name,
278                                    )),
279                                    supports_vision: vision,
280                                    supports_tools: true,
281                                    supports_streaming: streaming,
282                                    input_cost_per_million: None,
283                                    output_cost_per_million: None,
284                                },
285                            );
286                        }
287                    }
288                }
289            }
290        }
291
292        // 2) Fetch cross-region inference profiles (adds models like Claude Sonnet 4,
293        //    Llama 3.1/3.2/3.3/4, DeepSeek R1, etc. that aren't in foundation models)
294        let ip_url = format!(
295            "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
296            self.management_url()
297        );
298        let ip_resp = self
299            .client
300            .get(&ip_url)
301            .bearer_auth(&self.api_key)
302            .send()
303            .await;
304
305        if let Ok(resp) = ip_resp {
306            if resp.status().is_success() {
307                if let Ok(data) = resp.json::<Value>().await {
308                    if let Some(profiles) = data
309                        .get("inferenceProfileSummaries")
310                        .and_then(|v| v.as_array())
311                    {
312                        for p in profiles {
313                            let pid = p
314                                .get("inferenceProfileId")
315                                .and_then(|v| v.as_str())
316                                .unwrap_or("");
317                            let pname = p
318                                .get("inferenceProfileName")
319                                .and_then(|v| v.as_str())
320                                .unwrap_or("");
321
322                            // Only US cross-region profiles
323                            if !pid.starts_with("us.") {
324                                continue;
325                            }
326
327                            // Skip already-discovered models
328                            if models.contains_key(pid) {
329                                continue;
330                            }
331
332                            // Skip non-text models
333                            let name_lower = pname.to_lowercase();
334                            if name_lower.contains("image")
335                                || name_lower.contains("stable ")
336                                || name_lower.contains("upscale")
337                                || name_lower.contains("embed")
338                                || name_lower.contains("marengo")
339                                || name_lower.contains("outpaint")
340                                || name_lower.contains("inpaint")
341                                || name_lower.contains("erase")
342                                || name_lower.contains("recolor")
343                                || name_lower.contains("replace")
344                                || name_lower.contains("style ")
345                                || name_lower.contains("background")
346                                || name_lower.contains("sketch")
347                                || name_lower.contains("control")
348                                || name_lower.contains("transfer")
349                                || name_lower.contains("sonic")
350                                || name_lower.contains("pegasus")
351                                || name_lower.contains("rerank")
352                            {
353                                continue;
354                            }
355
356                            // Guess vision from known model families
357                            let vision = pid.contains("llama3-2-11b")
358                                || pid.contains("llama3-2-90b")
359                                || pid.contains("pixtral")
360                                || pid.contains("claude-3")
361                                || pid.contains("claude-sonnet-4")
362                                || pid.contains("claude-opus-4")
363                                || pid.contains("claude-haiku-4");
364
365                            let display_name = pname.replace("US ", "");
366                            let display_name = format!("{} (Bedrock)", display_name.trim());
367
368                            // Extract provider hint from model ID
369                            let provider_hint = pid
370                                .strip_prefix("us.")
371                                .unwrap_or(pid)
372                                .split('.')
373                                .next()
374                                .unwrap_or("");
375
376                            models.insert(
377                                pid.to_string(),
378                                ModelInfo {
379                                    id: pid.to_string(),
380                                    name: display_name,
381                                    provider: "bedrock".to_string(),
382                                    context_window: Self::estimate_context_window(
383                                        pid,
384                                        provider_hint,
385                                    ),
386                                    max_output_tokens: Some(Self::estimate_max_output(
387                                        pid,
388                                        provider_hint,
389                                    )),
390                                    supports_vision: vision,
391                                    supports_tools: true,
392                                    supports_streaming: true,
393                                    input_cost_per_million: None,
394                                    output_cost_per_million: None,
395                                },
396                            );
397                        }
398                    }
399                }
400            }
401        }
402
403        let mut result: Vec<ModelInfo> = models.into_values().collect();
404        result.sort_by(|a, b| a.id.cmp(&b.id));
405
406        tracing::info!(
407            provider = "bedrock",
408            model_count = result.len(),
409            "Discovered Bedrock models dynamically"
410        );
411
412        Ok(result)
413    }
414
415    /// Estimate context window size based on model family
416    fn estimate_context_window(model_id: &str, provider: &str) -> usize {
417        let id = model_id.to_lowercase();
418        if id.contains("anthropic") || id.contains("claude") {
419            200_000
420        } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
421        {
422            300_000
423        } else if id.contains("nova-micro") || id.contains("nova-2") {
424            128_000
425        } else if id.contains("deepseek") {
426            128_000
427        } else if id.contains("llama4") {
428            256_000
429        } else if id.contains("llama3") {
430            128_000
431        } else if id.contains("mistral-large-3") || id.contains("magistral") {
432            128_000
433        } else if id.contains("mistral") {
434            32_000
435        } else if id.contains("qwen") {
436            128_000
437        } else if id.contains("kimi") {
438            128_000
439        } else if id.contains("jamba") {
440            256_000
441        } else if id.contains("glm") {
442            128_000
443        } else if id.contains("minimax") {
444            128_000
445        } else if id.contains("gemma") {
446            128_000
447        } else if id.contains("cohere") || id.contains("command") {
448            128_000
449        } else if id.contains("nemotron") {
450            128_000
451        } else if provider.to_lowercase().contains("amazon") {
452            128_000
453        } else {
454            32_000
455        }
456    }
457
458    /// Estimate max output tokens based on model family
459    fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
460        let id = model_id.to_lowercase();
461        if id.contains("claude-opus-4-6") {
462            32_000
463        } else if id.contains("claude-opus-4-5") {
464            32_000
465        } else if id.contains("claude-opus-4-1") {
466            32_000
467        } else if id.contains("claude-sonnet-4-5") || id.contains("claude-sonnet-4") || id.contains("claude-3-7") {
468            64_000
469        } else if id.contains("claude-haiku-4-5") {
470            16_384
471        } else if id.contains("claude-opus-4") {
472            32_000
473        } else if id.contains("claude") {
474            8_192
475        } else if id.contains("nova") {
476            5_000
477        } else if id.contains("deepseek") {
478            16_384
479        } else if id.contains("llama4") {
480            16_384
481        } else if id.contains("llama") {
482            4_096
483        } else if id.contains("mistral-large-3") {
484            16_384
485        } else if id.contains("mistral") || id.contains("mixtral") {
486            8_192
487        } else if id.contains("qwen") {
488            8_192
489        } else if id.contains("kimi") {
490            8_192
491        } else if id.contains("jamba") {
492            4_096
493        } else {
494            4_096
495        }
496    }
497
498    /// Convert our generic messages to Bedrock Converse API format.
499    ///
500    /// Bedrock Converse uses:
501    /// - system prompt as a top-level "system" array
502    /// - messages with "role" and "content" array
503    /// - tool_use blocks in assistant content
504    /// - toolResult blocks in user content
505    ///
506    /// IMPORTANT: Bedrock requires strict role alternation (user/assistant).
507    /// Consecutive Role::Tool messages must be merged into a single "user"
508    /// message so all toolResult blocks for a given assistant turn appear
509    /// together. Consecutive same-role messages are also merged to prevent
510    /// validation errors.
511    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
512        let mut system_parts: Vec<Value> = Vec::new();
513        let mut api_messages: Vec<Value> = Vec::new();
514
515        for msg in messages {
516            match msg.role {
517                Role::System => {
518                    let text: String = msg
519                        .content
520                        .iter()
521                        .filter_map(|p| match p {
522                            ContentPart::Text { text } => Some(text.clone()),
523                            _ => None,
524                        })
525                        .collect::<Vec<_>>()
526                        .join("\n");
527                    system_parts.push(json!({"text": text}));
528                }
529                Role::User => {
530                    let mut content_parts: Vec<Value> = Vec::new();
531                    for part in &msg.content {
532                        match part {
533                            ContentPart::Text { text } => {
534                                if !text.is_empty() {
535                                    content_parts.push(json!({"text": text}));
536                                }
537                            }
538                            _ => {}
539                        }
540                    }
541                    if !content_parts.is_empty() {
542                        // Merge into previous user message if the last message is also "user"
543                        if let Some(last) = api_messages.last_mut() {
544                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
545                                if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
546                                    arr.extend(content_parts);
547                                    continue;
548                                }
549                            }
550                        }
551                        api_messages.push(json!({
552                            "role": "user",
553                            "content": content_parts
554                        }));
555                    }
556                }
557                Role::Assistant => {
558                    let mut content_parts: Vec<Value> = Vec::new();
559                    for part in &msg.content {
560                        match part {
561                            ContentPart::Text { text } => {
562                                if !text.is_empty() {
563                                    content_parts.push(json!({"text": text}));
564                                }
565                            }
566                            ContentPart::ToolCall {
567                                id,
568                                name,
569                                arguments,
570                            } => {
571                                let input: Value = serde_json::from_str(arguments)
572                                    .unwrap_or_else(|_| json!({"raw": arguments}));
573                                content_parts.push(json!({
574                                    "toolUse": {
575                                        "toolUseId": id,
576                                        "name": name,
577                                        "input": input
578                                    }
579                                }));
580                            }
581                            _ => {}
582                        }
583                    }
584                    if content_parts.is_empty() {
585                        content_parts.push(json!({"text": ""}));
586                    }
587                    // Merge into previous assistant message if consecutive
588                    if let Some(last) = api_messages.last_mut() {
589                        if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
590                            if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
591                                arr.extend(content_parts);
592                                continue;
593                            }
594                        }
595                    }
596                    api_messages.push(json!({
597                        "role": "assistant",
598                        "content": content_parts
599                    }));
600                }
601                Role::Tool => {
602                    // Tool results must be in a "user" message with toolResult blocks.
603                    // Merge into the previous user message if one exists (handles
604                    // consecutive Tool messages being collapsed into one user turn).
605                    let mut content_parts: Vec<Value> = Vec::new();
606                    for part in &msg.content {
607                        if let ContentPart::ToolResult {
608                            tool_call_id,
609                            content,
610                        } = part
611                        {
612                            content_parts.push(json!({
613                                "toolResult": {
614                                    "toolUseId": tool_call_id,
615                                    "content": [{"text": content}],
616                                    "status": "success"
617                                }
618                            }));
619                        }
620                    }
621                    if !content_parts.is_empty() {
622                        // Merge into previous user message (from earlier Tool messages)
623                        if let Some(last) = api_messages.last_mut() {
624                            if last.get("role").and_then(|r| r.as_str()) == Some("user") {
625                                if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
626                                    arr.extend(content_parts);
627                                    continue;
628                                }
629                            }
630                        }
631                        api_messages.push(json!({
632                            "role": "user",
633                            "content": content_parts
634                        }));
635                    }
636                }
637            }
638        }
639
640        (system_parts, api_messages)
641    }
642
643    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
644        tools
645            .iter()
646            .map(|t| {
647                json!({
648                    "toolSpec": {
649                        "name": t.name,
650                        "description": t.description,
651                        "inputSchema": {
652                            "json": t.parameters
653                        }
654                    }
655                })
656            })
657            .collect()
658    }
659}
660
661/// Bedrock Converse API response types
662
663#[derive(Debug, Deserialize)]
664#[serde(rename_all = "camelCase")]
665struct ConverseResponse {
666    output: ConverseOutput,
667    #[serde(default)]
668    stop_reason: Option<String>,
669    #[serde(default)]
670    usage: Option<ConverseUsage>,
671}
672
673#[derive(Debug, Deserialize)]
674struct ConverseOutput {
675    message: ConverseMessage,
676}
677
678#[derive(Debug, Deserialize)]
679struct ConverseMessage {
680    #[allow(dead_code)]
681    role: String,
682    content: Vec<ConverseContent>,
683}
684
685#[derive(Debug, Deserialize)]
686#[serde(untagged)]
687enum ConverseContent {
688    Text {
689        text: String,
690    },
691    ToolUse {
692        #[serde(rename = "toolUse")]
693        tool_use: ConverseToolUse,
694    },
695}
696
697#[derive(Debug, Deserialize)]
698#[serde(rename_all = "camelCase")]
699struct ConverseToolUse {
700    tool_use_id: String,
701    name: String,
702    input: Value,
703}
704
705#[derive(Debug, Deserialize)]
706#[serde(rename_all = "camelCase")]
707struct ConverseUsage {
708    #[serde(default)]
709    input_tokens: usize,
710    #[serde(default)]
711    output_tokens: usize,
712    #[serde(default)]
713    total_tokens: usize,
714}
715
716#[derive(Debug, Deserialize)]
717struct BedrockError {
718    message: String,
719}
720
721#[async_trait]
722impl Provider for BedrockProvider {
723    fn name(&self) -> &str {
724        "bedrock"
725    }
726
727    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
728        self.validate_api_key()?;
729        self.discover_models().await
730    }
731
732    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
733        let model_id = Self::resolve_model_id(&request.model);
734
735        tracing::debug!(
736            provider = "bedrock",
737            model = %model_id,
738            original_model = %request.model,
739            message_count = request.messages.len(),
740            tool_count = request.tools.len(),
741            "Starting Bedrock Converse request"
742        );
743
744        self.validate_api_key()?;
745
746        let (system_parts, messages) = Self::convert_messages(&request.messages);
747        let tools = Self::convert_tools(&request.tools);
748
749        let mut body = json!({
750            "messages": messages,
751        });
752
753        if !system_parts.is_empty() {
754            body["system"] = json!(system_parts);
755        }
756
757        // inferenceConfig
758        let mut inference_config = json!({});
759        if let Some(max_tokens) = request.max_tokens {
760            inference_config["maxTokens"] = json!(max_tokens);
761        } else {
762            inference_config["maxTokens"] = json!(8192);
763        }
764        if let Some(temp) = request.temperature {
765            inference_config["temperature"] = json!(temp);
766        }
767        if let Some(top_p) = request.top_p {
768            inference_config["topP"] = json!(top_p);
769        }
770        body["inferenceConfig"] = inference_config;
771
772        if !tools.is_empty() {
773            body["toolConfig"] = json!({"tools": tools});
774        }
775
776        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
777        let encoded_model_id = model_id.replace(':', "%3A");
778        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
779        tracing::debug!("Bedrock request URL: {}", url);
780
781        let response = self
782            .client
783            .post(&url)
784            .bearer_auth(&self.api_key)
785            .header("content-type", "application/json")
786            .header("accept", "application/json")
787            .json(&body)
788            .send()
789            .await
790            .context("Failed to send request to Bedrock")?;
791
792        let status = response.status();
793        let text = response
794            .text()
795            .await
796            .context("Failed to read Bedrock response")?;
797
798        if !status.is_success() {
799            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
800                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
801            }
802            anyhow::bail!(
803                "Bedrock API error: {} {}",
804                status,
805                &text[..text.len().min(500)]
806            );
807        }
808
809        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
810            "Failed to parse Bedrock response: {}",
811            &text[..text.len().min(300)]
812        ))?;
813
814        tracing::debug!(
815            stop_reason = ?response.stop_reason,
816            "Received Bedrock response"
817        );
818
819        let mut content = Vec::new();
820        let mut has_tool_calls = false;
821
822        for part in &response.output.message.content {
823            match part {
824                ConverseContent::Text { text } => {
825                    if !text.is_empty() {
826                        content.push(ContentPart::Text { text: text.clone() });
827                    }
828                }
829                ConverseContent::ToolUse { tool_use } => {
830                    has_tool_calls = true;
831                    content.push(ContentPart::ToolCall {
832                        id: tool_use.tool_use_id.clone(),
833                        name: tool_use.name.clone(),
834                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
835                    });
836                }
837            }
838        }
839
840        let finish_reason = if has_tool_calls {
841            FinishReason::ToolCalls
842        } else {
843            match response.stop_reason.as_deref() {
844                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
845                Some("max_tokens") => FinishReason::Length,
846                Some("tool_use") => FinishReason::ToolCalls,
847                Some("content_filtered") => FinishReason::ContentFilter,
848                _ => FinishReason::Stop,
849            }
850        };
851
852        let usage = response.usage.as_ref();
853
854        Ok(CompletionResponse {
855            message: Message {
856                role: Role::Assistant,
857                content,
858            },
859            usage: Usage {
860                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
861                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
862                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
863                cache_read_tokens: None,
864                cache_write_tokens: None,
865            },
866            finish_reason,
867        })
868    }
869
870    async fn complete_stream(
871        &self,
872        request: CompletionRequest,
873    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
874        // Fall back to non-streaming for now
875        let response = self.complete(request).await?;
876        let text = response
877            .message
878            .content
879            .iter()
880            .filter_map(|p| match p {
881                ContentPart::Text { text } => Some(text.clone()),
882                _ => None,
883            })
884            .collect::<Vec<_>>()
885            .join("");
886
887        Ok(Box::pin(futures::stream::once(async move {
888            StreamChunk::Text(text)
889        })))
890    }
891}