use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ViewSignal {
pub item_id: String,
pub timestamp_secs: u64,
pub view_count_delta: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum TrendClass {
Rising,
Stable,
Declining,
Viral(f32),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrendingScore {
pub item_id: String,
pub velocity: f32,
pub acceleration: f32,
pub trend_class: TrendClass,
}
const VELOCITY_HALF_LIFE_SECS: f64 = 300.0;
const ACCELERATION_HALF_LIFE_SECS: f64 = 600.0;
const MIN_ELAPSED_SECS: f64 = 1.0;
#[derive(Debug, Clone)]
struct ItemState {
velocity_ema: f64,
acceleration_ema: f64,
last_ts: u64,
}
impl ItemState {
fn new(ts: u64) -> Self {
Self {
velocity_ema: 0.0,
acceleration_ema: 0.0,
last_ts: ts,
}
}
}
pub struct TrendingDetector {
pub time_window_secs: u64,
pub viral_threshold: f32,
item_states: HashMap<String, ItemState>,
global_latest_ts: u64,
}
impl TrendingDetector {
#[must_use]
pub fn new(time_window_secs: u64, viral_threshold: f32) -> Self {
Self {
time_window_secs,
viral_threshold,
item_states: HashMap::new(),
global_latest_ts: 0,
}
}
pub fn update(&mut self, signal: ViewSignal) {
if signal.timestamp_secs > self.global_latest_ts {
self.global_latest_ts = signal.timestamp_secs;
}
let state = self
.item_states
.entry(signal.item_id.clone())
.or_insert_with(|| ItemState::new(signal.timestamp_secs));
if signal.timestamp_secs < state.last_ts {
return;
}
let elapsed = (signal.timestamp_secs - state.last_ts).max(1) as f64;
let elapsed_clamped = elapsed.max(MIN_ELAPSED_SECS);
let v_inst = signal.view_count_delta as f64 / elapsed_clamped;
let alpha_v = 1.0 - f64::exp(-elapsed_clamped * f64::ln(2.0) / VELOCITY_HALF_LIFE_SECS);
let prev_velocity = state.velocity_ema;
state.velocity_ema = alpha_v * v_inst + (1.0 - alpha_v) * prev_velocity;
let a_inst = (state.velocity_ema - prev_velocity) / elapsed_clamped;
let alpha_a = 1.0 - f64::exp(-elapsed_clamped * f64::ln(2.0) / ACCELERATION_HALF_LIFE_SECS);
state.acceleration_ema = alpha_a * a_inst + (1.0 - alpha_a) * state.acceleration_ema;
state.last_ts = signal.timestamp_secs;
}
#[must_use]
pub fn trending_items(&self, n: usize) -> Vec<TrendingScore> {
if n == 0 {
return Vec::new();
}
let mut scores: Vec<TrendingScore> = self
.item_states
.iter()
.map(|(id, state)| self.build_score(id, state))
.collect();
scores.sort_by(|a, b| {
let a_in = self.is_in_window(&self.item_states[&a.item_id]);
let b_in = self.is_in_window(&self.item_states[&b.item_id]);
match (a_in, b_in) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b
.velocity
.partial_cmp(&a.velocity)
.unwrap_or(std::cmp::Ordering::Equal),
}
});
scores.truncate(n);
scores
}
#[must_use]
pub fn detect_viral(&self) -> Vec<TrendingScore> {
let mut viral: Vec<TrendingScore> = self
.item_states
.iter()
.filter_map(|(id, state)| {
if state.acceleration_ema as f32 > self.viral_threshold {
Some(self.build_score(id, state))
} else {
None
}
})
.collect();
viral.sort_by(|a, b| {
b.acceleration
.partial_cmp(&a.acceleration)
.unwrap_or(std::cmp::Ordering::Equal)
});
viral
}
fn is_in_window(&self, state: &ItemState) -> bool {
self.global_latest_ts.saturating_sub(state.last_ts) <= self.time_window_secs
}
fn build_score(&self, id: &str, state: &ItemState) -> TrendingScore {
let velocity = state.velocity_ema as f32;
let acceleration = state.acceleration_ema as f32;
let trend_class = self.classify(velocity, acceleration);
TrendingScore {
item_id: id.to_string(),
velocity,
acceleration,
trend_class,
}
}
fn classify(&self, velocity: f32, acceleration: f32) -> TrendClass {
if acceleration > self.viral_threshold {
let multiplier = if velocity.abs() > f32::EPSILON {
(acceleration / velocity).abs()
} else {
acceleration.abs()
};
TrendClass::Viral(multiplier)
} else if acceleration > 0.0 {
TrendClass::Rising
} else if acceleration < 0.0 {
TrendClass::Declining
} else {
TrendClass::Stable
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sig(item_id: &str, ts: u64, delta: u32) -> ViewSignal {
ViewSignal {
item_id: item_id.to_string(),
timestamp_secs: ts,
view_count_delta: delta,
}
}
#[test]
fn test_steady_views_stable() {
let mut det = TrendingDetector::new(3600, 5.0);
for i in 0..10u64 {
det.update(sig("steady", 1000 + i * 60, 100));
}
let top = det.trending_items(1);
assert_eq!(top.len(), 1);
assert!(
matches!(top[0].trend_class, TrendClass::Stable | TrendClass::Rising),
"steady stream should be Stable or Rising, got {:?}",
top[0].trend_class
);
}
#[test]
fn test_spike_triggers_viral() {
let mut det = TrendingDetector::new(3600, 0.01);
for i in 0..5u64 {
det.update(sig("spike_item", 1000 + i * 60, 10));
}
det.update(sig("spike_item", 1300, 50_000));
det.update(sig("spike_item", 1360, 50_000));
let viral = det.detect_viral();
assert!(
viral.iter().any(|v| v.item_id == "spike_item"),
"spike_item should be detected as viral (got: {viral:?})"
);
}
#[test]
fn test_declining_traffic() {
let mut det = TrendingDetector::new(7200, 1_000_000.0);
for i in 0..20u64 {
let views = 2000u32.saturating_sub(i as u32 * 100);
det.update(sig("fading", 1_000_000 + i * 300, views));
}
let top = det.trending_items(1);
assert_eq!(top.len(), 1);
assert_eq!(top[0].item_id, "fading", "fading should be top item");
assert!(
matches!(
top[0].trend_class,
TrendClass::Declining | TrendClass::Stable
),
"declining traffic should be Declining or Stable, got {:?}",
top[0].trend_class
);
}
#[test]
fn test_top_n_order_by_velocity() {
let mut det = TrendingDetector::new(3600, 100.0);
for i in 0..5u64 {
det.update(sig("fast", 1000 + i * 60, 500));
det.update(sig("medium", 1000 + i * 60, 200));
det.update(sig("slow", 1000 + i * 60, 50));
}
let top = det.trending_items(3);
assert_eq!(top.len(), 3);
assert_eq!(top[0].item_id, "fast");
assert_eq!(top[1].item_id, "medium");
assert_eq!(top[2].item_id, "slow");
}
#[test]
fn test_n_zero_returns_empty() {
let mut det = TrendingDetector::new(3600, 5.0);
det.update(sig("x", 1000, 100));
assert!(det.trending_items(0).is_empty());
}
#[test]
fn test_no_signals_empty() {
let det = TrendingDetector::new(3600, 5.0);
assert!(det.trending_items(10).is_empty());
assert!(det.detect_viral().is_empty());
}
#[test]
fn test_detect_viral_threshold_respected() {
let mut det = TrendingDetector::new(3600, 0.1);
for i in 0..5u64 {
det.update(sig("normal", 1000 + i * 60, 50));
}
det.update(sig("viral", 2000, 1));
det.update(sig("viral", 2060, 500_000));
let viral = det.detect_viral();
assert!(
viral.iter().all(|v| v.acceleration > det.viral_threshold),
"all viral items should have acceleration above threshold"
);
}
#[test]
fn test_out_of_order_signals_ignored() {
let mut det = TrendingDetector::new(3600, 5.0);
det.update(sig("item", 2000, 100));
det.update(sig("item", 1000, 999)); det.update(sig("item", 2060, 100));
let top = det.trending_items(1);
assert_eq!(top.len(), 1);
assert!(top[0].velocity >= 0.0);
}
#[test]
fn test_viral_multiplier_positive() {
let mut det = TrendingDetector::new(3600, 0.01);
det.update(sig("v", 1000, 1));
det.update(sig("v", 1060, 100_000));
let viral = det.detect_viral();
for ts in &viral {
if let TrendClass::Viral(m) = ts.trend_class {
assert!(m > 0.0, "viral multiplier should be positive, got {m}");
}
}
}
#[test]
fn test_trending_items_length_capped() {
let mut det = TrendingDetector::new(3600, 100.0);
for idx in 0..20u64 {
det.update(sig(
&format!("item_{idx}"),
1000 + idx * 10,
(idx as u32 + 1) * 10,
));
}
let top = det.trending_items(5);
assert_eq!(top.len(), 5);
}
}