use std::collections::HashMap;
use std::sync::RwLock;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct SemanticTrigger {
pub id: String,
pub name: String,
pub description: String,
pub query: String,
pub embedding: Option<Vec<f32>>,
pub threshold: f32,
pub action: TriggerAction,
pub enabled: bool,
pub priority: i32,
pub max_fires_per_window: Option<usize>,
pub rate_limit_window_secs: Option<u64>,
pub tags: Vec<String>,
pub metadata: HashMap<String, String>,
pub created_at: f64,
}
#[derive(Debug, Clone)]
pub enum TriggerAction {
Notify {
channel: String,
template: Option<String>,
},
Route {
target: String,
context: Option<String>,
},
Escalate {
level: EscalationLevel,
reason: Option<String>,
},
SpawnAgent {
agent_type: String,
config: HashMap<String, String>,
},
Log {
level: LogLevel,
message: Option<String>,
},
Webhook {
url: String,
method: String,
headers: HashMap<String, String>,
},
Callback {
function: String,
args: HashMap<String, String>,
},
Chain(Vec<TriggerAction>),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EscalationLevel {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Debug,
Info,
Warn,
Error,
}
#[derive(Debug, Clone)]
pub struct TriggerEvent {
pub id: String,
pub content: String,
pub embedding: Option<Vec<f32>>,
pub source: EventSource,
pub metadata: HashMap<String, String>,
pub timestamp: f64,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EventSource {
UserMessage,
SystemEvent,
DataInsert,
MemoryCompaction,
ExternalApi,
AgentAction,
Custom(String),
}
#[derive(Debug, Clone)]
pub struct TriggerMatch {
pub trigger_id: String,
pub score: f32,
pub event_id: String,
pub timestamp: f64,
pub action_executed: bool,
pub execution_result: Option<String>,
}
#[derive(Debug, Clone, Default)]
pub struct TriggerStats {
pub events_processed: usize,
pub triggers_matched: usize,
pub actions_executed: usize,
pub matches_by_trigger: HashMap<String, usize>,
pub rate_limited: usize,
}
pub struct TriggerIndex {
triggers: RwLock<HashMap<String, SemanticTrigger>>,
trigger_embeddings: RwLock<Vec<(String, Vec<f32>)>>,
rate_limits: RwLock<HashMap<String, (usize, f64)>>,
recent_matches: RwLock<Vec<TriggerMatch>>,
stats: RwLock<TriggerStats>,
max_recent_matches: usize,
}
impl TriggerIndex {
pub fn new() -> Self {
Self {
triggers: RwLock::new(HashMap::new()),
trigger_embeddings: RwLock::new(Vec::new()),
rate_limits: RwLock::new(HashMap::new()),
recent_matches: RwLock::new(Vec::new()),
stats: RwLock::new(TriggerStats::default()),
max_recent_matches: 1000,
}
}
pub fn register_trigger(&self, mut trigger: SemanticTrigger) -> Result<(), TriggerError> {
if trigger.id.is_empty() {
return Err(TriggerError::InvalidTrigger(
"ID cannot be empty".to_string(),
));
}
if trigger.created_at == 0.0 {
trigger.created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
}
{
let mut triggers = self.triggers.write().unwrap();
triggers.insert(trigger.id.clone(), trigger.clone());
}
if let Some(embedding) = &trigger.embedding {
let mut embeddings = self.trigger_embeddings.write().unwrap();
embeddings.push((trigger.id.clone(), embedding.clone()));
}
Ok(())
}
pub fn remove_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
let removed = {
let mut triggers = self.triggers.write().unwrap();
triggers.remove(trigger_id)
};
if removed.is_some() {
let mut embeddings = self.trigger_embeddings.write().unwrap();
embeddings.retain(|(id, _)| id != trigger_id);
}
removed
}
pub fn set_enabled(&self, trigger_id: &str, enabled: bool) -> bool {
let mut triggers = self.triggers.write().unwrap();
if let Some(trigger) = triggers.get_mut(trigger_id) {
trigger.enabled = enabled;
true
} else {
false
}
}
pub fn set_threshold(&self, trigger_id: &str, threshold: f32) -> bool {
let mut triggers = self.triggers.write().unwrap();
if let Some(trigger) = triggers.get_mut(trigger_id) {
trigger.threshold = threshold.clamp(0.0, 1.0);
true
} else {
false
}
}
pub fn process_event(&self, event: &TriggerEvent) -> Vec<TriggerMatch> {
let mut matches = Vec::new();
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs_f64();
{
let mut stats = self.stats.write().unwrap();
stats.events_processed += 1;
}
let event_embedding = match &event.embedding {
Some(emb) => emb.clone(),
None => {
return matches;
}
};
let candidates = self.find_candidates(&event_embedding, 10);
let triggers = self.triggers.read().unwrap();
for (trigger_id, score) in candidates {
if let Some(trigger) = triggers.get(&trigger_id) {
if !trigger.enabled {
continue;
}
if score < trigger.threshold {
continue;
}
if !self.check_rate_limit(&trigger_id, trigger, now) {
let mut stats = self.stats.write().unwrap();
stats.rate_limited += 1;
continue;
}
let trigger_match = TriggerMatch {
trigger_id: trigger_id.clone(),
score,
event_id: event.id.clone(),
timestamp: now,
action_executed: false,
execution_result: None,
};
matches.push(trigger_match);
{
let mut stats = self.stats.write().unwrap();
stats.triggers_matched += 1;
*stats
.matches_by_trigger
.entry(trigger_id.clone())
.or_insert(0) += 1;
}
}
}
matches.sort_by(|a, b| {
let trigger_a = triggers.get(&a.trigger_id);
let trigger_b = triggers.get(&b.trigger_id);
match (trigger_a, trigger_b) {
(Some(ta), Some(tb)) => ta.priority.cmp(&tb.priority).then_with(|| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
}),
_ => std::cmp::Ordering::Equal,
}
});
{
let mut recent = self.recent_matches.write().unwrap();
for m in &matches {
recent.push(m.clone());
}
while recent.len() > self.max_recent_matches {
recent.remove(0);
}
}
matches
}
fn find_candidates(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
let embeddings = self.trigger_embeddings.read().unwrap();
if embeddings.is_empty() {
return Vec::new();
}
let mut candidates: Vec<(String, f32)> = embeddings
.iter()
.map(|(id, emb)| {
let score = cosine_similarity(query, emb);
(id.clone(), score)
})
.collect();
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
candidates.truncate(k);
candidates
}
fn check_rate_limit(&self, trigger_id: &str, trigger: &SemanticTrigger, now: f64) -> bool {
let max_fires = match trigger.max_fires_per_window {
Some(max) => max,
None => return true, };
let window_secs = trigger.rate_limit_window_secs.unwrap_or(60);
let mut rate_limits = self.rate_limits.write().unwrap();
let entry = rate_limits
.entry(trigger_id.to_string())
.or_insert((0, now));
if now - entry.1 > window_secs as f64 {
entry.0 = 1;
entry.1 = now;
return true;
}
if entry.0 < max_fires {
entry.0 += 1;
return true;
}
false
}
pub fn execute_action(&self, trigger_match: &mut TriggerMatch) -> Result<(), TriggerError> {
let triggers = self.triggers.read().unwrap();
let trigger = triggers
.get(&trigger_match.trigger_id)
.ok_or_else(|| TriggerError::TriggerNotFound(trigger_match.trigger_id.clone()))?;
let result = self.execute_action_impl(&trigger.action, trigger_match)?;
trigger_match.action_executed = true;
trigger_match.execution_result = Some(result);
{
let mut stats = self.stats.write().unwrap();
stats.actions_executed += 1;
}
Ok(())
}
fn execute_action_impl(
&self,
action: &TriggerAction,
trigger_match: &TriggerMatch,
) -> Result<String, TriggerError> {
match action {
TriggerAction::Notify { channel, template } => {
Ok(format!(
"Notified channel '{}' (template: {:?})",
channel, template
))
}
TriggerAction::Route { target, context } => {
Ok(format!("Routed to '{}' (context: {:?})", target, context))
}
TriggerAction::Escalate { level, reason } => Ok(format!(
"Escalated at level {:?} (reason: {:?})",
level, reason
)),
TriggerAction::SpawnAgent {
agent_type,
config: _,
} => Ok(format!("Spawned agent of type '{}'", agent_type)),
TriggerAction::Log { level, message } => {
let msg = message.as_deref().unwrap_or(&trigger_match.trigger_id);
Ok(format!("Logged at {:?}: {}", level, msg))
}
TriggerAction::Webhook {
url,
method,
headers: _,
} => {
Ok(format!("Called webhook {} {}", method, url))
}
TriggerAction::Callback { function, args: _ } => {
Ok(format!("Called callback function '{}'", function))
}
TriggerAction::Chain(actions) => {
let mut results = Vec::new();
for sub_action in actions {
let result = self.execute_action_impl(sub_action, trigger_match)?;
results.push(result);
}
Ok(format!("Chain executed: [{}]", results.join(", ")))
}
}
}
pub fn list_triggers(&self) -> Vec<SemanticTrigger> {
self.triggers.read().unwrap().values().cloned().collect()
}
pub fn get_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
self.triggers.read().unwrap().get(trigger_id).cloned()
}
pub fn recent_matches(&self, limit: usize) -> Vec<TriggerMatch> {
let matches = self.recent_matches.read().unwrap();
matches.iter().rev().take(limit).cloned().collect()
}
pub fn stats(&self) -> TriggerStats {
self.stats.read().unwrap().clone()
}
pub fn clear_stats(&self) {
let mut stats = self.stats.write().unwrap();
*stats = TriggerStats::default();
}
}
impl Default for TriggerIndex {
fn default() -> Self {
Self::new()
}
}
pub struct TriggerBuilder {
trigger: SemanticTrigger,
}
impl TriggerBuilder {
pub fn new(id: &str, query: &str) -> Self {
Self {
trigger: SemanticTrigger {
id: id.to_string(),
name: id.to_string(),
description: String::new(),
query: query.to_string(),
embedding: None,
threshold: 0.8,
action: TriggerAction::Log {
level: LogLevel::Info,
message: None,
},
enabled: true,
priority: 0,
max_fires_per_window: None,
rate_limit_window_secs: None,
tags: Vec::new(),
metadata: HashMap::new(),
created_at: 0.0,
},
}
}
pub fn name(mut self, name: &str) -> Self {
self.trigger.name = name.to_string();
self
}
pub fn description(mut self, description: &str) -> Self {
self.trigger.description = description.to_string();
self
}
pub fn embedding(mut self, embedding: Vec<f32>) -> Self {
self.trigger.embedding = Some(embedding);
self
}
pub fn threshold(mut self, threshold: f32) -> Self {
self.trigger.threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn action(mut self, action: TriggerAction) -> Self {
self.trigger.action = action;
self
}
pub fn notify(mut self, channel: &str) -> Self {
self.trigger.action = TriggerAction::Notify {
channel: channel.to_string(),
template: None,
};
self
}
pub fn route(mut self, target: &str) -> Self {
self.trigger.action = TriggerAction::Route {
target: target.to_string(),
context: None,
};
self
}
pub fn escalate(mut self, level: EscalationLevel) -> Self {
self.trigger.action = TriggerAction::Escalate {
level,
reason: None,
};
self
}
pub fn priority(mut self, priority: i32) -> Self {
self.trigger.priority = priority;
self
}
pub fn rate_limit(mut self, max_fires: usize, window_secs: u64) -> Self {
self.trigger.max_fires_per_window = Some(max_fires);
self.trigger.rate_limit_window_secs = Some(window_secs);
self
}
pub fn tag(mut self, tag: &str) -> Self {
self.trigger.tags.push(tag.to_string());
self
}
pub fn enabled(mut self, enabled: bool) -> Self {
self.trigger.enabled = enabled;
self
}
pub fn build(self) -> SemanticTrigger {
self.trigger
}
}
#[derive(Debug, Clone)]
pub enum TriggerError {
InvalidTrigger(String),
TriggerNotFound(String),
ActionFailed(String),
RateLimitExceeded(String),
EmbeddingError(String),
}
impl std::fmt::Display for TriggerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidTrigger(msg) => write!(f, "Invalid trigger: {}", msg),
Self::TriggerNotFound(id) => write!(f, "Trigger not found: {}", id),
Self::ActionFailed(msg) => write!(f, "Action failed: {}", msg),
Self::RateLimitExceeded(id) => write!(f, "Rate limit exceeded for trigger: {}", id),
Self::EmbeddingError(msg) => write!(f, "Embedding error: {}", msg),
}
}
}
impl std::error::Error for TriggerError {}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-10 || norm_b < 1e-10 {
return 0.0;
}
dot / (norm_a * norm_b)
}
pub fn create_notify_trigger(
id: &str,
query: &str,
channel: &str,
embedding: Vec<f32>,
) -> SemanticTrigger {
TriggerBuilder::new(id, query)
.embedding(embedding)
.notify(channel)
.build()
}
pub fn create_escalation_trigger(
id: &str,
query: &str,
level: EscalationLevel,
embedding: Vec<f32>,
) -> SemanticTrigger {
TriggerBuilder::new(id, query)
.embedding(embedding)
.escalate(level)
.priority(-1) .build()
}
#[cfg(test)]
mod tests {
use super::*;
fn mock_embedding(seed: u64) -> Vec<f32> {
(0..128)
.map(|i| ((i as u64 + seed) % 100) as f32 / 100.0 - 0.5)
.collect()
}
#[test]
fn test_trigger_registration() {
let index = TriggerIndex::new();
let trigger = TriggerBuilder::new("privacy_concern", "user mentions privacy concerns")
.embedding(mock_embedding(1))
.threshold(0.75)
.escalate(EscalationLevel::High)
.build();
index.register_trigger(trigger).unwrap();
let triggers = index.list_triggers();
assert_eq!(triggers.len(), 1);
assert_eq!(triggers[0].id, "privacy_concern");
}
#[test]
fn test_trigger_matching() {
let index = TriggerIndex::new();
let trigger = TriggerBuilder::new("security_alert", "security vulnerability")
.embedding(mock_embedding(1))
.threshold(0.5) .notify("security-team")
.build();
index.register_trigger(trigger).unwrap();
let event = TriggerEvent {
id: "event_1".to_string(),
content: "possible security issue detected".to_string(),
embedding: Some(mock_embedding(1)), source: EventSource::SystemEvent,
metadata: HashMap::new(),
timestamp: 0.0,
};
let matches = index.process_event(&event);
assert!(!matches.is_empty());
assert_eq!(matches[0].trigger_id, "security_alert");
assert!(matches[0].score > 0.5);
}
#[test]
fn test_trigger_disable() {
let index = TriggerIndex::new();
let trigger = TriggerBuilder::new("test_trigger", "test")
.embedding(mock_embedding(1))
.threshold(0.5)
.build();
index.register_trigger(trigger).unwrap();
index.set_enabled("test_trigger", false);
let event = TriggerEvent {
id: "event_1".to_string(),
content: "test".to_string(),
embedding: Some(mock_embedding(1)),
source: EventSource::UserMessage,
metadata: HashMap::new(),
timestamp: 0.0,
};
let matches = index.process_event(&event);
assert!(matches.is_empty());
}
#[test]
fn test_rate_limiting() {
let index = TriggerIndex::new();
let trigger = TriggerBuilder::new("rate_limited", "test")
.embedding(mock_embedding(1))
.threshold(0.5)
.rate_limit(2, 60) .build();
index.register_trigger(trigger).unwrap();
let event = TriggerEvent {
id: "event_1".to_string(),
content: "test".to_string(),
embedding: Some(mock_embedding(1)),
source: EventSource::UserMessage,
metadata: HashMap::new(),
timestamp: 0.0,
};
let m1 = index.process_event(&event);
let m2 = index.process_event(&event);
let m3 = index.process_event(&event);
assert!(!m1.is_empty());
assert!(!m2.is_empty());
assert!(m3.is_empty());
let stats = index.stats();
assert!(stats.rate_limited >= 1);
}
#[test]
fn test_action_execution() {
let index = TriggerIndex::new();
let trigger = TriggerBuilder::new("log_trigger", "test")
.embedding(mock_embedding(1))
.threshold(0.5)
.action(TriggerAction::Log {
level: LogLevel::Info,
message: Some("Test message".to_string()),
})
.build();
index.register_trigger(trigger).unwrap();
let event = TriggerEvent {
id: "event_1".to_string(),
content: "test".to_string(),
embedding: Some(mock_embedding(1)),
source: EventSource::UserMessage,
metadata: HashMap::new(),
timestamp: 0.0,
};
let mut matches = index.process_event(&event);
assert!(!matches.is_empty());
index.execute_action(&mut matches[0]).unwrap();
assert!(matches[0].action_executed);
assert!(matches[0].execution_result.is_some());
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.01);
let c = vec![0.0, 1.0, 0.0];
let sim2 = cosine_similarity(&a, &c);
assert!(sim2.abs() < 0.01);
}
#[test]
fn test_trigger_builder() {
let trigger = TriggerBuilder::new("test", "test query")
.name("Test Trigger")
.description("A test trigger")
.threshold(0.85)
.priority(5)
.tag("test")
.tag("example")
.notify("test-channel")
.rate_limit(10, 300)
.build();
assert_eq!(trigger.id, "test");
assert_eq!(trigger.name, "Test Trigger");
assert_eq!(trigger.threshold, 0.85);
assert_eq!(trigger.priority, 5);
assert_eq!(trigger.tags.len(), 2);
assert_eq!(trigger.max_fires_per_window, Some(10));
}
}