use std::sync::Mutex;
use std::time::{Duration, Instant};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::common::message::Message;
use crate::error::{Error, Result};
use crate::transform::step::Step;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WindowStepConfig {
#[serde(default, deserialize_with = "deserialize_duration_opt")]
pub duration: Option<Duration>,
#[serde(default)]
pub size: Option<usize>,
#[serde(default)]
pub operation: WindowOperation,
#[serde(default)]
pub strategy: SelectStrategy,
#[serde(default = "default_max_messages")]
pub max_messages: usize,
#[serde(default)]
pub on_overflow: OverflowStrategy,
}
fn default_max_messages() -> usize {
10000
}
fn deserialize_duration_opt<'de, D>(
deserializer: D,
) -> std::result::Result<Option<Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
let opt: Option<String> = Option::deserialize(deserializer)?;
match opt {
None => Ok(None),
Some(s) => parse_duration(&s)
.map(Some)
.map_err(serde::de::Error::custom),
}
}
fn parse_duration(s: &str) -> Result<Duration> {
let s = s.trim();
if s.is_empty() {
return Err(Error::config("Empty duration string"));
}
let (num_str, unit) = s
.char_indices()
.find(|(_, c)| !c.is_ascii_digit())
.map(|(i, _)| (&s[..i], &s[i..]))
.unwrap_or((s, ""));
let num: u64 = num_str
.parse()
.map_err(|_| Error::config(format!("Invalid duration number: {}", num_str)))?;
let duration = match unit.trim() {
"ms" => Duration::from_millis(num),
"s" | "" => Duration::from_secs(num),
"m" => Duration::from_secs(num * 60),
"h" => Duration::from_secs(num * 3600),
other => return Err(Error::config(format!("Unknown duration unit: {}", other))),
};
Ok(duration)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum WindowOperation {
#[default]
Merge,
SelectOne,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum SelectStrategy {
#[default]
First,
Last,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum OverflowStrategy {
#[default]
DropOldest,
Error,
}
struct WindowState {
buffer: Vec<Message>,
last_emit: Instant,
}
impl WindowState {
fn new() -> Self {
Self {
buffer: Vec::new(),
last_emit: Instant::now(),
}
}
}
pub struct WindowStep {
duration: Option<Duration>,
size: Option<usize>,
operation: WindowOperation,
strategy: SelectStrategy,
max_messages: usize,
on_overflow: OverflowStrategy,
state: Mutex<WindowState>,
}
impl std::fmt::Debug for WindowStep {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WindowStep")
.field("duration", &self.duration)
.field("size", &self.size)
.field("operation", &self.operation)
.field("strategy", &self.strategy)
.field("max_messages", &self.max_messages)
.field("on_overflow", &self.on_overflow)
.finish_non_exhaustive()
}
}
impl WindowStep {
pub fn new(config: WindowStepConfig) -> Result<Self> {
if config.duration.is_none() && config.size.is_none() {
return Err(Error::config(
"Window step requires at least one trigger: 'duration' or 'size'",
));
}
if config.max_messages == 0 {
return Err(Error::config("max_messages must be greater than 0"));
}
if let Some(size) = config.size
&& size > config.max_messages
{
return Err(Error::config(format!(
"size ({}) cannot exceed max_messages ({})",
size, config.max_messages
)));
}
Ok(Self {
duration: config.duration,
size: config.size,
operation: config.operation,
strategy: config.strategy,
max_messages: config.max_messages,
on_overflow: config.on_overflow,
state: Mutex::new(WindowState::new()),
})
}
fn should_emit(&self, state: &WindowState) -> bool {
if let Some(size) = self.size
&& state.buffer.len() >= size
{
return true;
}
if let Some(duration) = self.duration
&& state.last_emit.elapsed() >= duration
{
return true;
}
false
}
fn aggregate(&self, messages: Vec<Message>) -> Option<Message> {
if messages.is_empty() {
return None;
}
match self.operation {
WindowOperation::Merge => Some(self.merge_messages(messages)),
WindowOperation::SelectOne => self.select_one(messages),
}
}
fn merge_messages(&self, messages: Vec<Message>) -> Message {
let mut merged = serde_json::Map::new();
let first_meta = messages[0].meta.clone();
let payloads: Vec<Value> = messages.into_iter().map(|m| m.payload).collect();
let all_objects = payloads.iter().all(|p| p.is_object());
let merged_payload = if all_objects {
for payload in payloads {
if let Value::Object(obj) = payload {
for (k, v) in obj {
merged.insert(k, v);
}
}
}
Value::Object(merged)
} else {
Value::Array(payloads)
};
Message {
meta: first_meta,
payload: merged_payload,
}
}
fn select_one(&self, mut messages: Vec<Message>) -> Option<Message> {
match self.strategy {
SelectStrategy::First => messages.into_iter().next(),
SelectStrategy::Last => messages.pop(),
}
}
}
impl Step for WindowStep {
fn step_type(&self) -> &'static str {
"window"
}
fn process(&self, msg: Message) -> Result<Option<Message>> {
let mut state = self
.state
.lock()
.map_err(|_| Error::transform("Lock poisoned"))?;
if state.buffer.len() >= self.max_messages {
match self.on_overflow {
OverflowStrategy::DropOldest => {
state.buffer.remove(0);
tracing::debug!(step = "window", "Buffer full, dropped oldest message");
}
OverflowStrategy::Error => {
return Err(Error::transform(format!(
"Window buffer full (max_messages={})",
self.max_messages
)));
}
}
}
state.buffer.push(msg);
if self.should_emit(&state) {
let messages = std::mem::take(&mut state.buffer);
state.last_emit = Instant::now();
Ok(self.aggregate(messages))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn make_msg(payload: Value) -> Message {
Message::new("test", payload)
}
#[test]
fn test_parse_duration_seconds() {
assert_eq!(parse_duration("30s").unwrap(), Duration::from_secs(30));
assert_eq!(parse_duration("1s").unwrap(), Duration::from_secs(1));
}
#[test]
fn test_parse_duration_milliseconds() {
assert_eq!(parse_duration("500ms").unwrap(), Duration::from_millis(500));
}
#[test]
fn test_parse_duration_minutes() {
assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
}
#[test]
fn test_parse_duration_hours() {
assert_eq!(parse_duration("2h").unwrap(), Duration::from_secs(7200));
}
#[test]
fn test_config_requires_trigger() {
let config = WindowStepConfig {
duration: None,
size: None,
operation: WindowOperation::Merge,
strategy: SelectStrategy::First,
max_messages: 1000,
on_overflow: OverflowStrategy::DropOldest,
};
let result = WindowStep::new(config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("trigger"));
}
#[test]
fn test_config_size_exceeds_max() {
let config = WindowStepConfig {
duration: None,
size: Some(100),
operation: WindowOperation::Merge,
strategy: SelectStrategy::First,
max_messages: 50,
on_overflow: OverflowStrategy::DropOldest,
};
let result = WindowStep::new(config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceed"));
}
#[test]
fn test_window_count_trigger() {
let config = WindowStepConfig {
duration: None,
size: Some(3),
operation: WindowOperation::Merge,
strategy: SelectStrategy::First,
max_messages: 1000,
on_overflow: OverflowStrategy::DropOldest,
};
let step = WindowStep::new(config).unwrap();
assert!(step.process(make_msg(json!({"a": 1}))).unwrap().is_none());
assert!(step.process(make_msg(json!({"b": 2}))).unwrap().is_none());
let result = step.process(make_msg(json!({"c": 3}))).unwrap();
assert!(result.is_some());
let output = result.unwrap();
assert_eq!(output.payload["a"], 1);
assert_eq!(output.payload["b"], 2);
assert_eq!(output.payload["c"], 3);
}
#[test]
fn test_window_select_first() {
let config = WindowStepConfig {
duration: None,
size: Some(3),
operation: WindowOperation::SelectOne,
strategy: SelectStrategy::First,
max_messages: 1000,
on_overflow: OverflowStrategy::DropOldest,
};
let step = WindowStep::new(config).unwrap();
step.process(make_msg(json!({"value": "first"}))).unwrap();
step.process(make_msg(json!({"value": "second"}))).unwrap();
let result = step.process(make_msg(json!({"value": "third"}))).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().payload["value"], "first");
}
#[test]
fn test_window_select_last() {
let config = WindowStepConfig {
duration: None,
size: Some(3),
operation: WindowOperation::SelectOne,
strategy: SelectStrategy::Last,
max_messages: 1000,
on_overflow: OverflowStrategy::DropOldest,
};
let step = WindowStep::new(config).unwrap();
step.process(make_msg(json!({"value": "first"}))).unwrap();
step.process(make_msg(json!({"value": "second"}))).unwrap();
let result = step.process(make_msg(json!({"value": "third"}))).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().payload["value"], "third");
}
#[test]
fn test_overflow_drop_oldest() {
let config = WindowStepConfig {
duration: Some(Duration::from_secs(3600)), size: Some(3),
operation: WindowOperation::SelectOne,
strategy: SelectStrategy::First,
max_messages: 3,
on_overflow: OverflowStrategy::DropOldest,
};
let step = WindowStep::new(config).unwrap();
step.process(make_msg(json!({"value": 1}))).unwrap();
step.process(make_msg(json!({"value": 2}))).unwrap();
let result = step.process(make_msg(json!({"value": 3}))).unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().payload["value"], 1); }
#[test]
fn test_overflow_error() {
let config = WindowStepConfig {
duration: Some(Duration::from_secs(3600)), size: None,
operation: WindowOperation::Merge,
strategy: SelectStrategy::First,
max_messages: 2,
on_overflow: OverflowStrategy::Error,
};
let step = WindowStep::new(config).unwrap();
step.process(make_msg(json!({"a": 1}))).unwrap();
step.process(make_msg(json!({"b": 2}))).unwrap();
let result = step.process(make_msg(json!({"c": 3})));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("buffer full"));
}
#[test]
fn test_merge_non_objects_as_array() {
let config = WindowStepConfig {
duration: None,
size: Some(3),
operation: WindowOperation::Merge,
strategy: SelectStrategy::First,
max_messages: 1000,
on_overflow: OverflowStrategy::DropOldest,
};
let step = WindowStep::new(config).unwrap();
step.process(make_msg(json!(1))).unwrap();
step.process(make_msg(json!(2))).unwrap();
let result = step.process(make_msg(json!(3))).unwrap();
assert!(result.is_some());
let payload = result.unwrap().payload;
assert!(payload.is_array());
assert_eq!(payload.as_array().unwrap().len(), 3);
}
}