use std::ops::ControlFlow;
use std::path::PathBuf;
use std::time::Instant;
use packet_dissector::registry::DissectorRegistry;
use serde::Deserialize;
use super::limits::ResourceLimits;
use crate::decode_as;
use crate::error::{DsctError, Result as DsctResult, ResultExt, format_error};
use crate::esp_sa;
use crate::filter::normalize_protocol_name;
use crate::stats;
fn string_or_vec<'de, D>(deserializer: D) -> std::result::Result<Vec<String>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de;
struct StringOrVec;
impl<'de> de::Visitor<'de> for StringOrVec {
type Value = Vec<String>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string or an array of strings")
}
fn visit_str<E: de::Error>(self, value: &str) -> std::result::Result<Vec<String>, E> {
Ok(vec![value.to_owned()])
}
fn visit_seq<A: de::SeqAccess<'de>>(
self,
mut seq: A,
) -> std::result::Result<Vec<String>, A::Error> {
let mut v = Vec::with_capacity(seq.size_hint().unwrap_or(0));
while let Some(s) = seq.next_element()? {
v.push(s);
}
Ok(v)
}
}
deserializer.deserialize_any(StringOrVec)
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct DsctGetStatsParams {
pub file: String,
#[serde(default, deserialize_with = "string_or_vec")]
pub protocols: Vec<String>,
#[serde(default)]
pub top_talkers: bool,
#[serde(default)]
pub stream_summary: bool,
#[serde(default)]
pub top: Option<usize>,
#[serde(default)]
pub decode_as: Vec<String>,
#[serde(default)]
pub esp_sa: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct DsctListFieldsParams {
#[serde(default, deserialize_with = "string_or_vec")]
pub protocols: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub(crate) struct DsctGetSchemaParams {
#[serde(default)]
pub command: Option<String>,
}
pub(crate) fn do_get_stats(
arguments: serde_json::Value,
limits: &ResourceLimits,
) -> std::result::Result<serde_json::Value, String> {
let params: DsctGetStatsParams =
serde_json::from_value(arguments).map_err(|e| format!("invalid arguments: {e}"))?;
do_get_stats_inner(params, limits).map_err(|e| format_error(&e))
}
fn do_get_stats_inner(
params: DsctGetStatsParams,
limits: &ResourceLimits,
) -> DsctResult<serde_json::Value> {
let file = PathBuf::from(¶ms.file);
let file_meta =
std::fs::metadata(&file).context(format!("failed to stat file: {}", file.display()))?;
if file_meta.len() > limits.max_file_size {
return Err(DsctError::msg(format!(
"file size ({} bytes) exceeds limit ({} bytes)",
file_meta.len(),
limits.max_file_size
)));
}
let top_n = params.top.unwrap_or(10);
let mut registry = DissectorRegistry::default();
decode_as::parse_and_apply(&mut registry, ¶ms.decode_as)?;
esp_sa::parse_and_apply(®istry, ¶ms.esp_sa)?;
let proto_norm: Vec<String> = params
.protocols
.iter()
.map(|p| normalize_protocol_name(p))
.collect();
let enable_tcp_streams =
params.stream_summary && (proto_norm.is_empty() || proto_norm.iter().any(|p| p == "tcp"));
let flags =
stats::StatsFlags::from_protocols(&proto_norm, params.top_talkers, enable_tcp_streams);
let mut collector = stats::StatsCollector::from_flags(&flags);
let deadline = Instant::now() + limits.timeout;
let reader = crate::input::CaptureReader::open(&file).context("failed to open capture file")?;
let mut packets_seen: u64 = 0;
let mut dissect_buf = packet_dissector_core::packet::DissectBuffer::new();
reader.for_each_packet(|meta, data| {
packets_seen += 1;
if packets_seen.is_multiple_of(1024) && Instant::now() > deadline {
return Ok(ControlFlow::Break(()));
}
collector.record_meta(meta.timestamp_secs, meta.timestamp_usecs);
let dissect_buf = dissect_buf.clear_into();
if registry
.dissect_with_link_type(data, meta.link_type, dissect_buf)
.is_ok()
{
let packet = packet_dissector_core::packet::Packet::new(dissect_buf, data);
collector.process_packet(
&packet,
meta.timestamp_secs,
meta.timestamp_usecs,
meta.original_length,
);
}
Ok(ControlFlow::Continue(()))
})?;
let output = collector.finalize(top_n);
serde_json::to_value(&output).context("failed to serialize stats")
}
pub(crate) fn do_list_protocols() -> std::result::Result<serde_json::Value, String> {
let registry = DissectorRegistry::default();
let schemas = registry.all_field_schemas();
let entries: Vec<serde_json::Value> = schemas
.iter()
.map(|s| {
serde_json::json!({
"name": s.short_name,
"full_name": s.name,
})
})
.collect();
Ok(serde_json::Value::Array(entries))
}
pub(crate) fn do_list_fields(
arguments: serde_json::Value,
) -> std::result::Result<serde_json::Value, String> {
let params: DsctListFieldsParams =
serde_json::from_value(arguments).map_err(|e| format!("invalid arguments: {e}"))?;
do_list_fields_inner(params).map_err(|e| format_error(&e))
}
fn do_list_fields_inner(params: DsctListFieldsParams) -> DsctResult<serde_json::Value> {
let registry = DissectorRegistry::default();
let schemas = registry.all_field_schemas();
let filter_normalized: Vec<String> = params
.protocols
.iter()
.map(|s| normalize_protocol_name(s))
.collect();
let mut entries = Vec::new();
for s in &schemas {
let short = normalize_protocol_name(s.short_name);
if !filter_normalized.is_empty() && !filter_normalized.contains(&short) {
continue;
}
for fd in s.fields {
entries.push(crate::schema::fd_to_json(
fd,
s.short_name,
s.short_name,
s.name,
));
}
}
Ok(serde_json::Value::Array(entries))
}
pub(crate) fn do_get_schema(
arguments: serde_json::Value,
) -> std::result::Result<serde_json::Value, String> {
let params: DsctGetSchemaParams =
serde_json::from_value(arguments).map_err(|e| format!("invalid arguments: {e}"))?;
do_get_schema_inner(params).map_err(|e| format_error(&e))
}
fn do_get_schema_inner(params: DsctGetSchemaParams) -> DsctResult<serde_json::Value> {
let cmd = params.command.as_deref().unwrap_or("read");
match cmd {
"read" => Ok(crate::schema::read_schema()),
"stats" => Ok(crate::schema::stats_schema()),
other => Err(DsctError::invalid_argument(format!(
"unknown command '{other}'. Available: read, stats"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn list_protocols_returns_json_array() {
let result = do_list_protocols();
let value = result.expect("list_protocols should succeed");
let arr = value.as_array().expect("should be a JSON array");
assert!(!arr.is_empty());
let first = &arr[0];
assert!(first.get("name").is_some());
assert!(first.get("full_name").is_some());
}
#[test]
fn list_fields_returns_json_array() {
let result = do_list_fields(serde_json::json!({}));
let value = result.expect("list_fields should succeed");
let arr = value.as_array().expect("should be a JSON array");
assert!(!arr.is_empty());
}
#[test]
fn list_fields_filtered_by_protocol() {
let result = do_list_fields(serde_json::json!({"protocols": ["dns"]}));
let value = result.expect("list_fields should succeed");
let arr = value.as_array().expect("should be a JSON array");
assert!(!arr.is_empty());
for entry in arr {
assert_eq!(entry["protocol"].as_str().unwrap().to_lowercase(), "dns");
}
}
#[test]
fn list_fields_filtered_by_protocol_single_string() {
let result = do_list_fields(serde_json::json!({"protocols": "dns"}));
let value = result.expect("list_fields should accept a single string");
let arr = value.as_array().expect("should be a JSON array");
assert!(!arr.is_empty());
for entry in arr {
assert_eq!(entry["protocol"].as_str().unwrap().to_lowercase(), "dns");
}
}
#[test]
fn get_schema_read() {
let result = do_get_schema(serde_json::json!({"command": "read"}));
let value = result.expect("get_schema read should succeed");
assert_eq!(value["title"], "dsct read packet record");
}
#[test]
fn get_schema_stats() {
let result = do_get_schema(serde_json::json!({"command": "stats"}));
let value = result.expect("get_schema stats should succeed");
assert_eq!(value["title"], "dsct stats output");
}
#[test]
fn get_schema_default_is_read() {
let result = do_get_schema(serde_json::json!({}));
let value = result.expect("get_schema default should succeed");
assert_eq!(value["title"], "dsct read packet record");
}
#[test]
fn get_schema_unknown_returns_error() {
let result = do_get_schema(serde_json::json!({"command": "nonexistent"}));
assert!(result.is_err());
}
#[test]
fn get_stats_missing_file_returns_error() {
let limits = ResourceLimits::default();
let result = do_get_stats(
serde_json::json!({
"file": "/nonexistent/file.pcap",
}),
&limits,
);
assert!(result.is_err());
}
}