use std::path::{Path, PathBuf};
use std::sync::Arc;
use rsigma_eval::event::Event;
use rsigma_eval::{
CorrelationConfig, CorrelationEngine, CorrelationSnapshot, Engine, Pipeline, ProcessResult,
parse_pipeline_file,
};
use rsigma_parser::SigmaCollection;
use crate::sources::{self, SourceResolver, TemplateExpander};
pub struct RuntimeEngine {
engine: EngineVariant,
pipelines: Vec<Pipeline>,
pipeline_paths: Vec<PathBuf>,
rules_path: std::path::PathBuf,
corr_config: CorrelationConfig,
include_event: bool,
source_resolver: Option<Arc<dyn SourceResolver>>,
allow_remote_include: bool,
bloom_prefilter: bool,
bloom_max_bytes: Option<usize>,
#[cfg(feature = "daachorse-index")]
cross_rule_ac: bool,
}
enum EngineVariant {
DetectionOnly(Box<Engine>),
WithCorrelations(Box<CorrelationEngine>),
}
pub struct EngineStats {
pub detection_rules: usize,
pub correlation_rules: usize,
pub state_entries: usize,
}
impl RuntimeEngine {
pub fn new(
rules_path: std::path::PathBuf,
pipelines: Vec<Pipeline>,
corr_config: CorrelationConfig,
include_event: bool,
) -> Self {
RuntimeEngine {
engine: EngineVariant::DetectionOnly(Box::new(Engine::new())),
pipelines,
pipeline_paths: Vec::new(),
rules_path,
corr_config,
include_event,
source_resolver: None,
allow_remote_include: false,
bloom_prefilter: false,
bloom_max_bytes: None,
#[cfg(feature = "daachorse-index")]
cross_rule_ac: false,
}
}
pub fn set_bloom_prefilter(&mut self, enabled: bool) {
self.bloom_prefilter = enabled;
}
pub fn set_bloom_max_bytes(&mut self, max_bytes: usize) {
self.bloom_max_bytes = Some(max_bytes);
}
#[cfg(feature = "daachorse-index")]
pub fn set_cross_rule_ac(&mut self, enabled: bool) {
self.cross_rule_ac = enabled;
}
pub fn set_source_resolver(&mut self, resolver: Arc<dyn SourceResolver>) {
self.source_resolver = Some(resolver);
}
pub fn source_resolver(&self) -> Option<&Arc<dyn SourceResolver>> {
self.source_resolver.as_ref()
}
pub fn set_allow_remote_include(&mut self, allow: bool) {
self.allow_remote_include = allow;
}
pub fn allow_remote_include(&self) -> bool {
self.allow_remote_include
}
pub fn set_pipeline_paths(&mut self, paths: Vec<PathBuf>) {
self.pipeline_paths = paths;
}
pub fn pipeline_paths(&self) -> &[PathBuf] {
&self.pipeline_paths
}
pub async fn resolve_dynamic_pipelines(&mut self) -> Result<(), String> {
let Some(resolver) = &self.source_resolver else {
return Ok(());
};
let mut resolved_pipelines = Vec::with_capacity(self.pipelines.len());
for pipeline in &self.pipelines {
if pipeline.is_dynamic() {
match sources::resolve_all(resolver.as_ref(), &pipeline.sources).await {
Ok(resolved_data) => {
let mut expanded = TemplateExpander::expand(pipeline, &resolved_data);
sources::include::expand_includes(
&mut expanded,
&resolved_data,
self.allow_remote_include,
)?;
resolved_pipelines.push(expanded);
}
Err(e) => {
return Err(format!(
"Failed to resolve dynamic pipeline '{}': {e}",
pipeline.name
));
}
}
} else {
resolved_pipelines.push(pipeline.clone());
}
}
self.pipelines = resolved_pipelines;
Ok(())
}
pub fn load_rules(&mut self) -> Result<EngineStats, String> {
let load_span = tracing::info_span!("load_rules", rules_path = %self.rules_path.display());
let _enter = load_span.enter();
let load_start = std::time::Instant::now();
if !self.pipeline_paths.is_empty() {
self.pipelines = reload_pipelines(&self.pipeline_paths)?;
}
if self.source_resolver.is_some() && self.pipelines.iter().any(|p| p.is_dynamic()) {
if let Ok(handle) = tokio::runtime::Handle::try_current() {
let pipelines = std::mem::take(&mut self.pipelines);
let resolver = self.source_resolver.clone().unwrap();
let allow_remote = self.allow_remote_include;
let resolved = tokio::task::block_in_place(|| {
handle.block_on(async {
resolve_pipelines_async(&resolver, &pipelines, allow_remote).await
})
});
match resolved {
Ok(p) => self.pipelines = p,
Err(e) => {
self.pipelines = pipelines;
tracing::warn!(error = %e, "Dynamic source resolution failed, using unresolved pipelines");
}
}
} else {
tracing::warn!("No tokio runtime available for dynamic source resolution");
}
}
let previous_state = self.export_state();
let collection = load_collection(&self.rules_path)?;
let has_correlations = !collection.correlations.is_empty();
if has_correlations {
let mut engine = CorrelationEngine::new(self.corr_config.clone());
engine.set_include_event(self.include_event);
if let Some(budget) = self.bloom_max_bytes {
engine.set_bloom_max_bytes(budget);
}
engine.set_bloom_prefilter(self.bloom_prefilter);
#[cfg(feature = "daachorse-index")]
engine.set_cross_rule_ac(self.cross_rule_ac);
for p in &self.pipelines {
engine.add_pipeline(p.clone());
}
engine
.add_collection(&collection)
.map_err(|e| format!("Error compiling rules: {e}"))?;
if let Some(snapshot) = previous_state {
engine.import_state(snapshot);
}
let stats = EngineStats {
detection_rules: engine.detection_rule_count(),
correlation_rules: engine.correlation_rule_count(),
state_entries: engine.state_count(),
};
self.engine = EngineVariant::WithCorrelations(Box::new(engine));
tracing::debug!(
detection_rules = stats.detection_rules,
correlation_rules = stats.correlation_rules,
duration_ms = load_start.elapsed().as_millis() as u64,
"Rule load complete",
);
Ok(stats)
} else {
let mut engine = Engine::new();
engine.set_include_event(self.include_event);
if let Some(budget) = self.bloom_max_bytes {
engine.set_bloom_max_bytes(budget);
}
engine.set_bloom_prefilter(self.bloom_prefilter);
#[cfg(feature = "daachorse-index")]
engine.set_cross_rule_ac(self.cross_rule_ac);
for p in &self.pipelines {
engine.add_pipeline(p.clone());
}
engine
.add_collection(&collection)
.map_err(|e| format!("Error compiling rules: {e}"))?;
let stats = EngineStats {
detection_rules: engine.rule_count(),
correlation_rules: 0,
state_entries: 0,
};
self.engine = EngineVariant::DetectionOnly(Box::new(engine));
tracing::debug!(
detection_rules = stats.detection_rules,
correlation_rules = stats.correlation_rules,
duration_ms = load_start.elapsed().as_millis() as u64,
"Rule load complete",
);
Ok(stats)
}
}
pub fn process_batch<E: Event + Sync>(&mut self, events: &[&E]) -> Vec<ProcessResult> {
match &mut self.engine {
EngineVariant::DetectionOnly(engine) => {
let batch_detections = engine.evaluate_batch(events);
batch_detections
.into_iter()
.map(|detections| ProcessResult {
detections,
correlations: vec![],
})
.collect()
}
EngineVariant::WithCorrelations(engine) => engine.process_batch(events),
}
}
pub fn stats(&self) -> EngineStats {
match &self.engine {
EngineVariant::DetectionOnly(engine) => EngineStats {
detection_rules: engine.rule_count(),
correlation_rules: 0,
state_entries: 0,
},
EngineVariant::WithCorrelations(engine) => EngineStats {
detection_rules: engine.detection_rule_count(),
correlation_rules: engine.correlation_rule_count(),
state_entries: engine.state_count(),
},
}
}
pub fn rules_path(&self) -> &Path {
&self.rules_path
}
pub fn pipelines(&self) -> &[Pipeline] {
&self.pipelines
}
pub fn corr_config(&self) -> &CorrelationConfig {
&self.corr_config
}
pub fn include_event(&self) -> bool {
self.include_event
}
pub fn export_state(&self) -> Option<CorrelationSnapshot> {
match &self.engine {
EngineVariant::DetectionOnly(_) => None,
EngineVariant::WithCorrelations(engine) => Some(engine.export_state()),
}
}
pub fn import_state(&mut self, snapshot: &CorrelationSnapshot) -> bool {
if let EngineVariant::WithCorrelations(engine) = &mut self.engine {
engine.import_state(snapshot.clone())
} else {
true
}
}
}
fn load_collection(path: &Path) -> Result<SigmaCollection, String> {
let collection = if path.is_dir() {
rsigma_parser::parse_sigma_directory(path)
.map_err(|e| format!("Error loading rules from {}: {e}", path.display()))?
} else {
rsigma_parser::parse_sigma_file(path)
.map_err(|e| format!("Error loading rule {}: {e}", path.display()))?
};
if !collection.errors.is_empty() {
tracing::warn!(
count = collection.errors.len(),
"Parse errors while loading rules"
);
for (i, err) in collection.errors.iter().take(3).enumerate() {
tracing::warn!(index = i + 1, error = %err, "Rule parse error detail");
}
}
Ok(collection)
}
fn reload_pipelines(paths: &[PathBuf]) -> Result<Vec<Pipeline>, String> {
let mut pipelines = Vec::with_capacity(paths.len());
for path in paths {
let pipeline = parse_pipeline_file(path)
.map_err(|e| format!("Error reloading pipeline {}: {e}", path.display()))?;
pipelines.push(pipeline);
}
pipelines.sort_by_key(|p| p.priority);
Ok(pipelines)
}
async fn resolve_pipelines_async(
resolver: &Arc<dyn SourceResolver>,
pipelines: &[Pipeline],
allow_remote_include: bool,
) -> Result<Vec<Pipeline>, String> {
let mut resolved_pipelines = Vec::with_capacity(pipelines.len());
for pipeline in pipelines {
if pipeline.is_dynamic() {
let resolved_data = sources::resolve_all(resolver.as_ref(), &pipeline.sources)
.await
.map_err(|e| {
format!(
"Failed to resolve dynamic pipeline '{}': {e}",
pipeline.name
)
})?;
let mut expanded = TemplateExpander::expand(pipeline, &resolved_data);
sources::include::expand_includes(&mut expanded, &resolved_data, allow_remote_include)?;
resolved_pipelines.push(expanded);
} else {
resolved_pipelines.push(pipeline.clone());
}
}
Ok(resolved_pipelines)
}