use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rsigma_eval::{EvaluationResult, ResultBody};
use tokio::sync::Semaphore;
use crate::metrics::{MetricsHook, NoopMetrics};
mod command;
mod http;
pub mod http_cache;
mod lookup;
mod scope;
mod template;
#[cfg(test)]
mod tests;
pub use command::{CommandEnricher, OutputFormat};
pub use http::{HttpEnricher, HttpEnricherClient, build_default_http_client};
pub use http_cache::{CacheKey, CacheOutcome, HttpResponseCache};
pub use lookup::LookupEnricher;
pub use scope::Scope;
pub use template::{TemplateEnricher, TemplateError, validate_template_namespace};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum EnricherKind {
Detection,
Correlation,
}
impl EnricherKind {
pub fn as_str(&self) -> &'static str {
match self {
EnricherKind::Detection => "detection",
EnricherKind::Correlation => "correlation",
}
}
pub fn matches(&self, body: &ResultBody) -> bool {
matches!(
(self, body),
(EnricherKind::Detection, ResultBody::Detection(_))
| (EnricherKind::Correlation, ResultBody::Correlation(_))
)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum OnError {
#[default]
Skip,
Null,
Drop,
}
#[derive(Debug, Clone)]
pub struct EnrichError {
pub enricher_id: String,
pub kind: EnrichErrorKind,
}
impl std::fmt::Display for EnrichError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "enricher '{}': {}", self.enricher_id, self.kind)
}
}
impl std::error::Error for EnrichError {}
#[derive(Debug, Clone)]
pub enum EnrichErrorKind {
Timeout,
Fetch(String),
Parse(String),
Extract(String),
TemplateRender(String),
}
impl std::fmt::Display for EnrichErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EnrichErrorKind::Timeout => write!(f, "timeout"),
EnrichErrorKind::Fetch(m) => write!(f, "fetch failed: {m}"),
EnrichErrorKind::Parse(m) => write!(f, "parse failed: {m}"),
EnrichErrorKind::Extract(m) => write!(f, "extract failed: {m}"),
EnrichErrorKind::TemplateRender(m) => write!(f, "template render failed: {m}"),
}
}
}
#[async_trait]
pub trait Enricher: Send + Sync {
fn kind(&self) -> EnricherKind;
fn id(&self) -> &str;
fn inject_field(&self) -> &str;
fn timeout(&self) -> Duration {
Duration::from_secs(5)
}
fn scope(&self) -> &Scope;
fn on_error(&self) -> OnError {
OnError::Skip
}
async fn enrich(&self, result: &mut EvaluationResult) -> Result<(), EnrichError>;
}
enum EnrichOutcome {
Ok,
Skip,
Null,
Drop,
Filtered,
}
pub struct EnrichmentPipeline {
enrichers: Vec<Box<dyn Enricher>>,
semaphore: Arc<Semaphore>,
metrics: Arc<dyn MetricsHook>,
}
impl std::fmt::Debug for EnrichmentPipeline {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnrichmentPipeline")
.field("enrichers", &self.enrichers.len())
.field("permits", &self.semaphore.available_permits())
.finish()
}
}
impl EnrichmentPipeline {
pub fn new(enrichers: Vec<Box<dyn Enricher>>, max_concurrent_enrichments: usize) -> Self {
let permits = if max_concurrent_enrichments == 0 {
16
} else {
max_concurrent_enrichments
};
Self {
enrichers,
semaphore: Arc::new(Semaphore::new(permits)),
metrics: Arc::new(NoopMetrics),
}
}
pub fn with_metrics(mut self, metrics: Arc<dyn MetricsHook>) -> Self {
for enricher in &self.enrichers {
metrics.register_enricher(enricher.id(), enricher.kind().as_str());
}
self.metrics = metrics;
self
}
pub fn is_empty(&self) -> bool {
self.enrichers.is_empty()
}
pub fn len(&self) -> usize {
self.enrichers.len()
}
pub fn enrichers(&self) -> impl Iterator<Item = &dyn Enricher> {
self.enrichers.iter().map(|e| &**e)
}
pub async fn run(&self, results: &mut Vec<EvaluationResult>) {
if self.enrichers.is_empty() || results.is_empty() {
return;
}
let mut drop_indices: Vec<usize> = Vec::new();
for (idx, result) in results.iter_mut().enumerate() {
let permit = self.semaphore.clone().acquire_owned().await.ok();
if permit.is_none() {
tracing::debug!("Enrichment semaphore closed, draining remaining results");
return;
}
let _permit = permit.unwrap();
let mut should_drop = false;
for enricher in &self.enrichers {
match Self::run_one(enricher.as_ref(), result, self.metrics.as_ref()).await {
EnrichOutcome::Drop => {
should_drop = true;
break;
}
EnrichOutcome::Ok
| EnrichOutcome::Skip
| EnrichOutcome::Null
| EnrichOutcome::Filtered => {}
}
}
if should_drop {
drop_indices.push(idx);
}
}
if !drop_indices.is_empty() {
for idx in drop_indices.into_iter().rev() {
results.swap_remove(idx);
}
}
}
async fn run_one(
enricher: &dyn Enricher,
result: &mut EvaluationResult,
metrics: &dyn MetricsHook,
) -> EnrichOutcome {
if !enricher.kind().matches(&result.body) {
return EnrichOutcome::Filtered;
}
if !enricher.scope().matches(result) {
return EnrichOutcome::Filtered;
}
let inject_field = enricher.inject_field().to_string();
let timeout = enricher.timeout();
let id = enricher.id().to_string();
let kind_label = enricher.kind().as_str();
let on_error = enricher.on_error();
metrics.on_enrichment_queue_depth_change(1);
let started = std::time::Instant::now();
let outcome = tokio::time::timeout(timeout, enricher.enrich(result)).await;
let elapsed = started.elapsed().as_secs_f64();
metrics.on_enrichment_queue_depth_change(-1);
let err = match outcome {
Ok(Ok(())) => {
metrics.on_enrichment_completed(&id, kind_label, "success", elapsed);
return EnrichOutcome::Ok;
}
Ok(Err(e)) => e,
Err(_) => EnrichError {
enricher_id: id.clone(),
kind: EnrichErrorKind::Timeout,
},
};
let is_timeout = matches!(err.kind, EnrichErrorKind::Timeout);
match on_error {
OnError::Skip => {
tracing::warn!(
enricher_id = %id,
kind = %kind_label,
error = %err,
"Enricher failed, skipping"
);
metrics.on_enrichment_completed(
&id,
kind_label,
if is_timeout { "timeout" } else { "skip" },
elapsed,
);
EnrichOutcome::Skip
}
OnError::Null => {
tracing::warn!(
enricher_id = %id,
kind = %kind_label,
error = %err,
"Enricher failed, injecting null"
);
let map = result
.header
.enrichments
.get_or_insert_with(serde_json::Map::new);
map.insert(inject_field, serde_json::Value::Null);
metrics.on_enrichment_completed(
&id,
kind_label,
if is_timeout { "timeout" } else { "error" },
elapsed,
);
EnrichOutcome::Null
}
OnError::Drop => {
tracing::warn!(
enricher_id = %id,
kind = %kind_label,
error = %err,
"Enricher failed, dropping result"
);
metrics.on_enrichment_completed(&id, kind_label, "drop", elapsed);
EnrichOutcome::Drop
}
}
}
}
impl Default for EnrichmentPipeline {
fn default() -> Self {
Self::new(Vec::new(), 16)
}
}
impl Clone for EnrichmentPipeline {
fn clone(&self) -> Self {
Self {
enrichers: Vec::new(),
semaphore: Arc::clone(&self.semaphore),
metrics: Arc::clone(&self.metrics),
}
}
}
pub fn inject_enrichment(
result: &mut EvaluationResult,
inject_field: &str,
value: serde_json::Value,
) {
let map = result
.header
.enrichments
.get_or_insert_with(serde_json::Map::new);
map.insert(inject_field.to_string(), value);
}
pub type EnricherFactory =
Arc<dyn Fn(&serde_json::Value) -> Result<Box<dyn Enricher>, String> + Send + Sync>;
fn registry() -> &'static std::sync::RwLock<std::collections::HashMap<String, EnricherFactory>> {
use std::sync::OnceLock;
static REGISTRY: OnceLock<
std::sync::RwLock<std::collections::HashMap<String, EnricherFactory>>,
> = OnceLock::new();
REGISTRY.get_or_init(|| std::sync::RwLock::new(std::collections::HashMap::new()))
}
pub fn register_builtin(name: &str, factory: EnricherFactory) -> Result<(), String> {
if matches!(name, "template" | "lookup" | "http" | "command") {
return Err(format!(
"cannot register '{name}': name is reserved for a built-in primitive"
));
}
let reg = registry();
let mut guard = reg
.write()
.map_err(|_| "enricher registry poisoned".to_string())?;
if guard.contains_key(name) {
return Err(format!("enricher type '{name}' is already registered"));
}
guard.insert(name.to_string(), factory);
Ok(())
}
pub fn lookup_builtin(name: &str) -> Option<EnricherFactory> {
let reg = registry();
let guard = reg.read().ok()?;
guard.get(name).cloned()
}
#[cfg(test)]
pub(crate) fn clear_builtin_registry() {
if let Ok(mut guard) = registry().write() {
guard.clear();
}
}