use std::collections::VecDeque;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use rs_genai::session::{SessionEvent, SessionWriter};
use super::BoxFuture;
use crate::state::State;
pub trait PatternDetector: Send + Sync {
fn check(&self, state: &State, event: Option<&SessionEvent>, now: Instant) -> bool;
fn reset(&self);
fn needs_timer(&self) -> bool {
false
}
}
pub struct TemporalPattern {
pub name: String,
pub detector: Box<dyn PatternDetector>,
pub action: Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>,
pub cooldown: Option<Duration>,
last_triggered: parking_lot::Mutex<Option<Instant>>,
}
impl TemporalPattern {
pub fn new(
name: impl Into<String>,
detector: Box<dyn PatternDetector>,
action: Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>,
cooldown: Option<Duration>,
) -> Self {
Self {
name: name.into(),
detector,
action,
cooldown,
last_triggered: parking_lot::Mutex::new(None),
}
}
fn try_fire(
&self,
state: &State,
event: Option<&SessionEvent>,
writer: &Arc<dyn SessionWriter>,
now: Instant,
) -> Option<BoxFuture<()>> {
if !self.detector.check(state, event, now) {
return None;
}
let mut last = self.last_triggered.lock();
if let (Some(cooldown), Some(prev)) = (self.cooldown, *last) {
if now.duration_since(prev) < cooldown {
return None;
}
}
*last = Some(now);
let s = state.clone();
let w = writer.clone();
Some((self.action)(s, w))
}
}
pub struct TemporalRegistry {
patterns: Vec<TemporalPattern>,
}
impl Default for TemporalRegistry {
fn default() -> Self {
Self::new()
}
}
impl TemporalRegistry {
pub fn new() -> Self {
Self {
patterns: Vec::new(),
}
}
pub fn add(&mut self, pattern: TemporalPattern) {
self.patterns.push(pattern);
}
pub fn check_all(
&self,
state: &State,
event: Option<&SessionEvent>,
writer: &Arc<dyn SessionWriter>,
) -> Vec<BoxFuture<()>> {
let now = Instant::now();
self.patterns
.iter()
.filter_map(|p| p.try_fire(state, event, writer, now))
.collect()
}
pub fn needs_timer(&self) -> bool {
self.patterns.iter().any(|p| p.detector.needs_timer())
}
}
pub struct SustainedDetector {
condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
duration: Duration,
became_true_at: parking_lot::Mutex<Option<Instant>>,
}
impl SustainedDetector {
pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, duration: Duration) -> Self {
Self {
condition,
duration,
became_true_at: parking_lot::Mutex::new(None),
}
}
}
impl PatternDetector for SustainedDetector {
fn check(&self, state: &State, _event: Option<&SessionEvent>, now: Instant) -> bool {
if (self.condition)(state) {
let mut guard = self.became_true_at.lock();
match *guard {
None => {
*guard = Some(now);
false
}
Some(t) => now.duration_since(t) >= self.duration,
}
} else {
*self.became_true_at.lock() = None;
false
}
}
fn reset(&self) {
*self.became_true_at.lock() = None;
}
fn needs_timer(&self) -> bool {
true
}
}
pub struct RateDetector {
filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
count: u32,
window: Duration,
timestamps: parking_lot::Mutex<VecDeque<Instant>>,
}
impl RateDetector {
pub fn new(
filter: Arc<dyn Fn(&SessionEvent) -> bool + Send + Sync>,
count: u32,
window: Duration,
) -> Self {
Self {
filter,
count,
window,
timestamps: parking_lot::Mutex::new(VecDeque::new()),
}
}
}
impl PatternDetector for RateDetector {
fn check(&self, _state: &State, event: Option<&SessionEvent>, now: Instant) -> bool {
let mut ts = self.timestamps.lock();
if let Some(evt) = event {
if (self.filter)(evt) {
ts.push_back(now);
}
}
while let Some(&front) = ts.front() {
if now.duration_since(front) > self.window {
ts.pop_front();
} else {
break;
}
}
ts.len() as u32 >= self.count
}
fn reset(&self) {
self.timestamps.lock().clear();
}
}
pub struct TurnCountDetector {
condition: Arc<dyn Fn(&State) -> bool + Send + Sync>,
required: u32,
consecutive: AtomicU32,
}
impl TurnCountDetector {
pub fn new(condition: Arc<dyn Fn(&State) -> bool + Send + Sync>, required: u32) -> Self {
Self {
condition,
required,
consecutive: AtomicU32::new(0),
}
}
}
impl PatternDetector for TurnCountDetector {
fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
if (self.condition)(state) {
let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
prev + 1 >= self.required
} else {
self.consecutive.store(0, Ordering::SeqCst);
false
}
}
fn reset(&self) {
self.consecutive.store(0, Ordering::SeqCst);
}
}
pub struct ConsecutiveFailureDetector {
tool_name: String,
threshold: u32,
consecutive: AtomicU32,
}
impl ConsecutiveFailureDetector {
pub fn new(tool_name: impl Into<String>, threshold: u32) -> Self {
Self {
tool_name: tool_name.into(),
threshold,
consecutive: AtomicU32::new(0),
}
}
}
impl PatternDetector for ConsecutiveFailureDetector {
fn check(&self, state: &State, _event: Option<&SessionEvent>, _now: Instant) -> bool {
let key = format!("bg:{}_failed", self.tool_name);
let failed: bool = state.get(&key).unwrap_or(false);
if failed {
let prev = self.consecutive.fetch_add(1, Ordering::SeqCst);
prev + 1 >= self.threshold
} else {
self.consecutive.store(0, Ordering::SeqCst);
false
}
}
fn reset(&self) {
self.consecutive.store(0, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
struct MockWriter;
#[async_trait::async_trait]
impl SessionWriter for MockWriter {
async fn send_audio(&self, _: Vec<u8>) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn send_text(&self, _: String) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn send_tool_response(
&self,
_: Vec<rs_genai::protocol::FunctionResponse>,
) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn send_client_content(
&self,
_: Vec<rs_genai::protocol::Content>,
_: bool,
) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn send_video(&self, _: Vec<u8>) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn update_instruction(
&self,
_: String,
) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn signal_activity_start(&self) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn signal_activity_end(&self) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
async fn disconnect(&self) -> Result<(), rs_genai::session::SessionError> {
Ok(())
}
}
fn mock_writer() -> Arc<dyn SessionWriter> {
Arc::new(MockWriter)
}
fn counting_action(
counter: Arc<AtomicU32>,
) -> Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync> {
Arc::new(move |_state, _writer| {
let c = counter.clone();
Box::pin(async move {
c.fetch_add(1, Ordering::SeqCst);
})
})
}
#[test]
fn sustained_fires_after_duration() {
let state = State::new();
state.set("hot", true);
let detector = SustainedDetector::new(
Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
Duration::from_secs(5),
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
assert!(detector.check(&state, None, t0 + Duration::from_secs(5)));
assert!(detector.check(&state, None, t0 + Duration::from_secs(6)));
}
#[test]
fn sustained_resets_on_false() {
let state = State::new();
state.set("hot", true);
let detector = SustainedDetector::new(
Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
Duration::from_secs(5),
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
state.set("hot", false);
assert!(!detector.check(&state, None, t0 + Duration::from_secs(2)));
state.set("hot", true);
assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
assert!(!detector.check(&state, None, t0 + Duration::from_secs(7)));
assert!(detector.check(&state, None, t0 + Duration::from_secs(8)));
}
#[test]
fn sustained_reset_clears_state() {
let state = State::new();
state.set("hot", true);
let detector = SustainedDetector::new(
Arc::new(|s: &State| s.get::<bool>("hot").unwrap_or(false)),
Duration::from_secs(5),
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
detector.reset();
assert!(!detector.check(&state, None, t0 + Duration::from_secs(4)));
assert!(detector.check(&state, None, t0 + Duration::from_secs(9)));
}
#[test]
fn rate_fires_when_count_reached() {
let state = State::new();
let detector = RateDetector::new(
Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
3,
Duration::from_secs(10),
);
let t0 = Instant::now();
let event = SessionEvent::TurnComplete;
assert!(!detector.check(&state, Some(&event), t0));
assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
}
#[test]
fn rate_does_not_fire_when_events_outside_window() {
let state = State::new();
let detector = RateDetector::new(
Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
3,
Duration::from_secs(5),
);
let t0 = Instant::now();
let event = SessionEvent::TurnComplete;
assert!(!detector.check(&state, Some(&event), t0));
assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(10)));
}
#[test]
fn rate_filter_rejects_events() {
let state = State::new();
let detector = RateDetector::new(
Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
2,
Duration::from_secs(10),
);
let t0 = Instant::now();
let text_event = SessionEvent::TextDelta("hello".to_string());
assert!(!detector.check(&state, Some(&text_event), t0));
assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(1)));
assert!(!detector.check(&state, Some(&text_event), t0 + Duration::from_secs(2)));
assert!(!detector.check(&state, None, t0 + Duration::from_secs(3)));
}
#[test]
fn turn_count_fires_after_n_consecutive() {
let state = State::new();
state.set("confused", true);
let detector = TurnCountDetector::new(
Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
3,
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[test]
fn turn_count_resets_on_false() {
let state = State::new();
state.set("confused", true);
let detector = TurnCountDetector::new(
Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
3,
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
state.set("confused", false);
assert!(!detector.check(&state, None, t0));
state.set("confused", true);
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[test]
fn consecutive_failure_fires_after_threshold() {
let state = State::new();
state.set("bg:search_failed", true);
let detector = ConsecutiveFailureDetector::new("search", 3);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[test]
fn consecutive_failure_resets_on_success() {
let state = State::new();
state.set("bg:search_failed", true);
let detector = ConsecutiveFailureDetector::new("search", 3);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
state.set("bg:search_failed", false);
assert!(!detector.check(&state, None, t0));
state.set("bg:search_failed", true);
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[tokio::test]
async fn pattern_cooldown_prevents_rapid_refiring() {
let counter = Arc::new(AtomicU32::new(0));
let state = State::new();
state.set("active", true);
let writer = mock_writer();
let pattern = TemporalPattern::new(
"test-cooldown",
Box::new(SustainedDetector::new(
Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
Duration::from_secs(0), )),
counting_action(counter.clone()),
Some(Duration::from_secs(10)), );
let t0 = Instant::now();
assert!(pattern.try_fire(&state, None, &writer, t0).is_none());
let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(1));
assert!(fut.is_some());
fut.unwrap().await;
assert_eq!(counter.load(Ordering::SeqCst), 1);
assert!(pattern
.try_fire(&state, None, &writer, t0 + Duration::from_millis(2))
.is_none());
let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_secs(11));
assert!(fut.is_some());
fut.unwrap().await;
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn registry_check_all_returns_actions() {
let counter = Arc::new(AtomicU32::new(0));
let state = State::new();
state.set("confused", true);
let writer = mock_writer();
let mut registry = TemporalRegistry::new();
registry.add(TemporalPattern::new(
"confusion",
Box::new(TurnCountDetector::new(
Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
1,
)),
counting_action(counter.clone()),
None,
));
let actions = registry.check_all(&state, None, &writer);
assert_eq!(actions.len(), 1);
for fut in actions {
fut.await;
}
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[test]
fn needs_timer_true_with_sustained_detector() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = TemporalRegistry::new();
registry.add(TemporalPattern::new(
"sustained",
Box::new(SustainedDetector::new(
Arc::new(|_: &State| true),
Duration::from_secs(5),
)),
counting_action(counter),
None,
));
assert!(registry.needs_timer());
}
#[test]
fn needs_timer_false_without_sustained_detector() {
let counter = Arc::new(AtomicU32::new(0));
let mut registry = TemporalRegistry::new();
registry.add(TemporalPattern::new(
"turn-count",
Box::new(TurnCountDetector::new(Arc::new(|_: &State| true), 3)),
counting_action(counter.clone()),
None,
));
registry.add(TemporalPattern::new(
"rate",
Box::new(RateDetector::new(
Arc::new(|_: &SessionEvent| true),
5,
Duration::from_secs(10),
)),
counting_action(counter),
None,
));
assert!(!registry.needs_timer());
}
#[test]
fn default_creates_empty_registry() {
let registry = TemporalRegistry::default();
assert!(!registry.needs_timer());
}
#[test]
fn rate_reset_clears_timestamps() {
let state = State::new();
let detector = RateDetector::new(
Arc::new(|evt: &SessionEvent| matches!(evt, SessionEvent::TurnComplete)),
2,
Duration::from_secs(10),
);
let t0 = Instant::now();
let event = SessionEvent::TurnComplete;
assert!(!detector.check(&state, Some(&event), t0));
detector.reset();
assert!(!detector.check(&state, Some(&event), t0 + Duration::from_secs(1)));
assert!(detector.check(&state, Some(&event), t0 + Duration::from_secs(2)));
}
#[test]
fn turn_count_reset_clears_counter() {
let state = State::new();
state.set("confused", true);
let detector = TurnCountDetector::new(
Arc::new(|s: &State| s.get::<bool>("confused").unwrap_or(false)),
3,
);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
detector.reset();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[test]
fn consecutive_failure_reset_clears_counter() {
let state = State::new();
state.set("bg:search_failed", true);
let detector = ConsecutiveFailureDetector::new("search", 3);
let t0 = Instant::now();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
detector.reset();
assert!(!detector.check(&state, None, t0));
assert!(!detector.check(&state, None, t0));
assert!(detector.check(&state, None, t0));
}
#[test]
fn sustained_detector_needs_timer() {
let detector = SustainedDetector::new(Arc::new(|_: &State| true), Duration::from_secs(5));
assert!(detector.needs_timer());
}
#[test]
fn rate_detector_does_not_need_timer() {
let detector = RateDetector::new(
Arc::new(|_: &SessionEvent| true),
5,
Duration::from_secs(10),
);
assert!(!detector.needs_timer());
}
#[test]
fn turn_count_detector_does_not_need_timer() {
let detector = TurnCountDetector::new(Arc::new(|_: &State| true), 3);
assert!(!detector.needs_timer());
}
#[tokio::test]
async fn pattern_without_cooldown_fires_every_time() {
let counter = Arc::new(AtomicU32::new(0));
let state = State::new();
state.set("active", true);
let writer = mock_writer();
let pattern = TemporalPattern::new(
"no-cooldown",
Box::new(TurnCountDetector::new(
Arc::new(|s: &State| s.get::<bool>("active").unwrap_or(false)),
1,
)),
counting_action(counter.clone()),
None, );
let t0 = Instant::now();
for i in 0..5u32 {
let fut = pattern.try_fire(&state, None, &writer, t0 + Duration::from_millis(i as u64));
assert!(fut.is_some(), "should fire on iteration {i}");
fut.unwrap().await;
}
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
}