use crate::datasource::ValueConstraint;
use crate::interval::AllenRelation;
use std::collections::{HashMap, HashSet};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Var(pub String);
impl Var {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
}
impl fmt::Display for Var {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "?{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum Target<V> {
Bind(Var),
Literal(V),
Constraint(ValueConstraint<V>),
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Clause<L, V> {
pub source: Var,
pub label: L,
pub target: Target<V>,
pub negated: bool,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MetricGap {
pub min: Option<f64>,
pub max: Option<f64>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct TemporalConstraint {
pub left: Var,
pub relation: AllenRelation,
pub right: Var,
pub gap: Option<MetricGap>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Negation<L, V> {
pub between_start: Var,
pub between_end: Option<Var>,
pub clauses: Vec<Clause<L, V>>,
#[doc(hidden)]
pub is_global: bool,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Pattern<L, V> {
pub name: String,
pub stages: Vec<Stage<L, V>>,
pub temporal: Vec<TemporalConstraint>,
pub negations: Vec<Negation<L, V>>,
pub group: Option<String>,
pub metadata: HashMap<String, String>,
pub deadline_ticks: Option<u64>,
pub repeat_range: Option<RepeatRange>,
#[cfg_attr(feature = "serde", serde(default))]
pub unordered_groups: Vec<Vec<usize>>,
#[cfg_attr(feature = "serde", serde(default))]
pub private: bool,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct RepeatRange {
pub stage_start: usize,
pub stage_end: usize,
pub min_reps: usize,
pub max_reps: Option<usize>,
pub shared_vars: HashSet<String>,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct Stage<L, V> {
pub anchor: Var,
pub clauses: Vec<Clause<L, V>>,
}
impl<V> Target<V> {
pub fn map<V2>(&self, f: &impl Fn(&V) -> V2) -> Target<V2> {
match self {
Target::Bind(v) => Target::Bind(v.clone()),
Target::Literal(v) => Target::Literal(f(v)),
Target::Constraint(c) => Target::Constraint(c.map(f)),
}
}
}
impl<L, V> Clause<L, V> {
pub fn map_types<L2, V2>(
&self,
label_fn: &impl Fn(&L) -> L2,
value_fn: &impl Fn(&V) -> V2,
) -> Clause<L2, V2> {
Clause {
source: self.source.clone(),
label: label_fn(&self.label),
target: self.target.map(value_fn),
negated: self.negated,
}
}
}
impl<L, V> Stage<L, V> {
pub fn map_types<L2, V2>(
&self,
label_fn: &impl Fn(&L) -> L2,
value_fn: &impl Fn(&V) -> V2,
) -> Stage<L2, V2> {
Stage {
anchor: self.anchor.clone(),
clauses: self
.clauses
.iter()
.map(|c| c.map_types(label_fn, value_fn))
.collect(),
}
}
}
impl<L, V> Negation<L, V> {
pub fn map_types<L2, V2>(
&self,
label_fn: &impl Fn(&L) -> L2,
value_fn: &impl Fn(&V) -> V2,
) -> Negation<L2, V2> {
Negation {
between_start: self.between_start.clone(),
between_end: self.between_end.clone(),
clauses: self
.clauses
.iter()
.map(|c| c.map_types(label_fn, value_fn))
.collect(),
is_global: self.is_global,
}
}
}
impl<L, V> Pattern<L, V> {
pub fn map_types<L2, V2>(
&self,
label_fn: impl Fn(&L) -> L2,
value_fn: impl Fn(&V) -> V2,
) -> Pattern<L2, V2> {
Pattern {
name: self.name.clone(),
stages: self
.stages
.iter()
.map(|s| s.map_types(&label_fn, &value_fn))
.collect(),
temporal: self.temporal.clone(),
negations: self
.negations
.iter()
.map(|n| n.map_types(&label_fn, &value_fn))
.collect(),
group: self.group.clone(),
metadata: self.metadata.clone(),
deadline_ticks: self.deadline_ticks,
repeat_range: self.repeat_range.clone(),
unordered_groups: self.unordered_groups.clone(),
private: self.private,
}
}
pub fn unordered_group_for(&self, stage_idx: usize) -> Option<&Vec<usize>> {
self.unordered_groups
.iter()
.find(|g| g.contains(&stage_idx))
}
pub fn same_unordered_group(&self, a: usize, b: usize) -> bool {
self.unordered_groups
.iter()
.any(|g| g.contains(&a) && g.contains(&b))
}
pub fn all_vars(&self) -> Vec<&Var> {
let mut vars = Vec::new();
for stage in &self.stages {
vars.push(&stage.anchor);
for clause in &stage.clauses {
vars.push(&clause.source);
if let Target::Bind(ref v) = clause.target {
vars.push(v);
}
}
}
for neg in &self.negations {
vars.push(&neg.between_start);
if let Some(ref v) = neg.between_end {
vars.push(v);
}
for clause in &neg.clauses {
vars.push(&clause.source);
if let Target::Bind(ref v) = clause.target {
vars.push(v);
}
}
}
vars.sort_by_key(|v| &v.0);
vars.dedup();
vars
}
pub fn condition_count(&self) -> usize {
self.stages.iter().map(|s| s.clauses.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::builder::PatternBuilder;
use crate::datasource::ValueConstraint;
use crate::interval::AllenRelation;
#[test]
fn map_types_transforms_labels_and_values() {
let pattern = PatternBuilder::<String, String>::new("test")
.stage("e1", |s| {
s.edge("e1", "eventType".into(), "betray".into()).edge_bind(
"e1",
"actor".into(),
"char",
)
})
.build();
let mapped = pattern.map_types(|l| l.len() as u32, |v| v.len() as i64);
assert_eq!(mapped.name, "test");
assert_eq!(mapped.stages[0].clauses[0].label, 9); assert_eq!(mapped.stages[0].clauses[0].target, Target::Literal(6)); assert_eq!(mapped.stages[0].clauses[1].label, 5); assert!(matches!(
mapped.stages[0].clauses[1].target,
Target::Bind(_)
));
}
#[test]
fn map_types_preserves_structure() {
let pattern = PatternBuilder::<String, String>::new("arc")
.stage("e1", |s| s.edge("e1", "type".into(), "setup".into()))
.stage("e2", |s| s.edge("e2", "type".into(), "payoff".into()))
.temporal("e1", AllenRelation::Before, "e2")
.unless_between("e1", "e2", |n| n.edge("mid", "type".into(), "block".into()))
.build();
let mapped = pattern.map_types(|l| l.to_uppercase(), |v| v.to_uppercase());
assert_eq!(mapped.stages.len(), 2);
assert_eq!(mapped.temporal.len(), 1);
assert_eq!(mapped.negations.len(), 1);
assert_eq!(mapped.temporal[0].relation, AllenRelation::Before);
assert_eq!(mapped.temporal[0].left, Var::new("e1"));
assert_eq!(mapped.temporal[0].right, Var::new("e2"));
assert_eq!(mapped.negations[0].between_start, Var::new("e1"));
assert_eq!(mapped.stages[0].clauses[0].label, "TYPE");
assert_eq!(
mapped.stages[0].clauses[0].target,
Target::Literal("SETUP".into())
);
}
#[test]
fn condition_count_sums_clauses_across_stages() {
let pattern = PatternBuilder::<String, String>::new("test")
.stage("e1", |s| {
s.edge("e1", "eventType".into(), "betray".into()).edge_bind(
"e1",
"actor".into(),
"char",
)
})
.stage("e2", |s| s.edge("e2", "eventType".into(), "betray".into()))
.build();
assert_eq!(pattern.condition_count(), 3);
let empty = PatternBuilder::<String, String>::new("empty").build();
assert_eq!(empty.condition_count(), 0);
let with_negation = PatternBuilder::<String, String>::new("neg")
.stage("e1", |s| s.edge("e1", "type".into(), "start".into()))
.stage("e2", |s| s.edge("e2", "type".into(), "end".into()))
.unless_between("e1", "e2", |n| n.edge("mid", "type".into(), "block".into()))
.build();
assert_eq!(with_negation.condition_count(), 2); }
#[test]
fn map_types_handles_all_constraint_variants() {
let double = |v: &i32| (*v * 2) as i64;
assert_eq!(ValueConstraint::Eq(5).map(&double), ValueConstraint::Eq(10));
assert_eq!(ValueConstraint::Lt(3).map(&double), ValueConstraint::Lt(6));
assert_eq!(ValueConstraint::Gt(4).map(&double), ValueConstraint::Gt(8));
assert_eq!(
ValueConstraint::Lte(2).map(&double),
ValueConstraint::Lte(4)
);
assert_eq!(
ValueConstraint::Gte(1).map(&double),
ValueConstraint::Gte(2)
);
assert_eq!(
ValueConstraint::Between(1, 10).map(&double),
ValueConstraint::Between(2, 20)
);
assert_eq!(
ValueConstraint::<i32>::Any.map(&double),
ValueConstraint::<i64>::Any
);
let clause = Clause {
source: Var::new("e"),
label: "score".to_string(),
target: Target::Constraint(ValueConstraint::Gt(50)),
negated: false,
};
let mapped = clause.map_types(&|l: &String| l.clone(), &double);
assert_eq!(mapped.target, Target::Constraint(ValueConstraint::Gt(100)));
}
#[test]
fn builder_metadata_propagates() {
let pattern = PatternBuilder::<String, String>::new("test")
.metadata("severity", "high")
.metadata("mitre", "T1078")
.stage("e1", |s| s.edge("e1", "type".into(), "x".into()))
.build();
assert_eq!(pattern.metadata.get("severity").unwrap(), "high");
assert_eq!(pattern.metadata.get("mitre").unwrap(), "T1078");
assert_eq!(pattern.metadata.len(), 2);
}
#[test]
fn metadata_empty_by_default() {
let pattern = PatternBuilder::<String, String>::new("test")
.stage("e1", |s| s.edge("e1", "type".into(), "x".into()))
.build();
assert!(pattern.metadata.is_empty());
}
#[test]
fn private_pattern_field() {
let pattern = PatternBuilder::<String, String>::new("helper")
.stage("e1", |s| s.edge("e1", "type".into(), "test".into()))
.private()
.build();
assert!(pattern.private);
let public = PatternBuilder::<String, String>::new("visible")
.stage("e1", |s| s.edge("e1", "type".into(), "test".into()))
.build();
assert!(!public.private);
}
#[test]
fn map_types_preserves_metadata() {
let pattern = PatternBuilder::<String, String>::new("test")
.metadata("key", "value")
.stage("e1", |s| s.edge("e1", "type".into(), "x".into()))
.build();
let mapped = pattern.map_types(|l| l.to_uppercase(), |v| v.to_uppercase());
assert_eq!(mapped.metadata.get("key").unwrap(), "value");
}
}