1use super::{
8 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
9 Role, StreamChunk, ToolDefinition, Usage,
10};
11use anyhow::{Context, Result};
12use async_trait::async_trait;
13use reqwest::Client;
14use serde::Deserialize;
15use serde_json::{Value, json};
16
17const DEFAULT_REGION: &str = "us-east-1";
18
19pub struct BedrockProvider {
20 client: Client,
21 api_key: String,
22 region: String,
23}
24
25impl std::fmt::Debug for BedrockProvider {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("BedrockProvider")
28 .field("api_key", &"<REDACTED>")
29 .field("region", &self.region)
30 .finish()
31 }
32}
33
34impl BedrockProvider {
35 pub fn new(api_key: String) -> Result<Self> {
36 Self::with_region(api_key, DEFAULT_REGION.to_string())
37 }
38
39 pub fn with_region(api_key: String, region: String) -> Result<Self> {
40 tracing::debug!(
41 provider = "bedrock",
42 region = %region,
43 api_key_len = api_key.len(),
44 "Creating Bedrock provider"
45 );
46 Ok(Self {
47 client: Client::new(),
48 api_key,
49 region,
50 })
51 }
52
53 fn validate_api_key(&self) -> Result<()> {
54 if self.api_key.is_empty() {
55 anyhow::bail!("Bedrock API key is empty");
56 }
57 Ok(())
58 }
59
60 fn base_url(&self) -> String {
61 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
62 }
63
64 fn resolve_model_id(model: &str) -> &str {
68 match model {
69 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
71 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
72 "claude-3.5-haiku" | "claude-haiku-3.5" => {
73 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
74 }
75 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
76 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
77 }
78 "nova-pro" => "amazon.nova-pro-v1:0",
80 "nova-lite" => "amazon.nova-lite-v1:0",
81 "nova-micro" => "amazon.nova-micro-v1:0",
82 "nova-premier" => "amazon.nova-premier-v1:0",
83 "llama-3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
85 "llama-3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
86 "llama-3.1-405b" => "us.meta.llama3-1-405b-instruct-v1:0",
87 "mistral-large" => "us.mistral.mistral-large-2407-v1:0",
89 other => other,
91 }
92 }
93
94 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
102 let mut system_parts: Vec<Value> = Vec::new();
103 let mut api_messages: Vec<Value> = Vec::new();
104
105 for msg in messages {
106 match msg.role {
107 Role::System => {
108 let text: String = msg
109 .content
110 .iter()
111 .filter_map(|p| match p {
112 ContentPart::Text { text } => Some(text.clone()),
113 _ => None,
114 })
115 .collect::<Vec<_>>()
116 .join("\n");
117 system_parts.push(json!({"text": text}));
118 }
119 Role::User => {
120 let mut content_parts: Vec<Value> = Vec::new();
121 for part in &msg.content {
122 match part {
123 ContentPart::Text { text } => {
124 if !text.is_empty() {
125 content_parts.push(json!({"text": text}));
126 }
127 }
128 _ => {}
129 }
130 }
131 if !content_parts.is_empty() {
132 api_messages.push(json!({
133 "role": "user",
134 "content": content_parts
135 }));
136 }
137 }
138 Role::Assistant => {
139 let mut content_parts: Vec<Value> = Vec::new();
140 for part in &msg.content {
141 match part {
142 ContentPart::Text { text } => {
143 if !text.is_empty() {
144 content_parts.push(json!({"text": text}));
145 }
146 }
147 ContentPart::ToolCall {
148 id,
149 name,
150 arguments,
151 } => {
152 let input: Value = serde_json::from_str(arguments)
153 .unwrap_or_else(|_| json!({"raw": arguments}));
154 content_parts.push(json!({
155 "toolUse": {
156 "toolUseId": id,
157 "name": name,
158 "input": input
159 }
160 }));
161 }
162 _ => {}
163 }
164 }
165 if content_parts.is_empty() {
166 content_parts.push(json!({"text": ""}));
167 }
168 api_messages.push(json!({
169 "role": "assistant",
170 "content": content_parts
171 }));
172 }
173 Role::Tool => {
174 let mut content_parts: Vec<Value> = Vec::new();
176 for part in &msg.content {
177 if let ContentPart::ToolResult {
178 tool_call_id,
179 content,
180 } = part
181 {
182 content_parts.push(json!({
183 "toolResult": {
184 "toolUseId": tool_call_id,
185 "content": [{"text": content}],
186 "status": "success"
187 }
188 }));
189 }
190 }
191 if !content_parts.is_empty() {
192 api_messages.push(json!({
193 "role": "user",
194 "content": content_parts
195 }));
196 }
197 }
198 }
199 }
200
201 (system_parts, api_messages)
202 }
203
204 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
205 tools
206 .iter()
207 .map(|t| {
208 json!({
209 "toolSpec": {
210 "name": t.name,
211 "description": t.description,
212 "inputSchema": {
213 "json": t.parameters
214 }
215 }
216 })
217 })
218 .collect()
219 }
220}
221
222#[derive(Debug, Deserialize)]
225#[serde(rename_all = "camelCase")]
226struct ConverseResponse {
227 output: ConverseOutput,
228 #[serde(default)]
229 stop_reason: Option<String>,
230 #[serde(default)]
231 usage: Option<ConverseUsage>,
232}
233
234#[derive(Debug, Deserialize)]
235struct ConverseOutput {
236 message: ConverseMessage,
237}
238
239#[derive(Debug, Deserialize)]
240struct ConverseMessage {
241 #[allow(dead_code)]
242 role: String,
243 content: Vec<ConverseContent>,
244}
245
246#[derive(Debug, Deserialize)]
247#[serde(untagged)]
248enum ConverseContent {
249 Text {
250 text: String,
251 },
252 ToolUse {
253 #[serde(rename = "toolUse")]
254 tool_use: ConverseToolUse,
255 },
256}
257
258#[derive(Debug, Deserialize)]
259#[serde(rename_all = "camelCase")]
260struct ConverseToolUse {
261 tool_use_id: String,
262 name: String,
263 input: Value,
264}
265
266#[derive(Debug, Deserialize)]
267#[serde(rename_all = "camelCase")]
268struct ConverseUsage {
269 #[serde(default)]
270 input_tokens: usize,
271 #[serde(default)]
272 output_tokens: usize,
273 #[serde(default)]
274 total_tokens: usize,
275}
276
277#[derive(Debug, Deserialize)]
278struct BedrockError {
279 message: String,
280}
281
282#[async_trait]
283impl Provider for BedrockProvider {
284 fn name(&self) -> &str {
285 "bedrock"
286 }
287
288 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
289 self.validate_api_key()?;
290
291 Ok(vec![
292 ModelInfo {
294 id: "us.anthropic.claude-sonnet-4-20250514-v1:0".to_string(),
295 name: "Claude Sonnet 4 (Bedrock)".to_string(),
296 provider: "bedrock".to_string(),
297 context_window: 200_000,
298 max_output_tokens: Some(64_000),
299 supports_vision: true,
300 supports_tools: true,
301 supports_streaming: true,
302 input_cost_per_million: Some(3.0),
303 output_cost_per_million: Some(15.0),
304 },
305 ModelInfo {
306 id: "us.anthropic.claude-opus-4-20250514-v1:0".to_string(),
307 name: "Claude Opus 4 (Bedrock)".to_string(),
308 provider: "bedrock".to_string(),
309 context_window: 200_000,
310 max_output_tokens: Some(32_000),
311 supports_vision: true,
312 supports_tools: true,
313 supports_streaming: true,
314 input_cost_per_million: Some(15.0),
315 output_cost_per_million: Some(75.0),
316 },
317 ModelInfo {
318 id: "us.anthropic.claude-3-5-haiku-20241022-v1:0".to_string(),
319 name: "Claude 3.5 Haiku (Bedrock)".to_string(),
320 provider: "bedrock".to_string(),
321 context_window: 200_000,
322 max_output_tokens: Some(8_192),
323 supports_vision: true,
324 supports_tools: true,
325 supports_streaming: true,
326 input_cost_per_million: Some(0.80),
327 output_cost_per_million: Some(4.0),
328 },
329 ModelInfo {
330 id: "us.anthropic.claude-3-5-sonnet-20241022-v2:0".to_string(),
331 name: "Claude 3.5 Sonnet v2 (Bedrock)".to_string(),
332 provider: "bedrock".to_string(),
333 context_window: 200_000,
334 max_output_tokens: Some(8_192),
335 supports_vision: true,
336 supports_tools: true,
337 supports_streaming: true,
338 input_cost_per_million: Some(3.0),
339 output_cost_per_million: Some(15.0),
340 },
341 ModelInfo {
343 id: "amazon.nova-pro-v1:0".to_string(),
344 name: "Amazon Nova Pro".to_string(),
345 provider: "bedrock".to_string(),
346 context_window: 300_000,
347 max_output_tokens: Some(5_000),
348 supports_vision: true,
349 supports_tools: true,
350 supports_streaming: true,
351 input_cost_per_million: Some(0.80),
352 output_cost_per_million: Some(3.20),
353 },
354 ModelInfo {
355 id: "amazon.nova-lite-v1:0".to_string(),
356 name: "Amazon Nova Lite".to_string(),
357 provider: "bedrock".to_string(),
358 context_window: 300_000,
359 max_output_tokens: Some(5_000),
360 supports_vision: true,
361 supports_tools: true,
362 supports_streaming: true,
363 input_cost_per_million: Some(0.06),
364 output_cost_per_million: Some(0.24),
365 },
366 ModelInfo {
367 id: "amazon.nova-micro-v1:0".to_string(),
368 name: "Amazon Nova Micro".to_string(),
369 provider: "bedrock".to_string(),
370 context_window: 128_000,
371 max_output_tokens: Some(5_000),
372 supports_vision: false,
373 supports_tools: true,
374 supports_streaming: true,
375 input_cost_per_million: Some(0.035),
376 output_cost_per_million: Some(0.14),
377 },
378 ModelInfo {
380 id: "us.meta.llama3-1-70b-instruct-v1:0".to_string(),
381 name: "Llama 3.1 70B (Bedrock)".to_string(),
382 provider: "bedrock".to_string(),
383 context_window: 128_000,
384 max_output_tokens: Some(2_048),
385 supports_vision: false,
386 supports_tools: true,
387 supports_streaming: true,
388 input_cost_per_million: Some(0.72),
389 output_cost_per_million: Some(0.72),
390 },
391 ModelInfo {
392 id: "us.meta.llama3-1-8b-instruct-v1:0".to_string(),
393 name: "Llama 3.1 8B (Bedrock)".to_string(),
394 provider: "bedrock".to_string(),
395 context_window: 128_000,
396 max_output_tokens: Some(2_048),
397 supports_vision: false,
398 supports_tools: true,
399 supports_streaming: true,
400 input_cost_per_million: Some(0.22),
401 output_cost_per_million: Some(0.22),
402 },
403 ModelInfo {
405 id: "us.mistral.mistral-large-2407-v1:0".to_string(),
406 name: "Mistral Large (Bedrock)".to_string(),
407 provider: "bedrock".to_string(),
408 context_window: 128_000,
409 max_output_tokens: Some(8_192),
410 supports_vision: false,
411 supports_tools: true,
412 supports_streaming: true,
413 input_cost_per_million: Some(2.0),
414 output_cost_per_million: Some(6.0),
415 },
416 ])
417 }
418
419 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
420 let model_id = Self::resolve_model_id(&request.model);
421
422 tracing::debug!(
423 provider = "bedrock",
424 model = %model_id,
425 original_model = %request.model,
426 message_count = request.messages.len(),
427 tool_count = request.tools.len(),
428 "Starting Bedrock Converse request"
429 );
430
431 self.validate_api_key()?;
432
433 let (system_parts, messages) = Self::convert_messages(&request.messages);
434 let tools = Self::convert_tools(&request.tools);
435
436 let mut body = json!({
437 "messages": messages,
438 });
439
440 if !system_parts.is_empty() {
441 body["system"] = json!(system_parts);
442 }
443
444 let mut inference_config = json!({});
446 if let Some(max_tokens) = request.max_tokens {
447 inference_config["maxTokens"] = json!(max_tokens);
448 } else {
449 inference_config["maxTokens"] = json!(8192);
450 }
451 if let Some(temp) = request.temperature {
452 inference_config["temperature"] = json!(temp);
453 }
454 if let Some(top_p) = request.top_p {
455 inference_config["topP"] = json!(top_p);
456 }
457 body["inferenceConfig"] = inference_config;
458
459 if !tools.is_empty() {
460 body["toolConfig"] = json!({"tools": tools});
461 }
462
463 let encoded_model_id = model_id.replace(':', "%3A");
465 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
466 tracing::debug!("Bedrock request URL: {}", url);
467
468 let response = self
469 .client
470 .post(&url)
471 .bearer_auth(&self.api_key)
472 .header("content-type", "application/json")
473 .header("accept", "application/json")
474 .json(&body)
475 .send()
476 .await
477 .context("Failed to send request to Bedrock")?;
478
479 let status = response.status();
480 let text = response
481 .text()
482 .await
483 .context("Failed to read Bedrock response")?;
484
485 if !status.is_success() {
486 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
487 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
488 }
489 anyhow::bail!(
490 "Bedrock API error: {} {}",
491 status,
492 &text[..text.len().min(500)]
493 );
494 }
495
496 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
497 "Failed to parse Bedrock response: {}",
498 &text[..text.len().min(300)]
499 ))?;
500
501 tracing::debug!(
502 stop_reason = ?response.stop_reason,
503 "Received Bedrock response"
504 );
505
506 let mut content = Vec::new();
507 let mut has_tool_calls = false;
508
509 for part in &response.output.message.content {
510 match part {
511 ConverseContent::Text { text } => {
512 if !text.is_empty() {
513 content.push(ContentPart::Text { text: text.clone() });
514 }
515 }
516 ConverseContent::ToolUse { tool_use } => {
517 has_tool_calls = true;
518 content.push(ContentPart::ToolCall {
519 id: tool_use.tool_use_id.clone(),
520 name: tool_use.name.clone(),
521 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
522 });
523 }
524 }
525 }
526
527 let finish_reason = if has_tool_calls {
528 FinishReason::ToolCalls
529 } else {
530 match response.stop_reason.as_deref() {
531 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
532 Some("max_tokens") => FinishReason::Length,
533 Some("tool_use") => FinishReason::ToolCalls,
534 Some("content_filtered") => FinishReason::ContentFilter,
535 _ => FinishReason::Stop,
536 }
537 };
538
539 let usage = response.usage.as_ref();
540
541 Ok(CompletionResponse {
542 message: Message {
543 role: Role::Assistant,
544 content,
545 },
546 usage: Usage {
547 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
548 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
549 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
550 cache_read_tokens: None,
551 cache_write_tokens: None,
552 },
553 finish_reason,
554 })
555 }
556
557 async fn complete_stream(
558 &self,
559 request: CompletionRequest,
560 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
561 let response = self.complete(request).await?;
563 let text = response
564 .message
565 .content
566 .iter()
567 .filter_map(|p| match p {
568 ContentPart::Text { text } => Some(text.clone()),
569 _ => None,
570 })
571 .collect::<Vec<_>>()
572 .join("");
573
574 Ok(Box::pin(futures::stream::once(async move {
575 StreamChunk::Text(text)
576 })))
577 }
578}