use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum TemplateKind {
V9Data,
V9Options,
IpfixData,
IpfixOptions,
IpfixV9Data,
IpfixV9Options,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TemplateStoreKey {
pub scope: Arc<str>,
pub kind: TemplateKind,
pub template_id: u16,
}
impl TemplateStoreKey {
pub fn new(scope: impl Into<Arc<str>>, kind: TemplateKind, template_id: u16) -> Self {
Self {
scope: scope.into(),
kind,
template_id,
}
}
}
#[derive(Debug)]
pub enum TemplateStoreError {
Backend(Box<dyn std::error::Error + Send + Sync>),
Codec(String),
}
impl std::fmt::Display for TemplateStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TemplateStoreError::Backend(e) => write!(f, "template store backend error: {}", e),
TemplateStoreError::Codec(msg) => {
write!(f, "template store codec error: {}", msg)
}
}
}
}
impl std::error::Error for TemplateStoreError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
TemplateStoreError::Backend(e) => Some(e.as_ref()),
TemplateStoreError::Codec(_) => None,
}
}
}
pub trait TemplateStore: Send + Sync + std::fmt::Debug {
fn get(&self, key: &TemplateStoreKey) -> Result<Option<Vec<u8>>, TemplateStoreError>;
fn put(&self, key: &TemplateStoreKey, value: &[u8]) -> Result<(), TemplateStoreError>;
fn remove(&self, key: &TemplateStoreKey) -> Result<(), TemplateStoreError>;
}
#[derive(Debug, Default)]
pub struct InMemoryTemplateStore {
inner: Mutex<HashMap<TemplateStoreKey, Vec<u8>>>,
}
impl InMemoryTemplateStore {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.inner.lock().expect("poisoned").len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl TemplateStore for InMemoryTemplateStore {
fn get(&self, key: &TemplateStoreKey) -> Result<Option<Vec<u8>>, TemplateStoreError> {
Ok(self.inner.lock().expect("poisoned").get(key).cloned())
}
fn put(&self, key: &TemplateStoreKey, value: &[u8]) -> Result<(), TemplateStoreError> {
self.inner
.lock()
.expect("poisoned")
.insert(key.clone(), value.to_vec());
Ok(())
}
fn remove(&self, key: &TemplateStoreKey) -> Result<(), TemplateStoreError> {
self.inner.lock().expect("poisoned").remove(key);
Ok(())
}
}
pub(crate) const WIRE_VERSION: u8 = 1;
use crate::variable_versions::ipfix::lookup::IPFixField;
use crate::variable_versions::ipfix::{
OptionsTemplate as IpfixOptionsTemplate, Template as IpfixTemplate,
TemplateField as IpfixTemplateField,
};
use crate::variable_versions::v9::lookup::{ScopeFieldType, V9Field};
use crate::variable_versions::v9::{
OptionsTemplate as V9OptionsTemplate, OptionsTemplateScopeField, Template as V9Template,
TemplateField as V9TemplateField,
};
pub(crate) fn encode_v9_template(t: &V9Template) -> Vec<u8> {
let mut out = Vec::with_capacity(5 + t.fields.len() * 4);
out.push(WIRE_VERSION);
out.extend_from_slice(&t.template_id.to_be_bytes());
out.extend_from_slice(&t.field_count.to_be_bytes());
for f in &t.fields {
out.extend_from_slice(&f.field_type_number.to_be_bytes());
out.extend_from_slice(&f.field_length.to_be_bytes());
}
out
}
pub(crate) fn decode_v9_template(bytes: &[u8]) -> Result<V9Template, TemplateStoreError> {
let mut r = WireReader::new(bytes);
r.expect_version()?;
let template_id = r.u16()?;
let field_count = r.u16()?;
let mut fields = Vec::with_capacity(usize::from(field_count));
for _ in 0..field_count {
let field_type_number = r.u16()?;
let field_length = r.u16()?;
fields.push(V9TemplateField {
field_type_number,
field_type: V9Field::from(field_type_number),
field_length,
});
}
Ok(V9Template {
template_id,
field_count,
fields,
})
}
pub(crate) fn encode_v9_options_template(t: &V9OptionsTemplate) -> Vec<u8> {
let mut out = Vec::with_capacity(7 + t.scope_fields.len() * 4 + t.option_fields.len() * 4);
out.push(WIRE_VERSION);
out.extend_from_slice(&t.template_id.to_be_bytes());
out.extend_from_slice(&t.options_scope_length.to_be_bytes());
out.extend_from_slice(&t.options_length.to_be_bytes());
for f in &t.scope_fields {
out.extend_from_slice(&f.field_type_number.to_be_bytes());
out.extend_from_slice(&f.field_length.to_be_bytes());
}
for f in &t.option_fields {
out.extend_from_slice(&f.field_type_number.to_be_bytes());
out.extend_from_slice(&f.field_length.to_be_bytes());
}
out
}
pub(crate) fn decode_v9_options_template(
bytes: &[u8],
) -> Result<V9OptionsTemplate, TemplateStoreError> {
let mut r = WireReader::new(bytes);
r.expect_version()?;
let template_id = r.u16()?;
let options_scope_length = r.u16()?;
let options_length = r.u16()?;
if !options_scope_length.is_multiple_of(4) || !options_length.is_multiple_of(4) {
return Err(TemplateStoreError::Codec(format!(
"v9 options template length not aligned to 4: scope={} options={}",
options_scope_length, options_length
)));
}
let scope_count = usize::from(options_scope_length / 4);
let option_count = usize::from(options_length / 4);
let mut scope_fields = Vec::with_capacity(scope_count);
for _ in 0..scope_count {
let field_type_number = r.u16()?;
let field_length = r.u16()?;
scope_fields.push(OptionsTemplateScopeField {
field_type_number,
field_type: ScopeFieldType::from(field_type_number),
field_length,
});
}
let mut option_fields = Vec::with_capacity(option_count);
for _ in 0..option_count {
let field_type_number = r.u16()?;
let field_length = r.u16()?;
option_fields.push(V9TemplateField {
field_type_number,
field_type: V9Field::from(field_type_number),
field_length,
});
}
Ok(V9OptionsTemplate {
template_id,
options_scope_length,
options_length,
scope_fields,
option_fields,
})
}
pub(crate) fn encode_ipfix_template(t: &IpfixTemplate) -> Vec<u8> {
let mut out = Vec::with_capacity(5 + t.fields.len() * 9);
out.push(WIRE_VERSION);
out.extend_from_slice(&t.template_id.to_be_bytes());
out.extend_from_slice(&t.field_count.to_be_bytes());
for f in &t.fields {
encode_ipfix_field(&mut out, f);
}
out
}
pub(crate) fn decode_ipfix_template(bytes: &[u8]) -> Result<IpfixTemplate, TemplateStoreError> {
let mut r = WireReader::new(bytes);
r.expect_version()?;
let template_id = r.u16()?;
let field_count = r.u16()?;
let mut fields = Vec::with_capacity(usize::from(field_count));
for _ in 0..field_count {
fields.push(decode_ipfix_field(&mut r)?);
}
Ok(IpfixTemplate {
template_id,
field_count,
fields,
})
}
pub(crate) fn encode_ipfix_options_template(t: &IpfixOptionsTemplate) -> Vec<u8> {
let mut out = Vec::with_capacity(7 + t.fields.len() * 9);
out.push(WIRE_VERSION);
out.extend_from_slice(&t.template_id.to_be_bytes());
out.extend_from_slice(&t.field_count.to_be_bytes());
out.extend_from_slice(&t.scope_field_count.to_be_bytes());
for f in &t.fields {
encode_ipfix_field(&mut out, f);
}
out
}
pub(crate) fn decode_ipfix_options_template(
bytes: &[u8],
) -> Result<IpfixOptionsTemplate, TemplateStoreError> {
let mut r = WireReader::new(bytes);
r.expect_version()?;
let template_id = r.u16()?;
let field_count = r.u16()?;
let scope_field_count = r.u16()?;
let mut fields = Vec::with_capacity(usize::from(field_count));
for _ in 0..field_count {
fields.push(decode_ipfix_field(&mut r)?);
}
Ok(IpfixOptionsTemplate {
template_id,
field_count,
scope_field_count,
fields,
})
}
fn encode_ipfix_field(out: &mut Vec<u8>, f: &IpfixTemplateField) {
out.extend_from_slice(&f.field_type_number.to_be_bytes());
out.extend_from_slice(&f.field_length.to_be_bytes());
match f.enterprise_number {
Some(en) => {
out.push(1);
out.extend_from_slice(&en.to_be_bytes());
}
None => out.push(0),
}
}
fn decode_ipfix_field(
r: &mut WireReader<'_>,
) -> Result<IpfixTemplateField, TemplateStoreError> {
let field_type_number = r.u16()?;
let field_length = r.u16()?;
let enterprise_present = r.u8()?;
let enterprise_number = match enterprise_present {
0 => None,
1 => Some(r.u32()?),
other => {
return Err(TemplateStoreError::Codec(format!(
"invalid enterprise flag: {}",
other
)));
}
};
Ok(IpfixTemplateField {
field_type_number,
field_length,
enterprise_number,
field_type: IPFixField::new(field_type_number, enterprise_number),
})
}
struct WireReader<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> WireReader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
fn expect_version(&mut self) -> Result<(), TemplateStoreError> {
let v = self.u8()?;
if v != WIRE_VERSION {
return Err(TemplateStoreError::Codec(format!(
"unsupported wire version: {} (expected {})",
v, WIRE_VERSION
)));
}
Ok(())
}
fn u8(&mut self) -> Result<u8, TemplateStoreError> {
let bytes = self.take(1)?;
Ok(bytes[0])
}
fn u16(&mut self) -> Result<u16, TemplateStoreError> {
let bytes = self.take(2)?;
Ok(u16::from_be_bytes([bytes[0], bytes[1]]))
}
fn u32(&mut self) -> Result<u32, TemplateStoreError> {
let bytes = self.take(4)?;
Ok(u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]))
}
fn take(&mut self, n: usize) -> Result<&'a [u8], TemplateStoreError> {
if self.pos + n > self.buf.len() {
return Err(TemplateStoreError::Codec(format!(
"unexpected end of payload at offset {} (need {} more)",
self.pos, n
)));
}
let s = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(s)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn in_memory_round_trip() {
let store = InMemoryTemplateStore::new();
let key = TemplateStoreKey::new("1.2.3.4:2055", TemplateKind::V9Data, 256);
assert!(store.get(&key).unwrap().is_none());
store.put(&key, b"hello").unwrap();
assert_eq!(store.get(&key).unwrap().as_deref(), Some(&b"hello"[..]));
assert_eq!(store.len(), 1);
store.remove(&key).unwrap();
assert!(store.get(&key).unwrap().is_none());
assert!(store.is_empty());
}
#[test]
fn v9_template_wire_round_trip() {
let original = V9Template {
template_id: 256,
field_count: 3,
fields: vec![
V9TemplateField {
field_type_number: 8,
field_type: V9Field::from(8),
field_length: 4,
},
V9TemplateField {
field_type_number: 12,
field_type: V9Field::from(12),
field_length: 4,
},
V9TemplateField {
field_type_number: 1,
field_type: V9Field::from(1),
field_length: 8,
},
],
};
let bytes = encode_v9_template(&original);
let decoded = decode_v9_template(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn ipfix_template_wire_round_trip_with_enterprise() {
let original = IpfixTemplate {
template_id: 300,
field_count: 2,
fields: vec![
IpfixTemplateField {
field_type_number: 8,
field_length: 4,
enterprise_number: None,
field_type: IPFixField::new(8, None),
},
IpfixTemplateField {
field_type_number: 1,
field_length: 8,
enterprise_number: Some(9),
field_type: IPFixField::new(1, Some(9)),
},
],
};
let bytes = encode_ipfix_template(&original);
let decoded = decode_ipfix_template(&bytes).unwrap();
assert_eq!(decoded, original);
}
#[test]
fn rejects_bad_version() {
let bad = vec![99u8, 0, 0, 0, 0];
let err = decode_v9_template(&bad).unwrap_err();
assert!(matches!(err, TemplateStoreError::Codec(_)));
}
#[test]
fn rejects_truncated_payload() {
let bytes = vec![WIRE_VERSION, 0]; let err = decode_v9_template(&bytes).unwrap_err();
assert!(matches!(err, TemplateStoreError::Codec(_)));
}
}