1use super::{
10 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11 Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23 client: Client,
24 api_key: String,
25 region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("BedrockProvider")
31 .field("api_key", &"<REDACTED>")
32 .field("region", &self.region)
33 .finish()
34 }
35}
36
37impl BedrockProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 Self::with_region(api_key, DEFAULT_REGION.to_string())
40 }
41
42 pub fn with_region(api_key: String, region: String) -> Result<Self> {
43 tracing::debug!(
44 provider = "bedrock",
45 region = %region,
46 api_key_len = api_key.len(),
47 "Creating Bedrock provider"
48 );
49 Ok(Self {
50 client: Client::new(),
51 api_key,
52 region,
53 })
54 }
55
56 fn validate_api_key(&self) -> Result<()> {
57 if self.api_key.is_empty() {
58 anyhow::bail!("Bedrock API key is empty");
59 }
60 Ok(())
61 }
62
63 fn base_url(&self) -> String {
64 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65 }
66
67 fn management_url(&self) -> String {
69 format!("https://bedrock.{}.amazonaws.com", self.region)
70 }
71
72 fn resolve_model_id(model: &str) -> &str {
76 match model {
77 "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
79 "claude-opus-4.5" | "claude-4.5-opus" => "us.anthropic.claude-opus-4-5-20251101-v1:0",
80 "claude-opus-4.1" | "claude-4.1-opus" => "us.anthropic.claude-opus-4-1-20250805-v1:0",
81 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
82 "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
83 "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
84 }
85 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
86 "claude-haiku-4.5" | "claude-4.5-haiku" => {
87 "us.anthropic.claude-haiku-4-5-20251001-v1:0"
88 }
89 "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
90 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
91 }
92 "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
93 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
94 }
95 "claude-3.5-haiku" | "claude-haiku-3.5" => {
96 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
97 }
98 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
99 "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
100 }
101 "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
102 "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
103 "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
104
105 "nova-pro" => "amazon.nova-pro-v1:0",
107 "nova-lite" => "amazon.nova-lite-v1:0",
108 "nova-micro" => "amazon.nova-micro-v1:0",
109 "nova-premier" => "us.amazon.nova-premier-v1:0",
110
111 "llama-4-maverick" | "llama4-maverick" => "us.meta.llama4-maverick-17b-instruct-v1:0",
113 "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
114 "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
115 "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
116 "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
117 "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
118 "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
119 "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
120 "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
121 "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
122 "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
123
124 "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
126 "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
127 "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
128 "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
129 "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
130 "magistral-small" => "us.mistral.magistral-small-2509",
131
132 "deepseek-r1" => "us.deepseek.r1-v1:0",
134 "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
135
136 "command-r" => "us.cohere.command-r-v1:0",
138 "command-r-plus" => "us.cohere.command-r-plus-v1:0",
139
140 "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
142 "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
143 "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
144
145 "gemma-3-27b" => "us.google.gemma-3-27b-it",
147 "gemma-3-12b" => "us.google.gemma-3-12b-it",
148 "gemma-3-4b" => "us.google.gemma-3-4b-it",
149
150 "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
152 "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
153
154 "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
156 "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
157
158 "minimax-m2" => "us.minimax.minimax-m2",
160 "minimax-m2.1" => "us.minimax.minimax-m2.1",
161
162 "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
164 "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
165 "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
166
167 "glm-4.7" => "us.zai.glm-4.7",
169 "glm-4.7-flash" => "us.zai.glm-4.7-flash",
170
171 other => other,
173 }
174 }
175
176 async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
179 let mut models: HashMap<String, ModelInfo> = HashMap::new();
180
181 let fm_url = format!("{}/foundation-models", self.management_url());
183 let fm_resp = self
184 .client
185 .get(&fm_url)
186 .bearer_auth(&self.api_key)
187 .send()
188 .await;
189
190 if let Ok(resp) = fm_resp {
191 if resp.status().is_success() {
192 if let Ok(data) = resp.json::<Value>().await {
193 if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
194 for m in summaries {
195 let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
196 let model_name =
197 m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
198 let provider_name =
199 m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
200
201 let output_modalities: Vec<&str> = m
202 .get("outputModalities")
203 .and_then(|v| v.as_array())
204 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
205 .unwrap_or_default();
206
207 let input_modalities: Vec<&str> = m
208 .get("inputModalities")
209 .and_then(|v| v.as_array())
210 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
211 .unwrap_or_default();
212
213 let inference_types: Vec<&str> = m
214 .get("inferenceTypesSupported")
215 .and_then(|v| v.as_array())
216 .map(|a| a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>())
217 .unwrap_or_default();
218
219 if !output_modalities.contains(&"TEXT")
221 || (!inference_types.contains(&"ON_DEMAND")
222 && !inference_types.contains(&"INFERENCE_PROFILE"))
223 {
224 continue;
225 }
226
227 let name_lower = model_name.to_lowercase();
229 if name_lower.contains("rerank")
230 || name_lower.contains("embed")
231 || name_lower.contains("safeguard")
232 || name_lower.contains("sonic")
233 || name_lower.contains("pegasus")
234 {
235 continue;
236 }
237
238 let streaming = m
239 .get("responseStreamingSupported")
240 .and_then(|v| v.as_bool())
241 .unwrap_or(false);
242 let vision = input_modalities.contains(&"IMAGE");
243
244 let actual_id = if model_id.starts_with("amazon.") {
246 model_id.to_string()
247 } else {
248 format!("us.{}", model_id)
249 };
250
251 let display_name = format!("{} (Bedrock)", model_name);
252
253 models.insert(
254 actual_id.clone(),
255 ModelInfo {
256 id: actual_id,
257 name: display_name,
258 provider: "bedrock".to_string(),
259 context_window: Self::estimate_context_window(
260 model_id,
261 provider_name,
262 ),
263 max_output_tokens: Some(Self::estimate_max_output(
264 model_id,
265 provider_name,
266 )),
267 supports_vision: vision,
268 supports_tools: true,
269 supports_streaming: streaming,
270 input_cost_per_million: None,
271 output_cost_per_million: None,
272 },
273 );
274 }
275 }
276 }
277 }
278 }
279
280 let ip_url = format!(
283 "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
284 self.management_url()
285 );
286 let ip_resp = self
287 .client
288 .get(&ip_url)
289 .bearer_auth(&self.api_key)
290 .send()
291 .await;
292
293 if let Ok(resp) = ip_resp {
294 if resp.status().is_success() {
295 if let Ok(data) = resp.json::<Value>().await {
296 if let Some(profiles) = data
297 .get("inferenceProfileSummaries")
298 .and_then(|v| v.as_array())
299 {
300 for p in profiles {
301 let pid = p
302 .get("inferenceProfileId")
303 .and_then(|v| v.as_str())
304 .unwrap_or("");
305 let pname = p
306 .get("inferenceProfileName")
307 .and_then(|v| v.as_str())
308 .unwrap_or("");
309
310 if !pid.starts_with("us.") {
312 continue;
313 }
314
315 if models.contains_key(pid) {
317 continue;
318 }
319
320 let name_lower = pname.to_lowercase();
322 if name_lower.contains("image")
323 || name_lower.contains("stable ")
324 || name_lower.contains("upscale")
325 || name_lower.contains("embed")
326 || name_lower.contains("marengo")
327 || name_lower.contains("outpaint")
328 || name_lower.contains("inpaint")
329 || name_lower.contains("erase")
330 || name_lower.contains("recolor")
331 || name_lower.contains("replace")
332 || name_lower.contains("style ")
333 || name_lower.contains("background")
334 || name_lower.contains("sketch")
335 || name_lower.contains("control")
336 || name_lower.contains("transfer")
337 || name_lower.contains("sonic")
338 || name_lower.contains("pegasus")
339 || name_lower.contains("rerank")
340 {
341 continue;
342 }
343
344 let vision = pid.contains("llama3-2-11b")
346 || pid.contains("llama3-2-90b")
347 || pid.contains("pixtral")
348 || pid.contains("claude-3")
349 || pid.contains("claude-sonnet-4")
350 || pid.contains("claude-opus-4")
351 || pid.contains("claude-haiku-4");
352
353 let display_name = pname.replace("US ", "");
354 let display_name = format!("{} (Bedrock)", display_name.trim());
355
356 let provider_hint = pid
358 .strip_prefix("us.")
359 .unwrap_or(pid)
360 .split('.')
361 .next()
362 .unwrap_or("");
363
364 models.insert(
365 pid.to_string(),
366 ModelInfo {
367 id: pid.to_string(),
368 name: display_name,
369 provider: "bedrock".to_string(),
370 context_window: Self::estimate_context_window(
371 pid,
372 provider_hint,
373 ),
374 max_output_tokens: Some(Self::estimate_max_output(
375 pid,
376 provider_hint,
377 )),
378 supports_vision: vision,
379 supports_tools: true,
380 supports_streaming: true,
381 input_cost_per_million: None,
382 output_cost_per_million: None,
383 },
384 );
385 }
386 }
387 }
388 }
389 }
390
391 let mut result: Vec<ModelInfo> = models.into_values().collect();
392 result.sort_by(|a, b| a.id.cmp(&b.id));
393
394 tracing::info!(
395 provider = "bedrock",
396 model_count = result.len(),
397 "Discovered Bedrock models dynamically"
398 );
399
400 Ok(result)
401 }
402
403 fn estimate_context_window(model_id: &str, provider: &str) -> usize {
405 let id = model_id.to_lowercase();
406 if id.contains("anthropic") || id.contains("claude") {
407 200_000
408 } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
409 {
410 300_000
411 } else if id.contains("nova-micro") || id.contains("nova-2") {
412 128_000
413 } else if id.contains("deepseek") {
414 128_000
415 } else if id.contains("llama4") {
416 256_000
417 } else if id.contains("llama3") {
418 128_000
419 } else if id.contains("mistral-large-3") || id.contains("magistral") {
420 128_000
421 } else if id.contains("mistral") {
422 32_000
423 } else if id.contains("qwen") {
424 128_000
425 } else if id.contains("kimi") {
426 128_000
427 } else if id.contains("jamba") {
428 256_000
429 } else if id.contains("glm") {
430 128_000
431 } else if id.contains("minimax") {
432 128_000
433 } else if id.contains("gemma") {
434 128_000
435 } else if id.contains("cohere") || id.contains("command") {
436 128_000
437 } else if id.contains("nemotron") {
438 128_000
439 } else if provider.to_lowercase().contains("amazon") {
440 128_000
441 } else {
442 32_000
443 }
444 }
445
446 fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
448 let id = model_id.to_lowercase();
449 if id.contains("claude-opus-4-6") {
450 32_000
451 } else if id.contains("claude-opus-4-5") {
452 32_000
453 } else if id.contains("claude-opus-4-1") {
454 32_000
455 } else if id.contains("claude-sonnet-4-5")
456 || id.contains("claude-sonnet-4")
457 || id.contains("claude-3-7")
458 {
459 64_000
460 } else if id.contains("claude-haiku-4-5") {
461 16_384
462 } else if id.contains("claude-opus-4") {
463 32_000
464 } else if id.contains("claude") {
465 8_192
466 } else if id.contains("nova") {
467 5_000
468 } else if id.contains("deepseek") {
469 16_384
470 } else if id.contains("llama4") {
471 16_384
472 } else if id.contains("llama") {
473 4_096
474 } else if id.contains("mistral-large-3") {
475 16_384
476 } else if id.contains("mistral") || id.contains("mixtral") {
477 8_192
478 } else if id.contains("qwen") {
479 8_192
480 } else if id.contains("kimi") {
481 8_192
482 } else if id.contains("jamba") {
483 4_096
484 } else {
485 4_096
486 }
487 }
488
489 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
503 let mut system_parts: Vec<Value> = Vec::new();
504 let mut api_messages: Vec<Value> = Vec::new();
505
506 for msg in messages {
507 match msg.role {
508 Role::System => {
509 let text: String = msg
510 .content
511 .iter()
512 .filter_map(|p| match p {
513 ContentPart::Text { text } => Some(text.clone()),
514 _ => None,
515 })
516 .collect::<Vec<_>>()
517 .join("\n");
518 system_parts.push(json!({"text": text}));
519 }
520 Role::User => {
521 let mut content_parts: Vec<Value> = Vec::new();
522 for part in &msg.content {
523 match part {
524 ContentPart::Text { text } => {
525 if !text.is_empty() {
526 content_parts.push(json!({"text": text}));
527 }
528 }
529 _ => {}
530 }
531 }
532 if !content_parts.is_empty() {
533 if let Some(last) = api_messages.last_mut() {
535 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
536 if let Some(arr) =
537 last.get_mut("content").and_then(|c| c.as_array_mut())
538 {
539 arr.extend(content_parts);
540 continue;
541 }
542 }
543 }
544 api_messages.push(json!({
545 "role": "user",
546 "content": content_parts
547 }));
548 }
549 }
550 Role::Assistant => {
551 let mut content_parts: Vec<Value> = Vec::new();
552 for part in &msg.content {
553 match part {
554 ContentPart::Text { text } => {
555 if !text.is_empty() {
556 content_parts.push(json!({"text": text}));
557 }
558 }
559 ContentPart::ToolCall {
560 id,
561 name,
562 arguments,
563 } => {
564 let input: Value = serde_json::from_str(arguments)
565 .unwrap_or_else(|_| json!({"raw": arguments}));
566 content_parts.push(json!({
567 "toolUse": {
568 "toolUseId": id,
569 "name": name,
570 "input": input
571 }
572 }));
573 }
574 _ => {}
575 }
576 }
577 if content_parts.is_empty() {
578 content_parts.push(json!({"text": ""}));
579 }
580 if let Some(last) = api_messages.last_mut() {
582 if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
583 if let Some(arr) =
584 last.get_mut("content").and_then(|c| c.as_array_mut())
585 {
586 arr.extend(content_parts);
587 continue;
588 }
589 }
590 }
591 api_messages.push(json!({
592 "role": "assistant",
593 "content": content_parts
594 }));
595 }
596 Role::Tool => {
597 let mut content_parts: Vec<Value> = Vec::new();
601 for part in &msg.content {
602 if let ContentPart::ToolResult {
603 tool_call_id,
604 content,
605 } = part
606 {
607 content_parts.push(json!({
608 "toolResult": {
609 "toolUseId": tool_call_id,
610 "content": [{"text": content}],
611 "status": "success"
612 }
613 }));
614 }
615 }
616 if !content_parts.is_empty() {
617 if let Some(last) = api_messages.last_mut() {
619 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
620 if let Some(arr) =
621 last.get_mut("content").and_then(|c| c.as_array_mut())
622 {
623 arr.extend(content_parts);
624 continue;
625 }
626 }
627 }
628 api_messages.push(json!({
629 "role": "user",
630 "content": content_parts
631 }));
632 }
633 }
634 }
635 }
636
637 (system_parts, api_messages)
638 }
639
640 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
641 tools
642 .iter()
643 .map(|t| {
644 json!({
645 "toolSpec": {
646 "name": t.name,
647 "description": t.description,
648 "inputSchema": {
649 "json": t.parameters
650 }
651 }
652 })
653 })
654 .collect()
655 }
656}
657
658#[derive(Debug, Deserialize)]
661#[serde(rename_all = "camelCase")]
662struct ConverseResponse {
663 output: ConverseOutput,
664 #[serde(default)]
665 stop_reason: Option<String>,
666 #[serde(default)]
667 usage: Option<ConverseUsage>,
668}
669
670#[derive(Debug, Deserialize)]
671struct ConverseOutput {
672 message: ConverseMessage,
673}
674
675#[derive(Debug, Deserialize)]
676struct ConverseMessage {
677 #[allow(dead_code)]
678 role: String,
679 content: Vec<ConverseContent>,
680}
681
682#[derive(Debug, Deserialize)]
683#[serde(untagged)]
684enum ConverseContent {
685 Text {
686 text: String,
687 },
688 ToolUse {
689 #[serde(rename = "toolUse")]
690 tool_use: ConverseToolUse,
691 },
692}
693
694#[derive(Debug, Deserialize)]
695#[serde(rename_all = "camelCase")]
696struct ConverseToolUse {
697 tool_use_id: String,
698 name: String,
699 input: Value,
700}
701
702#[derive(Debug, Deserialize)]
703#[serde(rename_all = "camelCase")]
704struct ConverseUsage {
705 #[serde(default)]
706 input_tokens: usize,
707 #[serde(default)]
708 output_tokens: usize,
709 #[serde(default)]
710 total_tokens: usize,
711}
712
713#[derive(Debug, Deserialize)]
714struct BedrockError {
715 message: String,
716}
717
718#[async_trait]
719impl Provider for BedrockProvider {
720 fn name(&self) -> &str {
721 "bedrock"
722 }
723
724 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
725 self.validate_api_key()?;
726 self.discover_models().await
727 }
728
729 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
730 let model_id = Self::resolve_model_id(&request.model);
731
732 tracing::debug!(
733 provider = "bedrock",
734 model = %model_id,
735 original_model = %request.model,
736 message_count = request.messages.len(),
737 tool_count = request.tools.len(),
738 "Starting Bedrock Converse request"
739 );
740
741 self.validate_api_key()?;
742
743 let (system_parts, messages) = Self::convert_messages(&request.messages);
744 let tools = Self::convert_tools(&request.tools);
745
746 let mut body = json!({
747 "messages": messages,
748 });
749
750 if !system_parts.is_empty() {
751 body["system"] = json!(system_parts);
752 }
753
754 let mut inference_config = json!({});
756 if let Some(max_tokens) = request.max_tokens {
757 inference_config["maxTokens"] = json!(max_tokens);
758 } else {
759 inference_config["maxTokens"] = json!(8192);
760 }
761 if let Some(temp) = request.temperature {
762 inference_config["temperature"] = json!(temp);
763 }
764 if let Some(top_p) = request.top_p {
765 inference_config["topP"] = json!(top_p);
766 }
767 body["inferenceConfig"] = inference_config;
768
769 if !tools.is_empty() {
770 body["toolConfig"] = json!({"tools": tools});
771 }
772
773 let encoded_model_id = model_id.replace(':', "%3A");
775 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
776 tracing::debug!("Bedrock request URL: {}", url);
777
778 let response = self
779 .client
780 .post(&url)
781 .bearer_auth(&self.api_key)
782 .header("content-type", "application/json")
783 .header("accept", "application/json")
784 .json(&body)
785 .send()
786 .await
787 .context("Failed to send request to Bedrock")?;
788
789 let status = response.status();
790 let text = response
791 .text()
792 .await
793 .context("Failed to read Bedrock response")?;
794
795 if !status.is_success() {
796 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
797 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
798 }
799 anyhow::bail!(
800 "Bedrock API error: {} {}",
801 status,
802 &text[..text.len().min(500)]
803 );
804 }
805
806 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
807 "Failed to parse Bedrock response: {}",
808 &text[..text.len().min(300)]
809 ))?;
810
811 tracing::debug!(
812 stop_reason = ?response.stop_reason,
813 "Received Bedrock response"
814 );
815
816 let mut content = Vec::new();
817 let mut has_tool_calls = false;
818
819 for part in &response.output.message.content {
820 match part {
821 ConverseContent::Text { text } => {
822 if !text.is_empty() {
823 content.push(ContentPart::Text { text: text.clone() });
824 }
825 }
826 ConverseContent::ToolUse { tool_use } => {
827 has_tool_calls = true;
828 content.push(ContentPart::ToolCall {
829 id: tool_use.tool_use_id.clone(),
830 name: tool_use.name.clone(),
831 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
832 });
833 }
834 }
835 }
836
837 let finish_reason = if has_tool_calls {
838 FinishReason::ToolCalls
839 } else {
840 match response.stop_reason.as_deref() {
841 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
842 Some("max_tokens") => FinishReason::Length,
843 Some("tool_use") => FinishReason::ToolCalls,
844 Some("content_filtered") => FinishReason::ContentFilter,
845 _ => FinishReason::Stop,
846 }
847 };
848
849 let usage = response.usage.as_ref();
850
851 Ok(CompletionResponse {
852 message: Message {
853 role: Role::Assistant,
854 content,
855 },
856 usage: Usage {
857 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
858 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
859 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
860 cache_read_tokens: None,
861 cache_write_tokens: None,
862 },
863 finish_reason,
864 })
865 }
866
867 async fn complete_stream(
868 &self,
869 request: CompletionRequest,
870 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
871 let response = self.complete(request).await?;
873 let text = response
874 .message
875 .content
876 .iter()
877 .filter_map(|p| match p {
878 ContentPart::Text { text } => Some(text.clone()),
879 _ => None,
880 })
881 .collect::<Vec<_>>()
882 .join("");
883
884 Ok(Box::pin(futures::stream::once(async move {
885 StreamChunk::Text(text)
886 })))
887 }
888}