use crate::device::types::DeviceBinding;
use crate::session::crypto::{CryptoError, SessionCrypto};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
#[derive(Clone)]
pub(crate) struct BindingsCodec {
crypto: Option<SessionCrypto>,
}
impl BindingsCodec {
pub fn encrypted(crypto: SessionCrypto) -> Self {
Self {
crypto: Some(crypto),
}
}
pub fn plaintext() -> Self {
Self { crypto: None }
}
pub fn encode(&self, bindings: &[DeviceBinding]) -> Result<String, SqlDeviceStoreError> {
let bytes = rmp_serde::to_vec_named(bindings)?;
let payload = match &self.crypto {
Some(crypto) => crypto.encrypt(&bytes)?,
None => bytes,
};
Ok(BASE64.encode(payload))
}
pub fn decode(&self, stored: &str) -> Result<Vec<DeviceBinding>, SqlDeviceStoreError> {
if stored.is_empty() {
return Ok(Vec::new());
}
let payload = BASE64
.decode(stored)
.map_err(|_| SqlDeviceStoreError::Crypto(CryptoError))?;
let plaintext = match &self.crypto {
Some(crypto) => crypto.decrypt(&payload)?,
None => payload,
};
Ok(rmp_serde::from_slice(&plaintext)?)
}
}
pub(crate) mod trust_level_codec {
use crate::device::types::DeviceTrustLevel;
pub fn to_str(level: DeviceTrustLevel) -> &'static str {
match level {
DeviceTrustLevel::Unknown => "Unknown",
DeviceTrustLevel::Seen => "Seen",
DeviceTrustLevel::Trusted => "Trusted",
DeviceTrustLevel::Revoked => "Revoked",
}
}
pub fn from_str(s: &str) -> Option<DeviceTrustLevel> {
match s {
"Unknown" => Some(DeviceTrustLevel::Unknown),
"Seen" => Some(DeviceTrustLevel::Seen),
"Trusted" => Some(DeviceTrustLevel::Trusted),
"Revoked" => Some(DeviceTrustLevel::Revoked),
_ => None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum SqlDeviceStoreError {
#[error("database error: {0}")]
Db(#[from] sqlx::Error),
#[error("device-bindings MessagePack encoding failed: {0}")]
Encode(#[from] rmp_serde::encode::Error),
#[error("device-bindings MessagePack decoding failed: {0}")]
Decode(#[from] rmp_serde::decode::Error),
#[error("encryption/decryption error: {0}")]
Crypto(#[from] CryptoError),
#[error("unrecognised trust_level value: {0}")]
UnknownTrustLevel(String),
#[error("malformed stored value: {0}")]
MalformedRow(String),
}
#[cfg(test)]
mod device_sql_common_tests {
use super::*;
use crate::device::types::{
AttestationClass, DeviceBinding, DeviceTrustLevel, FingerprintHash,
};
use chrono::Utc;
#[test]
fn bindings_codec_round_trips_non_empty_vector() {
let codec = BindingsCodec::plaintext();
let now = Utc::now();
let bindings = vec![
DeviceBinding::Cookie {
token_hash: FingerprintHash::from_bytes([0xAA; 32]),
issued_at: now,
last_used_at: now,
},
DeviceBinding::WebAuthn {
credential_id: "cred-AX028".to_string(),
attestation_class: AttestationClass::None,
bound_at: now,
last_used_at: now,
},
];
let encoded = codec.encode(&bindings).expect("encode");
let decoded = codec.decode(&encoded).expect("decode");
assert_eq!(
decoded.len(),
2,
"decode must round-trip 2 bindings, not return empty"
);
assert_eq!(decoded, bindings, "round-trip must preserve content");
}
#[test]
fn trust_level_to_str_pins_variant_strings() {
assert_eq!(
trust_level_codec::to_str(DeviceTrustLevel::Unknown),
"Unknown"
);
assert_eq!(trust_level_codec::to_str(DeviceTrustLevel::Seen), "Seen");
assert_eq!(
trust_level_codec::to_str(DeviceTrustLevel::Trusted),
"Trusted"
);
assert_eq!(
trust_level_codec::to_str(DeviceTrustLevel::Revoked),
"Revoked"
);
}
#[test]
fn trust_level_from_str_round_trips_each_variant() {
assert_eq!(
trust_level_codec::from_str("Unknown"),
Some(DeviceTrustLevel::Unknown)
);
assert_eq!(
trust_level_codec::from_str("Seen"),
Some(DeviceTrustLevel::Seen)
);
assert_eq!(
trust_level_codec::from_str("Trusted"),
Some(DeviceTrustLevel::Trusted)
);
assert_eq!(
trust_level_codec::from_str("Revoked"),
Some(DeviceTrustLevel::Revoked)
);
assert_eq!(
trust_level_codec::from_str("not-a-level"),
None,
"unknown input must yield None, not Some(Default)"
);
}
#[test]
fn trust_level_codec_is_a_bijection_on_known_variants() {
for level in [
DeviceTrustLevel::Unknown,
DeviceTrustLevel::Seen,
DeviceTrustLevel::Trusted,
DeviceTrustLevel::Revoked,
] {
let s = trust_level_codec::to_str(level);
assert_eq!(
trust_level_codec::from_str(s),
Some(level),
"{level:?} must round-trip via to_str/from_str"
);
}
}
}