1use std::cell::RefCell;
46use std::collections::BTreeMap;
47use std::time::Duration;
48
49use serde_json::Value as JsonValue;
50
51use crate::triggers::test_util::clock;
52use crate::value::{VmError, VmValue};
53
54use super::TriggerEvent;
55
56pub const MAX_BUFFER_EVENTS: usize = 1024;
62
63const HARN_CHN_005: &str = "HARN-CHN-005";
64
65#[derive(Clone, Copy, Debug, PartialEq, Eq)]
68pub enum ExpireAction {
69 FirePartial,
71 Discard,
73}
74
75impl ExpireAction {
76 pub fn as_str(self) -> &'static str {
77 match self {
78 Self::FirePartial => "fire_partial",
79 Self::Discard => "discard",
80 }
81 }
82}
83
84#[derive(Clone, Debug)]
88pub struct TriggerAggregationConfig {
89 pub count: u32,
90 pub window: Duration,
91 pub key_path: Option<String>,
95 pub expire_action: ExpireAction,
96}
97
98#[derive(Debug)]
100struct AggregationBuffer {
101 events: Vec<TriggerEvent>,
102 window_start_ms: i64,
107 window_ms: i64,
108 expire_action: ExpireAction,
109}
110
111impl AggregationBuffer {
112 fn new(window_ms: i64, expire_action: ExpireAction) -> Self {
113 Self {
114 events: Vec::new(),
115 window_start_ms: clock::now_ms(),
116 window_ms,
117 expire_action,
118 }
119 }
120
121 fn expired_at(&self, now_ms: i64) -> bool {
122 now_ms.saturating_sub(self.window_start_ms) >= self.window_ms
123 }
124}
125
126#[derive(Debug)]
129pub enum AccumulateOutcome {
130 Buffered,
132 Ready(Vec<TriggerEvent>),
136}
137
138#[derive(Debug)]
140pub struct ExpiredFlush {
141 pub binding_key: String,
142 pub partition_key: Option<String>,
143 pub action: ExpireAction,
144 pub events: Vec<TriggerEvent>,
145}
146
147#[derive(Default)]
148struct AggregationRegistry {
149 buffers: BTreeMap<String, BTreeMap<String, AggregationBuffer>>,
152}
153
154thread_local! {
155 static REGISTRY: RefCell<AggregationRegistry> =
156 RefCell::new(AggregationRegistry::default());
157}
158
159pub fn clear_aggregation_state() {
162 REGISTRY.with(|slot| {
163 *slot.borrow_mut() = AggregationRegistry::default();
164 });
165}
166
167pub fn drop_binding_aggregation(binding_key: &str) -> Vec<TriggerEvent> {
174 REGISTRY.with(|slot| {
175 let mut registry = slot.borrow_mut();
176 registry
177 .buffers
178 .remove(binding_key)
179 .into_iter()
180 .flat_map(|partitions| partitions.into_values())
181 .flat_map(|buffer| buffer.events.into_iter())
182 .collect()
183 })
184}
185
186pub fn accumulate(
191 binding_key: &str,
192 config: &TriggerAggregationConfig,
193 partition_key: Option<&str>,
194 event: TriggerEvent,
195) -> AccumulateOutcome {
196 let partition_key_owned = partition_key.unwrap_or("").to_string();
197 let window_ms = config.window.as_millis() as i64;
198 let count = config.count;
199 let expire_action = config.expire_action;
200
201 REGISTRY.with(|slot| {
202 let mut registry = slot.borrow_mut();
203 let partitions = registry.buffers.entry(binding_key.to_string()).or_default();
204 let buffer = partitions
205 .entry(partition_key_owned.clone())
206 .or_insert_with(|| AggregationBuffer::new(window_ms, expire_action));
207
208 if buffer.events.len() >= MAX_BUFFER_EVENTS {
212 let mut overflow_meta = std::collections::BTreeMap::new();
213 overflow_meta.insert("binding_key".to_string(), serde_json::json!(binding_key));
214 overflow_meta.insert(
215 "partition_key".to_string(),
216 serde_json::json!(partition_key.unwrap_or("")),
217 );
218 overflow_meta.insert(
219 "max_events".to_string(),
220 serde_json::json!(MAX_BUFFER_EVENTS),
221 );
222 crate::events::log_warn_meta(
223 "triggers.aggregation.buffer_overflow",
224 "aggregation buffer exceeded MAX_BUFFER_EVENTS; dropping oldest entry",
225 overflow_meta,
226 );
227 buffer.events.remove(0);
228 }
229
230 buffer.events.push(event);
231
232 if buffer.events.len() as u32 >= count {
233 let buffer = partitions
234 .remove(&partition_key_owned)
235 .expect("buffer just inserted");
236 if partitions.is_empty() {
239 registry.buffers.remove(binding_key);
240 }
241 return AccumulateOutcome::Ready(buffer.events);
242 }
243 AccumulateOutcome::Buffered
244 })
245}
246
247pub fn drain_expired_aggregations() -> Vec<ExpiredFlush> {
255 let now_ms = clock::now_ms();
256 REGISTRY.with(|slot| {
257 let mut registry = slot.borrow_mut();
258 let mut expired = Vec::new();
259 let mut empty_bindings = Vec::new();
260 for (binding_key, partitions) in registry.buffers.iter_mut() {
261 let expired_partition_keys: Vec<String> = partitions
262 .iter()
263 .filter(|(_, buffer)| buffer.expired_at(now_ms) && !buffer.events.is_empty())
264 .map(|(key, _)| key.clone())
265 .collect();
266 for partition_key in expired_partition_keys {
267 let buffer = partitions
268 .remove(&partition_key)
269 .expect("partition just observed");
270 let action = buffer.expire_action;
271 expired.push(ExpiredFlush {
272 binding_key: binding_key.clone(),
273 partition_key: if partition_key.is_empty() {
274 None
275 } else {
276 Some(partition_key)
277 },
278 action,
279 events: buffer.events,
280 });
281 }
282 if partitions.is_empty() {
283 empty_bindings.push(binding_key.clone());
284 }
285 }
286 for binding_key in empty_bindings {
287 registry.buffers.remove(&binding_key);
288 }
289 expired
290 })
291}
292
293pub fn parse_aggregation_config(
298 raw: &VmValue,
299) -> Result<Option<TriggerAggregationConfig>, VmError> {
300 let map = match raw {
301 VmValue::Nil => return Ok(None),
302 VmValue::Dict(map) => map,
303 other => {
304 return Err(VmError::Runtime(format!(
305 "{HARN_CHN_005} trigger_register: `batch` must be a dict, got {}",
306 other.type_name()
307 )))
308 }
309 };
310
311 let count = map
312 .get("count")
313 .ok_or_else(|| {
314 VmError::Runtime(format!(
315 "{HARN_CHN_005} trigger_register: batch.count is required"
316 ))
317 })?
318 .as_int()
319 .ok_or_else(|| {
320 VmError::Runtime(format!(
321 "{HARN_CHN_005} trigger_register: batch.count must be a positive integer"
322 ))
323 })?;
324 if count <= 0 {
325 return Err(VmError::Runtime(format!(
326 "{HARN_CHN_005} trigger_register: batch.count must be greater than 0, got {count}"
327 )));
328 }
329 let count = count as u32;
330
331 let window_raw = match map.get("window") {
332 Some(VmValue::String(text)) => text.to_string(),
333 Some(other) => {
334 return Err(VmError::Runtime(format!(
335 "{HARN_CHN_005} trigger_register: batch.window must be a string like \"10m\", got {}",
336 other.type_name()
337 )))
338 }
339 None => {
340 return Err(VmError::Runtime(format!(
341 "{HARN_CHN_005} trigger_register: batch.window is required"
342 )))
343 }
344 };
345 let window = super::flow_control::parse_flow_control_duration(&window_raw).map_err(|err| {
346 VmError::Runtime(format!(
347 "{HARN_CHN_005} trigger_register: batch.window {err}"
348 ))
349 })?;
350
351 let key_path = match map.get("key") {
352 None | Some(VmValue::Nil) => None,
353 Some(VmValue::String(text)) => {
354 let trimmed = text.trim();
355 if trimmed.is_empty() {
356 None
357 } else {
358 Some(trimmed.to_string())
359 }
360 }
361 Some(other) => {
362 return Err(VmError::Runtime(format!(
363 "{HARN_CHN_005} trigger_register: batch.key must be a string JSON path, got {}",
364 other.type_name()
365 )))
366 }
367 };
368
369 let expire_action = match map.get("expire_action") {
370 None | Some(VmValue::Nil) => ExpireAction::FirePartial,
371 Some(VmValue::String(text)) => match text.as_ref() {
372 "fire" | "fire_partial" => ExpireAction::FirePartial,
376 "discard" => ExpireAction::Discard,
377 other => {
378 return Err(VmError::Runtime(format!(
379 "{HARN_CHN_005} trigger_register: unknown batch.expire_action '{other}', expected fire_partial|discard"
380 )))
381 }
382 },
383 Some(other) => {
384 return Err(VmError::Runtime(format!(
385 "{HARN_CHN_005} trigger_register: batch.expire_action must be a string, got {}",
386 other.type_name()
387 )))
388 }
389 };
390
391 Ok(Some(TriggerAggregationConfig {
392 count,
393 window,
394 key_path,
395 expire_action,
396 }))
397}
398
399pub fn partition_key_for_event(
406 config: &TriggerAggregationConfig,
407 payload: &JsonValue,
408) -> Option<String> {
409 let path = config.key_path.as_ref()?;
410 let value = json_path_lookup(payload, path)?;
411 Some(stringify_partition_key(value))
412}
413
414fn stringify_partition_key(value: &JsonValue) -> String {
415 match value {
416 JsonValue::String(text) => text.clone(),
417 JsonValue::Null => "null".to_string(),
418 JsonValue::Bool(value) => value.to_string(),
419 JsonValue::Number(value) => value.to_string(),
420 other => serde_json::to_string(other).unwrap_or_else(|_| "<unserializable>".to_string()),
421 }
422}
423
424fn json_path_lookup<'a>(value: &'a JsonValue, path: &str) -> Option<&'a JsonValue> {
425 let mut current = value;
426 for segment in path.split('.') {
427 if segment.is_empty() {
428 return None;
429 }
430 current = match current {
431 JsonValue::Object(map) => map.get(segment)?,
432 _ => return None,
433 };
434 }
435 Some(current)
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441 use crate::triggers::event::{GenericWebhookPayload, KnownProviderPayload};
442 use crate::triggers::{ProviderId, ProviderPayload, SignatureStatus};
443 use std::collections::BTreeMap;
444 use std::time::Duration;
445
446 fn mk_event(id: &str) -> TriggerEvent {
447 TriggerEvent::new(
448 ProviderId::from("channel"),
449 "channel.emit",
450 None,
451 id.to_string(),
452 None,
453 BTreeMap::new(),
454 ProviderPayload::Known(KnownProviderPayload::Webhook(GenericWebhookPayload {
455 source: Some("aggregation-test".to_string()),
456 content_type: Some("application/json".to_string()),
457 raw: serde_json::json!({"id": id}),
458 })),
459 SignatureStatus::Unsigned,
460 )
461 }
462
463 fn cfg(count: u32) -> TriggerAggregationConfig {
464 TriggerAggregationConfig {
465 count,
466 window: Duration::from_secs(60),
467 key_path: None,
468 expire_action: ExpireAction::FirePartial,
469 }
470 }
471
472 #[test]
473 fn accumulate_fires_when_count_reached() {
474 clear_aggregation_state();
475 let config = cfg(3);
476 for id in ["a", "b"] {
477 match accumulate("t1@v1", &config, None, mk_event(id)) {
478 AccumulateOutcome::Buffered => {}
479 AccumulateOutcome::Ready(_) => panic!("fired too early"),
480 }
481 }
482 let outcome = accumulate("t1@v1", &config, None, mk_event("c"));
483 match outcome {
484 AccumulateOutcome::Ready(events) => assert_eq!(events.len(), 3),
485 AccumulateOutcome::Buffered => panic!("should have fired"),
486 }
487 clear_aggregation_state();
488 }
489
490 #[test]
491 fn keyed_buffers_are_independent() {
492 clear_aggregation_state();
493 let config = cfg(2);
494 let _ = accumulate("t1@v1", &config, Some("repoA"), mk_event("a1"));
495 let _ = accumulate("t1@v1", &config, Some("repoB"), mk_event("b1"));
496 let a2 = accumulate("t1@v1", &config, Some("repoA"), mk_event("a2"));
497 let b2 = accumulate("t1@v1", &config, Some("repoB"), mk_event("b2"));
498 assert!(matches!(a2, AccumulateOutcome::Ready(_)));
499 assert!(matches!(b2, AccumulateOutcome::Ready(_)));
500 clear_aggregation_state();
501 }
502
503 #[test]
504 fn drop_binding_removes_buffers() {
505 clear_aggregation_state();
506 let config = cfg(5);
507 let _ = accumulate("t1@v1", &config, None, mk_event("a"));
508 let _ = accumulate("t1@v1", &config, None, mk_event("b"));
509 let leftover = drop_binding_aggregation("t1@v1");
510 assert_eq!(leftover.len(), 2);
511 let outcome = accumulate("t1@v1", &cfg(2), None, mk_event("c"));
513 assert!(matches!(outcome, AccumulateOutcome::Buffered));
514 clear_aggregation_state();
515 }
516}