pushwire-core 0.1.1

Shared types and codecs for push-wire multiplexed push protocol
Documentation
use std::collections::HashMap;

use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;

/// Header: [fragment_id:16][total:2][index:2]
pub const FRAGMENT_HEADER_LEN: usize = 16 + 2 + 2;

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FragmentShard {
    pub fragment_id: Uuid,
    pub total: u16,
    pub index: u16,
    pub payload: Vec<u8>,
}

#[derive(Debug, Error, PartialEq, Eq)]
pub enum FragmentError {
    #[error("invalid fragment size")]
    InvalidSize,
    #[error("duplicate fragment {index}/{total}")]
    Duplicate { index: u16, total: u16 },
    #[error("unexpected total mismatch")]
    TotalMismatch,
}

/// Assembles fragments keyed by fragment_id until all parts arrive.
#[derive(Default, Debug)]
pub struct FragmentAssembler {
    inflight: HashMap<Uuid, Inflight>,
    order: Vec<Uuid>,
}

/// Retention/cleanup policy for fragments.
#[derive(Debug, Clone, Copy)]
pub struct FragmentRetention {
    pub max_inflight: usize,
}

impl Default for FragmentRetention {
    fn default() -> Self {
        Self { max_inflight: 1024 }
    }
}

#[derive(Debug)]
struct Inflight {
    total: u16,
    received: Vec<Option<Vec<u8>>>,
    seen: u16,
}

impl FragmentAssembler {
    pub fn new() -> Self {
        Self {
            inflight: HashMap::new(),
            order: Vec::new(),
        }
    }

    /// Push a new fragment; returns Some(full_payload) when completed.
    pub fn push(&mut self, shard: FragmentShard) -> Result<Option<Vec<u8>>, FragmentError> {
        if shard.total == 0 {
            return Err(FragmentError::InvalidSize);
        }

        let entry = self
            .inflight
            .entry(shard.fragment_id)
            .or_insert_with(|| Inflight::new(shard.total));
        if entry.seen == 0 {
            self.order.push(shard.fragment_id);
        }

        if entry.total != shard.total {
            return Err(FragmentError::TotalMismatch);
        }

        if shard.index as usize >= entry.received.len() {
            return Err(FragmentError::InvalidSize);
        }

        if entry.received[shard.index as usize].is_some() {
            return Err(FragmentError::Duplicate {
                index: shard.index,
                total: shard.total,
            });
        }

        entry.received[shard.index as usize] = Some(shard.payload);
        entry.seen += 1;

        if entry.seen == shard.total {
            let mut buf = Vec::new();
            for part in entry.received.iter_mut() {
                if let Some(mut chunk) = part.take() {
                    buf.append(&mut chunk);
                }
            }
            self.inflight.remove(&shard.fragment_id);
            self.order.retain(|id| id != &shard.fragment_id);
            Ok(Some(buf))
        } else {
            Ok(None)
        }
    }

    pub fn inflight(&self) -> usize {
        self.inflight.len()
    }

    /// Apply retention: evict oldest if exceeding max inflight.
    pub fn enforce_retention(&mut self, retention: FragmentRetention) {
        if self.inflight.len() <= retention.max_inflight {
            return;
        }
        while self.inflight.len() > retention.max_inflight {
            if let Some(oldest) = self.order.first().cloned() {
                self.inflight.remove(&oldest);
                self.order.remove(0);
            } else {
                break;
            }
        }
    }
}

impl Inflight {
    fn new(total: u16) -> Self {
        Self {
            total,
            received: vec![None; total as usize],
            seen: 0,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn assembles_in_order() {
        let id = Uuid::new_v4();
        let mut assembler = FragmentAssembler::new();
        let total = 3;

        let shard1 = FragmentShard {
            fragment_id: id,
            total,
            index: 0,
            payload: b"hel".to_vec(),
        };
        assert!(assembler.push(shard1).unwrap().is_none());

        let shard2 = FragmentShard {
            fragment_id: id,
            total,
            index: 1,
            payload: b"lo ".to_vec(),
        };
        assert!(assembler.push(shard2).unwrap().is_none());

        let shard3 = FragmentShard {
            fragment_id: id,
            total,
            index: 2,
            payload: b"world".to_vec(),
        };
        let assembled = assembler.push(shard3).unwrap().expect("should assemble");
        assert_eq!(assembled, b"hello world");
        assert_eq!(assembler.inflight(), 0);
    }

    #[test]
    fn assembles_out_of_order() {
        let id = Uuid::new_v4();
        let mut assembler = FragmentAssembler::new();
        let total = 2;

        let shard2 = FragmentShard {
            fragment_id: id,
            total,
            index: 1,
            payload: b"world".to_vec(),
        };
        assert!(assembler.push(shard2).unwrap().is_none());

        let shard1 = FragmentShard {
            fragment_id: id,
            total,
            index: 0,
            payload: b"hello ".to_vec(),
        };
        let assembled = assembler.push(shard1).unwrap().expect("should assemble");
        assert_eq!(assembled, b"hello world");
    }

    #[test]
    fn rejects_duplicates() {
        let id = Uuid::new_v4();
        let mut assembler = FragmentAssembler::new();
        let total = 2;
        let shard0 = FragmentShard {
            fragment_id: id,
            total,
            index: 0,
            payload: b"abc".to_vec(),
        };
        assert!(assembler.push(shard0.clone()).unwrap().is_none());
        let err = assembler.push(shard0).unwrap_err();
        assert_eq!(err, FragmentError::Duplicate { index: 0, total: 2 });

        let shard1 = FragmentShard {
            fragment_id: id,
            total,
            index: 1,
            payload: b"xyz".to_vec(),
        };
        let assembled = assembler.push(shard1).unwrap().expect("should assemble");
        assert_eq!(assembled, b"abcxyz");
    }

    #[test]
    fn mismatched_total_is_error() {
        let id = Uuid::new_v4();
        let mut assembler = FragmentAssembler::new();
        let shard1 = FragmentShard {
            fragment_id: id,
            total: 2,
            index: 0,
            payload: vec![],
        };
        assert!(assembler.push(shard1).is_ok());

        let shard2 = FragmentShard {
            fragment_id: id,
            total: 3,
            index: 1,
            payload: vec![],
        };
        let err = assembler.push(shard2).unwrap_err();
        assert_eq!(err, FragmentError::TotalMismatch);
    }

    #[test]
    fn retention_evicts_oldest() {
        let mut assembler = FragmentAssembler::new();
        let retention = FragmentRetention { max_inflight: 2 };

        let ids: Vec<_> = (0..3).map(|_| Uuid::new_v4()).collect();
        for id in &ids {
            let shard = FragmentShard {
                fragment_id: *id,
                total: 2,
                index: 0,
                payload: b"x".to_vec(),
            };
            let _ = assembler.push(shard).unwrap();
        }

        assembler.enforce_retention(retention);
        assert!(assembler.inflight.contains_key(&ids[1]));
        assert!(assembler.inflight.contains_key(&ids[2]));
        assert!(!assembler.inflight.contains_key(&ids[0]));
    }
}