Skip to main content

construct/providers/
bedrock.rs

1//! AWS Bedrock provider using the Converse API.
2//!
3//! Authentication: supports two methods:
4//! - **Bearer token**: set `BEDROCK_API_KEY` env var (takes precedence).
5//! - **SigV4 signing**: AWS AKSK (Access Key ID + Secret Access Key)
6//!   via environment variables or EC2 IMDSv2. SigV4 signing is implemented
7//!   manually using hmac/sha2 crates — no AWS SDK dependency.
8
9use crate::providers::traits::{
10    ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
11    Provider, ProviderCapabilities, TokenUsage, ToolCall as ProviderToolCall, ToolsPayload,
12};
13use crate::tools::ToolSpec;
14use async_trait::async_trait;
15use hmac::{Hmac, Mac};
16use reqwest::Client;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19
20/// Hostname prefix for the Bedrock Runtime endpoint.
21const ENDPOINT_PREFIX: &str = "bedrock-runtime";
22/// SigV4 signing service name (AWS uses "bedrock", not "bedrock-runtime").
23const SIGNING_SERVICE: &str = "bedrock";
24const DEFAULT_REGION: &str = "us-east-1";
25const DEFAULT_MAX_TOKENS: u32 = 4096;
26
27// ── Authentication ──────────────────────────────────────────────
28
29/// Authentication method for Bedrock: either SigV4 (AKSK) or Bearer token.
30enum BedrockAuth {
31    SigV4(AwsCredentials),
32    BearerToken(String),
33}
34
35// ── AWS Credentials ─────────────────────────────────────────────
36
37/// Resolved AWS credentials for SigV4 signing.
38struct AwsCredentials {
39    access_key_id: String,
40    secret_access_key: String,
41    session_token: Option<String>,
42    region: String,
43}
44
45impl AwsCredentials {
46    /// Resolve credentials: first try environment variables, then EC2 IMDSv2.
47    fn from_env() -> anyhow::Result<Self> {
48        let access_key_id = env_required("AWS_ACCESS_KEY_ID")?;
49        let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?;
50
51        let session_token = env_optional("AWS_SESSION_TOKEN");
52
53        let region = env_optional("AWS_REGION")
54            .or_else(|| env_optional("AWS_DEFAULT_REGION"))
55            .unwrap_or_else(|| DEFAULT_REGION.to_string());
56
57        Ok(Self {
58            access_key_id,
59            secret_access_key,
60            session_token,
61            region,
62        })
63    }
64
65    /// Fetch credentials from EC2 IMDSv2 instance metadata service.
66    async fn from_imds() -> anyhow::Result<Self> {
67        let client = reqwest::Client::builder()
68            .timeout(std::time::Duration::from_secs(3))
69            .build()?;
70
71        // Step 1: get IMDSv2 token
72        let token = client
73            .put("http://169.254.169.254/latest/api/token")
74            .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
75            .send()
76            .await?
77            .text()
78            .await?;
79
80        // Step 2: get IAM role name
81        let role = client
82            .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
83            .header("X-aws-ec2-metadata-token", &token)
84            .send()
85            .await?
86            .text()
87            .await?;
88        let role = role.trim().to_string();
89        anyhow::ensure!(!role.is_empty(), "No IAM role attached to this instance");
90
91        // Step 3: get credentials for that role
92        let creds_url = format!(
93            "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
94            role
95        );
96        let creds_json: serde_json::Value = client
97            .get(&creds_url)
98            .header("X-aws-ec2-metadata-token", &token)
99            .send()
100            .await?
101            .json()
102            .await?;
103
104        let access_key_id = creds_json["AccessKeyId"]
105            .as_str()
106            .ok_or_else(|| anyhow::anyhow!("Missing AccessKeyId in IMDS response"))?
107            .to_string();
108        let secret_access_key = creds_json["SecretAccessKey"]
109            .as_str()
110            .ok_or_else(|| anyhow::anyhow!("Missing SecretAccessKey in IMDS response"))?
111            .to_string();
112        let session_token = creds_json["Token"].as_str().map(|s| s.to_string());
113
114        // Step 4: get region from instance identity document
115        let region = match client
116            .get("http://169.254.169.254/latest/meta-data/placement/region")
117            .header("X-aws-ec2-metadata-token", &token)
118            .send()
119            .await
120        {
121            Ok(resp) => resp.text().await.unwrap_or_default(),
122            Err(_) => String::new(),
123        };
124        let region = if region.trim().is_empty() {
125            env_optional("AWS_REGION")
126                .or_else(|| env_optional("AWS_DEFAULT_REGION"))
127                .unwrap_or_else(|| DEFAULT_REGION.to_string())
128        } else {
129            region.trim().to_string()
130        };
131
132        tracing::info!(
133            "Loaded AWS credentials from EC2 instance metadata (role: {})",
134            role
135        );
136
137        Ok(Self {
138            access_key_id,
139            secret_access_key,
140            session_token,
141            region,
142        })
143    }
144
145    /// Resolve credentials: env vars first, then EC2 IMDS.
146    async fn resolve() -> anyhow::Result<Self> {
147        if let Ok(creds) = Self::from_env() {
148            return Ok(creds);
149        }
150        Self::from_imds().await
151    }
152
153    fn host(&self) -> String {
154        format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region)
155    }
156}
157
158fn env_required(name: &str) -> anyhow::Result<String> {
159    std::env::var(name)
160        .ok()
161        .map(|v| v.trim().to_string())
162        .filter(|v| !v.is_empty())
163        .ok_or_else(|| anyhow::anyhow!("Environment variable {name} is required for Bedrock"))
164}
165
166fn env_optional(name: &str) -> Option<String> {
167    std::env::var(name)
168        .ok()
169        .map(|v| v.trim().to_string())
170        .filter(|v| !v.is_empty())
171}
172
173// ── AWS SigV4 Signing ───────────────────────────────────────────
174
175fn sha256_hex(data: &[u8]) -> String {
176    let mut hasher = Sha256::new();
177    hasher.update(data);
178    hex::encode(hasher.finalize())
179}
180
181fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec<u8> {
182    let mut mac = Hmac::<Sha256>::new_from_slice(key).expect("HMAC can take key of any size");
183    mac.update(data);
184    mac.finalize().into_bytes().to_vec()
185}
186
187/// Derive the SigV4 signing key via HMAC chain.
188fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec<u8> {
189    let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes());
190    let k_region = hmac_sha256(&k_date, region.as_bytes());
191    let k_service = hmac_sha256(&k_region, service.as_bytes());
192    hmac_sha256(&k_service, b"aws4_request")
193}
194
195/// Build the SigV4 `Authorization` header value.
196///
197/// `headers` must be sorted by lowercase header name.
198fn build_authorization_header(
199    credentials: &AwsCredentials,
200    method: &str,
201    canonical_uri: &str,
202    query_string: &str,
203    headers: &[(String, String)],
204    payload: &[u8],
205    timestamp: &chrono::DateTime<chrono::Utc>,
206) -> String {
207    let date_stamp = timestamp.format("%Y%m%d").to_string();
208    let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
209
210    let mut canonical_headers = String::new();
211    for (k, v) in headers {
212        canonical_headers.push_str(k);
213        canonical_headers.push(':');
214        canonical_headers.push_str(v);
215        canonical_headers.push('\n');
216    }
217
218    let signed_headers: String = headers
219        .iter()
220        .map(|(k, _)| k.as_str())
221        .collect::<Vec<_>>()
222        .join(";");
223
224    let payload_hash = sha256_hex(payload);
225
226    let canonical_request = format!(
227        "{method}\n{canonical_uri}\n{query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}"
228    );
229
230    let credential_scope = format!(
231        "{date_stamp}/{}/{SIGNING_SERVICE}/aws4_request",
232        credentials.region
233    );
234
235    let string_to_sign = format!(
236        "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}",
237        sha256_hex(canonical_request.as_bytes())
238    );
239
240    let signing_key = derive_signing_key(
241        &credentials.secret_access_key,
242        &date_stamp,
243        &credentials.region,
244        SIGNING_SERVICE,
245    );
246
247    let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()));
248
249    format!(
250        "AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}",
251        credentials.access_key_id
252    )
253}
254
255// ── Converse API Types (Request) ────────────────────────────────
256
257#[derive(Debug, Serialize)]
258#[serde(rename_all = "camelCase")]
259struct ConverseRequest {
260    messages: Vec<ConverseMessage>,
261    #[serde(skip_serializing_if = "Option::is_none")]
262    system: Option<Vec<SystemBlock>>,
263    #[serde(skip_serializing_if = "Option::is_none")]
264    inference_config: Option<InferenceConfig>,
265    #[serde(skip_serializing_if = "Option::is_none")]
266    tool_config: Option<ToolConfig>,
267}
268
269#[derive(Debug, Serialize, Deserialize)]
270struct ConverseMessage {
271    role: String,
272    content: Vec<ContentBlock>,
273}
274
275/// Content blocks use Bedrock's union style:
276/// `{"text": "..."}`, `{"toolUse": {...}}`, `{"toolResult": {...}}`, `{"cachePoint": {...}}`.
277///
278/// Note: `text` is a simple string value, not a nested object. `toolUse` and `toolResult`
279/// are nested objects. We use `#[serde(untagged)]` with manual struct wrappers to
280/// match this mixed format.
281#[derive(Debug, Serialize, Deserialize)]
282#[serde(untagged)]
283enum ContentBlock {
284    Text(TextBlock),
285    ToolUse(ToolUseWrapper),
286    ToolResult(ToolResultWrapper),
287    CachePointBlock(CachePointWrapper),
288    Image(ImageWrapper),
289}
290
291#[derive(Debug, Serialize, Deserialize)]
292struct ImageWrapper {
293    image: ImageBlock,
294}
295
296#[derive(Debug, Serialize, Deserialize)]
297struct ImageBlock {
298    format: String,
299    source: ImageSource,
300}
301
302#[derive(Debug, Serialize, Deserialize)]
303#[serde(rename_all = "camelCase")]
304struct ImageSource {
305    bytes: String,
306}
307
308#[derive(Debug, Serialize, Deserialize)]
309struct TextBlock {
310    text: String,
311}
312
313#[derive(Debug, Serialize, Deserialize)]
314#[serde(rename_all = "camelCase")]
315struct ToolUseWrapper {
316    tool_use: ToolUseBlock,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320#[serde(rename_all = "camelCase")]
321struct ToolUseBlock {
322    tool_use_id: String,
323    name: String,
324    input: serde_json::Value,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328#[serde(rename_all = "camelCase")]
329struct ToolResultWrapper {
330    tool_result: ToolResultBlock,
331}
332
333#[derive(Debug, Serialize, Deserialize)]
334#[serde(rename_all = "camelCase")]
335struct ToolResultBlock {
336    tool_use_id: String,
337    content: Vec<ToolResultContent>,
338    status: String,
339}
340
341#[derive(Debug, Serialize, Deserialize)]
342#[serde(rename_all = "camelCase")]
343struct CachePointWrapper {
344    cache_point: CachePoint,
345}
346
347#[derive(Debug, Serialize, Deserialize)]
348struct ToolResultContent {
349    text: String,
350}
351
352#[derive(Debug, Serialize, Deserialize)]
353struct CachePoint {
354    #[serde(rename = "type")]
355    cache_type: String,
356}
357
358impl CachePoint {
359    fn default_cache() -> Self {
360        Self {
361            cache_type: "default".to_string(),
362        }
363    }
364}
365
366/// System prompt blocks: either `{"text": "..."}` or `{"cachePoint": {...}}`.
367#[derive(Debug, Serialize)]
368#[serde(untagged)]
369enum SystemBlock {
370    Text(TextBlock),
371    CachePoint(CachePointWrapper),
372}
373
374#[derive(Debug, Serialize)]
375#[serde(rename_all = "camelCase")]
376struct InferenceConfig {
377    max_tokens: u32,
378    temperature: f64,
379}
380
381#[derive(Debug, Serialize)]
382#[serde(rename_all = "camelCase")]
383struct ToolConfig {
384    tools: Vec<ToolDefinition>,
385}
386
387#[derive(Debug, Serialize)]
388#[serde(rename_all = "camelCase")]
389struct ToolDefinition {
390    tool_spec: ToolSpecDef,
391}
392
393#[derive(Debug, Serialize)]
394#[serde(rename_all = "camelCase")]
395struct ToolSpecDef {
396    name: String,
397    description: String,
398    input_schema: InputSchema,
399}
400
401#[derive(Debug, Serialize)]
402struct InputSchema {
403    json: serde_json::Value,
404}
405
406// ── Converse API Types (Response) ───────────────────────────────
407
408#[derive(Debug, Deserialize)]
409#[serde(rename_all = "camelCase")]
410struct ConverseResponse {
411    #[serde(default)]
412    output: Option<ConverseOutput>,
413    #[serde(default)]
414    #[allow(dead_code)]
415    stop_reason: Option<String>,
416    #[serde(default)]
417    usage: Option<BedrockUsage>,
418}
419
420#[derive(Debug, Deserialize)]
421#[serde(rename_all = "camelCase")]
422struct BedrockUsage {
423    #[serde(default)]
424    input_tokens: Option<u64>,
425    #[serde(default)]
426    output_tokens: Option<u64>,
427}
428
429#[derive(Debug, Deserialize)]
430struct ConverseOutput {
431    #[serde(default)]
432    message: Option<ConverseOutputMessage>,
433}
434
435#[derive(Debug, Deserialize)]
436struct ConverseOutputMessage {
437    #[allow(dead_code)]
438    role: String,
439    content: Vec<ResponseContentBlock>,
440}
441
442/// Response content blocks from the Converse API.
443///
444/// Uses `#[serde(untagged)]` to match Bedrock's union format where `text` is a
445/// simple string value and `toolUse` is a nested object. Unknown block types
446/// (e.g. `reasoningContent`, `guardContent`) are captured as `Other` to prevent
447/// deserialization failures.
448#[derive(Debug, Deserialize)]
449#[serde(untagged)]
450enum ResponseContentBlock {
451    ToolUse(ResponseToolUseWrapper),
452    Text(TextBlock),
453    Other(serde_json::Value),
454}
455
456#[derive(Debug, Deserialize)]
457#[serde(rename_all = "camelCase")]
458struct ResponseToolUseWrapper {
459    tool_use: ToolUseBlock,
460}
461
462// ── BedrockProvider ─────────────────────────────────────────────
463
464pub struct BedrockProvider {
465    auth: Option<BedrockAuth>,
466    max_tokens: u32,
467}
468
469impl BedrockProvider {
470    pub fn new() -> Self {
471        // Bearer token takes precedence over SigV4 credentials.
472        if let Some(token) = env_optional("BEDROCK_API_KEY") {
473            return Self {
474                auth: Some(BedrockAuth::BearerToken(token)),
475                max_tokens: DEFAULT_MAX_TOKENS,
476            };
477        }
478        Self {
479            auth: AwsCredentials::from_env().ok().map(BedrockAuth::SigV4),
480            max_tokens: DEFAULT_MAX_TOKENS,
481        }
482    }
483
484    pub async fn new_async() -> Self {
485        // Bearer token takes precedence over SigV4 credentials.
486        if let Some(token) = env_optional("BEDROCK_API_KEY") {
487            return Self {
488                auth: Some(BedrockAuth::BearerToken(token)),
489                max_tokens: DEFAULT_MAX_TOKENS,
490            };
491        }
492        let auth = AwsCredentials::resolve().await.ok().map(BedrockAuth::SigV4);
493        Self {
494            auth,
495            max_tokens: DEFAULT_MAX_TOKENS,
496        }
497    }
498
499    /// Create a provider using a Bearer token for authentication.
500    pub fn with_bearer_token(token: &str) -> Self {
501        Self {
502            auth: Some(BedrockAuth::BearerToken(token.to_string())),
503            max_tokens: DEFAULT_MAX_TOKENS,
504        }
505    }
506
507    /// Override the maximum output tokens for API requests.
508    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
509        self.max_tokens = max_tokens;
510        self
511    }
512
513    fn http_client(&self) -> Client {
514        crate::config::build_runtime_proxy_client_with_timeouts("provider.bedrock", 120, 10)
515    }
516
517    /// Percent-encode the model ID for URL path: only encode `:` to `%3A`.
518    /// Colons in model IDs (e.g. `v1:0`) must be encoded because `reqwest::Url`
519    /// may misparse them. Dots, hyphens, and alphanumerics are safe.
520    fn encode_model_path(model_id: &str) -> String {
521        model_id.replace(':', "%3A")
522    }
523
524    /// Resolve the AWS region from environment variables.
525    fn resolve_region() -> String {
526        env_optional("AWS_REGION")
527            .or_else(|| env_optional("AWS_DEFAULT_REGION"))
528            .unwrap_or_else(|| DEFAULT_REGION.to_string())
529    }
530
531    /// Build the actual request URL. Uses raw model ID (reqwest sends colons as-is).
532    fn endpoint_url(region: &str, model_id: &str) -> String {
533        format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse")
534    }
535
536    /// Build the canonical URI for SigV4 signing. Must URI-encode the path
537    /// per SigV4 spec: colons become `%3A`. AWS verifies the signature against
538    /// the encoded form even though the wire request uses raw colons.
539    fn canonical_uri(model_id: &str) -> String {
540        let encoded = Self::encode_model_path(model_id);
541        format!("/model/{encoded}/converse")
542    }
543
544    fn require_auth(&self) -> anyhow::Result<&BedrockAuth> {
545        self.auth.as_ref().ok_or_else(|| {
546            anyhow::anyhow!(
547                "AWS Bedrock credentials not set. Set BEDROCK_API_KEY for Bearer \
548                 token auth, or AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY for \
549                 SigV4 auth, or run on an EC2 instance with an IAM role attached."
550            )
551        })
552    }
553
554    /// Resolve auth: use cached if available, otherwise try env vars then IMDS.
555    async fn resolve_auth(&self) -> anyhow::Result<BedrockAuth> {
556        // If we already have auth cached, re-resolve from the same source.
557        if let Some(ref auth) = self.auth {
558            match auth {
559                BedrockAuth::BearerToken(token) => {
560                    return Ok(BedrockAuth::BearerToken(token.clone()));
561                }
562                BedrockAuth::SigV4(_) => {
563                    // Re-resolve SigV4 credentials (they may have rotated).
564                }
565            }
566        }
567        // Check Bearer token first.
568        if let Some(token) = env_optional("BEDROCK_API_KEY") {
569            return Ok(BedrockAuth::BearerToken(token));
570        }
571        // Fall back to SigV4.
572        if let Ok(creds) = AwsCredentials::from_env() {
573            return Ok(BedrockAuth::SigV4(creds));
574        }
575        Ok(BedrockAuth::SigV4(AwsCredentials::from_imds().await?))
576    }
577
578    // ── Cache heuristics (same thresholds as AnthropicProvider) ──
579
580    /// Cache system prompts larger than ~1024 tokens (3KB of text).
581    fn should_cache_system(text: &str) -> bool {
582        text.len() > 3072
583    }
584
585    /// Cache conversations with more than 4 messages (excluding system).
586    fn should_cache_conversation(messages: &[ChatMessage]) -> bool {
587        messages.iter().filter(|m| m.role != "system").count() > 4
588    }
589
590    // ── Message conversion ──────────────────────────────────────
591
592    fn convert_messages(
593        messages: &[ChatMessage],
594    ) -> (Option<Vec<SystemBlock>>, Vec<ConverseMessage>) {
595        let mut system_blocks = Vec::new();
596        let mut converse_messages = Vec::new();
597
598        for msg in messages {
599            match msg.role.as_str() {
600                "system" => {
601                    if system_blocks.is_empty() {
602                        system_blocks.push(SystemBlock::Text(TextBlock {
603                            text: msg.content.clone(),
604                        }));
605                    }
606                }
607                "assistant" => {
608                    if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
609                        converse_messages.push(ConverseMessage {
610                            role: "assistant".to_string(),
611                            content: blocks,
612                        });
613                    } else {
614                        converse_messages.push(ConverseMessage {
615                            role: "assistant".to_string(),
616                            content: vec![ContentBlock::Text(TextBlock {
617                                text: msg.content.clone(),
618                            })],
619                        });
620                    }
621                }
622                "tool" => {
623                    let tool_result_msg = Self::parse_tool_result_message(&msg.content)
624                        .unwrap_or_else(|| {
625                            // Fallback: always emit a toolResult block so the
626                            // Bedrock API contract (every toolUse needs a matching
627                            // toolResult) is never violated.
628                            let tool_use_id = Self::extract_tool_call_id(&msg.content)
629                                .or_else(|| Self::last_pending_tool_use_id(&converse_messages))
630                                .unwrap_or_else(|| "unknown".to_string());
631
632                            tracing::warn!(
633                                "Failed to parse tool result message, creating error \
634                                 toolResult for tool_use_id={}",
635                                tool_use_id
636                            );
637
638                            ConverseMessage {
639                                role: "user".to_string(),
640                                content: vec![ContentBlock::ToolResult(ToolResultWrapper {
641                                    tool_result: ToolResultBlock {
642                                        tool_use_id,
643                                        content: vec![ToolResultContent {
644                                            text: msg.content.clone(),
645                                        }],
646                                        status: "error".to_string(),
647                                    },
648                                })],
649                            }
650                        });
651
652                    // Merge consecutive tool results into a single user message.
653                    // Bedrock requires all toolResult blocks for a multi-tool-call
654                    // turn to appear in one user message.
655                    if let Some(last) = converse_messages.last_mut() {
656                        if last.role == "user"
657                            && last
658                                .content
659                                .iter()
660                                .all(|b| matches!(b, ContentBlock::ToolResult(_)))
661                        {
662                            last.content.extend(tool_result_msg.content);
663                            continue;
664                        }
665                    }
666                    converse_messages.push(tool_result_msg);
667                }
668                _ => {
669                    let content_blocks = Self::parse_user_content_blocks(&msg.content);
670                    converse_messages.push(ConverseMessage {
671                        role: "user".to_string(),
672                        content: content_blocks,
673                    });
674                }
675            }
676        }
677
678        let system = if system_blocks.is_empty() {
679            None
680        } else {
681            Some(system_blocks)
682        };
683        (system, converse_messages)
684    }
685
686    /// Try to extract a tool_call_id from partially-valid JSON content.
687    fn extract_tool_call_id(content: &str) -> Option<String> {
688        let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
689        value
690            .get("tool_call_id")
691            .or_else(|| value.get("tool_use_id"))
692            .or_else(|| value.get("toolUseId"))
693            .and_then(serde_json::Value::as_str)
694            .map(String::from)
695    }
696
697    /// Find the first unmatched tool_use_id from the last assistant message.
698    ///
699    /// When a tool result can't be parsed at all (not even the ID), we fall
700    /// back to matching it against the preceding assistant turn's toolUse
701    /// blocks that don't yet have a corresponding toolResult.
702    fn last_pending_tool_use_id(converse_messages: &[ConverseMessage]) -> Option<String> {
703        let last_assistant = converse_messages
704            .iter()
705            .rev()
706            .find(|m| m.role == "assistant")?;
707
708        let tool_use_ids: Vec<&str> = last_assistant
709            .content
710            .iter()
711            .filter_map(|b| match b {
712                ContentBlock::ToolUse(wrapper) => Some(wrapper.tool_use.tool_use_id.as_str()),
713                _ => None,
714            })
715            .collect();
716
717        let answered_ids: Vec<&str> = converse_messages
718            .iter()
719            .rev()
720            .take_while(|m| m.role == "user")
721            .flat_map(|m| m.content.iter())
722            .filter_map(|b| match b {
723                ContentBlock::ToolResult(wrapper) => Some(wrapper.tool_result.tool_use_id.as_str()),
724                _ => None,
725            })
726            .collect();
727
728        tool_use_ids
729            .into_iter()
730            .find(|id| !answered_ids.contains(id))
731            .map(String::from)
732    }
733
734    /// Parse user message content, extracting [IMAGE:data:...] markers into image blocks.
735    fn parse_user_content_blocks(content: &str) -> Vec<ContentBlock> {
736        let mut blocks: Vec<ContentBlock> = Vec::new();
737        let mut remaining = content;
738        let has_image = content.contains("[IMAGE:");
739        tracing::info!(
740            "parse_user_content_blocks called, len={}, has_image={}",
741            content.len(),
742            has_image
743        );
744
745        while let Some(start) = remaining.find("[IMAGE:") {
746            // Add any text before the marker
747            let text_before = &remaining[..start];
748            if !text_before.trim().is_empty() {
749                blocks.push(ContentBlock::Text(TextBlock {
750                    text: text_before.to_string(),
751                }));
752            }
753
754            let after = &remaining[start + 7..]; // skip "[IMAGE:"
755            if let Some(end) = after.find(']') {
756                let src = &after[..end];
757                remaining = &after[end + 1..];
758
759                // Only handle data URIs (base64 encoded images)
760                if let Some(rest) = src.strip_prefix("data:") {
761                    if let Some(semi) = rest.find(';') {
762                        let mime = &rest[..semi];
763                        let after_semi = &rest[semi + 1..];
764                        if let Some(b64) = after_semi.strip_prefix("base64,") {
765                            let format = match mime {
766                                "image/png" => "png",
767                                "image/gif" => "gif",
768                                "image/webp" => "webp",
769                                _ => "jpeg",
770                            };
771                            blocks.push(ContentBlock::Image(ImageWrapper {
772                                image: ImageBlock {
773                                    format: format.to_string(),
774                                    source: ImageSource {
775                                        bytes: b64.to_string(),
776                                    },
777                                },
778                            }));
779                            continue;
780                        }
781                    }
782                }
783                // Non-data-uri image: just include as text reference
784                blocks.push(ContentBlock::Text(TextBlock {
785                    text: format!("[image: {}]", src),
786                }));
787            } else {
788                // No closing bracket, treat rest as text
789                blocks.push(ContentBlock::Text(TextBlock {
790                    text: remaining.to_string(),
791                }));
792                break;
793            }
794        }
795
796        // Add any remaining text
797        if !remaining.trim().is_empty() {
798            blocks.push(ContentBlock::Text(TextBlock {
799                text: remaining.to_string(),
800            }));
801        }
802
803        if blocks.is_empty() {
804            blocks.push(ContentBlock::Text(TextBlock {
805                text: content.to_string(),
806            }));
807        }
808
809        blocks
810    }
811
812    /// Parse assistant message containing structured tool calls.
813    fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<ContentBlock>> {
814        let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
815        let tool_calls = value
816            .get("tool_calls")
817            .and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
818
819        let mut blocks = Vec::new();
820        if let Some(text) = value
821            .get("content")
822            .and_then(serde_json::Value::as_str)
823            .map(str::trim)
824            .filter(|t| !t.is_empty())
825        {
826            blocks.push(ContentBlock::Text(TextBlock {
827                text: text.to_string(),
828            }));
829        }
830        for call in tool_calls {
831            let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
832                .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
833            blocks.push(ContentBlock::ToolUse(ToolUseWrapper {
834                tool_use: ToolUseBlock {
835                    tool_use_id: call.id,
836                    name: call.name,
837                    input,
838                },
839            }));
840        }
841        Some(blocks)
842    }
843
844    /// Parse tool result message into a user message with ToolResult block.
845    fn parse_tool_result_message(content: &str) -> Option<ConverseMessage> {
846        let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
847        let tool_use_id = value
848            .get("tool_call_id")
849            .or_else(|| value.get("tool_use_id"))
850            .or_else(|| value.get("toolUseId"))
851            .and_then(serde_json::Value::as_str)?
852            .to_string();
853        let result = value
854            .get("content")
855            .and_then(serde_json::Value::as_str)
856            .unwrap_or("")
857            .to_string();
858        Some(ConverseMessage {
859            role: "user".to_string(),
860            content: vec![ContentBlock::ToolResult(ToolResultWrapper {
861                tool_result: ToolResultBlock {
862                    tool_use_id,
863                    content: vec![ToolResultContent { text: result }],
864                    status: "success".to_string(),
865                },
866            })],
867        })
868    }
869
870    // ── Tool conversion ─────────────────────────────────────────
871
872    fn convert_tools_to_converse(tools: Option<&[ToolSpec]>) -> Option<ToolConfig> {
873        let items = tools?;
874        if items.is_empty() {
875            return None;
876        }
877        let tool_defs: Vec<ToolDefinition> = items
878            .iter()
879            .map(|tool| ToolDefinition {
880                tool_spec: ToolSpecDef {
881                    name: tool.name.clone(),
882                    description: tool.description.clone(),
883                    input_schema: InputSchema {
884                        json: tool.parameters.clone(),
885                    },
886                },
887            })
888            .collect();
889        Some(ToolConfig { tools: tool_defs })
890    }
891
892    // ── Response parsing ────────────────────────────────────────
893
894    fn parse_converse_response(response: ConverseResponse) -> ProviderChatResponse {
895        let mut text_parts = Vec::new();
896        let mut tool_calls = Vec::new();
897
898        let usage = response.usage.map(|u| TokenUsage {
899            input_tokens: u.input_tokens,
900            output_tokens: u.output_tokens,
901            cached_input_tokens: None,
902        });
903
904        if let Some(output) = response.output {
905            if let Some(message) = output.message {
906                for block in message.content {
907                    match block {
908                        ResponseContentBlock::Text(tb) => {
909                            let trimmed = tb.text.trim().to_string();
910                            if !trimmed.is_empty() {
911                                text_parts.push(trimmed);
912                            }
913                        }
914                        ResponseContentBlock::ToolUse(wrapper) => {
915                            if !wrapper.tool_use.name.is_empty() {
916                                tool_calls.push(ProviderToolCall {
917                                    id: wrapper.tool_use.tool_use_id,
918                                    name: wrapper.tool_use.name,
919                                    arguments: wrapper.tool_use.input.to_string(),
920                                });
921                            }
922                        }
923                        ResponseContentBlock::Other(_) => {}
924                    }
925                }
926            }
927        }
928
929        ProviderChatResponse {
930            text: if text_parts.is_empty() {
931                None
932            } else {
933                Some(text_parts.join("\n"))
934            },
935            tool_calls,
936            usage,
937            reasoning_content: None,
938        }
939    }
940
941    // ── HTTP request ────────────────────────────────────────────
942
943    async fn send_converse_request(
944        &self,
945        auth: &BedrockAuth,
946        model: &str,
947        request_body: &ConverseRequest,
948    ) -> anyhow::Result<ConverseResponse> {
949        let payload = serde_json::to_vec(request_body)?;
950
951        // Debug: log image blocks in payload (truncated)
952        if let Ok(debug_val) = serde_json::from_slice::<serde_json::Value>(&payload) {
953            if let Some(msgs) = debug_val.get("messages").and_then(|m| m.as_array()) {
954                for msg in msgs {
955                    if let Some(content) = msg.get("content").and_then(|c| c.as_array()) {
956                        for block in content {
957                            if block.get("image").is_some() {
958                                let mut b = block.clone();
959                                if let Some(img) = b.get_mut("image") {
960                                    if let Some(src) = img.get_mut("source") {
961                                        if let Some(bytes) = src.get_mut("bytes") {
962                                            if let Some(s) = bytes.as_str() {
963                                                *bytes = serde_json::json!(format!(
964                                                    "<base64 {} chars>",
965                                                    s.len()
966                                                ));
967                                            }
968                                        }
969                                    }
970                                }
971                                tracing::info!(
972                                    "Bedrock image block: {}",
973                                    serde_json::to_string(&b).unwrap_or_default()
974                                );
975                            }
976                        }
977                    }
978                }
979            }
980        }
981
982        let response: reqwest::Response = match auth {
983            BedrockAuth::BearerToken(token) => {
984                let region = Self::resolve_region();
985                let url = Self::endpoint_url(&region, model);
986
987                self.http_client()
988                    .post(&url)
989                    .header("content-type", "application/json")
990                    .header("Authorization", format!("Bearer {token}"))
991                    .body(payload)
992                    .send()
993                    .await?
994            }
995            BedrockAuth::SigV4(credentials) => {
996                let url = Self::endpoint_url(&credentials.region, model);
997                let canonical_uri = Self::canonical_uri(model);
998                let now = chrono::Utc::now();
999                let host = credentials.host();
1000                let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string();
1001
1002                let mut headers_to_sign = vec![
1003                    ("content-type".to_string(), "application/json".to_string()),
1004                    ("host".to_string(), host),
1005                    ("x-amz-date".to_string(), amz_date.clone()),
1006                ];
1007                if let Some(ref session_token) = credentials.session_token {
1008                    headers_to_sign
1009                        .push(("x-amz-security-token".to_string(), session_token.clone()));
1010                }
1011                headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0));
1012
1013                let authorization = build_authorization_header(
1014                    credentials,
1015                    "POST",
1016                    &canonical_uri,
1017                    "",
1018                    &headers_to_sign,
1019                    &payload,
1020                    &now,
1021                );
1022
1023                let mut request = self
1024                    .http_client()
1025                    .post(&url)
1026                    .header("content-type", "application/json")
1027                    .header("x-amz-date", &amz_date)
1028                    .header("authorization", &authorization);
1029
1030                if let Some(ref session_token) = credentials.session_token {
1031                    request = request.header("x-amz-security-token", session_token);
1032                }
1033
1034                request.body(payload).send().await?
1035            }
1036        };
1037
1038        if !response.status().is_success() {
1039            return Err(super::api_error("Bedrock", response).await);
1040        }
1041
1042        let converse_response: ConverseResponse = response.json().await?;
1043        Ok(converse_response)
1044    }
1045}
1046
1047// ── Provider trait implementation ───────────────────────────────
1048
1049#[async_trait]
1050impl Provider for BedrockProvider {
1051    fn capabilities(&self) -> ProviderCapabilities {
1052        ProviderCapabilities {
1053            native_tool_calling: true,
1054            vision: true,
1055            prompt_caching: false,
1056        }
1057    }
1058
1059    fn supports_native_tools(&self) -> bool {
1060        true
1061    }
1062
1063    fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
1064        let tool_values: Vec<serde_json::Value> = tools
1065            .iter()
1066            .map(|t| {
1067                serde_json::json!({
1068                    "toolSpec": {
1069                        "name": t.name,
1070                        "description": t.description,
1071                        "inputSchema": { "json": t.parameters }
1072                    }
1073                })
1074            })
1075            .collect();
1076        ToolsPayload::Anthropic { tools: tool_values }
1077    }
1078
1079    async fn chat_with_system(
1080        &self,
1081        system_prompt: Option<&str>,
1082        message: &str,
1083        model: &str,
1084        temperature: f64,
1085    ) -> anyhow::Result<String> {
1086        let auth = self.resolve_auth().await?;
1087
1088        let system = system_prompt.map(|text| {
1089            let mut blocks = vec![SystemBlock::Text(TextBlock {
1090                text: text.to_string(),
1091            })];
1092            if Self::should_cache_system(text) {
1093                blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1094                    cache_point: CachePoint::default_cache(),
1095                }));
1096            }
1097            blocks
1098        });
1099
1100        let request = ConverseRequest {
1101            system,
1102            messages: vec![ConverseMessage {
1103                role: "user".to_string(),
1104                content: Self::parse_user_content_blocks(message),
1105            }],
1106            inference_config: Some(InferenceConfig {
1107                max_tokens: self.max_tokens,
1108                temperature,
1109            }),
1110            tool_config: None,
1111        };
1112
1113        let response = self.send_converse_request(&auth, model, &request).await?;
1114
1115        Self::parse_converse_response(response)
1116            .text
1117            .ok_or_else(|| anyhow::anyhow!("No response from Bedrock"))
1118    }
1119
1120    async fn chat(
1121        &self,
1122        request: ProviderChatRequest<'_>,
1123        model: &str,
1124        temperature: f64,
1125    ) -> anyhow::Result<ProviderChatResponse> {
1126        let auth = self.resolve_auth().await?;
1127
1128        let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages);
1129
1130        // Apply cachePoint to system if large.
1131        let system = system_blocks.map(|mut blocks| {
1132            let has_large_system = blocks
1133                .iter()
1134                .any(|b| matches!(b, SystemBlock::Text(tb) if Self::should_cache_system(&tb.text)));
1135            if has_large_system {
1136                blocks.push(SystemBlock::CachePoint(CachePointWrapper {
1137                    cache_point: CachePoint::default_cache(),
1138                }));
1139            }
1140            blocks
1141        });
1142
1143        // Apply cachePoint to last message if conversation is long.
1144        if Self::should_cache_conversation(request.messages) {
1145            if let Some(last_msg) = converse_messages.last_mut() {
1146                last_msg
1147                    .content
1148                    .push(ContentBlock::CachePointBlock(CachePointWrapper {
1149                        cache_point: CachePoint::default_cache(),
1150                    }));
1151            }
1152        }
1153
1154        let tool_config = Self::convert_tools_to_converse(request.tools);
1155
1156        let converse_request = ConverseRequest {
1157            system,
1158            messages: converse_messages,
1159            inference_config: Some(InferenceConfig {
1160                max_tokens: self.max_tokens,
1161                temperature,
1162            }),
1163            tool_config,
1164        };
1165
1166        let response = self
1167            .send_converse_request(&auth, model, &converse_request)
1168            .await?;
1169
1170        Ok(Self::parse_converse_response(response))
1171    }
1172
1173    async fn warmup(&self) -> anyhow::Result<()> {
1174        let region = match self.auth {
1175            Some(BedrockAuth::SigV4(ref creds)) => creds.region.clone(),
1176            Some(BedrockAuth::BearerToken(_)) => Self::resolve_region(),
1177            None => return Ok(()),
1178        };
1179        let url = format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/");
1180        let _ = self.http_client().get(&url).send().await;
1181        Ok(())
1182    }
1183}
1184
1185// ── Tests ───────────────────────────────────────────────────────
1186
1187#[cfg(test)]
1188mod tests {
1189    use super::*;
1190    use crate::providers::traits::ChatMessage;
1191
1192    /// RAII guard that sets/unsets an env var and restores the original on drop.
1193    struct EnvGuard {
1194        key: String,
1195        original: Option<String>,
1196    }
1197
1198    impl EnvGuard {
1199        fn set(key: &str, value: Option<&str>) -> Self {
1200            let original = std::env::var(key).ok();
1201            match value {
1202                // SAFETY: test-only, single-threaded test runner.
1203                Some(v) => unsafe { std::env::set_var(key, v) },
1204                // SAFETY: test-only, single-threaded test runner.
1205                None => unsafe { std::env::remove_var(key) },
1206            }
1207            Self {
1208                key: key.to_string(),
1209                original,
1210            }
1211        }
1212    }
1213
1214    impl Drop for EnvGuard {
1215        fn drop(&mut self) {
1216            match &self.original {
1217                // SAFETY: test-only, single-threaded test runner.
1218                Some(v) => unsafe { std::env::set_var(&self.key, v) },
1219                // SAFETY: test-only, single-threaded test runner.
1220                None => unsafe { std::env::remove_var(&self.key) },
1221            }
1222        }
1223    }
1224
1225    // ── SigV4 signing tests ─────────────────────────────────────
1226
1227    #[test]
1228    fn sha256_hex_empty_string() {
1229        // Known SHA-256 of empty input
1230        assert_eq!(
1231            sha256_hex(b""),
1232            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
1233        );
1234    }
1235
1236    #[test]
1237    fn sha256_hex_known_input() {
1238        // SHA-256 of "hello"
1239        assert_eq!(
1240            sha256_hex(b"hello"),
1241            "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
1242        );
1243    }
1244
1245    /// AWS documentation example key for SigV4 test vectors (not a real credential).
1246    const TEST_VECTOR_SECRET: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
1247
1248    #[test]
1249    fn hmac_sha256_known_input() {
1250        let test_key: &[u8] = b"key";
1251        let result = hmac_sha256(test_key, b"message");
1252        assert_eq!(
1253            hex::encode(&result),
1254            "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a"
1255        );
1256    }
1257
1258    #[test]
1259    fn derive_signing_key_structure() {
1260        // Verify the key derivation produces a 32-byte key (SHA-256 output).
1261        let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1262        assert_eq!(key.len(), 32);
1263    }
1264
1265    #[test]
1266    fn derive_signing_key_known_test_vector() {
1267        // AWS SigV4 test vector from documentation.
1268        let key = derive_signing_key(TEST_VECTOR_SECRET, "20150830", "us-east-1", "iam");
1269        assert_eq!(
1270            hex::encode(&key),
1271            "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9"
1272        );
1273    }
1274
1275    #[test]
1276    fn build_authorization_header_format() {
1277        let credentials = AwsCredentials {
1278            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1279            secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1280            session_token: None,
1281            region: "us-east-1".to_string(),
1282        };
1283
1284        let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1285            .unwrap()
1286            .with_timezone(&chrono::Utc);
1287
1288        let headers = vec![
1289            ("content-type".to_string(), "application/json".to_string()),
1290            (
1291                "host".to_string(),
1292                "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1293            ),
1294            ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1295        ];
1296
1297        let auth = build_authorization_header(
1298            &credentials,
1299            "POST",
1300            "/model/anthropic.claude-3-sonnet/converse",
1301            "",
1302            &headers,
1303            b"{}",
1304            &timestamp,
1305        );
1306
1307        // Verify structure
1308        assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/"));
1309        assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date"));
1310        assert!(auth.contains("Signature="));
1311        assert!(auth.contains("/us-east-1/bedrock/aws4_request"));
1312    }
1313
1314    #[test]
1315    fn build_authorization_header_includes_security_token_in_signed_headers() {
1316        let credentials = AwsCredentials {
1317            access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
1318            secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(),
1319            session_token: Some("session-token-value".to_string()),
1320            region: "us-east-1".to_string(),
1321        };
1322
1323        let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z")
1324            .unwrap()
1325            .with_timezone(&chrono::Utc);
1326
1327        let headers = vec![
1328            ("content-type".to_string(), "application/json".to_string()),
1329            (
1330                "host".to_string(),
1331                "bedrock-runtime.us-east-1.amazonaws.com".to_string(),
1332            ),
1333            ("x-amz-date".to_string(), "20240115T120000Z".to_string()),
1334            (
1335                "x-amz-security-token".to_string(),
1336                "session-token-value".to_string(),
1337            ),
1338        ];
1339
1340        let auth = build_authorization_header(
1341            &credentials,
1342            "POST",
1343            "/model/test-model/converse",
1344            "",
1345            &headers,
1346            b"{}",
1347            &timestamp,
1348        );
1349
1350        assert!(auth.contains("x-amz-security-token"));
1351    }
1352
1353    // ── Credential tests ────────────────────────────────────────
1354
1355    #[test]
1356    fn credentials_host_formats_correctly() {
1357        let creds = AwsCredentials {
1358            access_key_id: "AKID".to_string(),
1359            secret_access_key: "secret".to_string(),
1360            session_token: None,
1361            region: "us-west-2".to_string(),
1362        };
1363        assert_eq!(creds.host(), "bedrock-runtime.us-west-2.amazonaws.com");
1364    }
1365
1366    // ── Provider construction tests ─────────────────────────────
1367
1368    #[test]
1369    fn creates_without_credentials() {
1370        // Provider should construct even without env vars.
1371        let _provider = BedrockProvider::new();
1372    }
1373
1374    #[tokio::test]
1375    async fn chat_fails_without_credentials() {
1376        let provider = BedrockProvider {
1377            auth: None,
1378            max_tokens: DEFAULT_MAX_TOKENS,
1379        };
1380        let result = provider
1381            .chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7)
1382            .await;
1383        assert!(result.is_err());
1384        let err = result.unwrap_err().to_string();
1385        assert!(
1386            err.contains("credentials not set")
1387                || err.contains("169.254.169.254")
1388                || err.to_lowercase().contains("credential")
1389                || err.to_lowercase().contains("builder error"),
1390            "Expected missing-credentials style error, got: {err}"
1391        );
1392    }
1393
1394    // ── Bearer token tests ──────────────────────────────────────
1395
1396    #[test]
1397    fn creates_with_bearer_token() {
1398        let provider = BedrockProvider::with_bearer_token("test-api-key");
1399        assert!(provider.auth.is_some());
1400        assert!(
1401            matches!(provider.auth, Some(BedrockAuth::BearerToken(ref t)) if t == "test-api-key")
1402        );
1403    }
1404
1405    #[test]
1406    fn bearer_token_from_env() {
1407        let _guard = EnvGuard::set("BEDROCK_API_KEY", Some("env-bearer-token"));
1408        // Clear SigV4 vars to ensure Bearer is chosen.
1409        let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", None);
1410        let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", None);
1411
1412        let provider = BedrockProvider::new();
1413        assert!(matches!(
1414            provider.auth,
1415            Some(BedrockAuth::BearerToken(ref t)) if t == "env-bearer-token"
1416        ));
1417    }
1418
1419    #[test]
1420    fn bearer_token_precedence() {
1421        let _bearer_guard = EnvGuard::set("BEDROCK_API_KEY", Some("bearer-key"));
1422        let _ak_guard = EnvGuard::set("AWS_ACCESS_KEY_ID", Some("AKIAEXAMPLE"));
1423        let _sk_guard = EnvGuard::set("AWS_SECRET_ACCESS_KEY", Some("secret"));
1424
1425        let provider = BedrockProvider::new();
1426        // Bearer token should take priority over SigV4 credentials.
1427        assert!(matches!(
1428            provider.auth,
1429            Some(BedrockAuth::BearerToken(ref t)) if t == "bearer-key"
1430        ));
1431    }
1432
1433    // ── Endpoint URL tests ──────────────────────────────────────
1434
1435    #[test]
1436    fn endpoint_url_formats_correctly() {
1437        let url = BedrockProvider::endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6");
1438        assert_eq!(
1439            url,
1440            "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse"
1441        );
1442    }
1443
1444    #[test]
1445    fn endpoint_url_keeps_raw_colon() {
1446        // Endpoint URL uses raw colon so reqwest sends `:` on the wire.
1447        let url =
1448            BedrockProvider::endpoint_url("us-west-2", "anthropic.claude-3-5-haiku-20241022-v1:0");
1449        assert!(url.contains("/model/anthropic.claude-3-5-haiku-20241022-v1:0/converse"));
1450    }
1451
1452    #[test]
1453    fn canonical_uri_encodes_colon() {
1454        // Canonical URI must encode `:` as `%3A` for SigV4 signing.
1455        let uri = BedrockProvider::canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0");
1456        assert_eq!(
1457            uri,
1458            "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse"
1459        );
1460    }
1461
1462    #[test]
1463    fn canonical_uri_no_colon_unchanged() {
1464        let uri = BedrockProvider::canonical_uri("anthropic.claude-sonnet-4-6");
1465        assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse");
1466    }
1467
1468    // ── Message conversion tests ────────────────────────────────
1469
1470    #[test]
1471    fn convert_messages_system_extracted() {
1472        let messages = vec![
1473            ChatMessage::system("You are helpful"),
1474            ChatMessage::user("Hello"),
1475        ];
1476        let (system, msgs) = BedrockProvider::convert_messages(&messages);
1477        assert!(system.is_some());
1478        let system_blocks = system.unwrap();
1479        assert_eq!(system_blocks.len(), 1);
1480        assert_eq!(msgs.len(), 1);
1481        assert_eq!(msgs[0].role, "user");
1482    }
1483
1484    #[test]
1485    fn convert_messages_user_and_assistant() {
1486        let messages = vec![
1487            ChatMessage::user("Hello"),
1488            ChatMessage::assistant("Hi there"),
1489        ];
1490        let (system, msgs) = BedrockProvider::convert_messages(&messages);
1491        assert!(system.is_none());
1492        assert_eq!(msgs.len(), 2);
1493        assert_eq!(msgs[0].role, "user");
1494        assert_eq!(msgs[1].role, "assistant");
1495    }
1496
1497    #[test]
1498    fn convert_messages_tool_role_to_tool_result() {
1499        let tool_json = r#"{"tool_call_id": "call_123", "content": "Result data"}"#;
1500        let messages = vec![ChatMessage::tool(tool_json)];
1501        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1502        assert_eq!(msgs.len(), 1);
1503        assert_eq!(msgs[0].role, "user");
1504        assert!(matches!(msgs[0].content[0], ContentBlock::ToolResult(_)));
1505    }
1506
1507    #[test]
1508    fn convert_messages_assistant_tool_calls_parsed() {
1509        let tool_call_json = r#"{"content": "Let me check", "tool_calls": [{"id": "call_1", "name": "shell", "arguments": "{\"command\":\"ls\"}"}]}"#;
1510        let messages = vec![ChatMessage::assistant(tool_call_json)];
1511        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1512        assert_eq!(msgs.len(), 1);
1513        assert_eq!(msgs[0].role, "assistant");
1514        assert_eq!(msgs[0].content.len(), 2);
1515        assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1516        assert!(matches!(msgs[0].content[1], ContentBlock::ToolUse(_)));
1517    }
1518
1519    #[test]
1520    fn convert_messages_plain_assistant_text() {
1521        let messages = vec![ChatMessage::assistant("Just text")];
1522        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1523        assert_eq!(msgs.len(), 1);
1524        assert!(matches!(msgs[0].content[0], ContentBlock::Text(_)));
1525    }
1526
1527    // ── Cache tests ─────────────────────────────────────────────
1528
1529    #[test]
1530    fn should_cache_system_small_prompt() {
1531        assert!(!BedrockProvider::should_cache_system("Short prompt"));
1532    }
1533
1534    #[test]
1535    fn should_cache_system_large_prompt() {
1536        let large = "a".repeat(3073);
1537        assert!(BedrockProvider::should_cache_system(&large));
1538    }
1539
1540    #[test]
1541    fn should_cache_system_boundary() {
1542        assert!(!BedrockProvider::should_cache_system(&"a".repeat(3072)));
1543        assert!(BedrockProvider::should_cache_system(&"a".repeat(3073)));
1544    }
1545
1546    #[test]
1547    fn should_cache_conversation_short() {
1548        let messages = vec![
1549            ChatMessage::system("System"),
1550            ChatMessage::user("Hello"),
1551            ChatMessage::assistant("Hi"),
1552        ];
1553        assert!(!BedrockProvider::should_cache_conversation(&messages));
1554    }
1555
1556    #[test]
1557    fn should_cache_conversation_long() {
1558        let mut messages = vec![ChatMessage::system("System")];
1559        for i in 0..5 {
1560            messages.push(ChatMessage {
1561                role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(),
1562                content: format!("Message {i}"),
1563            });
1564        }
1565        assert!(BedrockProvider::should_cache_conversation(&messages));
1566    }
1567
1568    // ── Tool conversion tests ───────────────────────────────────
1569
1570    #[test]
1571    fn convert_tools_to_converse_formats_correctly() {
1572        let tools = vec![ToolSpec {
1573            name: "shell".to_string(),
1574            description: "Run commands".to_string(),
1575            parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}),
1576        }];
1577        let config = BedrockProvider::convert_tools_to_converse(Some(&tools));
1578        assert!(config.is_some());
1579        let config = config.unwrap();
1580        assert_eq!(config.tools.len(), 1);
1581        assert_eq!(config.tools[0].tool_spec.name, "shell");
1582    }
1583
1584    #[test]
1585    fn convert_tools_to_converse_empty_returns_none() {
1586        assert!(BedrockProvider::convert_tools_to_converse(Some(&[])).is_none());
1587        assert!(BedrockProvider::convert_tools_to_converse(None).is_none());
1588    }
1589
1590    // ── Serde tests ─────────────────────────────────────────────
1591
1592    #[test]
1593    fn converse_request_serializes_without_system() {
1594        let req = ConverseRequest {
1595            system: None,
1596            messages: vec![ConverseMessage {
1597                role: "user".to_string(),
1598                content: vec![ContentBlock::Text(TextBlock {
1599                    text: "Hello".to_string(),
1600                })],
1601            }],
1602            inference_config: Some(InferenceConfig {
1603                max_tokens: 4096,
1604                temperature: 0.7,
1605            }),
1606            tool_config: None,
1607        };
1608        let json = serde_json::to_string(&req).unwrap();
1609        assert!(!json.contains("system"));
1610        assert!(json.contains("Hello"));
1611        assert!(json.contains("maxTokens"));
1612    }
1613
1614    #[test]
1615    fn converse_response_deserializes_text() {
1616        let json = r#"{
1617            "output": {
1618                "message": {
1619                    "role": "assistant",
1620                    "content": [{"text": "Hello from Bedrock"}]
1621                }
1622            },
1623            "stopReason": "end_turn"
1624        }"#;
1625        let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1626        let parsed = BedrockProvider::parse_converse_response(resp);
1627        assert_eq!(parsed.text.as_deref(), Some("Hello from Bedrock"));
1628        assert!(parsed.tool_calls.is_empty());
1629    }
1630
1631    #[test]
1632    fn converse_response_deserializes_tool_use() {
1633        let json = r#"{
1634            "output": {
1635                "message": {
1636                    "role": "assistant",
1637                    "content": [
1638                        {"toolUse": {"toolUseId": "call_1", "name": "shell", "input": {"command": "ls"}}}
1639                    ]
1640                }
1641            },
1642            "stopReason": "tool_use"
1643        }"#;
1644        let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1645        let parsed = BedrockProvider::parse_converse_response(resp);
1646        assert!(parsed.text.is_none());
1647        assert_eq!(parsed.tool_calls.len(), 1);
1648        assert_eq!(parsed.tool_calls[0].name, "shell");
1649        assert_eq!(parsed.tool_calls[0].id, "call_1");
1650    }
1651
1652    #[test]
1653    fn converse_response_empty_output() {
1654        let json = r#"{"output": null, "stopReason": null}"#;
1655        let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1656        let parsed = BedrockProvider::parse_converse_response(resp);
1657        assert!(parsed.text.is_none());
1658        assert!(parsed.tool_calls.is_empty());
1659    }
1660
1661    #[test]
1662    fn content_block_text_serializes_as_flat_string() {
1663        let block = ContentBlock::Text(TextBlock {
1664            text: "Hello".to_string(),
1665        });
1666        let json = serde_json::to_string(&block).unwrap();
1667        // Must be {"text":"Hello"}, NOT {"text":{"text":"Hello"}}
1668        assert_eq!(json, r#"{"text":"Hello"}"#);
1669    }
1670
1671    #[test]
1672    fn content_block_tool_use_serializes_with_nested_object() {
1673        let block = ContentBlock::ToolUse(ToolUseWrapper {
1674            tool_use: ToolUseBlock {
1675                tool_use_id: "call_1".to_string(),
1676                name: "shell".to_string(),
1677                input: serde_json::json!({"command": "ls"}),
1678            },
1679        });
1680        let json = serde_json::to_string(&block).unwrap();
1681        assert!(json.contains(r#""toolUse""#));
1682        assert!(json.contains(r#""toolUseId":"call_1""#));
1683    }
1684
1685    #[test]
1686    fn content_block_cache_point_serializes() {
1687        let block = ContentBlock::CachePointBlock(CachePointWrapper {
1688            cache_point: CachePoint::default_cache(),
1689        });
1690        let json = serde_json::to_string(&block).unwrap();
1691        assert_eq!(json, r#"{"cachePoint":{"type":"default"}}"#);
1692    }
1693
1694    #[test]
1695    fn content_block_text_round_trips() {
1696        let original = ContentBlock::Text(TextBlock {
1697            text: "Hello".to_string(),
1698        });
1699        let json = serde_json::to_string(&original).unwrap();
1700        let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
1701        assert!(matches!(deserialized, ContentBlock::Text(tb) if tb.text == "Hello"));
1702    }
1703
1704    #[test]
1705    fn cache_point_serializes() {
1706        let cp = CachePoint::default_cache();
1707        let json = serde_json::to_string(&cp).unwrap();
1708        assert_eq!(json, r#"{"type":"default"}"#);
1709    }
1710
1711    #[tokio::test]
1712    async fn warmup_without_credentials_is_noop() {
1713        let provider = BedrockProvider {
1714            auth: None,
1715            max_tokens: DEFAULT_MAX_TOKENS,
1716        };
1717        let result = provider.warmup().await;
1718        assert!(result.is_ok());
1719    }
1720
1721    #[test]
1722    fn capabilities_reports_native_tool_calling() {
1723        let provider = BedrockProvider {
1724            auth: None,
1725            max_tokens: DEFAULT_MAX_TOKENS,
1726        };
1727        let caps = provider.capabilities();
1728        assert!(caps.native_tool_calling);
1729    }
1730
1731    #[test]
1732    fn converse_response_parses_usage() {
1733        let json = r#"{
1734            "output": {"message": {"role": "assistant", "content": [{"text": {"text": "Hello"}}]}},
1735            "usage": {"inputTokens": 500, "outputTokens": 100}
1736        }"#;
1737        let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1738        let usage = resp.usage.unwrap();
1739        assert_eq!(usage.input_tokens, Some(500));
1740        assert_eq!(usage.output_tokens, Some(100));
1741    }
1742
1743    #[test]
1744    fn converse_response_parses_without_usage() {
1745        let json = r#"{"output": {"message": {"role": "assistant", "content": []}}}"#;
1746        let resp: ConverseResponse = serde_json::from_str(json).unwrap();
1747        assert!(resp.usage.is_none());
1748    }
1749
1750    // ── Tool result fallback & merge tests ───────────────────────
1751
1752    #[test]
1753    fn fallback_tool_result_emits_tool_result_block_not_text() {
1754        // When tool message content is not valid JSON, we should still get
1755        // a toolResult block (not a plain text user message).
1756        let messages = vec![
1757            ChatMessage::user("do something"),
1758            ChatMessage::assistant(
1759                r#"{"content":"","tool_calls":[{"id":"tool_1","name":"shell","arguments":"{}"}]}"#,
1760            ),
1761            ChatMessage {
1762                role: "tool".to_string(),
1763                content: "not valid json".to_string(),
1764            },
1765        ];
1766        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1767        let tool_msg = &msgs[2];
1768        assert_eq!(tool_msg.role, "user");
1769        assert!(
1770            matches!(&tool_msg.content[0], ContentBlock::ToolResult(_)),
1771            "Expected ToolResult block, got {:?}",
1772            tool_msg.content[0]
1773        );
1774    }
1775
1776    #[test]
1777    fn fallback_recovers_tool_use_id_from_assistant() {
1778        let messages = vec![
1779            ChatMessage::user("run it"),
1780            ChatMessage::assistant(
1781                r#"{"content":"","tool_calls":[{"id":"tool_abc","name":"shell","arguments":"{}"}]}"#,
1782            ),
1783            ChatMessage {
1784                role: "tool".to_string(),
1785                content: "raw output with no json".to_string(),
1786            },
1787        ];
1788        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1789        if let ContentBlock::ToolResult(ref wrapper) = msgs[2].content[0] {
1790            assert_eq!(wrapper.tool_result.tool_use_id, "tool_abc");
1791            assert_eq!(wrapper.tool_result.status, "error");
1792        } else {
1793            panic!("Expected ToolResult block");
1794        }
1795    }
1796
1797    #[test]
1798    fn consecutive_tool_results_merged_into_single_message() {
1799        let messages = vec![
1800            ChatMessage::user("do two things"),
1801            ChatMessage::assistant(
1802                r#"{"content":"","tool_calls":[{"id":"t1","name":"a","arguments":"{}"},{"id":"t2","name":"b","arguments":"{}"}]}"#,
1803            ),
1804            ChatMessage::tool(r#"{"tool_call_id":"t1","content":"result 1"}"#),
1805            ChatMessage::tool(r#"{"tool_call_id":"t2","content":"result 2"}"#),
1806        ];
1807        let (_, msgs) = BedrockProvider::convert_messages(&messages);
1808        // Should be: user, assistant, user (merged tool results)
1809        assert_eq!(msgs.len(), 3, "Expected 3 messages, got {}", msgs.len());
1810        assert_eq!(msgs[2].role, "user");
1811        assert_eq!(
1812            msgs[2].content.len(),
1813            2,
1814            "Expected 2 tool results in one message"
1815        );
1816        assert!(matches!(&msgs[2].content[0], ContentBlock::ToolResult(_)));
1817        assert!(matches!(&msgs[2].content[1], ContentBlock::ToolResult(_)));
1818    }
1819
1820    #[test]
1821    fn extract_tool_call_id_tries_multiple_field_names() {
1822        assert_eq!(
1823            BedrockProvider::extract_tool_call_id(r#"{"tool_call_id":"a"}"#),
1824            Some("a".to_string())
1825        );
1826        assert_eq!(
1827            BedrockProvider::extract_tool_call_id(r#"{"tool_use_id":"b"}"#),
1828            Some("b".to_string())
1829        );
1830        assert_eq!(
1831            BedrockProvider::extract_tool_call_id(r#"{"toolUseId":"c"}"#),
1832            Some("c".to_string())
1833        );
1834        assert_eq!(
1835            BedrockProvider::extract_tool_call_id("not json at all"),
1836            None
1837        );
1838    }
1839
1840    #[test]
1841    fn parse_tool_result_accepts_alternate_id_fields() {
1842        let msg =
1843            BedrockProvider::parse_tool_result_message(r#"{"tool_use_id":"x","content":"ok"}"#);
1844        assert!(msg.is_some());
1845        if let ContentBlock::ToolResult(ref wrapper) = msg.unwrap().content[0] {
1846            assert_eq!(wrapper.tool_result.tool_use_id, "x");
1847        } else {
1848            panic!("Expected ToolResult");
1849        }
1850    }
1851}