car-inference 0.26.0

Local model inference for CAR — Candle backend with Qwen3 models
//! Minimal AWS Signature Version 4 signer for the Bedrock Converse protocol.
//!
//! Hand-rolled over `hmac`/`sha2` — deliberately NOT the ~200-crate AWS SDK.
//! Covers the header-based signing `bedrock-runtime` needs: a POST with a JSON
//! body, region/service scope, and an optional session token. Validated against
//! AWS's published SigV4 `get-vanilla` test vector (see tests).

use hmac::{Hmac, Mac};
use sha2::{Digest, Sha256};

type HmacSha256 = Hmac<Sha256>;

/// AWS credentials for signing. `session_token` is set for temporary
/// (STS / role) credentials and added as `x-amz-security-token`.
pub struct AwsCredentials {
    pub access_key_id: String,
    pub secret_access_key: String,
    pub session_token: Option<String>,
}

impl AwsCredentials {
    /// Resolve from the standard AWS environment variables. Returns `None` when
    /// the required key pair is absent, so callers can surface a clear
    /// "configure AWS credentials" error instead of signing with empties.
    pub fn from_env() -> Option<Self> {
        let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
            .ok()
            .filter(|s| !s.is_empty())?;
        let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
            .ok()
            .filter(|s| !s.is_empty())?;
        let session_token = std::env::var("AWS_SESSION_TOKEN")
            .ok()
            .filter(|s| !s.is_empty());
        Some(Self {
            access_key_id,
            secret_access_key,
            session_token,
        })
    }
}

fn hmac(key: &[u8], data: &[u8]) -> Vec<u8> {
    let mut mac = HmacSha256::new_from_slice(key).expect("HMAC accepts any key length");
    mac.update(data);
    mac.finalize().into_bytes().to_vec()
}

fn sha256_hex(data: &[u8]) -> String {
    let mut h = Sha256::new();
    h.update(data);
    hex::encode(h.finalize())
}

/// Percent-encode one path segment per RFC 3986 (SigV4 canonical-URI rules):
/// the unreserved set `A-Z a-z 0-9 - _ . ~` passes through; everything else
/// (notably `:`) becomes `%XX`. Bedrock model IDs contain `:` (e.g.
/// `anthropic.claude-...-v2:0`), so the path segment MUST be encoded
/// identically in the signed canonical request and the actual request URL —
/// otherwise the signature won't match.
pub fn uri_encode_segment(s: &str) -> String {
    let mut out = String::with_capacity(s.len());
    for b in s.bytes() {
        match b {
            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
                out.push(b as char)
            }
            _ => out.push_str(&format!("%{:02X}", b)),
        }
    }
    out
}

