1use std::collections::HashMap;
52use std::sync::RwLock;
53use std::time::{SystemTime, UNIX_EPOCH};
54
55#[derive(Debug, Clone)]
61pub struct SemanticTrigger {
62 pub id: String,
64
65 pub name: String,
67
68 pub description: String,
70
71 pub query: String,
73
74 pub embedding: Option<Vec<f32>>,
76
77 pub threshold: f32,
79
80 pub action: TriggerAction,
82
83 pub enabled: bool,
85
86 pub priority: i32,
88
89 pub max_fires_per_window: Option<usize>,
91
92 pub rate_limit_window_secs: Option<u64>,
94
95 pub tags: Vec<String>,
97
98 pub metadata: HashMap<String, String>,
100
101 pub created_at: f64,
103}
104
105#[derive(Debug, Clone)]
107pub enum TriggerAction {
108 Notify {
110 channel: String,
111 template: Option<String>,
112 },
113
114 Route {
116 target: String,
117 context: Option<String>,
118 },
119
120 Escalate {
122 level: EscalationLevel,
123 reason: Option<String>,
124 },
125
126 SpawnAgent {
128 agent_type: String,
129 config: HashMap<String, String>,
130 },
131
132 Log {
134 level: LogLevel,
135 message: Option<String>,
136 },
137
138 Webhook {
140 url: String,
141 method: String,
142 headers: HashMap<String, String>,
143 },
144
145 Callback {
147 function: String,
148 args: HashMap<String, String>,
149 },
150
151 Chain(Vec<TriggerAction>),
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum EscalationLevel {
158 Low,
159 Medium,
160 High,
161 Critical,
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
166pub enum LogLevel {
167 Debug,
168 Info,
169 Warn,
170 Error,
171}
172
173#[derive(Debug, Clone)]
179pub struct TriggerEvent {
180 pub id: String,
182
183 pub content: String,
185
186 pub embedding: Option<Vec<f32>>,
188
189 pub source: EventSource,
191
192 pub metadata: HashMap<String, String>,
194
195 pub timestamp: f64,
197}
198
199#[derive(Debug, Clone, PartialEq, Eq)]
201pub enum EventSource {
202 UserMessage,
204 SystemEvent,
206 DataInsert,
208 MemoryCompaction,
210 ExternalApi,
212 AgentAction,
214 Custom(String),
216}
217
218#[derive(Debug, Clone)]
220pub struct TriggerMatch {
221 pub trigger_id: String,
223
224 pub score: f32,
226
227 pub event_id: String,
229
230 pub timestamp: f64,
232
233 pub action_executed: bool,
235
236 pub execution_result: Option<String>,
238}
239
240#[derive(Debug, Clone, Default)]
242pub struct TriggerStats {
243 pub events_processed: usize,
245
246 pub triggers_matched: usize,
248
249 pub actions_executed: usize,
251
252 pub matches_by_trigger: HashMap<String, usize>,
254
255 pub rate_limited: usize,
257}
258
259pub struct TriggerIndex {
265 triggers: RwLock<HashMap<String, SemanticTrigger>>,
267
268 trigger_embeddings: RwLock<Vec<(String, Vec<f32>)>>,
270
271 rate_limits: RwLock<HashMap<String, (usize, f64)>>,
273
274 recent_matches: RwLock<Vec<TriggerMatch>>,
276
277 stats: RwLock<TriggerStats>,
279
280 max_recent_matches: usize,
282}
283
284impl TriggerIndex {
285 pub fn new() -> Self {
287 Self {
288 triggers: RwLock::new(HashMap::new()),
289 trigger_embeddings: RwLock::new(Vec::new()),
290 rate_limits: RwLock::new(HashMap::new()),
291 recent_matches: RwLock::new(Vec::new()),
292 stats: RwLock::new(TriggerStats::default()),
293 max_recent_matches: 1000,
294 }
295 }
296
297 pub fn register_trigger(&self, mut trigger: SemanticTrigger) -> Result<(), TriggerError> {
299 if trigger.id.is_empty() {
300 return Err(TriggerError::InvalidTrigger(
301 "ID cannot be empty".to_string(),
302 ));
303 }
304
305 if trigger.created_at == 0.0 {
307 trigger.created_at = SystemTime::now()
308 .duration_since(UNIX_EPOCH)
309 .unwrap_or_default()
310 .as_secs_f64();
311 }
312
313 {
315 let mut triggers = self.triggers.write().unwrap();
316 triggers.insert(trigger.id.clone(), trigger.clone());
317 }
318
319 if let Some(embedding) = &trigger.embedding {
321 let mut embeddings = self.trigger_embeddings.write().unwrap();
322 embeddings.push((trigger.id.clone(), embedding.clone()));
323 }
324
325 Ok(())
326 }
327
328 pub fn remove_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
330 let removed = {
331 let mut triggers = self.triggers.write().unwrap();
332 triggers.remove(trigger_id)
333 };
334
335 if removed.is_some() {
336 let mut embeddings = self.trigger_embeddings.write().unwrap();
337 embeddings.retain(|(id, _)| id != trigger_id);
338 }
339
340 removed
341 }
342
343 pub fn set_enabled(&self, trigger_id: &str, enabled: bool) -> bool {
345 let mut triggers = self.triggers.write().unwrap();
346 if let Some(trigger) = triggers.get_mut(trigger_id) {
347 trigger.enabled = enabled;
348 true
349 } else {
350 false
351 }
352 }
353
354 pub fn set_threshold(&self, trigger_id: &str, threshold: f32) -> bool {
356 let mut triggers = self.triggers.write().unwrap();
357 if let Some(trigger) = triggers.get_mut(trigger_id) {
358 trigger.threshold = threshold.clamp(0.0, 1.0);
359 true
360 } else {
361 false
362 }
363 }
364
365 pub fn process_event(&self, event: &TriggerEvent) -> Vec<TriggerMatch> {
367 let mut matches = Vec::new();
368 let now = SystemTime::now()
369 .duration_since(UNIX_EPOCH)
370 .unwrap_or_default()
371 .as_secs_f64();
372
373 {
375 let mut stats = self.stats.write().unwrap();
376 stats.events_processed += 1;
377 }
378
379 let event_embedding = match &event.embedding {
381 Some(emb) => emb.clone(),
382 None => {
383 return matches;
385 }
386 };
387
388 let candidates = self.find_candidates(&event_embedding, 10);
390
391 let triggers = self.triggers.read().unwrap();
392
393 for (trigger_id, score) in candidates {
394 if let Some(trigger) = triggers.get(&trigger_id) {
395 if !trigger.enabled {
397 continue;
398 }
399
400 if score < trigger.threshold {
402 continue;
403 }
404
405 if !self.check_rate_limit(&trigger_id, trigger, now) {
407 let mut stats = self.stats.write().unwrap();
408 stats.rate_limited += 1;
409 continue;
410 }
411
412 let trigger_match = TriggerMatch {
414 trigger_id: trigger_id.clone(),
415 score,
416 event_id: event.id.clone(),
417 timestamp: now,
418 action_executed: false,
419 execution_result: None,
420 };
421
422 matches.push(trigger_match);
423
424 {
426 let mut stats = self.stats.write().unwrap();
427 stats.triggers_matched += 1;
428 *stats
429 .matches_by_trigger
430 .entry(trigger_id.clone())
431 .or_insert(0) += 1;
432 }
433 }
434 }
435
436 matches.sort_by(|a, b| {
438 let trigger_a = triggers.get(&a.trigger_id);
439 let trigger_b = triggers.get(&b.trigger_id);
440
441 match (trigger_a, trigger_b) {
442 (Some(ta), Some(tb)) => ta.priority.cmp(&tb.priority).then_with(|| {
443 b.score
444 .partial_cmp(&a.score)
445 .unwrap_or(std::cmp::Ordering::Equal)
446 }),
447 _ => std::cmp::Ordering::Equal,
448 }
449 });
450
451 {
453 let mut recent = self.recent_matches.write().unwrap();
454 for m in &matches {
455 recent.push(m.clone());
456 }
457 while recent.len() > self.max_recent_matches {
459 recent.remove(0);
460 }
461 }
462
463 matches
464 }
465
466 fn find_candidates(&self, query: &[f32], k: usize) -> Vec<(String, f32)> {
468 let embeddings = self.trigger_embeddings.read().unwrap();
469
470 if embeddings.is_empty() {
471 return Vec::new();
472 }
473
474 let mut candidates: Vec<(String, f32)> = embeddings
476 .iter()
477 .map(|(id, emb)| {
478 let score = cosine_similarity(query, emb);
479 (id.clone(), score)
480 })
481 .collect();
482
483 candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
485
486 candidates.truncate(k);
487 candidates
488 }
489
490 fn check_rate_limit(&self, trigger_id: &str, trigger: &SemanticTrigger, now: f64) -> bool {
492 let max_fires = match trigger.max_fires_per_window {
493 Some(max) => max,
494 None => return true, };
496
497 let window_secs = trigger.rate_limit_window_secs.unwrap_or(60);
498
499 let mut rate_limits = self.rate_limits.write().unwrap();
500 let entry = rate_limits
501 .entry(trigger_id.to_string())
502 .or_insert((0, now));
503
504 if now - entry.1 > window_secs as f64 {
506 entry.0 = 1;
507 entry.1 = now;
508 return true;
509 }
510
511 if entry.0 < max_fires {
513 entry.0 += 1;
514 return true;
515 }
516
517 false
518 }
519
520 pub fn execute_action(&self, trigger_match: &mut TriggerMatch) -> Result<(), TriggerError> {
522 let triggers = self.triggers.read().unwrap();
523 let trigger = triggers
524 .get(&trigger_match.trigger_id)
525 .ok_or_else(|| TriggerError::TriggerNotFound(trigger_match.trigger_id.clone()))?;
526
527 let result = self.execute_action_impl(&trigger.action, trigger_match)?;
529
530 trigger_match.action_executed = true;
531 trigger_match.execution_result = Some(result);
532
533 {
535 let mut stats = self.stats.write().unwrap();
536 stats.actions_executed += 1;
537 }
538
539 Ok(())
540 }
541
542 fn execute_action_impl(
544 &self,
545 action: &TriggerAction,
546 trigger_match: &TriggerMatch,
547 ) -> Result<String, TriggerError> {
548 match action {
549 TriggerAction::Notify { channel, template } => {
550 Ok(format!(
552 "Notified channel '{}' (template: {:?})",
553 channel, template
554 ))
555 }
556
557 TriggerAction::Route { target, context } => {
558 Ok(format!("Routed to '{}' (context: {:?})", target, context))
559 }
560
561 TriggerAction::Escalate { level, reason } => Ok(format!(
562 "Escalated at level {:?} (reason: {:?})",
563 level, reason
564 )),
565
566 TriggerAction::SpawnAgent {
567 agent_type,
568 config: _,
569 } => Ok(format!("Spawned agent of type '{}'", agent_type)),
570
571 TriggerAction::Log { level, message } => {
572 let msg = message.as_deref().unwrap_or(&trigger_match.trigger_id);
573 Ok(format!("Logged at {:?}: {}", level, msg))
574 }
575
576 TriggerAction::Webhook {
577 url,
578 method,
579 headers: _,
580 } => {
581 Ok(format!("Called webhook {} {}", method, url))
583 }
584
585 TriggerAction::Callback { function, args: _ } => {
586 Ok(format!("Called callback function '{}'", function))
587 }
588
589 TriggerAction::Chain(actions) => {
590 let mut results = Vec::new();
591 for sub_action in actions {
592 let result = self.execute_action_impl(sub_action, trigger_match)?;
593 results.push(result);
594 }
595 Ok(format!("Chain executed: [{}]", results.join(", ")))
596 }
597 }
598 }
599
600 pub fn list_triggers(&self) -> Vec<SemanticTrigger> {
602 self.triggers.read().unwrap().values().cloned().collect()
603 }
604
605 pub fn get_trigger(&self, trigger_id: &str) -> Option<SemanticTrigger> {
607 self.triggers.read().unwrap().get(trigger_id).cloned()
608 }
609
610 pub fn recent_matches(&self, limit: usize) -> Vec<TriggerMatch> {
612 let matches = self.recent_matches.read().unwrap();
613 matches.iter().rev().take(limit).cloned().collect()
614 }
615
616 pub fn stats(&self) -> TriggerStats {
618 self.stats.read().unwrap().clone()
619 }
620
621 pub fn clear_stats(&self) {
623 let mut stats = self.stats.write().unwrap();
624 *stats = TriggerStats::default();
625 }
626}
627
628impl Default for TriggerIndex {
629 fn default() -> Self {
630 Self::new()
631 }
632}
633
634pub struct TriggerBuilder {
640 trigger: SemanticTrigger,
641}
642
643impl TriggerBuilder {
644 pub fn new(id: &str, query: &str) -> Self {
646 Self {
647 trigger: SemanticTrigger {
648 id: id.to_string(),
649 name: id.to_string(),
650 description: String::new(),
651 query: query.to_string(),
652 embedding: None,
653 threshold: 0.8,
654 action: TriggerAction::Log {
655 level: LogLevel::Info,
656 message: None,
657 },
658 enabled: true,
659 priority: 0,
660 max_fires_per_window: None,
661 rate_limit_window_secs: None,
662 tags: Vec::new(),
663 metadata: HashMap::new(),
664 created_at: 0.0,
665 },
666 }
667 }
668
669 pub fn name(mut self, name: &str) -> Self {
671 self.trigger.name = name.to_string();
672 self
673 }
674
675 pub fn description(mut self, description: &str) -> Self {
677 self.trigger.description = description.to_string();
678 self
679 }
680
681 pub fn embedding(mut self, embedding: Vec<f32>) -> Self {
683 self.trigger.embedding = Some(embedding);
684 self
685 }
686
687 pub fn threshold(mut self, threshold: f32) -> Self {
689 self.trigger.threshold = threshold.clamp(0.0, 1.0);
690 self
691 }
692
693 pub fn action(mut self, action: TriggerAction) -> Self {
695 self.trigger.action = action;
696 self
697 }
698
699 pub fn notify(mut self, channel: &str) -> Self {
701 self.trigger.action = TriggerAction::Notify {
702 channel: channel.to_string(),
703 template: None,
704 };
705 self
706 }
707
708 pub fn route(mut self, target: &str) -> Self {
710 self.trigger.action = TriggerAction::Route {
711 target: target.to_string(),
712 context: None,
713 };
714 self
715 }
716
717 pub fn escalate(mut self, level: EscalationLevel) -> Self {
719 self.trigger.action = TriggerAction::Escalate {
720 level,
721 reason: None,
722 };
723 self
724 }
725
726 pub fn priority(mut self, priority: i32) -> Self {
728 self.trigger.priority = priority;
729 self
730 }
731
732 pub fn rate_limit(mut self, max_fires: usize, window_secs: u64) -> Self {
734 self.trigger.max_fires_per_window = Some(max_fires);
735 self.trigger.rate_limit_window_secs = Some(window_secs);
736 self
737 }
738
739 pub fn tag(mut self, tag: &str) -> Self {
741 self.trigger.tags.push(tag.to_string());
742 self
743 }
744
745 pub fn enabled(mut self, enabled: bool) -> Self {
747 self.trigger.enabled = enabled;
748 self
749 }
750
751 pub fn build(self) -> SemanticTrigger {
753 self.trigger
754 }
755}
756
757#[derive(Debug, Clone)]
763pub enum TriggerError {
764 InvalidTrigger(String),
766 TriggerNotFound(String),
768 ActionFailed(String),
770 RateLimitExceeded(String),
772 EmbeddingError(String),
774}
775
776impl std::fmt::Display for TriggerError {
777 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
778 match self {
779 Self::InvalidTrigger(msg) => write!(f, "Invalid trigger: {}", msg),
780 Self::TriggerNotFound(id) => write!(f, "Trigger not found: {}", id),
781 Self::ActionFailed(msg) => write!(f, "Action failed: {}", msg),
782 Self::RateLimitExceeded(id) => write!(f, "Rate limit exceeded for trigger: {}", id),
783 Self::EmbeddingError(msg) => write!(f, "Embedding error: {}", msg),
784 }
785 }
786}
787
788impl std::error::Error for TriggerError {}
789
790fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
796 if a.len() != b.len() || a.is_empty() {
797 return 0.0;
798 }
799
800 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
801 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
802 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
803
804 if norm_a < 1e-10 || norm_b < 1e-10 {
805 return 0.0;
806 }
807
808 dot / (norm_a * norm_b)
809}
810
811pub fn create_notify_trigger(
817 id: &str,
818 query: &str,
819 channel: &str,
820 embedding: Vec<f32>,
821) -> SemanticTrigger {
822 TriggerBuilder::new(id, query)
823 .embedding(embedding)
824 .notify(channel)
825 .build()
826}
827
828pub fn create_escalation_trigger(
830 id: &str,
831 query: &str,
832 level: EscalationLevel,
833 embedding: Vec<f32>,
834) -> SemanticTrigger {
835 TriggerBuilder::new(id, query)
836 .embedding(embedding)
837 .escalate(level)
838 .priority(-1) .build()
840}
841
842#[cfg(test)]
847mod tests {
848 use super::*;
849
850 fn mock_embedding(seed: u64) -> Vec<f32> {
851 (0..128)
852 .map(|i| ((i as u64 + seed) % 100) as f32 / 100.0 - 0.5)
853 .collect()
854 }
855
856 #[test]
857 fn test_trigger_registration() {
858 let index = TriggerIndex::new();
859
860 let trigger = TriggerBuilder::new("privacy_concern", "user mentions privacy concerns")
861 .embedding(mock_embedding(1))
862 .threshold(0.75)
863 .escalate(EscalationLevel::High)
864 .build();
865
866 index.register_trigger(trigger).unwrap();
867
868 let triggers = index.list_triggers();
869 assert_eq!(triggers.len(), 1);
870 assert_eq!(triggers[0].id, "privacy_concern");
871 }
872
873 #[test]
874 fn test_trigger_matching() {
875 let index = TriggerIndex::new();
876
877 let trigger = TriggerBuilder::new("security_alert", "security vulnerability")
878 .embedding(mock_embedding(1))
879 .threshold(0.5) .notify("security-team")
881 .build();
882
883 index.register_trigger(trigger).unwrap();
884
885 let event = TriggerEvent {
887 id: "event_1".to_string(),
888 content: "possible security issue detected".to_string(),
889 embedding: Some(mock_embedding(1)), source: EventSource::SystemEvent,
891 metadata: HashMap::new(),
892 timestamp: 0.0,
893 };
894
895 let matches = index.process_event(&event);
896
897 assert!(!matches.is_empty());
898 assert_eq!(matches[0].trigger_id, "security_alert");
899 assert!(matches[0].score > 0.5);
900 }
901
902 #[test]
903 fn test_trigger_disable() {
904 let index = TriggerIndex::new();
905
906 let trigger = TriggerBuilder::new("test_trigger", "test")
907 .embedding(mock_embedding(1))
908 .threshold(0.5)
909 .build();
910
911 index.register_trigger(trigger).unwrap();
912
913 index.set_enabled("test_trigger", false);
915
916 let event = TriggerEvent {
917 id: "event_1".to_string(),
918 content: "test".to_string(),
919 embedding: Some(mock_embedding(1)),
920 source: EventSource::UserMessage,
921 metadata: HashMap::new(),
922 timestamp: 0.0,
923 };
924
925 let matches = index.process_event(&event);
926
927 assert!(matches.is_empty());
929 }
930
931 #[test]
932 fn test_rate_limiting() {
933 let index = TriggerIndex::new();
934
935 let trigger = TriggerBuilder::new("rate_limited", "test")
936 .embedding(mock_embedding(1))
937 .threshold(0.5)
938 .rate_limit(2, 60) .build();
940
941 index.register_trigger(trigger).unwrap();
942
943 let event = TriggerEvent {
944 id: "event_1".to_string(),
945 content: "test".to_string(),
946 embedding: Some(mock_embedding(1)),
947 source: EventSource::UserMessage,
948 metadata: HashMap::new(),
949 timestamp: 0.0,
950 };
951
952 let m1 = index.process_event(&event);
954 let m2 = index.process_event(&event);
955
956 let m3 = index.process_event(&event);
958
959 assert!(!m1.is_empty());
960 assert!(!m2.is_empty());
961 assert!(m3.is_empty());
962
963 let stats = index.stats();
965 assert!(stats.rate_limited >= 1);
966 }
967
968 #[test]
969 fn test_action_execution() {
970 let index = TriggerIndex::new();
971
972 let trigger = TriggerBuilder::new("log_trigger", "test")
973 .embedding(mock_embedding(1))
974 .threshold(0.5)
975 .action(TriggerAction::Log {
976 level: LogLevel::Info,
977 message: Some("Test message".to_string()),
978 })
979 .build();
980
981 index.register_trigger(trigger).unwrap();
982
983 let event = TriggerEvent {
984 id: "event_1".to_string(),
985 content: "test".to_string(),
986 embedding: Some(mock_embedding(1)),
987 source: EventSource::UserMessage,
988 metadata: HashMap::new(),
989 timestamp: 0.0,
990 };
991
992 let mut matches = index.process_event(&event);
993
994 assert!(!matches.is_empty());
995
996 index.execute_action(&mut matches[0]).unwrap();
998
999 assert!(matches[0].action_executed);
1000 assert!(matches[0].execution_result.is_some());
1001 }
1002
1003 #[test]
1004 fn test_cosine_similarity() {
1005 let a = vec![1.0, 0.0, 0.0];
1006 let b = vec![1.0, 0.0, 0.0];
1007
1008 let sim = cosine_similarity(&a, &b);
1009 assert!((sim - 1.0).abs() < 0.01);
1010
1011 let c = vec![0.0, 1.0, 0.0];
1012 let sim2 = cosine_similarity(&a, &c);
1013 assert!(sim2.abs() < 0.01);
1014 }
1015
1016 #[test]
1017 fn test_trigger_builder() {
1018 let trigger = TriggerBuilder::new("test", "test query")
1019 .name("Test Trigger")
1020 .description("A test trigger")
1021 .threshold(0.85)
1022 .priority(5)
1023 .tag("test")
1024 .tag("example")
1025 .notify("test-channel")
1026 .rate_limit(10, 300)
1027 .build();
1028
1029 assert_eq!(trigger.id, "test");
1030 assert_eq!(trigger.name, "Test Trigger");
1031 assert_eq!(trigger.threshold, 0.85);
1032 assert_eq!(trigger.priority, 5);
1033 assert_eq!(trigger.tags.len(), 2);
1034 assert_eq!(trigger.max_fires_per_window, Some(10));
1035 }
1036}