bitwarden-crypto 3.0.0

Internal crate for the bitwarden crate. Do not use.
Documentation
use std::fmt::DebugStruct;

use ciborium::Value;

use crate::{
    KEY_ID_SIZE,
    cose::{
        CONTAINED_KEY_ID, ContentNamespace, SAFE_CONTENT_NAMESPACE, SAFE_OBJECT_NAMESPACE,
        SafeObjectNamespace, extract_bytes, extract_integer,
    },
    keys::KeyId,
};

#[derive(Debug)]
pub(super) enum ExtractionError {
    MissingNamespace,
    InvalidNamespace,
}

pub(super) fn extract_safe_object_namespace(
    header: &coset::Header,
) -> Result<SafeObjectNamespace, ExtractionError> {
    match extract_integer(header, SAFE_OBJECT_NAMESPACE, "safe object namespace") {
        Ok(value) => value
            .try_into()
            .map_err(|_| ExtractionError::InvalidNamespace),
        Err(_) => Err(ExtractionError::MissingNamespace),
    }
}

pub(super) fn extract_safe_content_namespace<T: ContentNamespace>(
    header: &coset::Header,
) -> Result<T, ExtractionError> {
    match extract_integer(header, SAFE_CONTENT_NAMESPACE, "safe content namespace") {
        Ok(value) => value
            .try_into()
            .map_err(|_| ExtractionError::InvalidNamespace),
        Err(_) => Err(ExtractionError::MissingNamespace),
    }
}

pub(super) fn debug_fmt<C: ContentNamespace>(
    debug_struct: &mut DebugStruct,
    header: &coset::Header,
) {
    if let Ok(object_namespace) = extract_safe_object_namespace(header) {
        debug_struct.field("object_namespace", &object_namespace);
    }
    if let Ok(content_namespace) = extract_safe_content_namespace::<C>(header) {
        debug_struct.field("content_namespace", &content_namespace);
    }
}

fn set_header_value(header: &mut coset::Header, label: i64, value: Value) {
    if let Some((_, existing_value)) =
        header
            .rest
            .iter_mut()
            .find(|(existing_label, _)| matches!(existing_label, coset::Label::Int(existing) if *existing == label))
    {
        *existing_value = value;
    } else {
        header.rest.push((coset::Label::Int(label), value));
    }
}

pub(super) fn set_safe_namespaces<T: ContentNamespace>(
    header: &mut coset::Header,
    object_namespace: SafeObjectNamespace,
    content_namespace: T,
) {
    set_header_value(
        header,
        SAFE_OBJECT_NAMESPACE,
        Value::from(i128::from(object_namespace)),
    );
    set_header_value(
        header,
        SAFE_CONTENT_NAMESPACE,
        Value::from(content_namespace.into()),
    );
}

/// Validates the provided header contains the expected object and content namespace.
/// For backward compatibility, missing values are OK, but incorrect values are not.
/// The validation happens individually for both namespace layers, and either one
/// missing with the other being present is OK.
pub(super) fn validate_safe_namespaces<T: ContentNamespace>(
    header: &coset::Header,
    expected_object_namespace: SafeObjectNamespace,
    expected_content_namespace: T,
) -> Result<(), ExtractionError> {
    match extract_safe_object_namespace(header) {
        Ok(ns) if ns == expected_object_namespace => (),
        // If the namespace is present but doesn't match, return an error immediately.
        Ok(_) => return Err(ExtractionError::InvalidNamespace),
        // If the namespace is missing, do not validate for backward compatibility
        Err(ExtractionError::MissingNamespace) => (),
        // If the namespace is present but invalid (e.g., not an integer or out of range), return an
        // error.
        Err(ExtractionError::InvalidNamespace) => return Err(ExtractionError::InvalidNamespace),
    }

    match extract_safe_content_namespace::<T>(header) {
        Ok(ns) if ns == expected_content_namespace => Ok(()),
        // If the namespace is present but doesn't match, return an error immediately.
        Ok(_) => Err(ExtractionError::InvalidNamespace),
        // If the namespace is missing, do not validate for backward compatibility
        Err(ExtractionError::MissingNamespace) => Ok(()),
        // If the namespace is present but invalid (e.g., not an integer or out of range), return an
        // error.
        Err(ExtractionError::InvalidNamespace) => Err(ExtractionError::InvalidNamespace),
    }
}

