mod capability;
mod macros;
mod protocol;
use bytes::{Buf, BytesMut};
use thiserror::Error;
use crate::decoding::Parsable;
use crate::encoding::Writable;
use crate::error::STAGE_DECODING;
use crate::{NotEnoughData, ProtocolError};
pub use capability::Capability;
pub use macros::{MacroStage, MacroStages};
pub use protocol::Protocol;
#[derive(Clone, PartialEq, Debug)]
pub struct OptNeg {
pub version: u32,
pub capabilities: Capability,
pub protocol: Protocol,
pub macro_stages: MacroStages,
}
impl Default for OptNeg {
fn default() -> Self {
Self {
version: Self::VERSION,
capabilities: Capability::default(),
protocol: Protocol::default(),
macro_stages: MacroStages::default(),
}
}
}
#[derive(Debug, Error)]
pub enum CompatibilityError {
#[error("Received version {received} which is not compatible with {supported}")]
UnsupportedVersion {
received: u32,
supported: u32,
},
}
impl OptNeg {
const VERSION: u32 = 6;
const DATA_SIZE: usize = 4 + 4 + 4;
const CODE: u8 = b'O';
pub fn merge_compatible(mut self, other: &Self) -> Result<Self, CompatibilityError> {
if self.version < other.version {
return Err(CompatibilityError::UnsupportedVersion {
received: other.version,
supported: self.version,
});
}
self.protocol = self
.protocol
.merge_regarding_version(self.version, other.protocol);
self.capabilities = self
.capabilities
.merge_regarding_version(self.version, other.capabilities);
Ok(self)
}
}
impl Parsable for OptNeg {
const CODE: u8 = Self::CODE;
fn parse(mut buffer: BytesMut) -> Result<Self, ProtocolError> {
if buffer.len() != Self::DATA_SIZE {
return Err(NotEnoughData::new(
STAGE_DECODING,
"Option negotiation",
"not enough bits",
Self::DATA_SIZE,
buffer.len(),
buffer,
)
.into());
}
let mut version: [u8; 4] = [0; 4];
version.copy_from_slice(&buffer[0..4]);
let version = u32::from_be_bytes(version);
let mut capabilities: [u8; 4] = [0; 4];
capabilities.copy_from_slice(&buffer[4..8]);
let capabilities: Capability =
Capability::from_bits_retain(u32::from_be_bytes(capabilities));
let mut protocol: [u8; 4] = [0; 4];
protocol.copy_from_slice(&buffer[8..12]);
let protocol: Protocol = Protocol::from_bits_retain(u32::from_be_bytes(protocol));
buffer.advance(12);
Ok(Self {
version,
capabilities,
protocol,
macro_stages: MacroStages::default(),
})
}
}
impl Writable for OptNeg {
fn write(&self, buffer: &mut BytesMut) {
buffer.extend_from_slice(&self.version.to_be_bytes());
buffer.extend_from_slice(&self.capabilities.bits().to_be_bytes());
buffer.extend_from_slice(&self.protocol.bits().to_be_bytes());
self.macro_stages.write(buffer);
}
fn len(&self) -> usize {
Self::DATA_SIZE + self.macro_stages.len()
}
fn code(&self) -> u8 {
Self::CODE
}
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
fn ver_caps_prot() -> ([u8; 4], [u8; 4], [u8; 4]) {
let version = [0u8, 0u8, 0u8, 6u8];
let capabilities = [0u8, 0u8, 0u8, 255u8];
let protocol = [0u8, 0u8, 0u8, 0u8];
(version, capabilities, protocol)
}
#[cfg(feature = "count-allocations")]
#[allow(clippy::type_complexity)] fn create_optneg_from_bytes() -> (BytesMut, ([u8; 4], [u8; 4], [u8; 4])) {
let mut buffer = BytesMut::new();
let (version, capabilities, protocol) = ver_caps_prot();
buffer.extend_from_slice(&version);
buffer.extend_from_slice(&capabilities);
buffer.extend_from_slice(&protocol);
(buffer, (version, capabilities, protocol))
}
#[cfg(feature = "count-allocations")]
#[test]
fn test_parse_optneg() {
use super::OptNeg;
let (buffer, _) = create_optneg_from_bytes();
let info = allocation_counter::measure(|| {
let res = OptNeg::parse(buffer);
allocation_counter::opt_out(|| {
println!("{res:?}");
assert!(res.is_ok());
});
});
println!("{}", &info.count_total);
assert_eq!(info.count_total, 0);
}
#[test]
fn test_write_optneg() {
let (version, capabilities, protocol) = ver_caps_prot();
let mut expected = Vec::new();
expected.extend_from_slice(&version);
expected.extend_from_slice(&capabilities);
expected.extend_from_slice(&protocol);
let mut buffer = BytesMut::new();
let optneg = OptNeg::default();
optneg.write(&mut buffer);
assert_eq!(optneg.len(), buffer.len());
assert_eq!(optneg.code(), b'O');
assert_eq!(expected, buffer.to_vec());
}
}