1use super::{
10 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11 Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23 client: Client,
24 api_key: String,
25 region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("BedrockProvider")
31 .field("api_key", &"<REDACTED>")
32 .field("region", &self.region)
33 .finish()
34 }
35}
36
37impl BedrockProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 Self::with_region(api_key, DEFAULT_REGION.to_string())
40 }
41
42 pub fn with_region(api_key: String, region: String) -> Result<Self> {
43 tracing::debug!(
44 provider = "bedrock",
45 region = %region,
46 api_key_len = api_key.len(),
47 "Creating Bedrock provider"
48 );
49 Ok(Self {
50 client: Client::new(),
51 api_key,
52 region,
53 })
54 }
55
56 fn validate_api_key(&self) -> Result<()> {
57 if self.api_key.is_empty() {
58 anyhow::bail!("Bedrock API key is empty");
59 }
60 Ok(())
61 }
62
63 fn base_url(&self) -> String {
64 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65 }
66
67 fn management_url(&self) -> String {
69 format!("https://bedrock.{}.amazonaws.com", self.region)
70 }
71
72 fn resolve_model_id(model: &str) -> &str {
76 match model {
77 "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
79 "claude-opus-4.5" | "claude-4.5-opus" => {
80 "us.anthropic.claude-opus-4-5-20251101-v1:0"
81 }
82 "claude-opus-4.1" | "claude-4.1-opus" => {
83 "us.anthropic.claude-opus-4-1-20250805-v1:0"
84 }
85 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
86 "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
87 "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
88 }
89 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
90 "claude-haiku-4.5" | "claude-4.5-haiku" => {
91 "us.anthropic.claude-haiku-4-5-20251001-v1:0"
92 }
93 "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
94 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
95 }
96 "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
97 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
98 }
99 "claude-3.5-haiku" | "claude-haiku-3.5" => {
100 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
101 }
102 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
103 "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
104 }
105 "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
106 "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
107 "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
108
109 "nova-pro" => "amazon.nova-pro-v1:0",
111 "nova-lite" => "amazon.nova-lite-v1:0",
112 "nova-micro" => "amazon.nova-micro-v1:0",
113 "nova-premier" => "us.amazon.nova-premier-v1:0",
114
115 "llama-4-maverick" | "llama4-maverick" => {
117 "us.meta.llama4-maverick-17b-instruct-v1:0"
118 }
119 "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
120 "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
121 "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
122 "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
123 "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
124 "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
125 "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
126 "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
127 "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
128 "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
129
130 "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
132 "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
133 "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
134 "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
135 "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
136 "magistral-small" => "us.mistral.magistral-small-2509",
137
138 "deepseek-r1" => "us.deepseek.r1-v1:0",
140 "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
141
142 "command-r" => "us.cohere.command-r-v1:0",
144 "command-r-plus" => "us.cohere.command-r-plus-v1:0",
145
146 "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
148 "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
149 "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
150
151 "gemma-3-27b" => "us.google.gemma-3-27b-it",
153 "gemma-3-12b" => "us.google.gemma-3-12b-it",
154 "gemma-3-4b" => "us.google.gemma-3-4b-it",
155
156 "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
158 "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
159
160 "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
162 "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
163
164 "minimax-m2" => "us.minimax.minimax-m2",
166 "minimax-m2.1" => "us.minimax.minimax-m2.1",
167
168 "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
170 "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
171 "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
172
173 "glm-4.7" => "us.zai.glm-4.7",
175 "glm-4.7-flash" => "us.zai.glm-4.7-flash",
176
177 other => other,
179 }
180 }
181
182 async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
185 let mut models: HashMap<String, ModelInfo> = HashMap::new();
186
187 let fm_url = format!("{}/foundation-models", self.management_url());
189 let fm_resp = self
190 .client
191 .get(&fm_url)
192 .bearer_auth(&self.api_key)
193 .send()
194 .await;
195
196 if let Ok(resp) = fm_resp {
197 if resp.status().is_success() {
198 if let Ok(data) = resp.json::<Value>().await {
199 if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
200 for m in summaries {
201 let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
202 let model_name =
203 m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
204 let provider_name =
205 m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
206
207 let output_modalities: Vec<&str> = m
208 .get("outputModalities")
209 .and_then(|v| v.as_array())
210 .map(|a| {
211 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
212 })
213 .unwrap_or_default();
214
215 let input_modalities: Vec<&str> = m
216 .get("inputModalities")
217 .and_then(|v| v.as_array())
218 .map(|a| {
219 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
220 })
221 .unwrap_or_default();
222
223 let inference_types: Vec<&str> = m
224 .get("inferenceTypesSupported")
225 .and_then(|v| v.as_array())
226 .map(|a| {
227 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
228 })
229 .unwrap_or_default();
230
231 if !output_modalities.contains(&"TEXT")
233 || (!inference_types.contains(&"ON_DEMAND")
234 && !inference_types.contains(&"INFERENCE_PROFILE"))
235 {
236 continue;
237 }
238
239 let name_lower = model_name.to_lowercase();
241 if name_lower.contains("rerank")
242 || name_lower.contains("embed")
243 || name_lower.contains("safeguard")
244 || name_lower.contains("sonic")
245 || name_lower.contains("pegasus")
246 {
247 continue;
248 }
249
250 let streaming = m
251 .get("responseStreamingSupported")
252 .and_then(|v| v.as_bool())
253 .unwrap_or(false);
254 let vision = input_modalities.contains(&"IMAGE");
255
256 let actual_id = if model_id.starts_with("amazon.") {
258 model_id.to_string()
259 } else {
260 format!("us.{}", model_id)
261 };
262
263 let display_name = format!("{} (Bedrock)", model_name);
264
265 models.insert(
266 actual_id.clone(),
267 ModelInfo {
268 id: actual_id,
269 name: display_name,
270 provider: "bedrock".to_string(),
271 context_window: Self::estimate_context_window(
272 model_id,
273 provider_name,
274 ),
275 max_output_tokens: Some(Self::estimate_max_output(
276 model_id,
277 provider_name,
278 )),
279 supports_vision: vision,
280 supports_tools: true,
281 supports_streaming: streaming,
282 input_cost_per_million: None,
283 output_cost_per_million: None,
284 },
285 );
286 }
287 }
288 }
289 }
290 }
291
292 let ip_url = format!(
295 "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
296 self.management_url()
297 );
298 let ip_resp = self
299 .client
300 .get(&ip_url)
301 .bearer_auth(&self.api_key)
302 .send()
303 .await;
304
305 if let Ok(resp) = ip_resp {
306 if resp.status().is_success() {
307 if let Ok(data) = resp.json::<Value>().await {
308 if let Some(profiles) = data
309 .get("inferenceProfileSummaries")
310 .and_then(|v| v.as_array())
311 {
312 for p in profiles {
313 let pid = p
314 .get("inferenceProfileId")
315 .and_then(|v| v.as_str())
316 .unwrap_or("");
317 let pname = p
318 .get("inferenceProfileName")
319 .and_then(|v| v.as_str())
320 .unwrap_or("");
321
322 if !pid.starts_with("us.") {
324 continue;
325 }
326
327 if models.contains_key(pid) {
329 continue;
330 }
331
332 let name_lower = pname.to_lowercase();
334 if name_lower.contains("image")
335 || name_lower.contains("stable ")
336 || name_lower.contains("upscale")
337 || name_lower.contains("embed")
338 || name_lower.contains("marengo")
339 || name_lower.contains("outpaint")
340 || name_lower.contains("inpaint")
341 || name_lower.contains("erase")
342 || name_lower.contains("recolor")
343 || name_lower.contains("replace")
344 || name_lower.contains("style ")
345 || name_lower.contains("background")
346 || name_lower.contains("sketch")
347 || name_lower.contains("control")
348 || name_lower.contains("transfer")
349 || name_lower.contains("sonic")
350 || name_lower.contains("pegasus")
351 || name_lower.contains("rerank")
352 {
353 continue;
354 }
355
356 let vision = pid.contains("llama3-2-11b")
358 || pid.contains("llama3-2-90b")
359 || pid.contains("pixtral")
360 || pid.contains("claude-3")
361 || pid.contains("claude-sonnet-4")
362 || pid.contains("claude-opus-4")
363 || pid.contains("claude-haiku-4");
364
365 let display_name = pname.replace("US ", "");
366 let display_name = format!("{} (Bedrock)", display_name.trim());
367
368 let provider_hint = pid
370 .strip_prefix("us.")
371 .unwrap_or(pid)
372 .split('.')
373 .next()
374 .unwrap_or("");
375
376 models.insert(
377 pid.to_string(),
378 ModelInfo {
379 id: pid.to_string(),
380 name: display_name,
381 provider: "bedrock".to_string(),
382 context_window: Self::estimate_context_window(
383 pid,
384 provider_hint,
385 ),
386 max_output_tokens: Some(Self::estimate_max_output(
387 pid,
388 provider_hint,
389 )),
390 supports_vision: vision,
391 supports_tools: true,
392 supports_streaming: true,
393 input_cost_per_million: None,
394 output_cost_per_million: None,
395 },
396 );
397 }
398 }
399 }
400 }
401 }
402
403 let mut result: Vec<ModelInfo> = models.into_values().collect();
404 result.sort_by(|a, b| a.id.cmp(&b.id));
405
406 tracing::info!(
407 provider = "bedrock",
408 model_count = result.len(),
409 "Discovered Bedrock models dynamically"
410 );
411
412 Ok(result)
413 }
414
415 fn estimate_context_window(model_id: &str, provider: &str) -> usize {
417 let id = model_id.to_lowercase();
418 if id.contains("anthropic") || id.contains("claude") {
419 200_000
420 } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
421 {
422 300_000
423 } else if id.contains("nova-micro") || id.contains("nova-2") {
424 128_000
425 } else if id.contains("deepseek") {
426 128_000
427 } else if id.contains("llama4") {
428 256_000
429 } else if id.contains("llama3") {
430 128_000
431 } else if id.contains("mistral-large-3") || id.contains("magistral") {
432 128_000
433 } else if id.contains("mistral") {
434 32_000
435 } else if id.contains("qwen") {
436 128_000
437 } else if id.contains("kimi") {
438 128_000
439 } else if id.contains("jamba") {
440 256_000
441 } else if id.contains("glm") {
442 128_000
443 } else if id.contains("minimax") {
444 128_000
445 } else if id.contains("gemma") {
446 128_000
447 } else if id.contains("cohere") || id.contains("command") {
448 128_000
449 } else if id.contains("nemotron") {
450 128_000
451 } else if provider.to_lowercase().contains("amazon") {
452 128_000
453 } else {
454 32_000
455 }
456 }
457
458 fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
460 let id = model_id.to_lowercase();
461 if id.contains("claude-opus-4-6") {
462 32_000
463 } else if id.contains("claude-opus-4-5") {
464 32_000
465 } else if id.contains("claude-opus-4-1") {
466 32_000
467 } else if id.contains("claude-sonnet-4-5") || id.contains("claude-sonnet-4") || id.contains("claude-3-7") {
468 64_000
469 } else if id.contains("claude-haiku-4-5") {
470 16_384
471 } else if id.contains("claude-opus-4") {
472 32_000
473 } else if id.contains("claude") {
474 8_192
475 } else if id.contains("nova") {
476 5_000
477 } else if id.contains("deepseek") {
478 16_384
479 } else if id.contains("llama4") {
480 16_384
481 } else if id.contains("llama") {
482 4_096
483 } else if id.contains("mistral-large-3") {
484 16_384
485 } else if id.contains("mistral") || id.contains("mixtral") {
486 8_192
487 } else if id.contains("qwen") {
488 8_192
489 } else if id.contains("kimi") {
490 8_192
491 } else if id.contains("jamba") {
492 4_096
493 } else {
494 4_096
495 }
496 }
497
498 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
512 let mut system_parts: Vec<Value> = Vec::new();
513 let mut api_messages: Vec<Value> = Vec::new();
514
515 for msg in messages {
516 match msg.role {
517 Role::System => {
518 let text: String = msg
519 .content
520 .iter()
521 .filter_map(|p| match p {
522 ContentPart::Text { text } => Some(text.clone()),
523 _ => None,
524 })
525 .collect::<Vec<_>>()
526 .join("\n");
527 system_parts.push(json!({"text": text}));
528 }
529 Role::User => {
530 let mut content_parts: Vec<Value> = Vec::new();
531 for part in &msg.content {
532 match part {
533 ContentPart::Text { text } => {
534 if !text.is_empty() {
535 content_parts.push(json!({"text": text}));
536 }
537 }
538 _ => {}
539 }
540 }
541 if !content_parts.is_empty() {
542 if let Some(last) = api_messages.last_mut() {
544 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
545 if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
546 arr.extend(content_parts);
547 continue;
548 }
549 }
550 }
551 api_messages.push(json!({
552 "role": "user",
553 "content": content_parts
554 }));
555 }
556 }
557 Role::Assistant => {
558 let mut content_parts: Vec<Value> = Vec::new();
559 for part in &msg.content {
560 match part {
561 ContentPart::Text { text } => {
562 if !text.is_empty() {
563 content_parts.push(json!({"text": text}));
564 }
565 }
566 ContentPart::ToolCall {
567 id,
568 name,
569 arguments,
570 } => {
571 let input: Value = serde_json::from_str(arguments)
572 .unwrap_or_else(|_| json!({"raw": arguments}));
573 content_parts.push(json!({
574 "toolUse": {
575 "toolUseId": id,
576 "name": name,
577 "input": input
578 }
579 }));
580 }
581 _ => {}
582 }
583 }
584 if content_parts.is_empty() {
585 content_parts.push(json!({"text": ""}));
586 }
587 if let Some(last) = api_messages.last_mut() {
589 if last.get("role").and_then(|r| r.as_str()) == Some("assistant") {
590 if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
591 arr.extend(content_parts);
592 continue;
593 }
594 }
595 }
596 api_messages.push(json!({
597 "role": "assistant",
598 "content": content_parts
599 }));
600 }
601 Role::Tool => {
602 let mut content_parts: Vec<Value> = Vec::new();
606 for part in &msg.content {
607 if let ContentPart::ToolResult {
608 tool_call_id,
609 content,
610 } = part
611 {
612 content_parts.push(json!({
613 "toolResult": {
614 "toolUseId": tool_call_id,
615 "content": [{"text": content}],
616 "status": "success"
617 }
618 }));
619 }
620 }
621 if !content_parts.is_empty() {
622 if let Some(last) = api_messages.last_mut() {
624 if last.get("role").and_then(|r| r.as_str()) == Some("user") {
625 if let Some(arr) = last.get_mut("content").and_then(|c| c.as_array_mut()) {
626 arr.extend(content_parts);
627 continue;
628 }
629 }
630 }
631 api_messages.push(json!({
632 "role": "user",
633 "content": content_parts
634 }));
635 }
636 }
637 }
638 }
639
640 (system_parts, api_messages)
641 }
642
643 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
644 tools
645 .iter()
646 .map(|t| {
647 json!({
648 "toolSpec": {
649 "name": t.name,
650 "description": t.description,
651 "inputSchema": {
652 "json": t.parameters
653 }
654 }
655 })
656 })
657 .collect()
658 }
659}
660
661#[derive(Debug, Deserialize)]
664#[serde(rename_all = "camelCase")]
665struct ConverseResponse {
666 output: ConverseOutput,
667 #[serde(default)]
668 stop_reason: Option<String>,
669 #[serde(default)]
670 usage: Option<ConverseUsage>,
671}
672
673#[derive(Debug, Deserialize)]
674struct ConverseOutput {
675 message: ConverseMessage,
676}
677
678#[derive(Debug, Deserialize)]
679struct ConverseMessage {
680 #[allow(dead_code)]
681 role: String,
682 content: Vec<ConverseContent>,
683}
684
685#[derive(Debug, Deserialize)]
686#[serde(untagged)]
687enum ConverseContent {
688 Text {
689 text: String,
690 },
691 ToolUse {
692 #[serde(rename = "toolUse")]
693 tool_use: ConverseToolUse,
694 },
695}
696
697#[derive(Debug, Deserialize)]
698#[serde(rename_all = "camelCase")]
699struct ConverseToolUse {
700 tool_use_id: String,
701 name: String,
702 input: Value,
703}
704
705#[derive(Debug, Deserialize)]
706#[serde(rename_all = "camelCase")]
707struct ConverseUsage {
708 #[serde(default)]
709 input_tokens: usize,
710 #[serde(default)]
711 output_tokens: usize,
712 #[serde(default)]
713 total_tokens: usize,
714}
715
716#[derive(Debug, Deserialize)]
717struct BedrockError {
718 message: String,
719}
720
721#[async_trait]
722impl Provider for BedrockProvider {
723 fn name(&self) -> &str {
724 "bedrock"
725 }
726
727 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
728 self.validate_api_key()?;
729 self.discover_models().await
730 }
731
732 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
733 let model_id = Self::resolve_model_id(&request.model);
734
735 tracing::debug!(
736 provider = "bedrock",
737 model = %model_id,
738 original_model = %request.model,
739 message_count = request.messages.len(),
740 tool_count = request.tools.len(),
741 "Starting Bedrock Converse request"
742 );
743
744 self.validate_api_key()?;
745
746 let (system_parts, messages) = Self::convert_messages(&request.messages);
747 let tools = Self::convert_tools(&request.tools);
748
749 let mut body = json!({
750 "messages": messages,
751 });
752
753 if !system_parts.is_empty() {
754 body["system"] = json!(system_parts);
755 }
756
757 let mut inference_config = json!({});
759 if let Some(max_tokens) = request.max_tokens {
760 inference_config["maxTokens"] = json!(max_tokens);
761 } else {
762 inference_config["maxTokens"] = json!(8192);
763 }
764 if let Some(temp) = request.temperature {
765 inference_config["temperature"] = json!(temp);
766 }
767 if let Some(top_p) = request.top_p {
768 inference_config["topP"] = json!(top_p);
769 }
770 body["inferenceConfig"] = inference_config;
771
772 if !tools.is_empty() {
773 body["toolConfig"] = json!({"tools": tools});
774 }
775
776 let encoded_model_id = model_id.replace(':', "%3A");
778 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
779 tracing::debug!("Bedrock request URL: {}", url);
780
781 let response = self
782 .client
783 .post(&url)
784 .bearer_auth(&self.api_key)
785 .header("content-type", "application/json")
786 .header("accept", "application/json")
787 .json(&body)
788 .send()
789 .await
790 .context("Failed to send request to Bedrock")?;
791
792 let status = response.status();
793 let text = response
794 .text()
795 .await
796 .context("Failed to read Bedrock response")?;
797
798 if !status.is_success() {
799 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
800 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
801 }
802 anyhow::bail!(
803 "Bedrock API error: {} {}",
804 status,
805 &text[..text.len().min(500)]
806 );
807 }
808
809 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
810 "Failed to parse Bedrock response: {}",
811 &text[..text.len().min(300)]
812 ))?;
813
814 tracing::debug!(
815 stop_reason = ?response.stop_reason,
816 "Received Bedrock response"
817 );
818
819 let mut content = Vec::new();
820 let mut has_tool_calls = false;
821
822 for part in &response.output.message.content {
823 match part {
824 ConverseContent::Text { text } => {
825 if !text.is_empty() {
826 content.push(ContentPart::Text { text: text.clone() });
827 }
828 }
829 ConverseContent::ToolUse { tool_use } => {
830 has_tool_calls = true;
831 content.push(ContentPart::ToolCall {
832 id: tool_use.tool_use_id.clone(),
833 name: tool_use.name.clone(),
834 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
835 });
836 }
837 }
838 }
839
840 let finish_reason = if has_tool_calls {
841 FinishReason::ToolCalls
842 } else {
843 match response.stop_reason.as_deref() {
844 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
845 Some("max_tokens") => FinishReason::Length,
846 Some("tool_use") => FinishReason::ToolCalls,
847 Some("content_filtered") => FinishReason::ContentFilter,
848 _ => FinishReason::Stop,
849 }
850 };
851
852 let usage = response.usage.as_ref();
853
854 Ok(CompletionResponse {
855 message: Message {
856 role: Role::Assistant,
857 content,
858 },
859 usage: Usage {
860 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
861 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
862 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
863 cache_read_tokens: None,
864 cache_write_tokens: None,
865 },
866 finish_reason,
867 })
868 }
869
870 async fn complete_stream(
871 &self,
872 request: CompletionRequest,
873 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
874 let response = self.complete(request).await?;
876 let text = response
877 .message
878 .content
879 .iter()
880 .filter_map(|p| match p {
881 ContentPart::Text { text } => Some(text.clone()),
882 _ => None,
883 })
884 .collect::<Vec<_>>()
885 .join("");
886
887 Ok(Box::pin(futures::stream::once(async move {
888 StreamChunk::Text(text)
889 })))
890 }
891}