/// Extract the contained key ID from a COSE header, if present.
pub(super) fn extract_contained_key_id(header: &coset::Header) -> Result<Option<KeyId>, ()> {
    let key_id_bytes = extract_bytes(header, CONTAINED_KEY_ID, "key id");

    if let Ok(bytes) = key_id_bytes {
        let key_id_array: [u8; KEY_ID_SIZE] = bytes.as_slice().try_into().map_err(|_| ())?;
        Ok(Some(KeyId::from(key_id_array)))
    } else {
        Ok(None)
    }
}

#[cfg(test)]
mod tests {
    use ciborium::Value;

    use super::*;
    use crate::{cose::SAFE_OBJECT_NAMESPACE, safe::DataEnvelopeNamespace};

    fn count_label(header: &coset::Header, label: i64) -> usize {
        header
            .rest
            .iter()
            .filter(
                |(existing_label, _)| {
                    matches!(existing_label, coset::Label::Int(existing) if *existing == label)
                },
            )
            .count()
    }

    fn extract_safe_namespaces<T: ContentNamespace>(
        header: &coset::Header,
    ) -> Result<(SafeObjectNamespace, T), ExtractionError> {
        let object_namespace = extract_safe_object_namespace(header)?;
        let content_namespace = extract_safe_content_namespace(header)?;

        Ok((object_namespace, content_namespace))
    }

    #[test]
    fn set_safe_namespaces_sets_both_namespace_labels() {
        let mut header = coset::HeaderBuilder::new().build();

        set_safe_namespaces(
            &mut header,
            SafeObjectNamespace::DataEnvelope,
            DataEnvelopeNamespace::ExampleNamespace,
        );

        let extracted = extract_safe_namespaces::<DataEnvelopeNamespace>(&header);
        assert!(matches!(
            extracted,
            Ok((
                SafeObjectNamespace::DataEnvelope,
                DataEnvelopeNamespace::ExampleNamespace
            ))
        ));
    }

    #[test]
    fn set_safe_namespaces_overwrites_existing_namespace_values() {
        let mut header = coset::HeaderBuilder::new()
            .value(SAFE_OBJECT_NAMESPACE, Value::from(999_i64))
            .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
            .build();

        set_safe_namespaces(
            &mut header,
            SafeObjectNamespace::DataEnvelope,
            DataEnvelopeNamespace::ExampleNamespace,
        );

        assert_eq!(count_label(&header, SAFE_OBJECT_NAMESPACE), 1);
        assert_eq!(count_label(&header, SAFE_CONTENT_NAMESPACE), 1);
        assert!(matches!(
            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
            Ok((
                SafeObjectNamespace::DataEnvelope,
                DataEnvelopeNamespace::ExampleNamespace
            ))
        ));
    }

    #[test]
    fn extract_safe_namespaces_fails_when_namespace_missing() {
        let header = coset::HeaderBuilder::new().build();

        assert!(matches!(
            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
            Err(ExtractionError::MissingNamespace)
        ));
    }

    #[test]
    fn extract_safe_namespaces_fails_when_namespace_invalid() {
        let header = coset::HeaderBuilder::new()
            .value(
                SAFE_OBJECT_NAMESPACE,
                Value::from(SafeObjectNamespace::DataEnvelope as i64),
            )
            .value(SAFE_CONTENT_NAMESPACE, Value::from(999_i64))
            .build();

        assert!(matches!(
            extract_safe_namespaces::<DataEnvelopeNamespace>(&header),
            Err(ExtractionError::InvalidNamespace)
        ));
    }

    #[test]
    fn validate_safe_namespaces_allows_missing_labels_for_backwards_compat() {
        let header = coset::HeaderBuilder::new().build();

        let result = validate_safe_namespaces(
            &header,
            SafeObjectNamespace::DataEnvelope,
            DataEnvelopeNamespace::ExampleNamespace,
        );
        assert!(result.is_ok());
    }

    #[test]
    fn validate_safe_namespaces_rejects_namespace_mismatch() {
        let mut header = coset::HeaderBuilder::new().build();
        set_safe_namespaces(
            &mut header,
            SafeObjectNamespace::DataEnvelope,
            DataEnvelopeNamespace::ExampleNamespace,
        );

        let result = validate_safe_namespaces(
            &header,
            SafeObjectNamespace::DataEnvelope,
            DataEnvelopeNamespace::ExampleNamespace2,
        );
        assert!(matches!(result, Err(ExtractionError::InvalidNamespace)));
    }
}