extern crate alloc;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use crate::error::{SecurityError, SecurityErrorKind, SecurityResult};
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct WireProperty {
pub name: String,
pub value: String,
}
impl WireProperty {
#[must_use]
pub fn new(name: impl Into<String>, value: impl Into<String>) -> Self {
Self {
name: name.into(),
value: value.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct BinaryProperty {
pub name: String,
pub value: Vec<u8>,
}
impl BinaryProperty {
#[must_use]
pub fn new(name: impl Into<String>, value: impl Into<Vec<u8>>) -> Self {
Self {
name: name.into(),
value: value.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct DataHolder {
pub class_id: String,
pub properties: Vec<WireProperty>,
pub binary_properties: Vec<BinaryProperty>,
}
impl DataHolder {
#[must_use]
pub fn new(class_id: impl Into<String>) -> Self {
Self {
class_id: class_id.into(),
properties: Vec::new(),
binary_properties: Vec::new(),
}
}
#[must_use]
pub fn with_property(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
self.properties.push(WireProperty::new(name, value));
self
}
#[must_use]
pub fn with_binary_property(
mut self,
name: impl Into<String>,
value: impl Into<Vec<u8>>,
) -> Self {
self.binary_properties
.push(BinaryProperty::new(name, value));
self
}
pub fn set_property(&mut self, name: impl Into<String>, value: impl Into<String>) {
let n = name.into();
if let Some(existing) = self.properties.iter_mut().find(|p| p.name == n) {
existing.value = value.into();
} else {
self.properties.push(WireProperty {
name: n,
value: value.into(),
});
}
}
pub fn set_binary_property(&mut self, name: impl Into<String>, value: impl Into<Vec<u8>>) {
let n = name.into();
if let Some(existing) = self.binary_properties.iter_mut().find(|p| p.name == n) {
existing.value = value.into();
} else {
self.binary_properties.push(BinaryProperty {
name: n,
value: value.into(),
});
}
}
#[must_use]
pub fn property(&self, name: &str) -> Option<&str> {
self.properties
.iter()
.find(|p| p.name == name)
.map(|p| p.value.as_str())
}
#[must_use]
pub fn binary_property(&self, name: &str) -> Option<&[u8]> {
self.binary_properties
.iter()
.find(|p| p.name == name)
.map(|p| p.value.as_slice())
}
#[must_use]
pub fn to_cdr_le(&self) -> Vec<u8> {
encode(self, true)
}
#[must_use]
pub fn to_cdr_be(&self) -> Vec<u8> {
encode(self, false)
}
pub fn from_cdr_le(bytes: &[u8]) -> SecurityResult<Self> {
decode(bytes, true)
}
pub fn from_cdr_be(bytes: &[u8]) -> SecurityResult<Self> {
decode(bytes, false)
}
}
pub type IdentityToken = DataHolder;
pub type PermissionsToken = DataHolder;
pub type IdentityStatusToken = DataHolder;
pub type PermissionsCredentialToken = DataHolder;
pub type AuthRequestMessageToken = DataHolder;
pub type HandshakeMessageToken = DataHolder;
pub type CryptoToken = DataHolder;
pub mod class_id {
pub const AUTH_PKI_DH_V12: &str = "DDS:Auth:PKI-DH:1.2";
pub const ACCESS_PERMISSIONS_V12: &str = "DDS:Access:Permissions:1.2";
pub const CRYPTO_AES_GCM_GMAC_V12: &str = "DDS:Crypto:AES-GCM-GMAC:1.2";
pub const ACCESS_PERMISSIONS_CREDENTIAL: &str = "DDS:Access:PermissionsCredential";
}
pub mod prop {
pub const CERT_SN: &str = "dds.cert.sn";
pub const CERT_ALGO: &str = "dds.cert.algo";
pub const CA_SN: &str = "dds.ca.sn";
pub const CA_ALGO: &str = "dds.ca.algo";
pub const PERM_CA_SN: &str = "dds.perm_ca.sn";
pub const PERM_CA_ALGO: &str = "dds.perm_ca.algo";
}
impl IdentityToken {
#[must_use]
pub fn pki_dh_v12(
cert_sn: impl Into<String>,
cert_algo: impl Into<String>,
ca_sn: impl Into<String>,
ca_algo: impl Into<String>,
) -> Self {
DataHolder::new(class_id::AUTH_PKI_DH_V12)
.with_property(prop::CERT_SN, cert_sn)
.with_property(prop::CERT_ALGO, cert_algo)
.with_property(prop::CA_SN, ca_sn)
.with_property(prop::CA_ALGO, ca_algo)
}
#[must_use]
pub fn permissions_v12(perm_ca_sn: impl Into<String>, perm_ca_algo: impl Into<String>) -> Self {
DataHolder::new(class_id::ACCESS_PERMISSIONS_V12)
.with_property(prop::PERM_CA_SN, perm_ca_sn)
.with_property(prop::PERM_CA_ALGO, perm_ca_algo)
}
}
const MAX_TOKEN_BYTES: usize = 64 * 1024;
const MAX_PROPS: u32 = 256;
const MAX_BIN_PROPS: u32 = 256;
const MAX_STRING_LEN: u32 = 8 * 1024;
const MAX_BINARY_LEN: u32 = 8 * 1024;
fn encode(d: &DataHolder, le: bool) -> Vec<u8> {
let mut out = Vec::with_capacity(64);
encode_string(&mut out, &d.class_id, le);
encode_u32(&mut out, d.properties.len() as u32, le);
for p in &d.properties {
encode_string(&mut out, &p.name, le);
encode_string(&mut out, &p.value, le);
}
encode_u32(&mut out, d.binary_properties.len() as u32, le);
for p in &d.binary_properties {
encode_string(&mut out, &p.name, le);
encode_octet_seq(&mut out, &p.value, le);
}
out
}
fn decode(bytes: &[u8], le: bool) -> SecurityResult<DataHolder> {
if bytes.len() > MAX_TOKEN_BYTES {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: payload exceeds DoS cap",
));
}
let mut cur = Cursor::new(bytes);
let class_id = cur.read_string(le)?;
let prop_count = cur.read_u32(le)?;
if prop_count > MAX_PROPS {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: property count exceeds cap",
));
}
let mut properties = Vec::with_capacity(prop_count as usize);
for _ in 0..prop_count {
let name = cur.read_string(le)?;
let value = cur.read_string(le)?;
properties.push(WireProperty { name, value });
}
let bin_count = cur.read_u32(le)?;
if bin_count > MAX_BIN_PROPS {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: binary_property count exceeds cap",
));
}
let mut binary_properties = Vec::with_capacity(bin_count as usize);
for _ in 0..bin_count {
let name = cur.read_string(le)?;
let value = cur.read_octet_seq(le)?;
binary_properties.push(BinaryProperty { name, value });
}
Ok(DataHolder {
class_id,
properties,
binary_properties,
})
}
fn align_to(out: &mut Vec<u8>, n: usize) {
let pad = (n - out.len() % n) % n;
for _ in 0..pad {
out.push(0);
}
}
fn encode_u32(out: &mut Vec<u8>, v: u32, le: bool) {
align_to(out, 4);
if le {
out.extend_from_slice(&v.to_le_bytes());
} else {
out.extend_from_slice(&v.to_be_bytes());
}
}
fn encode_string(out: &mut Vec<u8>, s: &str, le: bool) {
let bytes = s.as_bytes();
let len = (bytes.len() + 1) as u32;
encode_u32(out, len, le);
out.extend_from_slice(bytes);
out.push(0);
}
fn encode_octet_seq(out: &mut Vec<u8>, v: &[u8], le: bool) {
encode_u32(out, v.len() as u32, le);
out.extend_from_slice(v);
}
struct Cursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Cursor<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn align_to(&mut self, n: usize) {
let pad = (n - self.pos % n) % n;
self.pos = self.pos.saturating_add(pad);
}
fn read_u32(&mut self, le: bool) -> SecurityResult<u32> {
self.align_to(4);
if self.pos + 4 > self.buf.len() {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: truncated u32",
));
}
let raw = [
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
];
self.pos += 4;
Ok(if le {
u32::from_le_bytes(raw)
} else {
u32::from_be_bytes(raw)
})
}
fn read_string(&mut self, le: bool) -> SecurityResult<String> {
let len = self.read_u32(le)?;
if len > MAX_STRING_LEN {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: string exceeds cap",
));
}
if len == 0 {
return Ok(String::new());
}
if self.pos + len as usize > self.buf.len() {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: truncated string",
));
}
let body = &self.buf[self.pos..self.pos + len as usize];
self.pos += len as usize;
if let Some((_, rest)) = body.split_last() {
if body.last() != Some(&0) {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: string missing trailing NUL",
));
}
let s = core::str::from_utf8(rest).map_err(|_| {
SecurityError::new(SecurityErrorKind::BadArgument, "token: string not UTF-8")
})?;
Ok(s.to_string())
} else {
Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: zero-length string body",
))
}
}
fn read_octet_seq(&mut self, le: bool) -> SecurityResult<Vec<u8>> {
let len = self.read_u32(le)?;
if len > MAX_BINARY_LEN {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: binary value exceeds cap",
));
}
if self.pos + len as usize > self.buf.len() {
return Err(SecurityError::new(
SecurityErrorKind::BadArgument,
"token: truncated binary",
));
}
let v = self.buf[self.pos..self.pos + len as usize].to_vec();
self.pos += len as usize;
Ok(v)
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn empty_data_holder_roundtrip_le() {
let dh = DataHolder::new("DDS:Auth:PKI-DH:1.2");
let bytes = dh.to_cdr_le();
let back = DataHolder::from_cdr_le(&bytes).unwrap();
assert_eq!(dh, back);
}
#[test]
fn empty_data_holder_roundtrip_be() {
let dh = DataHolder::new("DDS:Auth:PKI-DH:1.2");
let bytes = dh.to_cdr_be();
let back = DataHolder::from_cdr_be(&bytes).unwrap();
assert_eq!(dh, back);
}
#[test]
fn pki_dh_identity_token_roundtrip() {
let tok =
IdentityToken::pki_dh_v12("01:23:45:67", "ECDSA-SHA256", "FA:CE:0B:01", "RSA-SHA256");
let bytes = tok.to_cdr_le();
let back = DataHolder::from_cdr_le(&bytes).unwrap();
assert_eq!(tok, back);
assert_eq!(back.class_id, "DDS:Auth:PKI-DH:1.2");
assert_eq!(back.property("dds.cert.sn"), Some("01:23:45:67"));
assert_eq!(back.property("dds.cert.algo"), Some("ECDSA-SHA256"));
assert_eq!(back.property("dds.ca.sn"), Some("FA:CE:0B:01"));
assert_eq!(back.property("dds.ca.algo"), Some("RSA-SHA256"));
assert!(back.binary_properties.is_empty());
}
#[test]
fn permissions_token_roundtrip() {
let tok = IdentityToken::permissions_v12("DE:AD:BE:EF", "ECDSA-SHA256");
let le = tok.to_cdr_le();
let be = tok.to_cdr_be();
assert_eq!(tok, DataHolder::from_cdr_le(&le).unwrap());
assert_eq!(tok, DataHolder::from_cdr_be(&be).unwrap());
assert_ne!(le, be, "BE/LE Streams unterscheiden sich");
}
#[test]
fn token_with_binary_property_roundtrip() {
let tok = DataHolder::new("DDS:Auth:PKI-DH:1.2")
.with_property("dds.cert.sn", "01:23")
.with_binary_property("dds.cert.bytes", vec![0xCA, 0xFE, 0xBA, 0xBE, 0xDE]);
let bytes = tok.to_cdr_le();
let back = DataHolder::from_cdr_le(&bytes).unwrap();
assert_eq!(tok, back);
assert_eq!(
back.binary_property("dds.cert.bytes"),
Some(&[0xCA, 0xFE, 0xBA, 0xBE, 0xDE][..])
);
}
#[test]
fn cdr_le_layout_class_id_only() {
let dh = DataHolder::new("A");
let bytes = dh.to_cdr_le();
assert_eq!(
bytes,
vec![
0x02, 0x00, 0x00, 0x00, b'A', 0x00, 0x00, 0x00, 0, 0, 0, 0, 0, 0, 0, 0
]
);
}
#[test]
fn truncated_buffer_is_error() {
let err = DataHolder::from_cdr_le(&[0x10, 0x00, 0x00]).unwrap_err();
assert_eq!(err.kind, SecurityErrorKind::BadArgument);
}
#[test]
fn property_count_cap_rejects_huge() {
let mut bytes = Vec::new();
encode_string(&mut bytes, "X", true);
encode_u32(&mut bytes, 1_000_000, true); let err = DataHolder::from_cdr_le(&bytes).unwrap_err();
assert_eq!(err.kind, SecurityErrorKind::BadArgument);
}
#[test]
fn missing_trailing_nul_rejected() {
let bytes = vec![0x01, 0x00, 0x00, 0x00, b'A'];
let err = DataHolder::from_cdr_le(&bytes).unwrap_err();
assert_eq!(err.kind, SecurityErrorKind::BadArgument);
}
#[test]
fn dos_cap_overall_payload() {
let big = vec![0u8; MAX_TOKEN_BYTES + 1];
let err = DataHolder::from_cdr_le(&big).unwrap_err();
assert_eq!(err.kind, SecurityErrorKind::BadArgument);
}
}