use bacnet_types::error::Error;
use bytes::Bytes;
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SegmentedPduType {
ConfirmedRequest,
ComplexAck,
}
pub fn max_segment_payload(max_apdu_length: u16, pdu_type: SegmentedPduType) -> usize {
let overhead = match pdu_type {
SegmentedPduType::ConfirmedRequest => 6,
SegmentedPduType::ComplexAck => 5,
};
(max_apdu_length as usize).saturating_sub(overhead)
}
pub fn split_payload(payload: &[u8], max_segment_size: usize) -> Vec<Bytes> {
if max_segment_size == 0 || payload.is_empty() {
return vec![Bytes::copy_from_slice(payload)];
}
let segments: Vec<Bytes> = payload
.chunks(max_segment_size)
.map(Bytes::copy_from_slice)
.collect();
if segments.len() > 256 {
return vec![Bytes::copy_from_slice(payload)];
}
segments
}
pub struct SegmentReceiver {
segments: HashMap<u8, Bytes>,
}
impl Default for SegmentReceiver {
fn default() -> Self {
Self::new()
}
}
impl SegmentReceiver {
pub fn new() -> Self {
Self {
segments: HashMap::new(),
}
}
const MAX_SEGMENT_SIZE: usize = 1476;
pub fn receive(&mut self, sequence_number: u8, data: Bytes) -> Result<(), Error> {
if data.len() > Self::MAX_SEGMENT_SIZE {
return Err(Error::Segmentation(format!(
"segment size {} exceeds maximum {}",
data.len(),
Self::MAX_SEGMENT_SIZE
)));
}
self.segments.insert(sequence_number, data);
Ok(())
}
pub fn has_segment(&self, sequence_number: u8) -> bool {
self.segments.contains_key(&sequence_number)
}
pub fn received_count(&self) -> usize {
self.segments.len()
}
pub fn reassemble(&self, total_segments: usize) -> Result<Vec<u8>, Error> {
if total_segments > 256 {
return Err(Error::Segmentation(format!(
"total_segments {total_segments} exceeds maximum BACnet value (256)"
)));
}
let mut result = Vec::with_capacity(total_segments * 480);
for i in 0..total_segments {
let seq = i as u8;
match self.segments.get(&seq) {
Some(data) => result.extend_from_slice(data),
None => {
return Err(Error::Segmentation(format!(
"missing segment {} of {}",
i, total_segments
)));
}
}
}
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn max_segment_payload_confirmed_request() {
assert_eq!(
max_segment_payload(480, SegmentedPduType::ConfirmedRequest),
474
);
assert_eq!(
max_segment_payload(1476, SegmentedPduType::ConfirmedRequest),
1470
);
}
#[test]
fn max_segment_payload_complex_ack() {
assert_eq!(max_segment_payload(480, SegmentedPduType::ComplexAck), 475);
assert_eq!(
max_segment_payload(1476, SegmentedPduType::ComplexAck),
1471
);
}
#[test]
fn split_payload_fits_single_segment() {
let payload = vec![0u8; 100];
let segments = split_payload(&payload, 200);
assert_eq!(segments.len(), 1);
assert_eq!(segments[0], payload);
}
#[test]
fn split_payload_exact_fit() {
let payload = vec![0u8; 200];
let segments = split_payload(&payload, 100);
assert_eq!(segments.len(), 2);
assert_eq!(segments[0].len(), 100);
assert_eq!(segments[1].len(), 100);
}
#[test]
fn split_payload_remainder() {
let payload = vec![0u8; 250];
let segments = split_payload(&payload, 100);
assert_eq!(segments.len(), 3);
assert_eq!(segments[0].len(), 100);
assert_eq!(segments[1].len(), 100);
assert_eq!(segments[2].len(), 50);
}
#[test]
fn split_empty_payload() {
let segments = split_payload(&[], 100);
assert_eq!(segments.len(), 1);
assert!(segments[0].is_empty());
}
#[test]
fn reassemble_ordered_segments() {
let original = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
let segments = split_payload(&original, 3);
assert_eq!(segments.len(), 4);
let mut receiver = SegmentReceiver::new();
for (i, seg) in segments.iter().enumerate() {
receiver.receive(i as u8, seg.clone()).unwrap();
}
let reassembled = receiver.reassemble(segments.len()).unwrap();
assert_eq!(reassembled, original);
}
#[test]
fn reassemble_out_of_order() {
let mut receiver = SegmentReceiver::new();
receiver.receive(2, Bytes::from_static(&[5, 6])).unwrap();
receiver.receive(0, Bytes::from_static(&[1, 2])).unwrap();
receiver.receive(1, Bytes::from_static(&[3, 4])).unwrap();
let reassembled = receiver.reassemble(3).unwrap();
assert_eq!(reassembled, vec![1, 2, 3, 4, 5, 6]);
}
#[test]
fn reassemble_missing_segment_fails() {
let mut receiver = SegmentReceiver::new();
receiver.receive(0, Bytes::from_static(&[1, 2])).unwrap();
receiver.receive(2, Bytes::from_static(&[5, 6])).unwrap();
assert!(receiver.reassemble(3).is_err());
}
#[test]
fn split_payload_zero_segment_size() {
let result = split_payload(&[1, 2, 3], 0);
assert_eq!(result, vec![vec![1, 2, 3]]);
}
}