use alloc::vec::Vec;
use zerodds_cdr::{BufferReader, BufferWriter, DecodeError, EncodeError};
use crate::association_options::AssociationOptions;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AsContextSec {
pub target_supports: AssociationOptions,
pub target_requires: AssociationOptions,
pub client_authentication_mech: Vec<u8>,
pub target_name: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SasContextSec {
pub target_supports: AssociationOptions,
pub target_requires: AssociationOptions,
pub privilege_authorities: Vec<Vec<u8>>,
pub supported_naming_mechanisms: Vec<Vec<u8>>,
pub supported_identity_types: u32,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompoundSecMech {
pub target_requires: AssociationOptions,
pub transport_mech_tag: u32,
pub transport_mech_data: Vec<u8>,
pub as_context: AsContextSec,
pub sas_context: SasContextSec,
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct CompoundSecMechList {
pub stateful: bool,
pub mechanism_list: Vec<CompoundSecMech>,
}
fn write_octet_seq(w: &mut BufferWriter, bytes: &[u8]) -> Result<(), EncodeError> {
let n = u32::try_from(bytes.len()).map_err(|_| EncodeError::ValueOutOfRange {
message: "csiv2 octet/seq length exceeds u32::MAX",
})?;
w.write_u32(n)?;
w.write_bytes(bytes)
}
fn read_octet_seq(r: &mut BufferReader<'_>) -> Result<Vec<u8>, DecodeError> {
let n = r.read_u32()? as usize;
let bytes = r.read_bytes(n)?;
Ok(bytes.to_vec())
}
fn write_octet_seq_seq(w: &mut BufferWriter, items: &[Vec<u8>]) -> Result<(), EncodeError> {
let n = u32::try_from(items.len()).map_err(|_| EncodeError::ValueOutOfRange {
message: "csiv2 octet/seq length exceeds u32::MAX",
})?;
w.write_u32(n)?;
for item in items {
write_octet_seq(w, item)?;
}
Ok(())
}
fn read_octet_seq_seq(r: &mut BufferReader<'_>) -> Result<Vec<Vec<u8>>, DecodeError> {
let n = r.read_u32()? as usize;
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(read_octet_seq(r)?);
}
Ok(out)
}
impl AsContextSec {
pub fn encode(&self, w: &mut BufferWriter) -> Result<(), EncodeError> {
w.write_u16(self.target_supports.0)?;
w.write_u16(self.target_requires.0)?;
write_octet_seq(w, &self.client_authentication_mech)?;
write_octet_seq(w, &self.target_name)
}
pub fn decode(r: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
let target_supports = AssociationOptions(r.read_u16()?);
let target_requires = AssociationOptions(r.read_u16()?);
let client_authentication_mech = read_octet_seq(r)?;
let target_name = read_octet_seq(r)?;
Ok(Self {
target_supports,
target_requires,
client_authentication_mech,
target_name,
})
}
}
impl SasContextSec {
pub fn encode(&self, w: &mut BufferWriter) -> Result<(), EncodeError> {
w.write_u16(self.target_supports.0)?;
w.write_u16(self.target_requires.0)?;
write_octet_seq_seq(w, &self.privilege_authorities)?;
write_octet_seq_seq(w, &self.supported_naming_mechanisms)?;
w.write_u32(self.supported_identity_types)
}
pub fn decode(r: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
let target_supports = AssociationOptions(r.read_u16()?);
let target_requires = AssociationOptions(r.read_u16()?);
let privilege_authorities = read_octet_seq_seq(r)?;
let supported_naming_mechanisms = read_octet_seq_seq(r)?;
let supported_identity_types = r.read_u32()?;
Ok(Self {
target_supports,
target_requires,
privilege_authorities,
supported_naming_mechanisms,
supported_identity_types,
})
}
}
impl CompoundSecMech {
pub fn encode(&self, w: &mut BufferWriter) -> Result<(), EncodeError> {
w.write_u16(self.target_requires.0)?;
w.write_u32(self.transport_mech_tag)?;
write_octet_seq(w, &self.transport_mech_data)?;
self.as_context.encode(w)?;
self.sas_context.encode(w)
}
pub fn decode(r: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
let target_requires = AssociationOptions(r.read_u16()?);
let transport_mech_tag = r.read_u32()?;
let transport_mech_data = read_octet_seq(r)?;
let as_context = AsContextSec::decode(r)?;
let sas_context = SasContextSec::decode(r)?;
Ok(Self {
target_requires,
transport_mech_tag,
transport_mech_data,
as_context,
sas_context,
})
}
}
impl CompoundSecMechList {
pub fn encode(&self, w: &mut BufferWriter) -> Result<(), EncodeError> {
w.write_u8(u8::from(self.stateful))?;
let n =
u32::try_from(self.mechanism_list.len()).map_err(|_| EncodeError::ValueOutOfRange {
message: "csiv2 octet/seq length exceeds u32::MAX",
})?;
w.write_u32(n)?;
for mech in &self.mechanism_list {
mech.encode(w)?;
}
Ok(())
}
pub fn decode(r: &mut BufferReader<'_>) -> Result<Self, DecodeError> {
let stateful = r.read_u8()? != 0;
let n = r.read_u32()? as usize;
let mut mechanism_list = Vec::with_capacity(n);
for _ in 0..n {
mechanism_list.push(CompoundSecMech::decode(r)?);
}
Ok(Self {
stateful,
mechanism_list,
})
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
#[test]
fn build_mech_list_with_tls_and_gssup() {
let mech = CompoundSecMech {
target_requires: AssociationOptions::default()
.with(AssociationOptions::INTEGRITY)
.with(AssociationOptions::CONFIDENTIALITY),
transport_mech_tag: 36, transport_mech_data: alloc::vec![0xab, 0xcd],
as_context: AsContextSec {
target_supports: AssociationOptions::from_bits(
AssociationOptions::ESTABLISH_TRUST_IN_CLIENT,
),
target_requires: AssociationOptions::default(),
client_authentication_mech: super::super::gssup::GSSUP_OID_DER.to_vec(),
target_name: b"REALM.LAB".to_vec(),
},
sas_context: SasContextSec {
target_supports: AssociationOptions::default()
.with(AssociationOptions::IDENTITY_ASSERTION),
target_requires: AssociationOptions::default(),
privilege_authorities: alloc::vec![],
supported_naming_mechanisms: alloc::vec![],
supported_identity_types: 0,
},
};
let list = CompoundSecMechList {
stateful: true,
mechanism_list: alloc::vec![mech.clone()],
};
assert_eq!(list.mechanism_list.len(), 1);
assert_eq!(list.mechanism_list[0].transport_mech_tag, 36);
assert_eq!(
list.mechanism_list[0].as_context.client_authentication_mech,
super::super::gssup::GSSUP_OID_DER
);
}
#[test]
fn empty_list_has_no_mechanisms() {
let list = CompoundSecMechList::default();
assert!(list.mechanism_list.is_empty());
assert!(!list.stateful);
}
#[test]
fn cdr_roundtrip_compound_sec_mech_list() {
use zerodds_cdr::{BufferReader, BufferWriter, Endianness};
let original = CompoundSecMechList {
stateful: true,
mechanism_list: alloc::vec![CompoundSecMech {
target_requires: AssociationOptions::default()
.with(AssociationOptions::INTEGRITY)
.with(AssociationOptions::CONFIDENTIALITY),
transport_mech_tag: 36, transport_mech_data: alloc::vec![0xab, 0xcd, 0xef],
as_context: AsContextSec {
target_supports: AssociationOptions::from_bits(
AssociationOptions::ESTABLISH_TRUST_IN_CLIENT,
),
target_requires: AssociationOptions::default(),
client_authentication_mech: super::super::gssup::GSSUP_OID_DER.to_vec(),
target_name: b"REALM.LAB".to_vec(),
},
sas_context: SasContextSec {
target_supports: AssociationOptions::default()
.with(AssociationOptions::IDENTITY_ASSERTION),
target_requires: AssociationOptions::default(),
privilege_authorities: alloc::vec![b"auth-1".to_vec(), b"auth-2".to_vec()],
supported_naming_mechanisms: alloc::vec![b"GSS_KRB5".to_vec()],
supported_identity_types: 0x0000_0007,
},
}],
};
let mut w = BufferWriter::new(Endianness::Little);
original.encode(&mut w).expect("encode");
let bytes = w.into_bytes();
assert!(!bytes.is_empty());
let mut r = BufferReader::new(&bytes, Endianness::Little);
let decoded = CompoundSecMechList::decode(&mut r).expect("decode");
assert_eq!(original, decoded);
}
#[test]
fn cdr_roundtrip_empty_list() {
use zerodds_cdr::{BufferReader, BufferWriter, Endianness};
let original = CompoundSecMechList::default();
let mut w = BufferWriter::new(Endianness::Little);
original.encode(&mut w).expect("encode");
let bytes = w.into_bytes();
assert_eq!(bytes.len(), 8);
let mut r = BufferReader::new(&bytes, Endianness::Little);
let decoded = CompoundSecMechList::decode(&mut r).expect("decode");
assert_eq!(original, decoded);
}
}