Skip to main content

codetether_agent/provider/
bedrock.rs

1//! Amazon Bedrock provider implementation using the Converse API
2//!
3//! Supports all Bedrock foundation models via API Key bearer token auth.
4//! Uses the native Bedrock Converse API format.
5//! Reference: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
6
7use super::{
8    CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9    Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const DEFAULT_REGION: &str = "us-east-1";
18
19pub struct BedrockProvider {
20    client: Client,
21    api_key: String,
22    region: String,
23}
24
25impl std::fmt::Debug for BedrockProvider {
26    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27        f.debug_struct("BedrockProvider")
28            .field("api_key", &"<REDACTED>")
29            .field("region", &self.region)
30            .finish()
31    }
32}
33
34impl BedrockProvider {
35    pub fn new(api_key: String) -> Result<Self> {
36        Self::with_region(api_key, DEFAULT_REGION.to_string())
37    }
38
39    pub fn with_region(api_key: String, region: String) -> Result<Self> {
40        tracing::debug!(
41            provider = "bedrock",
42            region = %region,
43            api_key_len = api_key.len(),
44            "Creating Bedrock provider"
45        );
46        Ok(Self {
47            client: Client::new(),
48            api_key,
49            region,
50        })
51    }
52
53    fn validate_api_key(&self) -> Result<()> {
54        if self.api_key.is_empty() {
55            anyhow::bail!("Bedrock API key is empty");
56        }
57        Ok(())
58    }
59
60    fn base_url(&self) -> String {
61        format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
62    }
63
64    /// Resolve a short model alias to the full Bedrock model ID.
65    /// Allows users to specify e.g. "claude-sonnet-4" instead of
66    /// "anthropic.claude-sonnet-4-20250514-v1:0".
67    fn resolve_model_id(model: &str) -> &str {
68        match model {
69            // Anthropic Claude models (cross-region inference profiles)
70            "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
71            "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
72            "claude-3.5-haiku" | "claude-haiku-3.5" => {
73                "us.anthropic.claude-3-5-haiku-20241022-v1:0"
74            }
75            "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
76                "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
77            }
78            // Amazon Nova models (on-demand)
79            "nova-pro" => "amazon.nova-pro-v1:0",
80            "nova-lite" => "amazon.nova-lite-v1:0",
81            "nova-micro" => "amazon.nova-micro-v1:0",
82            "nova-premier" => "amazon.nova-premier-v1:0",
83            // Meta Llama models (cross-region inference profiles)
84            "llama-3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
85            "llama-3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
86            "llama-3.1-405b" => "us.meta.llama3-1-405b-instruct-v1:0",
87            // Mistral models (cross-region inference profiles)
88            "mistral-large" => "us.mistral.mistral-large-2407-v1:0",
89            // Pass through full model IDs unchanged
90            other => other,
91        }
92    }
93
94    /// Convert our generic messages to Bedrock Converse API format.
95    ///
96    /// Bedrock Converse uses:
97    /// - system prompt as a top-level "system" array
98    /// - messages with "role" and "content" array
99    /// - tool_use blocks in assistant content
100    /// - toolResult blocks in user content
101    fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
102        let mut system_parts: Vec<Value> = Vec::new();
103        let mut api_messages: Vec<Value> = Vec::new();
104
105        for msg in messages {
106            match msg.role {
107                Role::System => {
108                    let text: String = msg
109                        .content
110                        .iter()
111                        .filter_map(|p| match p {
112                            ContentPart::Text { text } => Some(text.clone()),
113                            _ => None,
114                        })
115                        .collect::<Vec<_>>()
116                        .join("\n");
117                    system_parts.push(json!({"text": text}));
118                }
119                Role::User => {
120                    let mut content_parts: Vec<Value> = Vec::new();
121                    for part in &msg.content {
122                        match part {
123                            ContentPart::Text { text } => {
124                                if !text.is_empty() {
125                                    content_parts.push(json!({"text": text}));
126                                }
127                            }
128                            _ => {}
129                        }
130                    }
131                    if !content_parts.is_empty() {
132                        api_messages.push(json!({
133                            "role": "user",
134                            "content": content_parts
135                        }));
136                    }
137                }
138                Role::Assistant => {
139                    let mut content_parts: Vec<Value> = Vec::new();
140                    for part in &msg.content {
141                        match part {
142                            ContentPart::Text { text } => {
143                                if !text.is_empty() {
144                                    content_parts.push(json!({"text": text}));
145                                }
146                            }
147                            ContentPart::ToolCall {
148                                id,
149                                name,
150                                arguments,
151                            } => {
152                                let input: Value = serde_json::from_str(arguments)
153                                    .unwrap_or_else(|_| json!({"raw": arguments}));
154                                content_parts.push(json!({
155                                    "toolUse": {
156                                        "toolUseId": id,
157                                        "name": name,
158                                        "input": input
159                                    }
160                                }));
161                            }
162                            _ => {}
163                        }
164                    }
165                    if content_parts.is_empty() {
166                        content_parts.push(json!({"text": ""}));
167                    }
168                    api_messages.push(json!({
169                        "role": "assistant",
170                        "content": content_parts
171                    }));
172                }
173                Role::Tool => {
174                    // Tool results go into a user message with toolResult blocks
175                    let mut content_parts: Vec<Value> = Vec::new();
176                    for part in &msg.content {
177                        if let ContentPart::ToolResult {
178                            tool_call_id,
179                            content,
180                        } = part
181                        {
182                            content_parts.push(json!({
183                                "toolResult": {
184                                    "toolUseId": tool_call_id,
185                                    "content": [{"text": content}],
186                                    "status": "success"
187                                }
188                            }));
189                        }
190                    }
191                    if !content_parts.is_empty() {
192                        api_messages.push(json!({
193                            "role": "user",
194                            "content": content_parts
195                        }));
196                    }
197                }
198            }
199        }
200
201        (system_parts, api_messages)
202    }
203
204    fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
205        tools
206            .iter()
207            .map(|t| {
208                json!({
209                    "toolSpec": {
210                        "name": t.name,
211                        "description": t.description,
212                        "inputSchema": {
213                            "json": t.parameters
214                        }
215                    }
216                })
217            })
218            .collect()
219    }
220}
221
222/// Bedrock Converse API response types
223
224#[derive(Debug, Deserialize)]
225#[serde(rename_all = "camelCase")]
226struct ConverseResponse {
227    output: ConverseOutput,
228    #[serde(default)]
229    stop_reason: Option<String>,
230    #[serde(default)]
231    usage: Option<ConverseUsage>,
232}
233
234#[derive(Debug, Deserialize)]
235struct ConverseOutput {
236    message: ConverseMessage,
237}
238
239#[derive(Debug, Deserialize)]
240struct ConverseMessage {
241    #[allow(dead_code)]
242    role: String,
243    content: Vec<ConverseContent>,
244}
245
246#[derive(Debug, Deserialize)]
247#[serde(untagged)]
248enum ConverseContent {
249    Text {
250        text: String,
251    },
252    ToolUse {
253        #[serde(rename = "toolUse")]
254        tool_use: ConverseToolUse,
255    },
256}
257
258#[derive(Debug, Deserialize)]
259#[serde(rename_all = "camelCase")]
260struct ConverseToolUse {
261    tool_use_id: String,
262    name: String,
263    input: Value,
264}
265
266#[derive(Debug, Deserialize)]
267#[serde(rename_all = "camelCase")]
268struct ConverseUsage {
269    #[serde(default)]
270    input_tokens: usize,
271    #[serde(default)]
272    output_tokens: usize,
273    #[serde(default)]
274    total_tokens: usize,
275}
276
277#[derive(Debug, Deserialize)]
278struct BedrockError {
279    message: String,
280}
281
282#[async_trait]
283impl Provider for BedrockProvider {
284    fn name(&self) -> &str {
285        "bedrock"
286    }
287
288    async fn list_models(&self) -> Result<Vec<ModelInfo>> {
289        self.validate_api_key()?;
290
291        Ok(vec![
292            // Anthropic Claude via Bedrock (best price/performance)
293            ModelInfo {
294                id: "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
295                name: "Claude Sonnet 4 (Bedrock)".to_string(),
296                provider: "bedrock".to_string(),
297                context_window: 200_000,
298                max_output_tokens: Some(64_000),
299                supports_vision: true,
300                supports_tools: true,
301                supports_streaming: true,
302                input_cost_per_million: Some(3.0),
303                output_cost_per_million: Some(15.0),
304            },
305            ModelInfo {
306                id: "us.anthropic.claude-opus-4-20250514-v1:0".to_string(),
307                name: "Claude Opus 4 (Bedrock)".to_string(),
308                provider: "bedrock".to_string(),
309                context_window: 200_000,
310                max_output_tokens: Some(32_000),
311                supports_vision: true,
312                supports_tools: true,
313                supports_streaming: true,
314                input_cost_per_million: Some(15.0),
315                output_cost_per_million: Some(75.0),
316            },
317            ModelInfo {
318                id: "us.anthropic.claude-3-5-haiku-20241022-v1:0".to_string(),
319                name: "Claude 3.5 Haiku (Bedrock)".to_string(),
320                provider: "bedrock".to_string(),
321                context_window: 200_000,
322                max_output_tokens: Some(8_192),
323                supports_vision: true,
324                supports_tools: true,
325                supports_streaming: true,
326                input_cost_per_million: Some(0.80),
327                output_cost_per_million: Some(4.0),
328            },
329            ModelInfo {
330                id: "us.anthropic.claude-3-5-sonnet-20241022-v2:0".to_string(),
331                name: "Claude 3.5 Sonnet v2 (Bedrock)".to_string(),
332                provider: "bedrock".to_string(),
333                context_window: 200_000,
334                max_output_tokens: Some(8_192),
335                supports_vision: true,
336                supports_tools: true,
337                supports_streaming: true,
338                input_cost_per_million: Some(3.0),
339                output_cost_per_million: Some(15.0),
340            },
341            // Amazon Nova models
342            ModelInfo {
343                id: "amazon.nova-pro-v1:0".to_string(),
344                name: "Amazon Nova Pro".to_string(),
345                provider: "bedrock".to_string(),
346                context_window: 300_000,
347                max_output_tokens: Some(5_000),
348                supports_vision: true,
349                supports_tools: true,
350                supports_streaming: true,
351                input_cost_per_million: Some(0.80),
352                output_cost_per_million: Some(3.20),
353            },
354            ModelInfo {
355                id: "amazon.nova-lite-v1:0".to_string(),
356                name: "Amazon Nova Lite".to_string(),
357                provider: "bedrock".to_string(),
358                context_window: 300_000,
359                max_output_tokens: Some(5_000),
360                supports_vision: true,
361                supports_tools: true,
362                supports_streaming: true,
363                input_cost_per_million: Some(0.06),
364                output_cost_per_million: Some(0.24),
365            },
366            ModelInfo {
367                id: "amazon.nova-micro-v1:0".to_string(),
368                name: "Amazon Nova Micro".to_string(),
369                provider: "bedrock".to_string(),
370                context_window: 128_000,
371                max_output_tokens: Some(5_000),
372                supports_vision: false,
373                supports_tools: true,
374                supports_streaming: true,
375                input_cost_per_million: Some(0.035),
376                output_cost_per_million: Some(0.14),
377            },
378            // Meta Llama models
379            ModelInfo {
380                id: "us.meta.llama3-1-70b-instruct-v1:0".to_string(),
381                name: "Llama 3.1 70B (Bedrock)".to_string(),
382                provider: "bedrock".to_string(),
383                context_window: 128_000,
384                max_output_tokens: Some(2_048),
385                supports_vision: false,
386                supports_tools: true,
387                supports_streaming: true,
388                input_cost_per_million: Some(0.72),
389                output_cost_per_million: Some(0.72),
390            },
391            ModelInfo {
392                id: "us.meta.llama3-1-8b-instruct-v1:0".to_string(),
393                name: "Llama 3.1 8B (Bedrock)".to_string(),
394                provider: "bedrock".to_string(),
395                context_window: 128_000,
396                max_output_tokens: Some(2_048),
397                supports_vision: false,
398                supports_tools: true,
399                supports_streaming: true,
400                input_cost_per_million: Some(0.22),
401                output_cost_per_million: Some(0.22),
402            },
403            // Mistral
404            ModelInfo {
405                id: "us.mistral.mistral-large-2407-v1:0".to_string(),
406                name: "Mistral Large (Bedrock)".to_string(),
407                provider: "bedrock".to_string(),
408                context_window: 128_000,
409                max_output_tokens: Some(8_192),
410                supports_vision: false,
411                supports_tools: true,
412                supports_streaming: true,
413                input_cost_per_million: Some(2.0),
414                output_cost_per_million: Some(6.0),
415            },
416        ])
417    }
418
419    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
420        let model_id = Self::resolve_model_id(&request.model);
421
422        tracing::debug!(
423            provider = "bedrock",
424            model = %model_id,
425            original_model = %request.model,
426            message_count = request.messages.len(),
427            tool_count = request.tools.len(),
428            "Starting Bedrock Converse request"
429        );
430
431        self.validate_api_key()?;
432
433        let (system_parts, messages) = Self::convert_messages(&request.messages);
434        let tools = Self::convert_tools(&request.tools);
435
436        let mut body = json!({
437            "messages": messages,
438        });
439
440        if !system_parts.is_empty() {
441            body["system"] = json!(system_parts);
442        }
443
444        // inferenceConfig
445        let mut inference_config = json!({});
446        if let Some(max_tokens) = request.max_tokens {
447            inference_config["maxTokens"] = json!(max_tokens);
448        } else {
449            inference_config["maxTokens"] = json!(8192);
450        }
451        if let Some(temp) = request.temperature {
452            inference_config["temperature"] = json!(temp);
453        }
454        if let Some(top_p) = request.top_p {
455            inference_config["topP"] = json!(top_p);
456        }
457        body["inferenceConfig"] = inference_config;
458
459        if !tools.is_empty() {
460            body["toolConfig"] = json!({"tools": tools});
461        }
462
463        // URL-encode the colon in model IDs (e.g. v1:0 -> v1%3A0)
464        let encoded_model_id = model_id.replace(':', "%3A");
465        let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
466        tracing::debug!("Bedrock request URL: {}", url);
467
468        let response = self
469            .client
470            .post(&url)
471            .bearer_auth(&self.api_key)
472            .header("content-type", "application/json")
473            .header("accept", "application/json")
474            .json(&body)
475            .send()
476            .await
477            .context("Failed to send request to Bedrock")?;
478
479        let status = response.status();
480        let text = response
481            .text()
482            .await
483            .context("Failed to read Bedrock response")?;
484
485        if !status.is_success() {
486            if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
487                anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
488            }
489            anyhow::bail!(
490                "Bedrock API error: {} {}",
491                status,
492                &text[..text.len().min(500)]
493            );
494        }
495
496        let response: ConverseResponse = serde_json::from_str(&text).context(format!(
497            "Failed to parse Bedrock response: {}",
498            &text[..text.len().min(300)]
499        ))?;
500
501        tracing::debug!(
502            stop_reason = ?response.stop_reason,
503            "Received Bedrock response"
504        );
505
506        let mut content = Vec::new();
507        let mut has_tool_calls = false;
508
509        for part in &response.output.message.content {
510            match part {
511                ConverseContent::Text { text } => {
512                    if !text.is_empty() {
513                        content.push(ContentPart::Text { text: text.clone() });
514                    }
515                }
516                ConverseContent::ToolUse { tool_use } => {
517                    has_tool_calls = true;
518                    content.push(ContentPart::ToolCall {
519                        id: tool_use.tool_use_id.clone(),
520                        name: tool_use.name.clone(),
521                        arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
522                    });
523                }
524            }
525        }
526
527        let finish_reason = if has_tool_calls {
528            FinishReason::ToolCalls
529        } else {
530            match response.stop_reason.as_deref() {
531                Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
532                Some("max_tokens") => FinishReason::Length,
533                Some("tool_use") => FinishReason::ToolCalls,
534                Some("content_filtered") => FinishReason::ContentFilter,
535                _ => FinishReason::Stop,
536            }
537        };
538
539        let usage = response.usage.as_ref();
540
541        Ok(CompletionResponse {
542            message: Message {
543                role: Role::Assistant,
544                content,
545            },
546            usage: Usage {
547                prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
548                completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
549                total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
550                cache_read_tokens: None,
551                cache_write_tokens: None,
552            },
553            finish_reason,
554        })
555    }
556
557    async fn complete_stream(
558        &self,
559        request: CompletionRequest,
560    ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
561        // Fall back to non-streaming for now
562        let response = self.complete(request).await?;
563        let text = response
564            .message
565            .content
566            .iter()
567            .filter_map(|p| match p {
568                ContentPart::Text { text } => Some(text.clone()),
569                _ => None,
570            })
571            .collect::<Vec<_>>()
572            .join("");
573
574        Ok(Box::pin(futures::stream::once(async move {
575            StreamChunk::Text(text)
576        })))
577    }
578}