/// Compute the SigV4 header set for a request. `canonical_path` is the
/// already-percent-encoded path (e.g. `/model/<encoded id>/converse`).
/// `base_headers` must include `host`. Returns the full header set to SEND:
/// the base headers plus `x-amz-date`, `authorization`, and (when a session
/// token is present) `x-amz-security-token`. `amz_date` is `YYYYMMDDTHHMMSSZ`;
/// `date_stamp` is `YYYYMMDD` (the date portion of the same instant).
#[allow(clippy::too_many_arguments)]
pub fn signed_headers(
    creds: &AwsCredentials,
    region: &str,
    service: &str,
    method: &str,
    canonical_path: &str,
    canonical_query: &str,
    base_headers: &[(String, String)],
    body: &[u8],
    amz_date: &str,
    date_stamp: &str,
) -> Vec<(String, String)> {
    // Header set that is both signed and sent.
    let mut headers: Vec<(String, String)> = base_headers.to_vec();
    headers.push(("x-amz-date".into(), amz_date.to_string()));
    if let Some(tok) = &creds.session_token {
        headers.push(("x-amz-security-token".into(), tok.clone()));
    }

    // Canonical headers: lowercase name, trimmed value, sorted by name, each
    // rendered "name:value\n". signed_headers: the same names, ';'-joined.
    let mut canon: Vec<(String, String)> = headers
        .iter()
        .map(|(k, v)| (k.to_lowercase(), v.trim().to_string()))
        .collect();
    canon.sort_by(|a, b| a.0.cmp(&b.0));
    let canonical_headers: String = canon.iter().map(|(k, v)| format!("{k}:{v}\n")).collect();
    let signed_header_names = canon
        .iter()
        .map(|(k, _)| k.clone())
        .collect::<Vec<_>>()
        .join(";");

    let payload_hash = sha256_hex(body);

    // Note the blank line: `canonical_headers` already ends in '\n', then the
    // spec adds another before `signed_header_names`.
    let canonical_request = format!(
        "{method}\n{canonical_path}\n{canonical_query}\n{canonical_headers}\n{signed_header_names}\n{payload_hash}"
    );

    let scope = format!("{date_stamp}/{region}/{service}/aws4_request");
    let string_to_sign = format!(
        "AWS4-HMAC-SHA256\n{amz_date}\n{scope}\n{}",
        sha256_hex(canonical_request.as_bytes())
    );

    let k_date = hmac(
        format!("AWS4{}", creds.secret_access_key).as_bytes(),
        date_stamp.as_bytes(),
    );
    let k_region = hmac(&k_date, region.as_bytes());
    let k_service = hmac(&k_region, service.as_bytes());
    let k_signing = hmac(&k_service, b"aws4_request");
    let signature = hex::encode(hmac(&k_signing, string_to_sign.as_bytes()));

    let authorization = format!(
        "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
        creds.access_key_id, scope, signed_header_names, signature
    );
    headers.push(("authorization".into(), authorization));
    headers
}

#[cfg(test)]
mod tests {
    use super::*;

    /// AWS's published SigV4 `get-vanilla` test vector. If this matches, the
    /// canonical-request / string-to-sign / signing-key / signature chain is
    /// correct — independent of Bedrock.
    #[test]
    fn matches_aws_get_vanilla_vector() {
        let creds = AwsCredentials {
            access_key_id: "AKIDEXAMPLE".into(),
            secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".into(),
            session_token: None,
        };
        let base = vec![("host".to_string(), "example.amazonaws.com".to_string())];
        let headers = signed_headers(
            &creds,
            "us-east-1",
            "service",
            "GET",
            "/",
            "",
            &base,
            b"",
            "20150830T123600Z",
            "20150830",
        );
        let auth = headers
            .iter()
            .find(|(k, _)| k == "authorization")
            .map(|(_, v)| v.as_str())
            .unwrap();
        assert_eq!(
            auth,
            "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \
             SignedHeaders=host;x-amz-date, \
             Signature=5fa00fa31553b73ebf1942676e86291e8372ff2a2260956d9b8aae1d763fbf31"
        );
    }

    #[test]
    fn uri_encode_handles_bedrock_model_ids() {
        // The `:` in a Bedrock model id must encode to %3A.
        assert_eq!(
            uri_encode_segment("anthropic.claude-3-5-sonnet-20241022-v2:0"),
            "anthropic.claude-3-5-sonnet-20241022-v2%3A0"
        );
    }

    #[test]
    fn session_token_is_added_to_signed_headers() {
        let creds = AwsCredentials {
            access_key_id: "AKID".into(),
            secret_access_key: "secret".into(),
            session_token: Some("token123".into()),
        };
        let base = vec![("host".to_string(), "h".to_string())];
        let headers = signed_headers(
            &creds, "us-east-1", "bedrock", "POST", "/", "", &base, b"{}", "20150830T123600Z",
            "20150830",
        );
        assert!(headers
            .iter()
            .any(|(k, v)| k == "x-amz-security-token" && v == "token123"));
        // ...and it's covered by SignedHeaders.
        let auth = headers.iter().find(|(k, _)| k == "authorization").unwrap();
        assert!(auth.1.contains("x-amz-security-token"));
    }
}