1use super::{
10 CompletionRequest, CompletionResponse, ContentPart, FinishReason, Message, ModelInfo, Provider,
11 Role, StreamChunk, ToolDefinition, Usage,
12};
13use anyhow::{Context, Result};
14use async_trait::async_trait;
15use reqwest::Client;
16use serde::Deserialize;
17use serde_json::{Value, json};
18use std::collections::HashMap;
19
20const DEFAULT_REGION: &str = "us-east-1";
21
22pub struct BedrockProvider {
23 client: Client,
24 api_key: String,
25 region: String,
26}
27
28impl std::fmt::Debug for BedrockProvider {
29 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
30 f.debug_struct("BedrockProvider")
31 .field("api_key", &"<REDACTED>")
32 .field("region", &self.region)
33 .finish()
34 }
35}
36
37impl BedrockProvider {
38 pub fn new(api_key: String) -> Result<Self> {
39 Self::with_region(api_key, DEFAULT_REGION.to_string())
40 }
41
42 pub fn with_region(api_key: String, region: String) -> Result<Self> {
43 tracing::debug!(
44 provider = "bedrock",
45 region = %region,
46 api_key_len = api_key.len(),
47 "Creating Bedrock provider"
48 );
49 Ok(Self {
50 client: Client::new(),
51 api_key,
52 region,
53 })
54 }
55
56 fn validate_api_key(&self) -> Result<()> {
57 if self.api_key.is_empty() {
58 anyhow::bail!("Bedrock API key is empty");
59 }
60 Ok(())
61 }
62
63 fn base_url(&self) -> String {
64 format!("https://bedrock-runtime.{}.amazonaws.com", self.region)
65 }
66
67 fn management_url(&self) -> String {
69 format!("https://bedrock.{}.amazonaws.com", self.region)
70 }
71
72 fn resolve_model_id(model: &str) -> &str {
76 match model {
77 "claude-opus-4.6" | "claude-4.6-opus" => "us.anthropic.claude-opus-4-6-v1",
79 "claude-opus-4.5" | "claude-4.5-opus" => {
80 "us.anthropic.claude-opus-4-5-20251101-v1:0"
81 }
82 "claude-opus-4.1" | "claude-4.1-opus" => {
83 "us.anthropic.claude-opus-4-1-20250805-v1:0"
84 }
85 "claude-opus-4" | "claude-4-opus" => "us.anthropic.claude-opus-4-20250514-v1:0",
86 "claude-sonnet-4.5" | "claude-4.5-sonnet" => {
87 "us.anthropic.claude-sonnet-4-5-20250929-v1:0"
88 }
89 "claude-sonnet-4" | "claude-4-sonnet" => "us.anthropic.claude-sonnet-4-20250514-v1:0",
90 "claude-haiku-4.5" | "claude-4.5-haiku" => {
91 "us.anthropic.claude-haiku-4-5-20251001-v1:0"
92 }
93 "claude-3.7-sonnet" | "claude-sonnet-3.7" => {
94 "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
95 }
96 "claude-3.5-sonnet-v2" | "claude-sonnet-3.5-v2" => {
97 "us.anthropic.claude-3-5-sonnet-20241022-v2:0"
98 }
99 "claude-3.5-haiku" | "claude-haiku-3.5" => {
100 "us.anthropic.claude-3-5-haiku-20241022-v1:0"
101 }
102 "claude-3.5-sonnet" | "claude-sonnet-3.5" => {
103 "us.anthropic.claude-3-5-sonnet-20240620-v1:0"
104 }
105 "claude-3-opus" | "claude-opus-3" => "us.anthropic.claude-3-opus-20240229-v1:0",
106 "claude-3-haiku" | "claude-haiku-3" => "us.anthropic.claude-3-haiku-20240307-v1:0",
107 "claude-3-sonnet" | "claude-sonnet-3" => "us.anthropic.claude-3-sonnet-20240229-v1:0",
108
109 "nova-pro" => "amazon.nova-pro-v1:0",
111 "nova-lite" => "amazon.nova-lite-v1:0",
112 "nova-micro" => "amazon.nova-micro-v1:0",
113 "nova-premier" => "us.amazon.nova-premier-v1:0",
114
115 "llama-4-maverick" | "llama4-maverick" => {
117 "us.meta.llama4-maverick-17b-instruct-v1:0"
118 }
119 "llama-4-scout" | "llama4-scout" => "us.meta.llama4-scout-17b-instruct-v1:0",
120 "llama-3.3-70b" | "llama3.3-70b" => "us.meta.llama3-3-70b-instruct-v1:0",
121 "llama-3.2-90b" | "llama3.2-90b" => "us.meta.llama3-2-90b-instruct-v1:0",
122 "llama-3.2-11b" | "llama3.2-11b" => "us.meta.llama3-2-11b-instruct-v1:0",
123 "llama-3.2-3b" | "llama3.2-3b" => "us.meta.llama3-2-3b-instruct-v1:0",
124 "llama-3.2-1b" | "llama3.2-1b" => "us.meta.llama3-2-1b-instruct-v1:0",
125 "llama-3.1-70b" | "llama3.1-70b" => "us.meta.llama3-1-70b-instruct-v1:0",
126 "llama-3.1-8b" | "llama3.1-8b" => "us.meta.llama3-1-8b-instruct-v1:0",
127 "llama-3-70b" | "llama3-70b" => "us.meta.llama3-70b-instruct-v1:0",
128 "llama-3-8b" | "llama3-8b" => "us.meta.llama3-8b-instruct-v1:0",
129
130 "mistral-large-3" | "mistral-large" => "us.mistral.mistral-large-3-675b-instruct",
132 "mistral-large-2402" => "us.mistral.mistral-large-2402-v1:0",
133 "mistral-small" => "us.mistral.mistral-small-2402-v1:0",
134 "mixtral-8x7b" => "us.mistral.mixtral-8x7b-instruct-v0:1",
135 "pixtral-large" => "us.mistral.pixtral-large-2502-v1:0",
136 "magistral-small" => "us.mistral.magistral-small-2509",
137
138 "deepseek-r1" => "us.deepseek.r1-v1:0",
140 "deepseek-v3" | "deepseek-v3.2" => "us.deepseek.v3.2",
141
142 "command-r" => "us.cohere.command-r-v1:0",
144 "command-r-plus" => "us.cohere.command-r-plus-v1:0",
145
146 "qwen3-32b" => "us.qwen.qwen3-32b-v1:0",
148 "qwen3-coder" | "qwen3-coder-next" => "us.qwen.qwen3-coder-next",
149 "qwen3-coder-30b" => "us.qwen.qwen3-coder-30b-a3b-v1:0",
150
151 "gemma-3-27b" => "us.google.gemma-3-27b-it",
153 "gemma-3-12b" => "us.google.gemma-3-12b-it",
154 "gemma-3-4b" => "us.google.gemma-3-4b-it",
155
156 "kimi-k2" | "kimi-k2-thinking" => "us.moonshot.kimi-k2-thinking",
158 "kimi-k2.5" => "us.moonshotai.kimi-k2.5",
159
160 "jamba-1.5-large" => "us.ai21.jamba-1-5-large-v1:0",
162 "jamba-1.5-mini" => "us.ai21.jamba-1-5-mini-v1:0",
163
164 "minimax-m2" => "us.minimax.minimax-m2",
166 "minimax-m2.1" => "us.minimax.minimax-m2.1",
167
168 "nemotron-nano-30b" => "us.nvidia.nemotron-nano-3-30b",
170 "nemotron-nano-12b" => "us.nvidia.nemotron-nano-12b-v2",
171 "nemotron-nano-9b" => "us.nvidia.nemotron-nano-9b-v2",
172
173 "glm-4.7" => "us.zai.glm-4.7",
175 "glm-4.7-flash" => "us.zai.glm-4.7-flash",
176
177 other => other,
179 }
180 }
181
182 async fn discover_models(&self) -> Result<Vec<ModelInfo>> {
185 let mut models: HashMap<String, ModelInfo> = HashMap::new();
186
187 let fm_url = format!("{}/foundation-models", self.management_url());
189 let fm_resp = self
190 .client
191 .get(&fm_url)
192 .bearer_auth(&self.api_key)
193 .send()
194 .await;
195
196 if let Ok(resp) = fm_resp {
197 if resp.status().is_success() {
198 if let Ok(data) = resp.json::<Value>().await {
199 if let Some(summaries) = data.get("modelSummaries").and_then(|v| v.as_array()) {
200 for m in summaries {
201 let model_id = m.get("modelId").and_then(|v| v.as_str()).unwrap_or("");
202 let model_name =
203 m.get("modelName").and_then(|v| v.as_str()).unwrap_or("");
204 let provider_name =
205 m.get("providerName").and_then(|v| v.as_str()).unwrap_or("");
206
207 let output_modalities: Vec<&str> = m
208 .get("outputModalities")
209 .and_then(|v| v.as_array())
210 .map(|a| {
211 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
212 })
213 .unwrap_or_default();
214
215 let input_modalities: Vec<&str> = m
216 .get("inputModalities")
217 .and_then(|v| v.as_array())
218 .map(|a| {
219 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
220 })
221 .unwrap_or_default();
222
223 let inference_types: Vec<&str> = m
224 .get("inferenceTypesSupported")
225 .and_then(|v| v.as_array())
226 .map(|a| {
227 a.iter().filter_map(|v| v.as_str()).collect::<Vec<_>>()
228 })
229 .unwrap_or_default();
230
231 if !output_modalities.contains(&"TEXT")
233 || (!inference_types.contains(&"ON_DEMAND")
234 && !inference_types.contains(&"INFERENCE_PROFILE"))
235 {
236 continue;
237 }
238
239 let name_lower = model_name.to_lowercase();
241 if name_lower.contains("rerank")
242 || name_lower.contains("embed")
243 || name_lower.contains("safeguard")
244 || name_lower.contains("sonic")
245 || name_lower.contains("pegasus")
246 {
247 continue;
248 }
249
250 let streaming = m
251 .get("responseStreamingSupported")
252 .and_then(|v| v.as_bool())
253 .unwrap_or(false);
254 let vision = input_modalities.contains(&"IMAGE");
255
256 let actual_id = if model_id.starts_with("amazon.") {
258 model_id.to_string()
259 } else {
260 format!("us.{}", model_id)
261 };
262
263 let display_name = format!("{} (Bedrock)", model_name);
264
265 models.insert(
266 actual_id.clone(),
267 ModelInfo {
268 id: actual_id,
269 name: display_name,
270 provider: "bedrock".to_string(),
271 context_window: Self::estimate_context_window(
272 model_id,
273 provider_name,
274 ),
275 max_output_tokens: Some(Self::estimate_max_output(
276 model_id,
277 provider_name,
278 )),
279 supports_vision: vision,
280 supports_tools: true,
281 supports_streaming: streaming,
282 input_cost_per_million: None,
283 output_cost_per_million: None,
284 },
285 );
286 }
287 }
288 }
289 }
290 }
291
292 let ip_url = format!(
295 "{}/inference-profiles?typeEquals=SYSTEM_DEFINED&maxResults=200",
296 self.management_url()
297 );
298 let ip_resp = self
299 .client
300 .get(&ip_url)
301 .bearer_auth(&self.api_key)
302 .send()
303 .await;
304
305 if let Ok(resp) = ip_resp {
306 if resp.status().is_success() {
307 if let Ok(data) = resp.json::<Value>().await {
308 if let Some(profiles) = data
309 .get("inferenceProfileSummaries")
310 .and_then(|v| v.as_array())
311 {
312 for p in profiles {
313 let pid = p
314 .get("inferenceProfileId")
315 .and_then(|v| v.as_str())
316 .unwrap_or("");
317 let pname = p
318 .get("inferenceProfileName")
319 .and_then(|v| v.as_str())
320 .unwrap_or("");
321
322 if !pid.starts_with("us.") {
324 continue;
325 }
326
327 if models.contains_key(pid) {
329 continue;
330 }
331
332 let name_lower = pname.to_lowercase();
334 if name_lower.contains("image")
335 || name_lower.contains("stable ")
336 || name_lower.contains("upscale")
337 || name_lower.contains("embed")
338 || name_lower.contains("marengo")
339 || name_lower.contains("outpaint")
340 || name_lower.contains("inpaint")
341 || name_lower.contains("erase")
342 || name_lower.contains("recolor")
343 || name_lower.contains("replace")
344 || name_lower.contains("style ")
345 || name_lower.contains("background")
346 || name_lower.contains("sketch")
347 || name_lower.contains("control")
348 || name_lower.contains("transfer")
349 || name_lower.contains("sonic")
350 || name_lower.contains("pegasus")
351 || name_lower.contains("rerank")
352 {
353 continue;
354 }
355
356 let vision = pid.contains("llama3-2-11b")
358 || pid.contains("llama3-2-90b")
359 || pid.contains("pixtral")
360 || pid.contains("claude-3")
361 || pid.contains("claude-sonnet-4")
362 || pid.contains("claude-opus-4")
363 || pid.contains("claude-haiku-4");
364
365 let display_name = pname.replace("US ", "");
366 let display_name = format!("{} (Bedrock)", display_name.trim());
367
368 let provider_hint = pid
370 .strip_prefix("us.")
371 .unwrap_or(pid)
372 .split('.')
373 .next()
374 .unwrap_or("");
375
376 models.insert(
377 pid.to_string(),
378 ModelInfo {
379 id: pid.to_string(),
380 name: display_name,
381 provider: "bedrock".to_string(),
382 context_window: Self::estimate_context_window(
383 pid,
384 provider_hint,
385 ),
386 max_output_tokens: Some(Self::estimate_max_output(
387 pid,
388 provider_hint,
389 )),
390 supports_vision: vision,
391 supports_tools: true,
392 supports_streaming: true,
393 input_cost_per_million: None,
394 output_cost_per_million: None,
395 },
396 );
397 }
398 }
399 }
400 }
401 }
402
403 let mut result: Vec<ModelInfo> = models.into_values().collect();
404 result.sort_by(|a, b| a.id.cmp(&b.id));
405
406 tracing::info!(
407 provider = "bedrock",
408 model_count = result.len(),
409 "Discovered Bedrock models dynamically"
410 );
411
412 Ok(result)
413 }
414
415 fn estimate_context_window(model_id: &str, provider: &str) -> usize {
417 let id = model_id.to_lowercase();
418 if id.contains("anthropic") || id.contains("claude") {
419 200_000
420 } else if id.contains("nova-pro") || id.contains("nova-lite") || id.contains("nova-premier")
421 {
422 300_000
423 } else if id.contains("nova-micro") || id.contains("nova-2") {
424 128_000
425 } else if id.contains("deepseek") {
426 128_000
427 } else if id.contains("llama4") {
428 256_000
429 } else if id.contains("llama3") {
430 128_000
431 } else if id.contains("mistral-large-3") || id.contains("magistral") {
432 128_000
433 } else if id.contains("mistral") {
434 32_000
435 } else if id.contains("qwen") {
436 128_000
437 } else if id.contains("kimi") {
438 128_000
439 } else if id.contains("jamba") {
440 256_000
441 } else if id.contains("glm") {
442 128_000
443 } else if id.contains("minimax") {
444 128_000
445 } else if id.contains("gemma") {
446 128_000
447 } else if id.contains("cohere") || id.contains("command") {
448 128_000
449 } else if id.contains("nemotron") {
450 128_000
451 } else if provider.to_lowercase().contains("amazon") {
452 128_000
453 } else {
454 32_000
455 }
456 }
457
458 fn estimate_max_output(model_id: &str, _provider: &str) -> usize {
460 let id = model_id.to_lowercase();
461 if id.contains("claude-opus-4-6") {
462 32_000
463 } else if id.contains("claude-opus-4-5") {
464 32_000
465 } else if id.contains("claude-opus-4-1") {
466 32_000
467 } else if id.contains("claude-sonnet-4-5") || id.contains("claude-sonnet-4") || id.contains("claude-3-7") {
468 64_000
469 } else if id.contains("claude-haiku-4-5") {
470 16_384
471 } else if id.contains("claude-opus-4") {
472 32_000
473 } else if id.contains("claude") {
474 8_192
475 } else if id.contains("nova") {
476 5_000
477 } else if id.contains("deepseek") {
478 16_384
479 } else if id.contains("llama4") {
480 16_384
481 } else if id.contains("llama") {
482 4_096
483 } else if id.contains("mistral-large-3") {
484 16_384
485 } else if id.contains("mistral") || id.contains("mixtral") {
486 8_192
487 } else if id.contains("qwen") {
488 8_192
489 } else if id.contains("kimi") {
490 8_192
491 } else if id.contains("jamba") {
492 4_096
493 } else {
494 4_096
495 }
496 }
497
498 fn convert_messages(messages: &[Message]) -> (Vec<Value>, Vec<Value>) {
506 let mut system_parts: Vec<Value> = Vec::new();
507 let mut api_messages: Vec<Value> = Vec::new();
508
509 for msg in messages {
510 match msg.role {
511 Role::System => {
512 let text: String = msg
513 .content
514 .iter()
515 .filter_map(|p| match p {
516 ContentPart::Text { text } => Some(text.clone()),
517 _ => None,
518 })
519 .collect::<Vec<_>>()
520 .join("\n");
521 system_parts.push(json!({"text": text}));
522 }
523 Role::User => {
524 let mut content_parts: Vec<Value> = Vec::new();
525 for part in &msg.content {
526 match part {
527 ContentPart::Text { text } => {
528 if !text.is_empty() {
529 content_parts.push(json!({"text": text}));
530 }
531 }
532 _ => {}
533 }
534 }
535 if !content_parts.is_empty() {
536 api_messages.push(json!({
537 "role": "user",
538 "content": content_parts
539 }));
540 }
541 }
542 Role::Assistant => {
543 let mut content_parts: Vec<Value> = Vec::new();
544 for part in &msg.content {
545 match part {
546 ContentPart::Text { text } => {
547 if !text.is_empty() {
548 content_parts.push(json!({"text": text}));
549 }
550 }
551 ContentPart::ToolCall {
552 id,
553 name,
554 arguments,
555 } => {
556 let input: Value = serde_json::from_str(arguments)
557 .unwrap_or_else(|_| json!({"raw": arguments}));
558 content_parts.push(json!({
559 "toolUse": {
560 "toolUseId": id,
561 "name": name,
562 "input": input
563 }
564 }));
565 }
566 _ => {}
567 }
568 }
569 if content_parts.is_empty() {
570 content_parts.push(json!({"text": ""}));
571 }
572 api_messages.push(json!({
573 "role": "assistant",
574 "content": content_parts
575 }));
576 }
577 Role::Tool => {
578 let mut content_parts: Vec<Value> = Vec::new();
580 for part in &msg.content {
581 if let ContentPart::ToolResult {
582 tool_call_id,
583 content,
584 } = part
585 {
586 content_parts.push(json!({
587 "toolResult": {
588 "toolUseId": tool_call_id,
589 "content": [{"text": content}],
590 "status": "success"
591 }
592 }));
593 }
594 }
595 if !content_parts.is_empty() {
596 api_messages.push(json!({
597 "role": "user",
598 "content": content_parts
599 }));
600 }
601 }
602 }
603 }
604
605 (system_parts, api_messages)
606 }
607
608 fn convert_tools(tools: &[ToolDefinition]) -> Vec<Value> {
609 tools
610 .iter()
611 .map(|t| {
612 json!({
613 "toolSpec": {
614 "name": t.name,
615 "description": t.description,
616 "inputSchema": {
617 "json": t.parameters
618 }
619 }
620 })
621 })
622 .collect()
623 }
624}
625
626#[derive(Debug, Deserialize)]
629#[serde(rename_all = "camelCase")]
630struct ConverseResponse {
631 output: ConverseOutput,
632 #[serde(default)]
633 stop_reason: Option<String>,
634 #[serde(default)]
635 usage: Option<ConverseUsage>,
636}
637
638#[derive(Debug, Deserialize)]
639struct ConverseOutput {
640 message: ConverseMessage,
641}
642
643#[derive(Debug, Deserialize)]
644struct ConverseMessage {
645 #[allow(dead_code)]
646 role: String,
647 content: Vec<ConverseContent>,
648}
649
650#[derive(Debug, Deserialize)]
651#[serde(untagged)]
652enum ConverseContent {
653 Text {
654 text: String,
655 },
656 ToolUse {
657 #[serde(rename = "toolUse")]
658 tool_use: ConverseToolUse,
659 },
660}
661
662#[derive(Debug, Deserialize)]
663#[serde(rename_all = "camelCase")]
664struct ConverseToolUse {
665 tool_use_id: String,
666 name: String,
667 input: Value,
668}
669
670#[derive(Debug, Deserialize)]
671#[serde(rename_all = "camelCase")]
672struct ConverseUsage {
673 #[serde(default)]
674 input_tokens: usize,
675 #[serde(default)]
676 output_tokens: usize,
677 #[serde(default)]
678 total_tokens: usize,
679}
680
681#[derive(Debug, Deserialize)]
682struct BedrockError {
683 message: String,
684}
685
686#[async_trait]
687impl Provider for BedrockProvider {
688 fn name(&self) -> &str {
689 "bedrock"
690 }
691
692 async fn list_models(&self) -> Result<Vec<ModelInfo>> {
693 self.validate_api_key()?;
694 self.discover_models().await
695 }
696
697 async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
698 let model_id = Self::resolve_model_id(&request.model);
699
700 tracing::debug!(
701 provider = "bedrock",
702 model = %model_id,
703 original_model = %request.model,
704 message_count = request.messages.len(),
705 tool_count = request.tools.len(),
706 "Starting Bedrock Converse request"
707 );
708
709 self.validate_api_key()?;
710
711 let (system_parts, messages) = Self::convert_messages(&request.messages);
712 let tools = Self::convert_tools(&request.tools);
713
714 let mut body = json!({
715 "messages": messages,
716 });
717
718 if !system_parts.is_empty() {
719 body["system"] = json!(system_parts);
720 }
721
722 let mut inference_config = json!({});
724 if let Some(max_tokens) = request.max_tokens {
725 inference_config["maxTokens"] = json!(max_tokens);
726 } else {
727 inference_config["maxTokens"] = json!(8192);
728 }
729 if let Some(temp) = request.temperature {
730 inference_config["temperature"] = json!(temp);
731 }
732 if let Some(top_p) = request.top_p {
733 inference_config["topP"] = json!(top_p);
734 }
735 body["inferenceConfig"] = inference_config;
736
737 if !tools.is_empty() {
738 body["toolConfig"] = json!({"tools": tools});
739 }
740
741 let encoded_model_id = model_id.replace(':', "%3A");
743 let url = format!("{}/model/{}/converse", self.base_url(), encoded_model_id);
744 tracing::debug!("Bedrock request URL: {}", url);
745
746 let response = self
747 .client
748 .post(&url)
749 .bearer_auth(&self.api_key)
750 .header("content-type", "application/json")
751 .header("accept", "application/json")
752 .json(&body)
753 .send()
754 .await
755 .context("Failed to send request to Bedrock")?;
756
757 let status = response.status();
758 let text = response
759 .text()
760 .await
761 .context("Failed to read Bedrock response")?;
762
763 if !status.is_success() {
764 if let Ok(err) = serde_json::from_str::<BedrockError>(&text) {
765 anyhow::bail!("Bedrock API error ({}): {}", status, err.message);
766 }
767 anyhow::bail!(
768 "Bedrock API error: {} {}",
769 status,
770 &text[..text.len().min(500)]
771 );
772 }
773
774 let response: ConverseResponse = serde_json::from_str(&text).context(format!(
775 "Failed to parse Bedrock response: {}",
776 &text[..text.len().min(300)]
777 ))?;
778
779 tracing::debug!(
780 stop_reason = ?response.stop_reason,
781 "Received Bedrock response"
782 );
783
784 let mut content = Vec::new();
785 let mut has_tool_calls = false;
786
787 for part in &response.output.message.content {
788 match part {
789 ConverseContent::Text { text } => {
790 if !text.is_empty() {
791 content.push(ContentPart::Text { text: text.clone() });
792 }
793 }
794 ConverseContent::ToolUse { tool_use } => {
795 has_tool_calls = true;
796 content.push(ContentPart::ToolCall {
797 id: tool_use.tool_use_id.clone(),
798 name: tool_use.name.clone(),
799 arguments: serde_json::to_string(&tool_use.input).unwrap_or_default(),
800 });
801 }
802 }
803 }
804
805 let finish_reason = if has_tool_calls {
806 FinishReason::ToolCalls
807 } else {
808 match response.stop_reason.as_deref() {
809 Some("end_turn") | Some("stop") | Some("stop_sequence") => FinishReason::Stop,
810 Some("max_tokens") => FinishReason::Length,
811 Some("tool_use") => FinishReason::ToolCalls,
812 Some("content_filtered") => FinishReason::ContentFilter,
813 _ => FinishReason::Stop,
814 }
815 };
816
817 let usage = response.usage.as_ref();
818
819 Ok(CompletionResponse {
820 message: Message {
821 role: Role::Assistant,
822 content,
823 },
824 usage: Usage {
825 prompt_tokens: usage.map(|u| u.input_tokens).unwrap_or(0),
826 completion_tokens: usage.map(|u| u.output_tokens).unwrap_or(0),
827 total_tokens: usage.map(|u| u.total_tokens).unwrap_or(0),
828 cache_read_tokens: None,
829 cache_write_tokens: None,
830 },
831 finish_reason,
832 })
833 }
834
835 async fn complete_stream(
836 &self,
837 request: CompletionRequest,
838 ) -> Result<futures::stream::BoxStream<'static, StreamChunk>> {
839 let response = self.complete(request).await?;
841 let text = response
842 .message
843 .content
844 .iter()
845 .filter_map(|p| match p {
846 ContentPart::Text { text } => Some(text.clone()),
847 _ => None,
848 })
849 .collect::<Vec<_>>()
850 .join("");
851
852 Ok(Box::pin(futures::stream::once(async move {
853 StreamChunk::Text(text)
854 })))
855 }
856}