Skip to main content

harn_vm/triggers/
aggregation.rs

1//! CH-04 (#1875): aggregation triggers (`batch { count, window, key, expire_action }`).
2//!
3//! Inngest-shape primitive — no other major durable-execution or agent
4//! system has first-class fire-after-N-events. Lets a Harn trigger declare:
5//!
6//! ```text
7//! trigger {
8//!   source: "channel:pr.merged",
9//!   batch: { count: 3, window: "10m", key: "repo", expire_action: "fire" },
10//!   handler: ...,
11//! }
12//! ```
13//!
14//! Semantics (see issue #1875):
15//! - Counter increments on each filter-passing event.
16//! - Counter resets on fire or on window expire.
17//! - `key`: optional JSON path into the channel payload; each distinct key
18//!   value maintains its own counter + window.
19//! - `expire_action`: `fire_partial` (default) flushes the partial batch
20//!   when the window elapses; `discard` drops it silently.
21//!
22//! State lives in a per-process thread-local registry keyed by binding key
23//! (id@version) and partition key. The buffer is capped at
24//! [`MAX_BUFFER_EVENTS`] to keep a stuck handler from running the runtime
25//! out of memory; overflow drops the oldest entries with a structured
26//! warning (`triggers.aggregation.buffer_overflow`).
27//!
28//! Window expiration is driven by two complementary mechanisms:
29//!
30//! 1. **Implicit sweep**: every emit pass through
31//!    [`crate::channels::dispatch_channel_emit_to_triggers`] first calls
32//!    [`drain_expired`] to flush any buffers whose window has elapsed.
33//! 2. **Explicit flush**: the `flush_trigger_aggregations()` builtin
34//!    drains all expired buffers immediately. Tests use this together
35//!    with `advance_time(ms)` to get deterministic window-expire
36//!    coverage; the Rust mock clock advances `clock::now_ms` so expired
37//!    buffers come through synchronously.
38//!
39//! Production runtimes that need a real-time fallback can also call
40//! [`drain_expired`] periodically from a tokio task. v1 keeps the contract
41//! synchronous-only because every dispatch path already runs through
42//! `emit_channel`, so the implicit sweep covers the common case without
43//! adding a background timer that would complicate replay determinism.
44
45use std::cell::RefCell;
46use std::collections::BTreeMap;
47use std::time::Duration;
48
49use serde_json::Value as JsonValue;
50
51use crate::triggers::test_util::clock;
52use crate::value::{VmError, VmValue};
53
54use super::TriggerEvent;
55
56/// Maximum number of events buffered per (binding, partition_key). A
57/// misconfigured handler that never fires would otherwise leak memory; we
58/// cap the buffer and drop the *oldest* entries on overflow so the most
59/// recent context survives. Overflow emits a structured warning so
60/// operators can spot stuck batches.
61pub const MAX_BUFFER_EVENTS: usize = 1024;
62
63const HARN_CHN_005: &str = "HARN-CHN-005";
64
65/// Action to take when the aggregation window elapses without reaching
66/// `count`.
67#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68pub enum ExpireAction {
69    /// Default: invoke the handler with the partial batch (length < count).
70    FirePartial,
71    /// Drop the buffer silently. No handler invocation, no event emitted.
72    Discard,
73}
74
75impl ExpireAction {
76    pub fn as_str(self) -> &'static str {
77        match self {
78            Self::FirePartial => "fire_partial",
79            Self::Discard => "discard",
80        }
81    }
82}
83
84/// Aggregation config attached to a trigger binding. Cloned into the
85/// binding at registration and read on every emit; intentionally cheap to
86/// clone.
87#[derive(Clone, Debug)]
88pub struct TriggerAggregationConfig {
89    pub count: u32,
90    pub window: Duration,
91    /// Dot-path into the channel payload (e.g. `"repo"`,
92    /// `"pull_request.user.login"`). When `None`, all events share a
93    /// single global counter for the binding.
94    pub key_path: Option<String>,
95    pub expire_action: ExpireAction,
96}
97
98/// A single in-memory buffer for one (binding, partition_key) pair.
99#[derive(Debug)]
100struct AggregationBuffer {
101    events: Vec<TriggerEvent>,
102    /// Wall-clock millisecond timestamp at which the window opened. Used
103    /// so [`drain_expired_aggregations`] can compare against
104    /// `clock::now_ms()` — honoring `mock_time(...)` /
105    /// `advance_time(...)` in tests.
106    window_start_ms: i64,
107    window_ms: i64,
108    expire_action: ExpireAction,
109}
110
111impl AggregationBuffer {
112    fn new(window_ms: i64, expire_action: ExpireAction) -> Self {
113        Self {
114            events: Vec::new(),
115            window_start_ms: clock::now_ms(),
116            window_ms,
117            expire_action,
118        }
119    }
120
121    fn expired_at(&self, now_ms: i64) -> bool {
122        now_ms.saturating_sub(self.window_start_ms) >= self.window_ms
123    }
124}
125
126/// Outcome of accumulating a single event into the buffer. The caller
127/// (channel fan-out) maps this to a dispatch call.
128#[derive(Debug)]
129pub enum AccumulateOutcome {
130    /// Buffer still under threshold; nothing to dispatch yet.
131    Buffered,
132    /// Threshold reached; dispatch the batched events now and reset the
133    /// buffer. The vector always contains exactly `count` entries — the
134    /// new one plus everything previously buffered.
135    Ready(Vec<TriggerEvent>),
136}
137
138/// Outcome of a window-expire sweep over a single buffer.
139#[derive(Debug)]
140pub struct ExpiredFlush {
141    pub binding_key: String,
142    pub partition_key: Option<String>,
143    pub action: ExpireAction,
144    pub events: Vec<TriggerEvent>,
145}
146
147#[derive(Default)]
148struct AggregationRegistry {
149    /// (binding_key, partition_key) → buffer. partition_key is the
150    /// stringified JSON value at `key_path`, or "" when `key_path` is None.
151    buffers: BTreeMap<String, BTreeMap<String, AggregationBuffer>>,
152}
153
154thread_local! {
155    static REGISTRY: RefCell<AggregationRegistry> =
156        RefCell::new(AggregationRegistry::default());
157}
158
159/// Reset all aggregation state. Called from the per-test reset hook so
160/// buffers do not leak between tests.
161pub fn clear_aggregation_state() {
162    REGISTRY.with(|slot| {
163        *slot.borrow_mut() = AggregationRegistry::default();
164    });
165}
166
167/// Drop any buffers owned by `binding_key`. Called when a trigger binding
168/// is drained or terminated so a long-lived buffer doesn't keep firing
169/// after the trigger went away. Returns any events that were still
170/// pending — callers can choose to flush them (currently they are
171/// discarded, matching the existing trigger drain contract that
172/// in-flight events stop on terminate).
173pub fn drop_binding_aggregation(binding_key: &str) -> Vec<TriggerEvent> {
174    REGISTRY.with(|slot| {
175        let mut registry = slot.borrow_mut();
176        registry
177            .buffers
178            .remove(binding_key)
179            .into_iter()
180            .flat_map(|partitions| partitions.into_values())
181            .flat_map(|buffer| buffer.events.into_iter())
182            .collect()
183    })
184}
185
186/// Accumulate `event` into the buffer for (binding_key, partition_key).
187/// Returns [`AccumulateOutcome::Ready`] when the buffer hits
188/// `config.count`; the buffer is removed in that case so the next event
189/// starts a fresh window.
190pub fn accumulate(
191    binding_key: &str,
192    config: &TriggerAggregationConfig,
193    partition_key: Option<&str>,
194    event: TriggerEvent,
195) -> AccumulateOutcome {
196    let partition_key_owned = partition_key.unwrap_or("").to_string();
197    let window_ms = config.window.as_millis() as i64;
198    let count = config.count;
199    let expire_action = config.expire_action;
200
201    REGISTRY.with(|slot| {
202        let mut registry = slot.borrow_mut();
203        let partitions = registry.buffers.entry(binding_key.to_string()).or_default();
204        let buffer = partitions
205            .entry(partition_key_owned.clone())
206            .or_insert_with(|| AggregationBuffer::new(window_ms, expire_action));
207
208        // Enforce the per-buffer cap. We drop the *oldest* entry rather
209        // than refusing the new one so the freshest event wins. Operators
210        // get a warn-level structured event so a stuck batch is visible.
211        if buffer.events.len() >= MAX_BUFFER_EVENTS {
212            let mut overflow_meta = std::collections::BTreeMap::new();
213            overflow_meta.insert("binding_key".to_string(), serde_json::json!(binding_key));
214            overflow_meta.insert(
215                "partition_key".to_string(),
216                serde_json::json!(partition_key.unwrap_or("")),
217            );
218            overflow_meta.insert(
219                "max_events".to_string(),
220                serde_json::json!(MAX_BUFFER_EVENTS),
221            );
222            crate::events::log_warn_meta(
223                "triggers.aggregation.buffer_overflow",
224                "aggregation buffer exceeded MAX_BUFFER_EVENTS; dropping oldest entry",
225                overflow_meta,
226            );
227            buffer.events.remove(0);
228        }
229
230        buffer.events.push(event);
231
232        if buffer.events.len() as u32 >= count {
233            let buffer = partitions
234                .remove(&partition_key_owned)
235                .expect("buffer just inserted");
236            // Clean up the partition map when its last partition leaves
237            // so the binding entry doesn't grow without bound.
238            if partitions.is_empty() {
239                registry.buffers.remove(binding_key);
240            }
241            return AccumulateOutcome::Ready(buffer.events);
242        }
243        AccumulateOutcome::Buffered
244    })
245}
246
247/// Sweep all buffers and return entries whose window has elapsed. The
248/// returned `events` vector is always non-empty for `FirePartial` (the
249/// caller dispatches it as a batched event) and is also non-empty for
250/// `Discard` (the caller drops it). An empty buffer never makes it into
251/// the result.
252///
253/// Idempotent: calling repeatedly returns successive expirations only.
254pub fn drain_expired_aggregations() -> Vec<ExpiredFlush> {
255    let now_ms = clock::now_ms();
256    REGISTRY.with(|slot| {
257        let mut registry = slot.borrow_mut();
258        let mut expired = Vec::new();
259        let mut empty_bindings = Vec::new();
260        for (binding_key, partitions) in registry.buffers.iter_mut() {
261            let expired_partition_keys: Vec<String> = partitions
262                .iter()
263                .filter(|(_, buffer)| buffer.expired_at(now_ms) && !buffer.events.is_empty())
264                .map(|(key, _)| key.clone())
265                .collect();
266            for partition_key in expired_partition_keys {
267                let buffer = partitions
268                    .remove(&partition_key)
269                    .expect("partition just observed");
270                let action = buffer.expire_action;
271                expired.push(ExpiredFlush {
272                    binding_key: binding_key.clone(),
273                    partition_key: if partition_key.is_empty() {
274                        None
275                    } else {
276                        Some(partition_key)
277                    },
278                    action,
279                    events: buffer.events,
280                });
281            }
282            if partitions.is_empty() {
283                empty_bindings.push(binding_key.clone());
284            }
285        }
286        for binding_key in empty_bindings {
287            registry.buffers.remove(&binding_key);
288        }
289        expired
290    })
291}
292
293/// Parse a `batch` field on a trigger config dict into a typed config.
294///
295/// Returns `HARN-CHN-005` on bad input: count <= 0, missing window,
296/// unparseable window, unknown `expire_action`, or wrong types.
297pub fn parse_aggregation_config(
298    raw: &VmValue,
299) -> Result<Option<TriggerAggregationConfig>, VmError> {
300    let map = match raw {
301        VmValue::Nil => return Ok(None),
302        VmValue::Dict(map) => map,
303        other => {
304            return Err(VmError::Runtime(format!(
305                "{HARN_CHN_005} trigger_register: `batch` must be a dict, got {}",
306                other.type_name()
307            )))
308        }
309    };
310
311    let count = map
312        .get("count")
313        .ok_or_else(|| {
314            VmError::Runtime(format!(
315                "{HARN_CHN_005} trigger_register: batch.count is required"
316            ))
317        })?
318        .as_int()
319        .ok_or_else(|| {
320            VmError::Runtime(format!(
321                "{HARN_CHN_005} trigger_register: batch.count must be a positive integer"
322            ))
323        })?;
324    if count <= 0 {
325        return Err(VmError::Runtime(format!(
326            "{HARN_CHN_005} trigger_register: batch.count must be greater than 0, got {count}"
327        )));
328    }
329    let count = count as u32;
330
331    let window_raw = match map.get("window") {
332        Some(VmValue::String(text)) => text.to_string(),
333        Some(other) => {
334            return Err(VmError::Runtime(format!(
335            "{HARN_CHN_005} trigger_register: batch.window must be a string like \"10m\", got {}",
336            other.type_name()
337        )))
338        }
339        None => {
340            return Err(VmError::Runtime(format!(
341                "{HARN_CHN_005} trigger_register: batch.window is required"
342            )))
343        }
344    };
345    let window = super::flow_control::parse_flow_control_duration(&window_raw).map_err(|err| {
346        VmError::Runtime(format!(
347            "{HARN_CHN_005} trigger_register: batch.window {err}"
348        ))
349    })?;
350
351    let key_path = match map.get("key") {
352        None | Some(VmValue::Nil) => None,
353        Some(VmValue::String(text)) => {
354            let trimmed = text.trim();
355            if trimmed.is_empty() {
356                None
357            } else {
358                Some(trimmed.to_string())
359            }
360        }
361        Some(other) => {
362            return Err(VmError::Runtime(format!(
363                "{HARN_CHN_005} trigger_register: batch.key must be a string JSON path, got {}",
364                other.type_name()
365            )))
366        }
367    };
368
369    let expire_action = match map.get("expire_action") {
370        None | Some(VmValue::Nil) => ExpireAction::FirePartial,
371        Some(VmValue::String(text)) => match text.as_ref() {
372            // Accept both names. The spec uses "fire" / "fire_partial"
373            // interchangeably for the "invoke handler with N<count
374            // events" case; "discard" drops the buffer.
375            "fire" | "fire_partial" => ExpireAction::FirePartial,
376            "discard" => ExpireAction::Discard,
377            other => {
378                return Err(VmError::Runtime(format!(
379                    "{HARN_CHN_005} trigger_register: unknown batch.expire_action '{other}', expected fire_partial|discard"
380                )))
381            }
382        },
383        Some(other) => {
384            return Err(VmError::Runtime(format!(
385                "{HARN_CHN_005} trigger_register: batch.expire_action must be a string, got {}",
386                other.type_name()
387            )))
388        }
389    };
390
391    Ok(Some(TriggerAggregationConfig {
392        count,
393        window,
394        key_path,
395        expire_action,
396    }))
397}
398
399/// Resolve the partition key for `event` against `config.key_path`.
400///
401/// Returns `None` when `key_path` is not set OR when the path doesn't
402/// resolve in the channel payload. A missing path collapses into the
403/// global ("" / `None`) bucket so a misconfigured key doesn't crash the
404/// emit; this matches `SpawnToPool`'s "missing path = default" pattern.
405pub fn partition_key_for_event(
406    config: &TriggerAggregationConfig,
407    payload: &JsonValue,
408) -> Option<String> {
409    let path = config.key_path.as_ref()?;
410    let value = json_path_lookup(payload, path)?;
411    Some(stringify_partition_key(value))
412}
413
414fn stringify_partition_key(value: &JsonValue) -> String {
415    match value {
416        JsonValue::String(text) => text.clone(),
417        JsonValue::Null => "null".to_string(),
418        JsonValue::Bool(value) => value.to_string(),
419        JsonValue::Number(value) => value.to_string(),
420        other => serde_json::to_string(other).unwrap_or_else(|_| "<unserializable>".to_string()),
421    }
422}
423
424fn json_path_lookup<'a>(value: &'a JsonValue, path: &str) -> Option<&'a JsonValue> {
425    let mut current = value;
426    for segment in path.split('.') {
427        if segment.is_empty() {
428            return None;
429        }
430        current = match current {
431            JsonValue::Object(map) => map.get(segment)?,
432            _ => return None,
433        };
434    }
435    Some(current)
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use crate::triggers::event::{GenericWebhookPayload, KnownProviderPayload};
442    use crate::triggers::{ProviderId, ProviderPayload, SignatureStatus};
443    use std::collections::BTreeMap;
444    use std::time::Duration;
445
446    fn mk_event(id: &str) -> TriggerEvent {
447        TriggerEvent::new(
448            ProviderId::from("channel"),
449            "channel.emit",
450            None,
451            id.to_string(),
452            None,
453            BTreeMap::new(),
454            ProviderPayload::Known(KnownProviderPayload::Webhook(GenericWebhookPayload {
455                source: Some("aggregation-test".to_string()),
456                content_type: Some("application/json".to_string()),
457                raw: serde_json::json!({"id": id}),
458            })),
459            SignatureStatus::Unsigned,
460        )
461    }
462
463    fn cfg(count: u32) -> TriggerAggregationConfig {
464        TriggerAggregationConfig {
465            count,
466            window: Duration::from_secs(60),
467            key_path: None,
468            expire_action: ExpireAction::FirePartial,
469        }
470    }
471
472    #[test]
473    fn accumulate_fires_when_count_reached() {
474        clear_aggregation_state();
475        let config = cfg(3);
476        for id in ["a", "b"] {
477            match accumulate("t1@v1", &config, None, mk_event(id)) {
478                AccumulateOutcome::Buffered => {}
479                AccumulateOutcome::Ready(_) => panic!("fired too early"),
480            }
481        }
482        let outcome = accumulate("t1@v1", &config, None, mk_event("c"));
483        match outcome {
484            AccumulateOutcome::Ready(events) => assert_eq!(events.len(), 3),
485            AccumulateOutcome::Buffered => panic!("should have fired"),
486        }
487        clear_aggregation_state();
488    }
489
490    #[test]
491    fn keyed_buffers_are_independent() {
492        clear_aggregation_state();
493        let config = cfg(2);
494        let _ = accumulate("t1@v1", &config, Some("repoA"), mk_event("a1"));
495        let _ = accumulate("t1@v1", &config, Some("repoB"), mk_event("b1"));
496        let a2 = accumulate("t1@v1", &config, Some("repoA"), mk_event("a2"));
497        let b2 = accumulate("t1@v1", &config, Some("repoB"), mk_event("b2"));
498        assert!(matches!(a2, AccumulateOutcome::Ready(_)));
499        assert!(matches!(b2, AccumulateOutcome::Ready(_)));
500        clear_aggregation_state();
501    }
502
503    #[test]
504    fn drop_binding_removes_buffers() {
505        clear_aggregation_state();
506        let config = cfg(5);
507        let _ = accumulate("t1@v1", &config, None, mk_event("a"));
508        let _ = accumulate("t1@v1", &config, None, mk_event("b"));
509        let leftover = drop_binding_aggregation("t1@v1");
510        assert_eq!(leftover.len(), 2);
511        // Re-accumulating after drop starts fresh.
512        let outcome = accumulate("t1@v1", &cfg(2), None, mk_event("c"));
513        assert!(matches!(outcome, AccumulateOutcome::Buffered));
514        clear_aggregation_state();
515    }
516}