anp 0.8.4

Rust SDK for Agent Network Protocol (ANP)
Documentation
use std::collections::BTreeMap;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};

use regex::Regex;
use serde_json::Value;

use crate::keys::PrivateKeyMaterial;

use super::did_wba::{
    generate_auth_header, generate_auth_header_with_overrides, AuthenticationError,
};
use super::http_signatures::{generate_http_signature_headers, HttpSignatureOptions};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthMode {
    HttpSignatures,
    LegacyDidWba,
    Auto,
}

impl AuthMode {
    pub fn from_str(value: &str) -> Self {
        match value.to_ascii_lowercase().as_str() {
            "legacy_didwba" => Self::LegacyDidWba,
            "auto" => Self::Auto,
            _ => Self::HttpSignatures,
        }
    }
}

#[derive(Debug)]
pub struct DIDWbaAuthHeader {
    did_document_path: PathBuf,
    private_key_path: PathBuf,
    auth_mode: AuthMode,
    did_document_cache: Option<Value>,
    tokens: HashMap<String, String>,
}

impl DIDWbaAuthHeader {
    pub fn new(
        did_document_path: impl AsRef<Path>,
        private_key_path: impl AsRef<Path>,
        auth_mode: AuthMode,
    ) -> Self {
        Self {
            did_document_path: did_document_path.as_ref().to_path_buf(),
            private_key_path: private_key_path.as_ref().to_path_buf(),
            auth_mode,
            did_document_cache: None,
            tokens: HashMap::new(),
        }
    }

    pub fn get_auth_header(
        &mut self,
        server_url: &str,
        force_new: bool,
        method: &str,
        headers: Option<&std::collections::BTreeMap<String, String>>,
        body: Option<&[u8]>,
    ) -> Result<std::collections::BTreeMap<String, String>, AuthenticationError> {
        let domain = extract_domain(server_url);
        if !force_new {
            if let Some(token) = self.tokens.get(&domain) {
                let mut result = std::collections::BTreeMap::new();
                result.insert("Authorization".to_string(), format!("Bearer {}", token));
                return Ok(result);
            }
        }

        let did_document = self.load_did_document()?.clone();
        let private_key = self.load_private_key()?;
        match self.auth_mode {
            AuthMode::HttpSignatures | AuthMode::Auto => generate_http_signature_headers(
                &did_document,
                server_url,
                method,
                &private_key,
                headers,
                body,
                HttpSignatureOptions::default(),
            )
            .map_err(|_| AuthenticationError::SignatureGenerationFailed),
            AuthMode::LegacyDidWba => {
                let value = generate_auth_header(&did_document, &domain, &private_key, "1.1")?;
                let mut result = std::collections::BTreeMap::new();
                result.insert("Authorization".to_string(), value);
                Ok(result)
            }
        }
    }

    pub fn update_token(
        &mut self,
        server_url: &str,
        headers: &std::collections::BTreeMap<String, String>,
    ) -> Option<String> {
        let domain = extract_domain(server_url);
        if let Some(value) = get_header_case_insensitive(headers, "Authentication-Info") {
            let parsed = parse_authentication_info(value);
            if let Some(token) = parsed.get("access_token") {
                self.tokens.insert(domain, token.clone());
                return Some(token.clone());
            }
        }
        if let Some(value) = get_header_case_insensitive(headers, "Authorization") {
            if let Some(token) = value.strip_prefix("Bearer ") {
                let token = token.to_string();
                self.tokens.insert(domain, token.clone());
                return Some(token);
            }
        }
        None
    }

    pub fn clear_token(&mut self, server_url: &str) {
        let domain = extract_domain(server_url);
        self.tokens.remove(&domain);
    }

    pub fn clear_all_tokens(&mut self) {
        self.tokens.clear();
    }

    pub fn should_retry_after_401(&self, response_headers: &BTreeMap<String, String>) -> bool {
        let Some(www_authenticate) =
            get_header_case_insensitive(response_headers, "WWW-Authenticate")
        else {
            return false;
        };
        let challenge = parse_www_authenticate(www_authenticate);
        if challenge.get("nonce").is_some() {
            return true;
        }
        !matches!(
            challenge.get("error").map(|value| value.as_str()),
            Some("invalid_did") | Some("invalid_verification_method") | Some("forbidden_did")
        )
    }

