1use crate::config::sinks::SinkType;
29use crate::events::sinks::Sink;
30use anyhow::{Result, anyhow};
31use async_trait::async_trait;
32use serde_json::Value;
33use std::collections::HashMap;
34use std::sync::Arc;
35use tokio::sync::{Mutex, RwLock};
36
37#[derive(Debug, Clone, PartialEq)]
39pub enum CounterOperation {
40 Increment,
42 Decrement,
44 Set,
46}
47
48impl CounterOperation {
49 pub fn parse(s: &str) -> Result<Self> {
51 match s {
52 "increment" | "inc" | "add" => Ok(Self::Increment),
53 "decrement" | "dec" | "sub" | "subtract" => Ok(Self::Decrement),
54 "set" => Ok(Self::Set),
55 _ => Err(anyhow!(
56 "invalid counter operation '{}': expected 'increment', 'decrement', or 'set'",
57 s
58 )),
59 }
60 }
61
62 pub fn apply(&self, current: f64, amount: f64) -> f64 {
64 match self {
65 Self::Increment => current + amount,
66 Self::Decrement => (current - amount).max(0.0), Self::Set => amount,
68 }
69 }
70}
71
72#[async_trait]
77pub trait EntityFieldUpdater: Send + Sync + std::fmt::Debug {
78 async fn read_field(&self, entity_type: &str, entity_id: &str, field: &str) -> Result<f64>;
82
83 async fn write_field(
85 &self,
86 entity_type: &str,
87 entity_id: &str,
88 field: &str,
89 value: f64,
90 ) -> Result<()>;
91}
92
93#[derive(Debug, Clone)]
95pub struct CounterConfig {
96 pub field: String,
98
99 pub operation: CounterOperation,
101}
102
103#[derive(Debug)]
111pub struct CounterSink {
112 config: CounterConfig,
114
115 updater: Arc<dyn EntityFieldUpdater>,
117
118 key_locks: RwLock<HashMap<String, Arc<Mutex<()>>>>,
121}
122
123impl CounterSink {
124 pub fn new(updater: Arc<dyn EntityFieldUpdater>, config: CounterConfig) -> Self {
126 Self {
127 config,
128 updater,
129 key_locks: RwLock::new(HashMap::new()),
130 }
131 }
132
133 async fn get_lock(&self, key: &str) -> Arc<Mutex<()>> {
135 {
137 let locks = self.key_locks.read().await;
138 if let Some(lock) = locks.get(key) {
139 return lock.clone();
140 }
141 }
142
143 let mut locks = self.key_locks.write().await;
145 locks
147 .entry(key.to_string())
148 .or_insert_with(|| Arc::new(Mutex::new(())))
149 .clone()
150 }
151}
152
153#[async_trait]
154impl Sink for CounterSink {
155 async fn deliver(
156 &self,
157 payload: Value,
158 _recipient_id: Option<&str>,
159 context_vars: &HashMap<String, Value>,
160 ) -> Result<()> {
161 let entity_type = payload
163 .get("entity_type")
164 .and_then(|v| v.as_str())
165 .or_else(|| context_vars.get("entity_type").and_then(|v| v.as_str()))
166 .ok_or_else(|| anyhow!("counter sink: entity_type not found in payload or context"))?
167 .to_string();
168
169 let entity_id = payload
170 .get("entity_id")
171 .and_then(|v| v.as_str())
172 .or_else(|| context_vars.get("entity_id").and_then(|v| v.as_str()))
173 .ok_or_else(|| anyhow!("counter sink: entity_id not found in payload or context"))?
174 .to_string();
175
176 let field = payload
178 .get("field")
179 .and_then(|v| v.as_str())
180 .unwrap_or(&self.config.field)
181 .to_string();
182
183 let operation = if let Some(op_str) = payload.get("operation").and_then(|v| v.as_str()) {
185 CounterOperation::parse(op_str)?
186 } else {
187 self.config.operation.clone()
188 };
189
190 let amount = payload.get("value").and_then(|v| v.as_f64()).unwrap_or(1.0);
192
193 let lock_key = format!("{}:{}:{}", entity_type, entity_id, field);
195 let lock = self.get_lock(&lock_key).await;
196 let _guard = lock.lock().await;
197
198 let current = self
200 .updater
201 .read_field(&entity_type, &entity_id, &field)
202 .await?;
203
204 let new_value = operation.apply(current, amount);
206
207 tracing::debug!(
208 entity_type = %entity_type,
209 entity_id = %entity_id,
210 field = %field,
211 current = current,
212 operation = ?operation,
213 amount = amount,
214 new_value = new_value,
215 "counter sink: updating field"
216 );
217
218 self.updater
220 .write_field(&entity_type, &entity_id, &field, new_value)
221 .await?;
222
223 Ok(())
224 }
225
226 fn name(&self) -> &str {
227 "counter"
228 }
229
230 fn sink_type(&self) -> SinkType {
231 SinkType::Counter
232 }
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238 use serde_json::json;
239 use tokio::sync::RwLock;
240
241 #[derive(Debug)]
243 struct MockEntityStore {
244 fields: RwLock<HashMap<String, f64>>,
246 }
247
248 impl MockEntityStore {
249 fn new() -> Self {
250 Self {
251 fields: RwLock::new(HashMap::new()),
252 }
253 }
254
255 fn key(entity_type: &str, entity_id: &str, field: &str) -> String {
256 format!("{}:{}:{}", entity_type, entity_id, field)
257 }
258
259 async fn set(&self, entity_type: &str, entity_id: &str, field: &str, value: f64) {
260 self.fields
261 .write()
262 .await
263 .insert(Self::key(entity_type, entity_id, field), value);
264 }
265 }
266
267 #[async_trait]
268 impl EntityFieldUpdater for MockEntityStore {
269 async fn read_field(&self, entity_type: &str, entity_id: &str, field: &str) -> Result<f64> {
270 let store = self.fields.read().await;
271 Ok(*store
272 .get(&Self::key(entity_type, entity_id, field))
273 .unwrap_or(&0.0))
274 }
275
276 async fn write_field(
277 &self,
278 entity_type: &str,
279 entity_id: &str,
280 field: &str,
281 value: f64,
282 ) -> Result<()> {
283 self.fields
284 .write()
285 .await
286 .insert(Self::key(entity_type, entity_id, field), value);
287 Ok(())
288 }
289 }
290
291 fn increment_config(field: &str) -> CounterConfig {
292 CounterConfig {
293 field: field.to_string(),
294 operation: CounterOperation::Increment,
295 }
296 }
297
298 #[tokio::test]
299 async fn test_counter_increment() {
300 let store = Arc::new(MockEntityStore::new());
301 store.set("capture", "cap-1", "like_count", 5.0).await;
302
303 let sink = CounterSink::new(store.clone(), increment_config("like_count"));
304
305 let payload = json!({
306 "entity_type": "capture",
307 "entity_id": "cap-1"
308 });
309
310 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
311
312 let value = store
313 .read_field("capture", "cap-1", "like_count")
314 .await
315 .unwrap();
316 assert_eq!(value, 6.0);
317 }
318
319 #[tokio::test]
320 async fn test_counter_increment_from_zero() {
321 let store = Arc::new(MockEntityStore::new());
322 let sink = CounterSink::new(store.clone(), increment_config("like_count"));
323
324 let payload = json!({
325 "entity_type": "capture",
326 "entity_id": "cap-1"
327 });
328
329 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
330
331 let value = store
332 .read_field("capture", "cap-1", "like_count")
333 .await
334 .unwrap();
335 assert_eq!(value, 1.0);
336 }
337
338 #[tokio::test]
339 async fn test_counter_decrement() {
340 let store = Arc::new(MockEntityStore::new());
341 store.set("capture", "cap-1", "like_count", 5.0).await;
342
343 let sink = CounterSink::new(
344 store.clone(),
345 CounterConfig {
346 field: "like_count".to_string(),
347 operation: CounterOperation::Decrement,
348 },
349 );
350
351 let payload = json!({
352 "entity_type": "capture",
353 "entity_id": "cap-1"
354 });
355
356 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
357
358 let value = store
359 .read_field("capture", "cap-1", "like_count")
360 .await
361 .unwrap();
362 assert_eq!(value, 4.0);
363 }
364
365 #[tokio::test]
366 async fn test_counter_decrement_floor_at_zero() {
367 let store = Arc::new(MockEntityStore::new());
368 store.set("capture", "cap-1", "like_count", 0.0).await;
369
370 let sink = CounterSink::new(
371 store.clone(),
372 CounterConfig {
373 field: "like_count".to_string(),
374 operation: CounterOperation::Decrement,
375 },
376 );
377
378 let payload = json!({
379 "entity_type": "capture",
380 "entity_id": "cap-1"
381 });
382
383 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
384
385 let value = store
386 .read_field("capture", "cap-1", "like_count")
387 .await
388 .unwrap();
389 assert_eq!(value, 0.0); }
391
392 #[tokio::test]
393 async fn test_counter_set() {
394 let store = Arc::new(MockEntityStore::new());
395 store.set("capture", "cap-1", "like_count", 5.0).await;
396
397 let sink = CounterSink::new(
398 store.clone(),
399 CounterConfig {
400 field: "like_count".to_string(),
401 operation: CounterOperation::Set,
402 },
403 );
404
405 let payload = json!({
406 "entity_type": "capture",
407 "entity_id": "cap-1",
408 "value": 42
409 });
410
411 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
412
413 let value = store
414 .read_field("capture", "cap-1", "like_count")
415 .await
416 .unwrap();
417 assert_eq!(value, 42.0);
418 }
419
420 #[tokio::test]
421 async fn test_counter_custom_amount() {
422 let store = Arc::new(MockEntityStore::new());
423 store.set("user", "u-1", "follower_count", 10.0).await;
424
425 let sink = CounterSink::new(store.clone(), increment_config("follower_count"));
426
427 let payload = json!({
428 "entity_type": "user",
429 "entity_id": "u-1",
430 "value": 5
431 });
432
433 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
434
435 let value = store
436 .read_field("user", "u-1", "follower_count")
437 .await
438 .unwrap();
439 assert_eq!(value, 15.0);
440 }
441
442 #[tokio::test]
443 async fn test_counter_override_field_and_operation() {
444 let store = Arc::new(MockEntityStore::new());
445 store.set("user", "u-1", "comment_count", 3.0).await;
446
447 let sink = CounterSink::new(store.clone(), increment_config("like_count"));
449
450 let payload = json!({
451 "entity_type": "user",
452 "entity_id": "u-1",
453 "field": "comment_count",
454 "operation": "decrement"
455 });
456
457 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
458
459 let value = store
460 .read_field("user", "u-1", "comment_count")
461 .await
462 .unwrap();
463 assert_eq!(value, 2.0);
464 }
465
466 #[tokio::test]
467 async fn test_counter_entity_from_context() {
468 let store = Arc::new(MockEntityStore::new());
469 store.set("capture", "cap-1", "like_count", 0.0).await;
470
471 let sink = CounterSink::new(store.clone(), increment_config("like_count"));
472
473 let payload = json!({}); let mut vars = HashMap::new();
476 vars.insert(
477 "entity_type".to_string(),
478 Value::String("capture".to_string()),
479 );
480 vars.insert("entity_id".to_string(), Value::String("cap-1".to_string()));
481
482 sink.deliver(payload, None, &vars).await.unwrap();
483
484 let value = store
485 .read_field("capture", "cap-1", "like_count")
486 .await
487 .unwrap();
488 assert_eq!(value, 1.0);
489 }
490
491 #[tokio::test]
492 async fn test_counter_missing_entity_type_error() {
493 let store = Arc::new(MockEntityStore::new());
494 let sink = CounterSink::new(store, increment_config("like_count"));
495
496 let payload = json!({"entity_id": "cap-1"});
497 let result = sink.deliver(payload, None, &HashMap::new()).await;
498 assert!(result.is_err());
499 assert!(result.unwrap_err().to_string().contains("entity_type"));
500 }
501
502 #[tokio::test]
503 async fn test_counter_missing_entity_id_error() {
504 let store = Arc::new(MockEntityStore::new());
505 let sink = CounterSink::new(store, increment_config("like_count"));
506
507 let payload = json!({"entity_type": "capture"});
508 let result = sink.deliver(payload, None, &HashMap::new()).await;
509 assert!(result.is_err());
510 assert!(result.unwrap_err().to_string().contains("entity_id"));
511 }
512
513 #[test]
514 fn test_counter_operation_parse() {
515 assert_eq!(
516 CounterOperation::parse("increment").unwrap(),
517 CounterOperation::Increment
518 );
519 assert_eq!(
520 CounterOperation::parse("inc").unwrap(),
521 CounterOperation::Increment
522 );
523 assert_eq!(
524 CounterOperation::parse("decrement").unwrap(),
525 CounterOperation::Decrement
526 );
527 assert_eq!(
528 CounterOperation::parse("dec").unwrap(),
529 CounterOperation::Decrement
530 );
531 assert_eq!(
532 CounterOperation::parse("set").unwrap(),
533 CounterOperation::Set
534 );
535 assert!(CounterOperation::parse("invalid").is_err());
536 }
537
538 #[test]
539 fn test_counter_sink_name_and_type() {
540 let store = Arc::new(MockEntityStore::new());
541 let sink = CounterSink::new(store, increment_config("like_count"));
542 assert_eq!(sink.name(), "counter");
543 assert_eq!(sink.sink_type(), SinkType::Counter);
544 }
545
546 #[tokio::test]
547 async fn test_counter_concurrent_increments() {
548 let store = Arc::new(MockEntityStore::new());
549 store.set("capture", "cap-1", "like_count", 0.0).await;
550
551 let sink = Arc::new(CounterSink::new(
552 store.clone(),
553 increment_config("like_count"),
554 ));
555
556 let mut handles = Vec::new();
558 for _ in 0..50 {
559 let sink = sink.clone();
560 handles.push(tokio::spawn(async move {
561 let payload = json!({
562 "entity_type": "capture",
563 "entity_id": "cap-1"
564 });
565 sink.deliver(payload, None, &HashMap::new()).await.unwrap();
566 }));
567 }
568
569 for h in handles {
570 h.await.unwrap();
571 }
572
573 let value = store
575 .read_field("capture", "cap-1", "like_count")
576 .await
577 .unwrap();
578 assert_eq!(
579 value, 50.0,
580 "All 50 increments should be applied atomically"
581 );
582 }
583}