use std::path::Path;
use lanscope_common::{FlowKey, FlowStats};
use ort::session::Session;
use ort::value::Value;
use crate::alert::{Alert, Severity};
use crate::error::{Error, Result};
use crate::features::{self, FEATURE_COUNT};
use crate::netfmt;
use super::Detector;
pub struct OnnxDetector {
session: Session,
threshold: f32,
}
impl OnnxDetector {
pub fn from_path(path: &Path) -> Result<Self> {
if !path.exists() {
return Err(Error::ModelNotFound(path.to_path_buf()));
}
let session = Session::builder()
.and_then(|mut b| b.commit_from_file(path))
.map_err(|e| Error::Config(format!("failed to load ONNX model: {e}")))?;
Ok(Self {
session,
threshold: 0.5,
})
}
fn score(&mut self, feats: &[f32; FEATURE_COUNT]) -> Option<f32> {
let input = Value::from_array(([1usize, FEATURE_COUNT], feats.to_vec())).ok()?;
let outputs = self.session.run(ort::inputs![input]).ok()?;
for (_name, value) in outputs.iter() {
if let Ok((_shape, data)) = value.try_extract_tensor::<f32>() {
if let Some(p) = data.last() {
return Some(*p);
}
}
}
None
}
}
impl Detector for OnnxDetector {
fn on_flow(&mut self, key: &FlowKey, stats: &FlowStats, now: i64) -> Vec<Alert> {
let feats = features::extract(key, stats);
match self.score(&feats) {
Some(p) if p >= self.threshold => vec![Alert::new(
now,
None,
Severity::Critical,
"ml_malicious",
format!(
"ML classifier flagged flow from {} → :{} (score {:.2})",
netfmt::fmt_ipv4(key.src_ip),
key.dst_port,
p
),
)],
_ => Vec::new(),
}
}
}