1use crate::providers::traits::{
10 ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
11 Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
12};
13use crate::tools::ToolSpec;
14use async_trait::async_trait;
15use hmac::{Hmac, Mac};
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19
20const ENDPOINT_PREFIX: &str = "bedrock-runtime";
22const SIGNING_SERVICE: &str = "bedrock";
24const DEFAULT_REGION: &str = "us-east-1";
25const DEFAULT_MAX_TOKENS: u32 = 4096;
26
27enum BedrockAuth {
31 SigV4(AwsCredentials),
32 BearerToken(String),
33}
34
35struct AwsCredentials {
39 access_key_id: String,
40 secret_access_key: String,
41 session_token: Option<String>,
42 region: String,
43}
44
45impl AwsCredentials {
46 fn from_env() -> anyhow::Result<Self> {
48 let access_key_id = env_required("AWS_ACCESS_KEY_ID")?;
49 let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?;
50
51 let session_token = env_optional("AWS_SESSION_TOKEN");
52
53 let region = env_optional("AWS_REGION")
54 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
55 .unwrap_or_else(|| DEFAULT_REGION.to_string());
56
57 Ok(Self {
58 access_key_id,
59 secret_access_key,
60 session_token,
61 region,
62 })
63 }
64
65 async fn from_imds() -> anyhow::Result<Self> {
67 let client = reqwest::Client::builder()
68 .timeout(std::time::Duration::from_secs(3))
69 .build()?;
70
71 let token = client
73 .put("http://169.254.169.254/latest/api/token")
74 .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
75 .send()
76 .await?
77 .text()
78 .await?;
79
80 let role = client
82 .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
83 .header("X-aws-ec2-metadata-token", &token)
84 .send()
85 .await?
86 .text()
87 .await?;
88 let role = role.trim().to_string();
89 anyhow::ensure!(!role.is_empty(), "No IAM role attached to this instance");
90
91 let creds_url = format!(
93 "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
94 role
95 );
96 let creds_json: serde_json::Value = client
97 .get(&creds_url)
98 .header("X-aws-ec2-metadata-token", &token)
99 .send()
100 .await?
101 .json()
102 .await?;
103
104 let access_key_id = creds_json["AccessKeyId"]
105 .as_str()
106 .ok_or_else(|| anyhow::anyhow!("Missing AccessKeyId in IMDS response"))?
107 .to_string();
108 let secret_access_key = creds_json["SecretAccessKey"]
109 .as_str()
110 .ok_or_else(|| anyhow::anyhow!("Missing SecretAccessKey in IMDS response"))?
111 .to_string();
112 let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
113
114 let region = match client
116 .get("http://169.254.169.254/latest/meta-data/placement/region")
117 .header("X-aws-ec2-metadata-token", &token)
118 .send()
119 .await
120 {
121 Ok(resp) => resp.text().await.unwrap_or_default(),
122 Err(_) => String::new(),
123 };
124 let region = if region.trim().is_empty() {
125 env_optional("AWS_REGION")
126 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
127 .unwrap_or_else(|| DEFAULT_REGION.to_string())
128 } else {
129 region.trim().to_string()
130 };
131
132 tracing::info!(
133 "Loaded AWS credentials from EC2 instance metadata (role: {})",
134 role
135 );
136
137 Ok(Self {
138 access_key_id,
139 secret_access_key,
140 session_token,
141 region,
142 })
143 }
144
145 async fn resolve() -> anyhow::Result<Self> {
147 if let Ok(creds) = Self::from_env() {
148 return Ok(creds);
149 }
150 Self::from_imds().await
151 }
152
153 fn host(&self) -> String {
154 format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region)
155 }
156}
157
158fn env_required(name: &str) -> anyhow::Result<String> {
159 std::env::var(name)
160 .ok()
161 .map(|v| v.trim().to_string())
162 .filter(|v| !v.is_empty())
163 .ok_or_else(|| anyhow::anyhow!("Environment variable {name} is required for Bedrock"))
164}
165
166fn env_optional(name: &str) -> Option<String> {
167 std::env::var(name)
168 .ok()
169 .map(|v| v.trim().to_string())
170 .filter(|v| !v.is_empty())
171}
172
173fn sha256_hex(data: &[u8]) -> String {
176 let mut hasher = Sha256::new();
177 hasher.update(data);
178 hex::encode(hasher.finalize())
179}
180
181fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
182 let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
183 mac.update(data);
184 mac.finalize().into_bytes().to_vec()
185}
186
187fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
189 let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
190 let k_region = hmac_sha256(&k_date, region.as_bytes());
191 let k_service = hmac_sha256(&k_region, service.as_bytes());
192 hmac_sha256(&k_service, b"aws4_request")
193}
194
195fn build_authorization_header(
199 credentials: &AwsCredentials,
200 method: &str,
201 canonical_uri: &str,
202 query_string: &str,
203 headers: &[(String, String)],
204 payload: &[u8],
205 timestamp: &chrono::DateTime<chrono::Utc>,
206) -> String {
207 let date_stamp = timestamp.format("%Y%m%d").to_string();
208 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
209
210 let mut canonical_headers = String::new();
211 for (k, v) in headers {
212 canonical_headers.push_str(k);
213 canonical_headers.push(':');
214 canonical_headers.push_str(v);
215 canonical_headers.push('\n');
216 }
217
218 let signed_headers: String = headers
219 .iter()
220 .map(|(k, _)| k.as_str())
221 .collect::<Vec<_>>()
222 .join(";");
223
224 let payload_hash = sha256_hex(payload);
225
226 let canonical_request = format!(
227 "{method}\n{canonical_uri}\n{query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
228 );
229
230 let credential_scope = format!(
231 "{date_stamp}/{}/{SIGNING_SERVICE}/aws4_request",
232 credentials.region
233 );
234
235 let string_to_sign = format!(
236 "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
237 sha256_hex(canonical_request.as_bytes())
238 );
239
240 let signing_key = derive_signing_key(
241 &credentials.secret_access_key,
242 &date_stamp,
243 &credentials.region,
244 SIGNING_SERVICE,
245 );
246
247 let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()));
248
249 format!(
250 "AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
251 credentials.access_key_id
252 )
253}
254
255#[derive(Debug, Serialize)]
258#[serde(rename_all = "camelCase")]
259struct ConverseRequest {
260 messages: Vec<ConverseMessage>,
261 #[serde(skip_serializing_if = "Option::is_none")]
262 system: Option<Vec<SystemBlock>>,
263 #[serde(skip_serializing_if = "Option::is_none")]
264 inference_config: Option<InferenceConfig>,
265 #[serde(skip_serializing_if = "Option::is_none")]
266 tool_config: Option<ToolConfig>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270struct ConverseMessage {
271 role: String,
272 content: Vec<ContentBlock>,
273}
274
275#[derive(Debug, Serialize, Deserialize)]
282#[serde(untagged)]
283enum ContentBlock {
284 Text(TextBlock),
285 ToolUse(ToolUseWrapper),
286 ToolResult(ToolResultWrapper),
287 CachePointBlock(CachePointWrapper),
288 Image(ImageWrapper),
289}
290
291#[derive(Debug, Serialize, Deserialize)]
292struct ImageWrapper {
293 image: ImageBlock,
294}
295
296#[derive(Debug, Serialize, Deserialize)]
297struct ImageBlock {
298 format: String,
299 source: ImageSource,
300}
301
302#[derive(Debug, Serialize, Deserialize)]
303#[serde(rename_all = "camelCase")]
304struct ImageSource {
305 bytes: String,
306}
307
308#[derive(Debug, Serialize, Deserialize)]
309struct TextBlock {
310 text: String,
311}
312
313#[derive(Debug, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315struct ToolUseWrapper {
316 tool_use: ToolUseBlock,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(rename_all = "camelCase")]
321struct ToolUseBlock {
322 tool_use_id: String,
323 name: String,
324 input: serde_json::Value,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(rename_all = "camelCase")]
329struct ToolResultWrapper {
330 tool_result: ToolResultBlock,
331}
332
333#[derive(Debug, Serialize, Deserialize)]
334#[serde(rename_all = "camelCase")]
335struct ToolResultBlock {
336 tool_use_id: String,
337 content: Vec<ToolResultContent>,
338 status: String,
339}
340
341#[derive(Debug, Serialize, Deserialize)]
342#[serde(rename_all = "camelCase")]
343struct CachePointWrapper {
344 cache_point: CachePoint,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
348struct ToolResultContent {
349 text: String,
350}
351
352#[derive(Debug, Serialize, Deserialize)]
353struct CachePoint {
354 #[serde(rename = "type")]
355 cache_type: String,
356}
357
358impl CachePoint {
359 fn default_cache() -> Self {
360 Self {
361 cache_type: "default".to_string(),
362 }
363 }
364}
365
366#[derive(Debug, Serialize)]
368#[serde(untagged)]
369enum SystemBlock {
370 Text(TextBlock),
371 CachePoint(CachePointWrapper),
372}
373
374#[derive(Debug, Serialize)]
375#[serde(rename_all = "camelCase")]
376struct InferenceConfig {
377 max_tokens: u32,
378 temperature: f64,
379}
380
381#[derive(Debug, Serialize)]
382#[serde(rename_all = "camelCase")]
383struct ToolConfig {
384 tools: Vec<ToolDefinition>,
385}
386
387#[derive(Debug, Serialize)]
388#[serde(rename_all = "camelCase")]
389struct ToolDefinition {
390 tool_spec: ToolSpecDef,
391}
392
393#[derive(Debug, Serialize)]
394#[serde(rename_all = "camelCase")]
395struct ToolSpecDef {
396 name: String,
397 description: String,
398 input_schema: InputSchema,
399}
400
401#[derive(Debug, Serialize)]
402struct InputSchema {
403 json: serde_json::Value,
404}
405
406#[derive(Debug, Deserialize)]
409#[serde(rename_all = "camelCase")]
410struct ConverseResponse {
411 #[serde(default)]
412 output: Option<ConverseOutput>,
413 #[serde(default)]
414 #[allow(dead_code)]
415 stop_reason: Option<String>,
416 #[serde(default)]
417 usage: Option<BedrockUsage>,
418}
419
420#[derive(Debug, Deserialize)]
421#[serde(rename_all = "camelCase")]
422struct BedrockUsage {
423 #[serde(default)]
424 input_tokens: Option<u64>,
425 #[serde(default)]
426 output_tokens: Option<u64>,
427}
428
429#[derive(Debug, Deserialize)]
430struct ConverseOutput {
431 #[serde(default)]
432 message: Option<ConverseOutputMessage>,
433}
434
435#[derive(Debug, Deserialize)]
436struct ConverseOutputMessage {
437 #[allow(dead_code)]
438 role: String,
439 content: Vec<ResponseContentBlock>,
440}
441
442#[derive(Debug, Deserialize)]
449#[serde(untagged)]
450enum ResponseContentBlock {
451 ToolUse(ResponseToolUseWrapper),
452 Text(TextBlock),
453 Other(serde_json::Value),
454}
455
456#[derive(Debug, Deserialize)]
457#[serde(rename_all = "camelCase")]
458struct ResponseToolUseWrapper {
459 tool_use: ToolUseBlock,
460}
461
462pub struct BedrockProvider {
465 auth: Option<BedrockAuth>,
466 max_tokens: u32,
467}
468
469impl BedrockProvider {
470 pub fn new() -> Self {
471 if let Some(token) = env_optional("BEDROCK_API_KEY") {
473 return Self {
474 auth: Some(BedrockAuth::BearerToken(token)),
475 max_tokens: DEFAULT_MAX_TOKENS,
476 };
477 }
478 Self {
479 auth: AwsCredentials::from_env().ok().map(BedrockAuth::SigV4),
480 max_tokens: DEFAULT_MAX_TOKENS,
481 }
482 }
483
484 pub async fn new_async() -> Self {
485 if let Some(token) = env_optional("BEDROCK_API_KEY") {
487 return Self {
488 auth: Some(BedrockAuth::BearerToken(token)),
489 max_tokens: DEFAULT_MAX_TOKENS,
490 };
491 }
492 let auth = AwsCredentials::resolve().await.ok().map(BedrockAuth::SigV4);
493 Self {
494 auth,
495 max_tokens: DEFAULT_MAX_TOKENS,
496 }
497 }
498
499 pub fn with_bearer_token(token: &str) -> Self {
501 Self {
502 auth: Some(BedrockAuth::BearerToken(token.to_string())),
503 max_tokens: DEFAULT_MAX_TOKENS,
504 }
505 }
506
507 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
509 self.max_tokens = max_tokens;
510 self
511 }
512
513 fn http_client(&self) -> Client {
514 crate::config::build_runtime_proxy_client_with_timeouts("provider.bedrock", 120, 10)
515 }
516
517 fn encode_model_path(model_id: &str) -> String {
521 model_id.replace(':', "%3A")
522 }
523
524 fn resolve_region() -> String {
526 env_optional("AWS_REGION")
527 .or_else(|| env_optional("AWS_DEFAULT_REGION"))
528 .unwrap_or_else(|| DEFAULT_REGION.to_string())
529 }
530
531 fn endpoint_url(region: &str, model_id: &str) -> String {
533 format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse")
534 }
535
536 fn canonical_uri(model_id: &str) -> String {
540 let encoded = Self::encode_model_path(model_id);
541 format!("/model/{encoded}/converse")
542 }
543
544 fn require_auth(&self) -> anyhow::Result<&BedrockAuth> {
545 self.auth.as_ref().ok_or_else(|| {
546 anyhow::anyhow!(
547 "AWS Bedrock credentials not set. Set BEDROCK_API_KEY for Bearer \
548 token auth, or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY for \
549 SigV4 auth, or run on an EC2 instance with an IAM role attached."
550 )
551 })
552 }
553
554 async fn resolve_auth(&self) -> anyhow::Result<BedrockAuth> {
556 if let Some(ref auth) = self.auth {
558 match auth {
559 BedrockAuth::BearerToken(token) => {
560 return Ok(BedrockAuth::BearerToken(token.clone()));
561 }
562 BedrockAuth::SigV4(_) => {
563 }
565 }
566 }
567 if let Some(token) = env_optional("BEDROCK_API_KEY") {
569 return Ok(BedrockAuth::BearerToken(token));
570 }
571 if let Ok(creds) = AwsCredentials::from_env() {
573 return Ok(BedrockAuth::SigV4(creds));
574 }
575 Ok(BedrockAuth::SigV4(AwsCredentials::from_imds().await?))
576 }
577
578 fn should_cache_system(text: &str) -> bool {
582 text.len() > 3072
583 }
584
585 fn should_cache_conversation(messages: &[ChatMessage]) -> bool {
587 messages.iter().filter(|m| m.role != "system").count() > 4
588 }
589
590 fn convert_messages(
593 messages: &[ChatMessage],
594 ) -> (Option<Vec<SystemBlock>>, Vec<ConverseMessage>) {
595 let mut system_blocks = Vec::new();
596 let mut converse_messages = Vec::new();
597
598 for msg in messages {
599 match msg.role.as_str() {
600 "system" => {
601 if system_blocks.is_empty() {
602 system_blocks.push(SystemBlock::Text(TextBlock {
603 text: msg.content.clone(),
604 }));
605 }
606 }
607 "assistant" => {
608 if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
609 converse_messages.push(ConverseMessage {
610 role: "assistant".to_string(),
611 content: blocks,
612 });
613 } else {
614 converse_messages.push(ConverseMessage {
615 role: "assistant".to_string(),
616 content: vec![ContentBlock::Text(TextBlock {
617 text: msg.content.clone(),
618 })],
619 });
620 }
621 }
622 "tool" => {
623 let tool_result_msg = Self::parse_tool_result_message(&msg.content)
624 .unwrap_or_else(|| {
625 let tool_use_id = Self::extract_tool_call_id(&msg.content)
629 .or_else(|| Self::last_pending_tool_use_id(&converse_messages))
630 .unwrap_or_else(|| "unknown".to_string());
631
632 tracing::warn!(
633 "Failed to parse tool result message, creating error \
634 toolResult for tool_use_id={}",
635 tool_use_id
636 );
637
638 ConverseMessage {
639 role: "user".to_string(),
640 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
641 tool_result: ToolResultBlock {
642 tool_use_id,
643 content: vec![ToolResultContent {
644 text: msg.content.clone(),
645 }],
646 status: "error".to_string(),
647 },
648 })],
649 }
650 });
651
652 if let Some(last) = converse_messages.last_mut() {
656 if last.role == "user"
657 && last
658 .content
659 .iter()
660 .all(|b| matches!(b, ContentBlock::ToolResult(_)))
661 {
662 last.content.extend(tool_result_msg.content);
663 continue;
664 }
665 }
666 converse_messages.push(tool_result_msg);
667 }
668 _ => {
669 let content_blocks = Self::parse_user_content_blocks(&msg.content);
670 converse_messages.push(ConverseMessage {
671 role: "user".to_string(),
672 content: content_blocks,
673 });
674 }
675 }
676 }
677
678 let system = if system_blocks.is_empty() {
679 None
680 } else {
681 Some(system_blocks)
682 };
683 (system, converse_messages)
684 }
685
686 fn extract_tool_call_id(content: &str) -> Option<String> {
688 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
689 value
690 .get("tool_call_id")
691 .or_else(|| value.get("tool_use_id"))
692 .or_else(|| value.get("toolUseId"))
693 .and_then(serde_json::Value::as_str)
694 .map(String::from)
695 }
696
697 fn last_pending_tool_use_id(converse_messages: &[ConverseMessage]) -> Option<String> {
703 let last_assistant = converse_messages
704 .iter()
705 .rev()
706 .find(|m| m.role == "assistant")?;
707
708 let tool_use_ids: Vec<&str> = last_assistant
709 .content
710 .iter()
711 .filter_map(|b| match b {
712 ContentBlock::ToolUse(wrapper) => Some(wrapper.tool_use.tool_use_id.as_str()),
713 _ => None,
714 })
715 .collect();
716
717 let answered_ids: Vec<&str> = converse_messages
718 .iter()
719 .rev()
720 .take_while(|m| m.role == "user")
721 .flat_map(|m| m.content.iter())
722 .filter_map(|b| match b {
723 ContentBlock::ToolResult(wrapper) => Some(wrapper.tool_result.tool_use_id.as_str()),
724 _ => None,
725 })
726 .collect();
727
728 tool_use_ids
729 .into_iter()
730 .find(|id| !answered_ids.contains(id))
731 .map(String::from)
732 }
733
734 fn parse_user_content_blocks(content: &str) -> Vec<ContentBlock> {
736 let mut blocks: Vec<ContentBlock> = Vec::new();
737 let mut remaining = content;
738 let has_image = content.contains("[IMAGE:");
739 tracing::info!(
740 "parse_user_content_blocks called, len={}, has_image={}",
741 content.len(),
742 has_image
743 );
744
745 while let Some(start) = remaining.find("[IMAGE:") {
746 let text_before = &remaining[..start];
748 if !text_before.trim().is_empty() {
749 blocks.push(ContentBlock::Text(TextBlock {
750 text: text_before.to_string(),
751 }));
752 }
753
754 let after = &remaining[start + 7..]; if let Some(end) = after.find(']') {
756 let src = &after[..end];
757 remaining = &after[end + 1..];
758
759 if let Some(rest) = src.strip_prefix("data:") {
761 if let Some(semi) = rest.find(';') {
762 let mime = &rest[..semi];
763 let after_semi = &rest[semi + 1..];
764 if let Some(b64) = after_semi.strip_prefix("base64,") {
765 let format = match mime {
766 "image/png" => "png",
767 "image/gif" => "gif",
768 "image/webp" => "webp",
769 _ => "jpeg",
770 };
771 blocks.push(ContentBlock::Image(ImageWrapper {
772 image: ImageBlock {
773 format: format.to_string(),
774 source: ImageSource {
775 bytes: b64.to_string(),
776 },
777 },
778 }));
779 continue;
780 }
781 }
782 }
783 blocks.push(ContentBlock::Text(TextBlock {
785 text: format!("[image: {}]", src),
786 }));
787 } else {
788 blocks.push(ContentBlock::Text(TextBlock {
790 text: remaining.to_string(),
791 }));
792 break;
793 }
794 }
795
796 if !remaining.trim().is_empty() {
798 blocks.push(ContentBlock::Text(TextBlock {
799 text: remaining.to_string(),
800 }));
801 }
802
803 if blocks.is_empty() {
804 blocks.push(ContentBlock::Text(TextBlock {
805 text: content.to_string(),
806 }));
807 }
808
809 blocks
810 }
811
812 fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<ContentBlock>> {
814 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
815 let tool_calls = value
816 .get("tool_calls")
817 .and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
818
819 let mut blocks = Vec::new();
820 if let Some(text) = value
821 .get("content")
822 .and_then(serde_json::Value::as_str)
823 .map(str::trim)
824 .filter(|t| !t.is_empty())
825 {
826 blocks.push(ContentBlock::Text(TextBlock {
827 text: text.to_string(),
828 }));
829 }
830 for call in tool_calls {
831 let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
832 .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
833 blocks.push(ContentBlock::ToolUse(ToolUseWrapper {
834 tool_use: ToolUseBlock {
835 tool_use_id: call.id,
836 name: call.name,
837 input,
838 },
839 }));
840 }
841 Some(blocks)
842 }
843
844 fn parse_tool_result_message(content: &str) -> Option<ConverseMessage> {
846 let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
847 let tool_use_id = value
848 .get("tool_call_id")
849 .or_else(|| value.get("tool_use_id"))
850 .or_else(|| value.get("toolUseId"))
851 .and_then(serde_json::Value::as_str)?
852 .to_string();
853 let result = value
854 .get("content")
855 .and_then(serde_json::Value::as_str)
856 .unwrap_or("")
857 .to_string();
858 Some(ConverseMessage {
859 role: "user".to_string(),
860 content: vec![ContentBlock::ToolResult(ToolResultWrapper {
861 tool_result: ToolResultBlock {
862 tool_use_id,
863 content: vec![ToolResultContent { text: result }],
864 status: "success".to_string(),
865 },
866 })],
867 })
868 }
869
870 fn convert_tools_to_converse(tools: Option<&[ToolSpec]>) -> Option<ToolConfig> {
873 let items = tools?;
874 if items.is_empty() {
875 return None;
876 }
877 let tool_defs: Vec<ToolDefinition> = items
878 .iter()
879 .map(|tool| ToolDefinition {
880 tool_spec: ToolSpecDef {
881 name: tool.name.clone(),
882 description: tool.description.clone(),
883 input_schema: InputSchema {
884 json: tool.parameters.clone(),
885 },
886 },
887 })
888 .collect();
889 Some(ToolConfig { tools: tool_defs })
890 }
891
892 fn parse_converse_response(response: ConverseResponse) -> ProviderChatResponse {
895 let mut text_parts = Vec::new();
896 let mut tool_calls = Vec::new();
897
898 let usage = response.usage.map(|u| TokenUsage {
899 input_tokens: u.input_tokens,
900 output_tokens: u.output_tokens,
901 cached_input_tokens: None,
902 });
903
904 if let Some(output) = response.output {
905 if let Some(message) = output.message {
906 for block in message.content {
907 match block {
908 ResponseContentBlock::Text(tb) => {
909 let trimmed = tb.text.trim().to_string();
910 if !trimmed.is_empty() {
911 text_parts.push(trimmed);
912 }
913 }
914 ResponseContentBlock::ToolUse(wrapper) => {
915 if !wrapper.tool_use.name.is_empty() {
916 tool_calls.push(ProviderToolCall {
917 id: wrapper.tool_use.tool_use_id,
918 name: wrapper.tool_use.name,
919 arguments: wrapper.tool_use.input.to_string(),
920 });
921 }
922 }
923 ResponseContentBlock::Other(_) => {}
924 }
925 }
926 }
927 }
928
929 ProviderChatResponse {
930 text: if text_parts.is_empty() {
931 None
932 } else {
933 Some(text_parts.join("\n"))
934 },
935 tool_calls,
936 usage,
937 reasoning_content: None,
938 }
939 }
940
941 async fn send_converse_request(
944 &self,
945 auth: &BedrockAuth,
946 model: &str,
947 request_body: &ConverseRequest,
948 ) -> anyhow::Result<ConverseResponse> {
949 let payload = serde_json::to_vec(request_body)?;
950
951 if let Ok(debug_val) = serde_json::from_slice::<serde_json::Value>(&payload) {
953 if let Some(msgs) = debug_val.get("messages").and_then(|m| m.as_array()) {
954 for msg in msgs {
955 if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
956 for block in content {
957 if block.get("image").is_some() {
958 let mut b = block.clone();
959 if let Some(img) = b.get_mut("image") {
960 if let Some(src) = img.get_mut("source") {
961 if let Some(bytes) = src.get_mut("bytes") {
962 if let Some(s) = bytes.as_str() {
963 *bytes = serde_json::json!(format!(
964 "<base64 {} chars>",
965 s.len()
966 ));
967 }
968 }
969 }
970 }
971 tracing::info!(
972 "Bedrock image block: {}",
973 serde_json::to_string(&b).unwrap_or_default()
974 );
975 }
976 }
977 }
978 }
979 }
980 }
981
982 let response: reqwest::Response = match auth {
983 BedrockAuth::BearerToken(token) => {
984 let region = Self::resolve_region();
985 let url = Self::endpoint_url(®ion, model);
986
987 self.http_client()
988 .post(&url)
989 .header("content-type", "application/json")
990 .header("Authorization", format!("Bearer {token}"))
991 .body(payload)
992 .send()
993 .await?
994 }
995 BedrockAuth::SigV4(credentials) => {
996 let url = Self::endpoint_url(&credentials.region, model);
997 let canonical_uri = Self::canonical_uri(model);
998 let now = chrono::Utc::now();
999 let host = credentials.host();
1000 let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
1001
1002 let mut headers_to_sign = vec![
1003 ("content-type".to_string(), "application/json".to_string()),
1004 ("host".to_string(), host),
1005 ("x-amz-date".to_string(), amz_date.clone()),
1006 ];
1007 if let Some(ref session_token) = credentials.session_token {
1008 headers_to_sign
1009 .push(("x-amz-security-token".to_string(), session_token.clone()));
1010 }
1011 headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
1012
1013 let authorization = build_authorization_header(
1014 credentials,
1015 "POST",
1016 &canonical_uri,
1017 "",
1018 &headers_to_sign,
1019 &payload,
1020 &now,
1021 );
1022
1023 let mut request = self
1024 .http_client()
1025 .post(&url)
1026 .header("content-type", "application/json")
1027 .header("x-amz-date", &amz_date)
1028 .header("authorization", &authorization);
1029
1030 if let Some(ref session_token) = credentials.session_token {
1031 request = request.header("x-amz-security-token", session_token);
1032 }
1033
1034 request.body(payload).send().await?
1035 }
1036 };
1037
1038 if !response.status().is_success() {
1039 return Err(super::api_error("Bedrock", response).await);
1040 }
1041
1042 let converse_response: ConverseResponse = response.json().await?;
1043 Ok(converse_response)
1044 }
1045}
1046
1047#[async_trait]
1050impl Provider for BedrockProvider {
1051 fn capabilities(&self) -> ProviderCapabilities {
1052 ProviderCapabilities {
1053 native_tool_calling: true,
1054 vision: true,
1055 prompt_caching: false,
1056 }
1057 }
1058
1059 fn supports_native_tools(&self) -> bool {
1060 true
1061 }
1062
1063 fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
1064 let tool_values: Vec<serde_json::Value> = tools
1065 .iter()
1066 .map(|t| {
1067 serde_json::json!({
1068 "toolSpec": {
1069 "name": t.name,
1070 "description": t.description,
1071 "inputSchema": { "json": t.parameters }
1072 }
1073 })
1074 })
1075 .collect();
1076 ToolsPayload::Anthropic { tools: tool_values }
1077 }
1078
1079 async fn chat_with_system(
1080 &self,
1081 system_prompt: Option<&str>,
1082 message: &str,
1083 model: &str,
1084 temperature: f64,
1085 ) -> anyhow::Result<String> {
1086 let auth = self.resolve_auth().await?;
1087
1088 let system = system_prompt.map(|text| {
1089 let mut blocks = vec![SystemBlock::Text(TextBlock {
1090 text: text.to_string(),
1091 })];
1092 if Self::should_cache_system(text) {
1093 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1094 cache_point: CachePoint::default_cache(),
1095 }));
1096 }
1097 blocks
1098 });
1099
1100 let request = ConverseRequest {
1101 system,
1102 messages: vec![ConverseMessage {
1103 role: "user".to_string(),
1104 content: Self::parse_user_content_blocks(message),
1105 }],
1106 inference_config: Some(InferenceConfig {
1107 max_tokens: self.max_tokens,
1108 temperature,
1109 }),
1110 tool_config: None,
1111 };
1112
1113 let response = self.send_converse_request(&auth, model, &request).await?;
1114
1115 Self::parse_converse_response(response)
1116 .text
1117 .ok_or_else(|| anyhow::anyhow!("No response from Bedrock"))
1118 }
1119
1120 async fn chat(
1121 &self,
1122 request: ProviderChatRequest<'_>,
1123 model: &str,
1124 temperature: f64,
1125 ) -> anyhow::Result<ProviderChatResponse> {
1126 let auth = self.resolve_auth().await?;
1127
1128 let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
1129
1130 let system = system_blocks.map(|mut blocks| {
1132 let has_large_system = blocks
1133 .iter()
1134 .any(|b| matches!(b, SystemBlock::Text(tb) if Self::should_cache_system(&tb.text)));
1135 if has_large_system {
1136 blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1137 cache_point: CachePoint::default_cache(),
1138 }));
1139 }
1140 blocks
1141 });
1142
1143 if Self::should_cache_conversation(request.messages) {
1145 if let Some(last_msg) = converse_messages.last_mut() {
1146 last_msg
1147 .content
1148 .push(ContentBlock::CachePointBlock(CachePointWrapper {
1149 cache_point: CachePoint::default_cache(),
1150 }));
1151 }
1152 }
1153
1154 let tool_config = Self::convert_tools_to_converse(request.tools);
1155
1156 let converse_request = ConverseRequest {
1157 system,
1158 messages: converse_messages,
1159 inference_config: Some(InferenceConfig {
1160 max_tokens: self.max_tokens,
1161 temperature,
1162 }),
1163 tool_config,
1164 };
1165
1166 let response = self
1167 .send_converse_request(&auth, model, &converse_request)
1168 .await?;
1169
1170 Ok(Self::parse_converse_response(response))
1171 }
1172
1173 async fn warmup(&self) -> anyhow::Result<()> {
1174 let region = match self.auth {
1175 Some(BedrockAuth::SigV4(ref creds)) => creds.region.clone(),
1176 Some(BedrockAuth::BearerToken(_)) => Self::resolve_region(),
1177 None => return Ok(()),
1178 };
1179 let url = format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/");
1180 let _ = self.http_client().get(&url).send().await;
1181 Ok(())
1182 }
1183}
1184
1185#[cfg(test)]
1188mod tests {
1189 use super::*;
1190 use crate::providers::traits::ChatMessage;
1191
1192 struct EnvGuard {
1194 key: String,
1195 original: Option<String>,
1196 }
1197
1198 impl EnvGuard {
1199 fn set(key: &str, value: Option<&str>) -> Self {
1200 let original = std::env::var(key).ok();
1201 match value {
1202 Some(v) => unsafe { std::env::set_var(key, v) },
1204 None => unsafe { std::env::remove_var(key) },
1206 }
1207 Self {
1208 key: key.to_string(),
1209 original,
1210 }
1211 }
1212 }
1213
1214 impl Drop for EnvGuard {
1215 fn drop(&mut self) {
1216 match &self.original {
1217 Some(v) => unsafe { std::env::set_var(&self.key, v) },
1219 None => unsafe { std::env::remove_var(&self.key) },
1221 }
1222 }
1223 }
1224
1225 #[test]
1228 fn sha256_hex_empty_string() {
1229 assert_eq!(
1231 sha256_hex(b""),
1232 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
1233 );
1234 }
1235
1236 #[test]
1237 fn sha256_hex_known_input() {
1238 assert_eq!(
1240 sha256_hex(b"hello"),
1241 "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
1242 );
1243 }
1244
1245 const TEST_VECTOR_SECRET: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
1247
1248 #[test]
1249 fn hmac_sha256_known_input() {
1250 let test_key: &[u8] = b"key";
1251 let result = hmac_sha256(test_key, b"message");
1252 assert_eq!(
1253 hex::encode(&result),
1254 "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
1255 );
1256 }
1257
1258 #[test]
1259 fn derive_signing_key_structure() {
1260 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1262 assert_eq!(key.len(), 32);
1263 }
1264
1265 #[test]
1266 fn derive_signing_key_known_test_vector() {
1267 let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1269 assert_eq!(
1270 hex::encode(&key),
1271 "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"
1272 );
1273 }
1274
1275 #[test]
1276 fn build_authorization_header_format() {
1277 let credentials = AwsCredentials {
1278 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1279 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1280 session_token: None,
1281 region: "us-east-1".to_string(),
1282 };
1283
1284 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1285 .unwrap()
1286 .with_timezone(&chrono::Utc);
1287
1288 let headers = vec![
1289 ("content-type".to_string(), "application/json".to_string()),
1290 (
1291 "host".to_string(),
1292 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1293 ),
1294 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1295 ];
1296
1297 let auth = build_authorization_header(
1298 &credentials,
1299 "POST",
1300 "/model/anthropic.claude-3-sonnet/converse",
1301 "",
1302 &headers,
1303 b"{}",
1304 ×tamp,
1305 );
1306
1307 assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/"));
1309 assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
1310 assert!(auth.contains("Signature="));
1311 assert!(auth.contains("/us-east-1/bedrock/aws4_request"));
1312 }
1313
1314 #[test]
1315 fn build_authorization_header_includes_security_token_in_signed_headers() {
1316 let credentials = AwsCredentials {
1317 access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1318 secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1319 session_token: Some("session-token-value".to_string()),
1320 region: "us-east-1".to_string(),
1321 };
1322
1323 let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1324 .unwrap()
1325 .with_timezone(&chrono::Utc);
1326
1327 let headers = vec![
1328 ("content-type".to_string(), "application/json".to_string()),
1329 (
1330 "host".to_string(),
1331 "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1332 ),
1333 ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1334 (
1335 "x-amz-security-token".to_string(),
1336 "session-token-value".to_string(),
1337 ),
1338 ];
1339
1340 let auth = build_authorization_header(
1341 &credentials,
1342 "POST",
1343 "/model/test-model/converse",
1344 "",
1345 &headers,
1346 b"{}",
1347 ×tamp,
1348 );
1349
1350 assert!(auth.contains("x-amz-security-token"));
1351 }
1352
1353 #[test]
1356 fn credentials_host_formats_correctly() {
1357 let creds = AwsCredentials {
1358 access_key_id: "AKID".to_string(),
1359 secret_access_key: "secret".to_string(),
1360 session_token: None,
1361 region: "us-west-2".to_string(),
1362 };
1363 assert_eq!(creds.host(), "bedrock-runtime.us-west-2.amazonaws.com");
1364 }
1365
1366 #[test]
1369 fn creates_without_credentials() {
1370 let _provider = BedrockProvider::new();
1372 }
1373
1374 #[tokio::test]
1375 async fn chat_fails_without_credentials() {
1376 let provider = BedrockProvider {
1377 auth: None,
1378 max_tokens: DEFAULT_MAX_TOKENS,
1379 };
1380 let result = provider
1381 .chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7)
1382 .await;
1383 assert!(result.is_err());
1384 let err = result.unwrap_err().to_string();
1385 assert!(
1386 err.contains("credentials not set")
1387 || err.contains("169.254.169.254")
1388 || err.to_lowercase().contains("credential")
1389 || err.to_lowercase().contains("builder error"),
1390 "Expected missing-credentials style error, got: {err}"
1391 );
1392 }
1393
1394 #[test]
1397 fn creates_with_bearer_token() {
1398 let provider = BedrockProvider::with_bearer_token("test-api-key");
1399 assert!(provider.auth.is_some());
1400 assert!(
1401 matches!(provider.auth, Some(BedrockAuth::BearerToken(ref t)) if t == "test-api-key")
1402 );
1403 }
1404
1405 #[test]
1406 fn bearer_token_from_env() {
1407 let _guard = EnvGuard::set("BEDROCK_API_KEY", Some("env-bearer-token"));
1408 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1410 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1411
1412 let provider = BedrockProvider::new();
1413 assert!(matches!(
1414 provider.auth,
1415 Some(BedrockAuth::BearerToken(ref t)) if t == "env-bearer-token"
1416 ));
1417 }
1418
1419 #[test]
1420 fn bearer_token_precedence() {
1421 let _bearer_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bearer-key"));
1422 let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("AKIAEXAMPLE"));
1423 let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret"));
1424
1425 let provider = BedrockProvider::new();
1426 assert!(matches!(
1428 provider.auth,
1429 Some(BedrockAuth::BearerToken(ref t)) if t == "bearer-key"
1430 ));
1431 }
1432
1433 #[test]
1436 fn endpoint_url_formats_correctly() {
1437 let url = BedrockProvider::endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6");
1438 assert_eq!(
1439 url,
1440 "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse"
1441 );
1442 }
1443
1444 #[test]
1445 fn endpoint_url_keeps_raw_colon() {
1446 let url =
1448 BedrockProvider::endpoint_url("us-west-2", "anthropic.claude-3-5-haiku-20241022-v1:0");
1449 assert!(url.contains("/model/anthropic.claude-3-5-haiku-20241022-v1:0/converse"));
1450 }
1451
1452 #[test]
1453 fn canonical_uri_encodes_colon() {
1454 let uri = BedrockProvider::canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0");
1456 assert_eq!(
1457 uri,
1458 "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse"
1459 );
1460 }
1461
1462 #[test]
1463 fn canonical_uri_no_colon_unchanged() {
1464 let uri = BedrockProvider::canonical_uri("anthropic.claude-sonnet-4-6");
1465 assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse");
1466 }
1467
1468 #[test]
1471 fn convert_messages_system_extracted() {
1472 let messages = vec![
1473 ChatMessage::system("You are helpful"),
1474 ChatMessage::user("Hello"),
1475 ];
1476 let (system, msgs) = BedrockProvider::convert_messages(&messages);
1477 assert!(system.is_some());
1478 let system_blocks = system.unwrap();
1479 assert_eq!(system_blocks.len(), 1);
1480 assert_eq!(msgs.len(), 1);
1481 assert_eq!(msgs[0].role, "user");
1482 }
1483
1484 #[test]
1485 fn convert_messages_user_and_assistant() {
1486 let messages = vec![
1487 ChatMessage::user("Hello"),
1488 ChatMessage::assistant("Hi there"),
1489 ];
1490 let (system, msgs) = BedrockProvider::convert_messages(&messages);
1491 assert!(system.is_none());
1492 assert_eq!(msgs.len(), 2);
1493 assert_eq!(msgs[0].role, "user");
1494 assert_eq!(msgs[1].role, "assistant");
1495 }
1496
1497 #[test]
1498 fn convert_messages_tool_role_to_tool_result() {
1499 let tool_json = r#"{"tool_call_id": "call_123", "content": "Result data"}"#;
1500 let messages = vec![ChatMessage::tool(tool_json)];
1501 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1502 assert_eq!(msgs.len(), 1);
1503 assert_eq!(msgs[0].role, "user");
1504 assert!(matches!(msgs[0].content[0], ContentBlock::ToolResult(_)));
1505 }
1506
1507 #[test]
1508 fn convert_messages_assistant_tool_calls_parsed() {
1509 let tool_call_json = r#"{"content": "Let me check", "tool_calls": [{"id": "call_1", "name": "shell", "arguments": "{\"command\":\"ls\"}"}]}"#;
1510 let messages = vec![ChatMessage::assistant(tool_call_json)];
1511 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1512 assert_eq!(msgs.len(), 1);
1513 assert_eq!(msgs[0].role, "assistant");
1514 assert_eq!(msgs[0].content.len(), 2);
1515 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1516 assert!(matches!(msgs[0].content[1], ContentBlock::ToolUse(_)));
1517 }
1518
1519 #[test]
1520 fn convert_messages_plain_assistant_text() {
1521 let messages = vec![ChatMessage::assistant("Just text")];
1522 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1523 assert_eq!(msgs.len(), 1);
1524 assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1525 }
1526
1527 #[test]
1530 fn should_cache_system_small_prompt() {
1531 assert!(!BedrockProvider::should_cache_system("Short prompt"));
1532 }
1533
1534 #[test]
1535 fn should_cache_system_large_prompt() {
1536 let large = "a".repeat(3073);
1537 assert!(BedrockProvider::should_cache_system(&large));
1538 }
1539
1540 #[test]
1541 fn should_cache_system_boundary() {
1542 assert!(!BedrockProvider::should_cache_system(&"a".repeat(3072)));
1543 assert!(BedrockProvider::should_cache_system(&"a".repeat(3073)));
1544 }
1545
1546 #[test]
1547 fn should_cache_conversation_short() {
1548 let messages = vec![
1549 ChatMessage::system("System"),
1550 ChatMessage::user("Hello"),
1551 ChatMessage::assistant("Hi"),
1552 ];
1553 assert!(!BedrockProvider::should_cache_conversation(&messages));
1554 }
1555
1556 #[test]
1557 fn should_cache_conversation_long() {
1558 let mut messages = vec![ChatMessage::system("System")];
1559 for i in 0..5 {
1560 messages.push(ChatMessage {
1561 role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
1562 content: format!("Message {i}"),
1563 });
1564 }
1565 assert!(BedrockProvider::should_cache_conversation(&messages));
1566 }
1567
1568 #[test]
1571 fn convert_tools_to_converse_formats_correctly() {
1572 let tools = vec![ToolSpec {
1573 name: "shell".to_string(),
1574 description: "Run commands".to_string(),
1575 parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}),
1576 }];
1577 let config = BedrockProvider::convert_tools_to_converse(Some(&tools));
1578 assert!(config.is_some());
1579 let config = config.unwrap();
1580 assert_eq!(config.tools.len(), 1);
1581 assert_eq!(config.tools[0].tool_spec.name, "shell");
1582 }
1583
1584 #[test]
1585 fn convert_tools_to_converse_empty_returns_none() {
1586 assert!(BedrockProvider::convert_tools_to_converse(Some(&[])).is_none());
1587 assert!(BedrockProvider::convert_tools_to_converse(None).is_none());
1588 }
1589
1590 #[test]
1593 fn converse_request_serializes_without_system() {
1594 let req = ConverseRequest {
1595 system: None,
1596 messages: vec![ConverseMessage {
1597 role: "user".to_string(),
1598 content: vec![ContentBlock::Text(TextBlock {
1599 text: "Hello".to_string(),
1600 })],
1601 }],
1602 inference_config: Some(InferenceConfig {
1603 max_tokens: 4096,
1604 temperature: 0.7,
1605 }),
1606 tool_config: None,
1607 };
1608 let json = serde_json::to_string(&req).unwrap();
1609 assert!(!json.contains("system"));
1610 assert!(json.contains("Hello"));
1611 assert!(json.contains("maxTokens"));
1612 }
1613
1614 #[test]
1615 fn converse_response_deserializes_text() {
1616 let json = r#"{
1617 "output": {
1618 "message": {
1619 "role": "assistant",
1620 "content": [{"text": "Hello from Bedrock"}]
1621 }
1622 },
1623 "stopReason": "end_turn"
1624 }"#;
1625 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1626 let parsed = BedrockProvider::parse_converse_response(resp);
1627 assert_eq!(parsed.text.as_deref(), Some("Hello from Bedrock"));
1628 assert!(parsed.tool_calls.is_empty());
1629 }
1630
1631 #[test]
1632 fn converse_response_deserializes_tool_use() {
1633 let json = r#"{
1634 "output": {
1635 "message": {
1636 "role": "assistant",
1637 "content": [
1638 {"toolUse": {"toolUseId": "call_1", "name": "shell", "input": {"command": "ls"}}}
1639 ]
1640 }
1641 },
1642 "stopReason": "tool_use"
1643 }"#;
1644 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1645 let parsed = BedrockProvider::parse_converse_response(resp);
1646 assert!(parsed.text.is_none());
1647 assert_eq!(parsed.tool_calls.len(), 1);
1648 assert_eq!(parsed.tool_calls[0].name, "shell");
1649 assert_eq!(parsed.tool_calls[0].id, "call_1");
1650 }
1651
1652 #[test]
1653 fn converse_response_empty_output() {
1654 let json = r#"{"output": null, "stopReason": null}"#;
1655 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1656 let parsed = BedrockProvider::parse_converse_response(resp);
1657 assert!(parsed.text.is_none());
1658 assert!(parsed.tool_calls.is_empty());
1659 }
1660
1661 #[test]
1662 fn content_block_text_serializes_as_flat_string() {
1663 let block = ContentBlock::Text(TextBlock {
1664 text: "Hello".to_string(),
1665 });
1666 let json = serde_json::to_string(&block).unwrap();
1667 assert_eq!(json, r#"{"text":"Hello"}"#);
1669 }
1670
1671 #[test]
1672 fn content_block_tool_use_serializes_with_nested_object() {
1673 let block = ContentBlock::ToolUse(ToolUseWrapper {
1674 tool_use: ToolUseBlock {
1675 tool_use_id: "call_1".to_string(),
1676 name: "shell".to_string(),
1677 input: serde_json::json!({"command": "ls"}),
1678 },
1679 });
1680 let json = serde_json::to_string(&block).unwrap();
1681 assert!(json.contains(r#""toolUse""#));
1682 assert!(json.contains(r#""toolUseId":"call_1""#));
1683 }
1684
1685 #[test]
1686 fn content_block_cache_point_serializes() {
1687 let block = ContentBlock::CachePointBlock(CachePointWrapper {
1688 cache_point: CachePoint::default_cache(),
1689 });
1690 let json = serde_json::to_string(&block).unwrap();
1691 assert_eq!(json, r#"{"cachePoint":{"type":"default"}}"#);
1692 }
1693
1694 #[test]
1695 fn content_block_text_round_trips() {
1696 let original = ContentBlock::Text(TextBlock {
1697 text: "Hello".to_string(),
1698 });
1699 let json = serde_json::to_string(&original).unwrap();
1700 let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
1701 assert!(matches!(deserialized, ContentBlock::Text(tb) if tb.text == "Hello"));
1702 }
1703
1704 #[test]
1705 fn cache_point_serializes() {
1706 let cp = CachePoint::default_cache();
1707 let json = serde_json::to_string(&cp).unwrap();
1708 assert_eq!(json, r#"{"type":"default"}"#);
1709 }
1710
1711 #[tokio::test]
1712 async fn warmup_without_credentials_is_noop() {
1713 let provider = BedrockProvider {
1714 auth: None,
1715 max_tokens: DEFAULT_MAX_TOKENS,
1716 };
1717 let result = provider.warmup().await;
1718 assert!(result.is_ok());
1719 }
1720
1721 #[test]
1722 fn capabilities_reports_native_tool_calling() {
1723 let provider = BedrockProvider {
1724 auth: None,
1725 max_tokens: DEFAULT_MAX_TOKENS,
1726 };
1727 let caps = provider.capabilities();
1728 assert!(caps.native_tool_calling);
1729 }
1730
1731 #[test]
1732 fn converse_response_parses_usage() {
1733 let json = r#"{
1734 "output": {"message": {"role": "assistant", "content": [{"text": {"text": "Hello"}}]}},
1735 "usage": {"inputTokens": 500, "outputTokens": 100}
1736 }"#;
1737 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1738 let usage = resp.usage.unwrap();
1739 assert_eq!(usage.input_tokens, Some(500));
1740 assert_eq!(usage.output_tokens, Some(100));
1741 }
1742
1743 #[test]
1744 fn converse_response_parses_without_usage() {
1745 let json = r#"{"output": {"message": {"role": "assistant", "content": []}}}"#;
1746 let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1747 assert!(resp.usage.is_none());
1748 }
1749
1750 #[test]
1753 fn fallback_tool_result_emits_tool_result_block_not_text() {
1754 let messages = vec![
1757 ChatMessage::user("do something"),
1758 ChatMessage::assistant(
1759 r#"{"content":"","tool_calls":[{"id":"tool_1","name":"shell","arguments":"{}"}]}"#,
1760 ),
1761 ChatMessage {
1762 role: "tool".to_string(),
1763 content: "not valid json".to_string(),
1764 },
1765 ];
1766 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1767 let tool_msg = &msgs[2];
1768 assert_eq!(tool_msg.role, "user");
1769 assert!(
1770 matches!(&tool_msg.content[0], ContentBlock::ToolResult(_)),
1771 "Expected ToolResult block, got {:?}",
1772 tool_msg.content[0]
1773 );
1774 }
1775
1776 #[test]
1777 fn fallback_recovers_tool_use_id_from_assistant() {
1778 let messages = vec![
1779 ChatMessage::user("run it"),
1780 ChatMessage::assistant(
1781 r#"{"content":"","tool_calls":[{"id":"tool_abc","name":"shell","arguments":"{}"}]}"#,
1782 ),
1783 ChatMessage {
1784 role: "tool".to_string(),
1785 content: "raw output with no json".to_string(),
1786 },
1787 ];
1788 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1789 if let ContentBlock::ToolResult(ref wrapper) = msgs[2].content[0] {
1790 assert_eq!(wrapper.tool_result.tool_use_id, "tool_abc");
1791 assert_eq!(wrapper.tool_result.status, "error");
1792 } else {
1793 panic!("Expected ToolResult block");
1794 }
1795 }
1796
1797 #[test]
1798 fn consecutive_tool_results_merged_into_single_message() {
1799 let messages = vec![
1800 ChatMessage::user("do two things"),
1801 ChatMessage::assistant(
1802 r#"{"content":"","tool_calls":[{"id":"t1","name":"a","arguments":"{}"},{"id":"t2","name":"b","arguments":"{}"}]}"#,
1803 ),
1804 ChatMessage::tool(r#"{"tool_call_id":"t1","content":"result 1"}"#),
1805 ChatMessage::tool(r#"{"tool_call_id":"t2","content":"result 2"}"#),
1806 ];
1807 let (_, msgs) = BedrockProvider::convert_messages(&messages);
1808 assert_eq!(msgs.len(), 3, "Expected 3 messages, got {}", msgs.len());
1810 assert_eq!(msgs[2].role, "user");
1811 assert_eq!(
1812 msgs[2].content.len(),
1813 2,
1814 "Expected 2 tool results in one message"
1815 );
1816 assert!(matches!(&msgs[2].content[0], ContentBlock::ToolResult(_)));
1817 assert!(matches!(&msgs[2].content[1], ContentBlock::ToolResult(_)));
1818 }
1819
1820 #[test]
1821 fn extract_tool_call_id_tries_multiple_field_names() {
1822 assert_eq!(
1823 BedrockProvider::extract_tool_call_id(r#"{"tool_call_id":"a"}"#),
1824 Some("a".to_string())
1825 );
1826 assert_eq!(
1827 BedrockProvider::extract_tool_call_id(r#"{"tool_use_id":"b"}"#),
1828 Some("b".to_string())
1829 );
1830 assert_eq!(
1831 BedrockProvider::extract_tool_call_id(r#"{"toolUseId":"c"}"#),
1832 Some("c".to_string())
1833 );
1834 assert_eq!(
1835 BedrockProvider::extract_tool_call_id("not json at all"),
1836 None
1837 );
1838 }
1839
1840 #[test]
1841 fn parse_tool_result_accepts_alternate_id_fields() {
1842 let msg =
1843 BedrockProvider::parse_tool_result_message(r#"{"tool_use_id":"x","content":"ok"}"#);
1844 assert!(msg.is_some());
1845 if let ContentBlock::ToolResult(ref wrapper) = msg.unwrap().content[0] {
1846 assert_eq!(wrapper.tool_result.tool_use_id, "x");
1847 } else {
1848 panic!("Expected ToolResult");
1849 }
1850 }
1851}