    pub fn get_challenge_auth_header(
        &mut self,
        server_url: &str,
        response_headers: &BTreeMap<String, String>,
        method: &str,
        headers: Option<&BTreeMap<String, String>>,
        body: Option<&[u8]>,
    ) -> Result<BTreeMap<String, String>, AuthenticationError> {
        let www_authenticate = get_header_case_insensitive(response_headers, "WWW-Authenticate");
        let accept_signature = get_header_case_insensitive(response_headers, "Accept-Signature");
        let challenge = www_authenticate
            .map(|value| parse_www_authenticate(value))
            .unwrap_or_default();
        let covered_components = normalize_covered_components(
            accept_signature
                .map(|value| parse_accept_signature(value))
                .as_ref(),
            headers,
            body,
        );
        let nonce = challenge.get("nonce").map(String::as_str);

        let did_document = self.load_did_document()?.clone();
        let private_key = self.load_private_key()?;
        match self.auth_mode {
            AuthMode::HttpSignatures | AuthMode::Auto => generate_http_signature_headers(
                &did_document,
                server_url,
                method,
                &private_key,
                headers,
                body,
                HttpSignatureOptions {
                    nonce: nonce.map(ToOwned::to_owned),
                    covered_components,
                    ..HttpSignatureOptions::default()
                },
            )
            .map_err(|_| AuthenticationError::SignatureGenerationFailed),
            AuthMode::LegacyDidWba => {
                let value = generate_auth_header_with_overrides(
                    &did_document,
                    &extract_domain(server_url),
                    &private_key,
                    "1.1",
                    nonce,
                    None,
                )?;
                let mut result = BTreeMap::new();
                result.insert("Authorization".to_string(), value);
                Ok(result)
            }
        }
    }

    fn load_did_document(&mut self) -> Result<&Value, AuthenticationError> {
        if self.did_document_cache.is_none() {
            let content = fs::read_to_string(&self.did_document_path)
                .map_err(|_| AuthenticationError::IoFailure)?;
            let value =
                serde_json::from_str(&content).map_err(|_| AuthenticationError::JsonFailure)?;
            self.did_document_cache = Some(value);
        }
        self.did_document_cache
            .as_ref()
            .ok_or(AuthenticationError::InvalidDidDocument)
    }

    fn load_private_key(&self) -> Result<PrivateKeyMaterial, AuthenticationError> {
        let content = fs::read_to_string(&self.private_key_path)
            .map_err(|_| AuthenticationError::IoFailure)?;
        PrivateKeyMaterial::from_pem(&content).map_err(|_| AuthenticationError::InvalidDidDocument)
    }
}

fn extract_domain(server_url: &str) -> String {
    url::Url::parse(server_url)
        .ok()
        .and_then(|value| value.host_str().map(|host| host.to_string()))
        .unwrap_or_else(|| server_url.to_string())
}

fn get_header_case_insensitive<'a>(
    headers: &'a std::collections::BTreeMap<String, String>,
    name: &str,
) -> Option<&'a String> {
    headers
        .iter()
        .find(|(key, _)| key.eq_ignore_ascii_case(name))
        .map(|(_, value)| value)
}

fn parse_authentication_info(value: &str) -> HashMap<String, String> {
    value
        .split(',')
        .filter_map(|item| item.trim().split_once('='))
        .map(|(key, raw)| {
            (
                key.trim().to_string(),
                raw.trim().trim_matches('"').to_string(),
            )
        })
        .collect()
}

fn parse_www_authenticate(value: &str) -> HashMap<String, String> {
    let normalized = value
        .trim()
        .strip_prefix("DIDWba ")
        .or_else(|| value.trim().strip_prefix("didwba "))
        .unwrap_or(value.trim());
    let regex = Regex::new(r#"([\w-]+)=("[^"]*"|[^,]+)"#).expect("regex should compile");
    regex
        .captures_iter(normalized)
        .filter_map(|captures| {
            let key = captures.get(1)?.as_str().trim().to_string();
            let value = captures
                .get(2)?
                .as_str()
                .trim()
                .trim_matches('"')
                .to_string();
            Some((key, value))
        })
        .collect()
}

fn parse_accept_signature(value: &str) -> Vec<String> {
    let regex = Regex::new(r#""([^"]+)""#).expect("regex should compile");
    regex
        .captures_iter(value)
        .filter_map(|captures| captures.get(1).map(|matched| matched.as_str().to_string()))
        .collect()
}

fn normalize_covered_components(
    covered_components: Option<&Vec<String>>,
    headers: Option<&BTreeMap<String, String>>,
    body: Option<&[u8]>,
) -> Option<Vec<String>> {
    let covered_components = covered_components?;
    let body_present = body.map(|bytes| !bytes.is_empty()).unwrap_or(false);
    let normalized_headers = headers
        .cloned()
        .unwrap_or_default()
        .into_iter()
        .map(|(key, value)| (key.to_ascii_lowercase(), value))
        .collect::<BTreeMap<_, _>>();

    let mut result = Vec::new();
    for component in covered_components {
        let lower = component.to_ascii_lowercase();
        if lower == "content-digest" && !body_present {
            continue;
        }
        if lower == "content-length"
            && !body_present
            && !normalized_headers.contains_key("content-length")
        {
            continue;
        }
        if lower == "content-type" && !normalized_headers.contains_key("content-type") {
            continue;
        }
        if !lower.starts_with('@')
            && lower != "content-length"
            && lower != "content-digest"
            && !normalized_headers.contains_key(&lower)
        {
            continue;
        }
        result.push(component.clone());
    }
    Some(result)
}