use super::record::{PacketRecord, TransformTrace};
use super::Result;
pub trait PacketTransform {
fn name(&self) -> &'static str;
fn transform(
&mut self,
record: PacketRecord,
emit: &mut dyn FnMut(PacketRecord) -> Result<()>,
) -> Result<()>;
fn transform_to_output(&mut self, record: PacketRecord) -> Result<TransformOutput>
where
Self: Sized,
{
let mut output = TransformOutput::new();
self.transform(record, &mut |record| {
output.push(record);
Ok(())
})?;
Ok(output)
}
}
#[derive(Debug, Clone, Default)]
pub struct TransformOutput {
records: Vec<PacketRecord>,
}
impl TransformOutput {
pub fn new() -> Self {
Self::default()
}
pub fn records(&self) -> &[PacketRecord] {
&self.records
}
pub fn len(&self) -> usize {
self.records.len()
}
pub fn is_empty(&self) -> bool {
self.records.is_empty()
}
pub fn push(&mut self, record: PacketRecord) -> &mut Self {
self.records.push(record);
self
}
pub fn emit(&mut self, record: PacketRecord) -> Result<()> {
self.records.push(record);
Ok(())
}
pub fn clear(&mut self) -> &mut Self {
self.records.clear();
self
}
pub fn into_records(self) -> Vec<PacketRecord> {
self.records
}
}
#[derive(Debug, Clone, Default)]
pub struct PassThroughTransform {
input_count: usize,
emitted_count: usize,
}
impl PassThroughTransform {
pub fn new() -> Self {
Self::default()
}
pub const fn input_count(&self) -> usize {
self.input_count
}
pub const fn emitted_count(&self) -> usize {
self.emitted_count
}
}
impl PacketTransform for PassThroughTransform {
fn name(&self) -> &'static str {
"pass-through"
}
fn transform(
&mut self,
record: PacketRecord,
emit: &mut dyn FnMut(PacketRecord) -> Result<()>,
) -> Result<()> {
self.input_count += 1;
emit(record)?;
self.emitted_count += 1;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct DropAllTransform {
dropped_count: usize,
}
impl DropAllTransform {
pub fn new() -> Self {
Self::default()
}
pub const fn dropped_count(&self) -> usize {
self.dropped_count
}
}
impl PacketTransform for DropAllTransform {
fn name(&self) -> &'static str {
"drop-all"
}
fn transform(
&mut self,
_record: PacketRecord,
_emit: &mut dyn FnMut(PacketRecord) -> Result<()>,
) -> Result<()> {
self.dropped_count += 1;
Ok(())
}
}
#[derive(Debug, Clone, Default)]
pub struct DuplicateTransform {
input_count: usize,
emitted_count: usize,
}
impl DuplicateTransform {
pub fn new() -> Self {
Self::default()
}
pub const fn input_count(&self) -> usize {
self.input_count
}
pub const fn emitted_count(&self) -> usize {
self.emitted_count
}
}
impl PacketTransform for DuplicateTransform {
fn name(&self) -> &'static str {
"duplicate"
}
fn transform(
&mut self,
record: PacketRecord,
emit: &mut dyn FnMut(PacketRecord) -> Result<()>,
) -> Result<()> {
self.input_count += 1;
emit(record.clone())?;
self.emitted_count += 1;
emit(record)?;
self.emitted_count += 1;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct TraceAppendTransform {
name: &'static str,
note: Option<String>,
input_count: usize,
emitted_count: usize,
}
impl TraceAppendTransform {
pub const fn new(name: &'static str) -> Self {
Self {
name,
note: None,
input_count: 0,
emitted_count: 0,
}
}
pub const fn trace_append() -> Self {
Self::new("trace-append")
}
pub fn with_note(mut self, note: impl Into<String>) -> Self {
self.note = Some(note.into());
self
}
pub const fn input_count(&self) -> usize {
self.input_count
}
pub const fn emitted_count(&self) -> usize {
self.emitted_count
}
}
impl Default for TraceAppendTransform {
fn default() -> Self {
Self::trace_append()
}
}
impl PacketTransform for TraceAppendTransform {
fn name(&self) -> &'static str {
self.name
}
fn transform(
&mut self,
mut record: PacketRecord,
emit: &mut dyn FnMut(PacketRecord) -> Result<()>,
) -> Result<()> {
self.input_count += 1;
let mut trace = TransformTrace::new(self.name());
if let Some(note) = &self.note {
trace = trace.with_note(note.clone());
}
record.metadata_mut().push_transform_trace(trace);
emit(record)?;
self.emitted_count += 1;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::super::record::{BackendKind, PacketOrigin};
use super::super::WireError;
use super::*;
use crate::Raw;
fn record(payload: &'static str) -> PacketRecord {
PacketRecord::new(Raw::from(payload))
.with_origin(PacketOrigin::Generated)
.with_backend(BackendKind::Memory)
.with_interface("lo")
}
#[test]
fn transform_output_buffers_records_in_order() {
let mut output = TransformOutput::new();
assert!(output.is_empty());
output.push(record("one"));
output.emit(record("two")).unwrap();
assert_eq!(output.len(), 2);
assert_eq!(output.records()[0].packet().summary(), "Raw(len=3)");
assert_eq!(output.records()[1].packet().summary(), "Raw(len=3)");
output.clear();
assert!(output.is_empty());
}
#[test]
fn pass_through_emits_one_record_unchanged() {
let input = record("payload");
let mut transform = PassThroughTransform::new();
let output = transform.transform_to_output(input).unwrap();
assert_eq!(transform.name(), "pass-through");
assert_eq!(transform.input_count(), 1);
assert_eq!(transform.emitted_count(), 1);
assert_eq!(output.len(), 1);
assert_eq!(output.records()[0].packet().summary(), "Raw(len=7)");
assert_eq!(
output.records()[0].metadata().origin(),
PacketOrigin::Generated
);
assert_eq!(
output.records()[0].metadata().backend(),
&BackendKind::Memory
);
assert_eq!(output.records()[0].metadata().interface(), Some("lo"));
}
#[test]
fn drop_all_emits_zero_records() {
let mut transform = DropAllTransform::new();
let output = transform.transform_to_output(record("payload")).unwrap();
assert_eq!(transform.name(), "drop-all");
assert_eq!(transform.dropped_count(), 1);
assert!(output.is_empty());
}
#[test]
fn duplicate_emits_two_records_per_input() {
let mut transform = DuplicateTransform::new();
let output = transform.transform_to_output(record("payload")).unwrap();
assert_eq!(transform.name(), "duplicate");
assert_eq!(transform.input_count(), 1);
assert_eq!(transform.emitted_count(), 2);
assert_eq!(output.len(), 2);
assert_eq!(output.records()[0].packet().summary(), "Raw(len=7)");
assert_eq!(output.records()[1].packet().summary(), "Raw(len=7)");
assert_eq!(
output.records()[0].metadata().origin(),
PacketOrigin::Generated
);
assert_eq!(
output.records()[1].metadata().backend(),
&BackendKind::Memory
);
}
#[test]
fn trace_append_adds_transform_history() {
let mut transform = TraceAppendTransform::new("decode-ip").with_note("decoded");
let output = transform.transform_to_output(record("payload")).unwrap();
assert_eq!(transform.name(), "decode-ip");
assert_eq!(transform.input_count(), 1);
assert_eq!(transform.emitted_count(), 1);
assert_eq!(output.len(), 1);
let traces = output.records()[0].metadata().transforms();
assert_eq!(traces.len(), 1);
assert_eq!(traces[0].name(), "decode-ip");
assert_eq!(traces[0].note(), Some("decoded"));
assert_eq!(output.records()[0].packet().summary(), "Raw(len=7)");
}
#[test]
fn packet_transform_is_object_safe() {
let mut transform: Box<dyn PacketTransform> = Box::new(PassThroughTransform::new());
let mut output = TransformOutput::new();
transform
.transform(record("payload"), &mut |record| output.emit(record))
.unwrap();
assert_eq!(transform.name(), "pass-through");
assert_eq!(output.len(), 1);
assert_eq!(output.records()[0].packet().summary(), "Raw(len=7)");
}
#[test]
fn transform_propagates_emitter_errors() {
let mut transform = DuplicateTransform::new();
let err = transform
.transform(record("payload"), &mut |_record| {
Err(WireError::transform("collector", "closed"))
})
.unwrap_err();
assert_eq!(err.to_string(), "wire transform 'collector' failed: closed");
assert_eq!(transform.input_count(), 1);
assert_eq!(transform.emitted_count(), 0);
}
}