1use std::collections::HashMap;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22use serde::{Deserialize, Serialize};
23
24#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
34pub struct AnyValue<T> {
35 inner: Option<T>,
36}
37
38impl<T> AnyValue<T> {
39 pub fn new() -> Self {
41 Self { inner: None }
42 }
43
44 pub fn with(value: T) -> Self {
46 Self { inner: Some(value) }
47 }
48
49 pub fn set(&mut self, value: T) {
51 self.inner = Some(value);
52 }
53
54 pub fn get(&self) -> Option<&T> {
56 self.inner.as_ref()
57 }
58
59 pub fn take(&mut self) -> Option<T> {
61 self.inner.take()
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.inner.is_none()
67 }
68}
69
70#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
79pub struct Topic<T> {
80 queue: Vec<T>,
81}
82
83impl<T> Topic<T> {
84 pub fn new() -> Self {
86 Self { queue: Vec::new() }
87 }
88
89 pub fn send(&mut self, value: T) {
91 self.queue.push(value);
92 }
93
94 pub fn extend<I: IntoIterator<Item = T>>(&mut self, values: I) {
96 self.queue.extend(values);
97 }
98
99 pub fn drain(&mut self) -> Vec<T> {
101 std::mem::take(&mut self.queue)
102 }
103
104 pub fn peek(&self) -> &[T] {
106 &self.queue
107 }
108
109 pub fn len(&self) -> usize {
111 self.queue.len()
112 }
113
114 pub fn is_empty(&self) -> bool {
116 self.queue.is_empty()
117 }
118}
119
120#[derive(Debug, Clone, Default, Serialize, Deserialize)]
131pub struct BinaryOp<T> {
132 value: Option<T>,
133 #[serde(skip)]
134 op: Option<fn(&T, &T) -> T>,
135}
136
137impl<T: PartialEq> PartialEq for BinaryOp<T> {
138 fn eq(&self, other: &Self) -> bool {
139 self.value == other.value
140 }
141}
142
143impl<T: Eq> Eq for BinaryOp<T> {}
144
145impl<T: Clone> BinaryOp<T> {
146 pub fn new(op: fn(&T, &T) -> T) -> Self {
148 Self {
149 value: None,
150 op: Some(op),
151 }
152 }
153
154 pub fn with_initial(op: fn(&T, &T) -> T, initial: T) -> Self {
156 Self {
157 value: Some(initial),
158 op: Some(op),
159 }
160 }
161
162 pub fn rehydrate(mut self, op: fn(&T, &T) -> T) -> Self {
165 self.op = Some(op);
166 self
167 }
168
169 pub fn write(&mut self, value: T) -> cognis_core::Result<()> {
171 let op = self.op.ok_or_else(|| {
172 cognis_core::CognisError::Internal(
173 "BinaryOp: write called before rehydrate (no op set)".into(),
174 )
175 })?;
176 self.value = Some(match self.value.as_ref() {
177 Some(existing) => op(existing, &value),
178 None => value,
179 });
180 Ok(())
181 }
182
183 pub fn get(&self) -> Option<&T> {
185 self.value.as_ref()
186 }
187
188 pub fn take(&mut self) -> Option<T> {
190 self.value.take()
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Broadcast<T> {
207 items: Vec<(u64, T)>,
209 cursors: HashMap<String, u64>,
211 next_seq: u64,
213}
214
215impl<T> Default for Broadcast<T> {
216 fn default() -> Self {
217 Self {
218 items: Vec::new(),
219 cursors: HashMap::new(),
220 next_seq: 0,
221 }
222 }
223}
224
225impl<T: Clone> Broadcast<T> {
226 pub fn new() -> Self {
228 Self::default()
229 }
230
231 pub fn subscribe(&mut self, name: impl Into<String>) {
234 let name = name.into();
235 self.cursors.entry(name).or_insert(self.next_seq);
236 }
237
238 pub fn unsubscribe(&mut self, name: &str) {
241 self.cursors.remove(name);
242 }
243
244 pub fn send(&mut self, value: T) {
246 self.items.push((self.next_seq, value));
247 self.next_seq += 1;
248 }
249
250 pub fn read(&mut self, name: &str) -> Vec<T> {
253 let cursor = match self.cursors.get_mut(name) {
254 Some(c) => c,
255 None => return Vec::new(),
256 };
257 let out: Vec<T> = self
258 .items
259 .iter()
260 .filter(|(seq, _)| *seq >= *cursor)
261 .map(|(_, v)| v.clone())
262 .collect();
263 *cursor = self.next_seq;
264 out
265 }
266
267 pub fn gc(&mut self) {
270 if self.cursors.is_empty() {
271 self.items.clear();
273 return;
274 }
275 let min_cursor = self
276 .cursors
277 .values()
278 .copied()
279 .min()
280 .unwrap_or(self.next_seq);
281 self.items.retain(|(seq, _)| *seq >= min_cursor);
282 }
283
284 pub fn len(&self) -> usize {
286 self.items.len()
287 }
288
289 pub fn is_empty(&self) -> bool {
291 self.items.is_empty()
292 }
293}
294
295#[derive(Debug, Clone)]
306pub struct Untracked<T> {
307 pub inner: T,
309}
310
311impl<T: Default> Default for Untracked<T> {
312 fn default() -> Self {
313 Self {
314 inner: T::default(),
315 }
316 }
317}
318
319impl<T> Untracked<T> {
320 pub fn new(value: T) -> Self {
322 Self { inner: value }
323 }
324
325 pub fn into_inner(self) -> T {
327 self.inner
328 }
329}
330
331impl<T> serde::Serialize for Untracked<T> {
337 fn serialize<S: serde::Serializer>(
338 &self,
339 serializer: S,
340 ) -> std::result::Result<S::Ok, S::Error> {
341 serializer.serialize_unit()
345 }
346}
347
348impl<'de, T: Default> serde::Deserialize<'de> for Untracked<T> {
349 fn deserialize<D: serde::Deserializer<'de>>(
350 deserializer: D,
351 ) -> std::result::Result<Self, D::Error> {
352 serde::de::IgnoredAny::deserialize(deserializer)?;
357 Ok(Self::default())
358 }
359}
360
361pub type CustomMergeFn<T> = Box<dyn Fn(&mut T, T) + Send + Sync>;
368
369pub struct CustomChannel<T> {
377 label: &'static str,
378 value: T,
379 on_write: CustomMergeFn<T>,
380}
381
382impl<T: Default> CustomChannel<T> {
383 pub fn new<F>(label: &'static str, on_write: F) -> Self
386 where
387 F: Fn(&mut T, T) + Send + Sync + 'static,
388 {
389 Self {
390 label,
391 value: T::default(),
392 on_write: Box::new(on_write),
393 }
394 }
395}
396
397impl<T> CustomChannel<T> {
398 pub fn with_initial<F>(label: &'static str, initial: T, on_write: F) -> Self
400 where
401 F: Fn(&mut T, T) + Send + Sync + 'static,
402 {
403 Self {
404 label,
405 value: initial,
406 on_write: Box::new(on_write),
407 }
408 }
409
410 pub fn write(&mut self, value: T) {
412 (self.on_write)(&mut self.value, value);
413 }
414
415 pub fn get(&self) -> &T {
417 &self.value
418 }
419
420 pub fn get_mut(&mut self) -> &mut T {
422 &mut self.value
423 }
424
425 pub fn replace(&mut self, new: T) -> T {
427 std::mem::replace(&mut self.value, new)
428 }
429}
430
431impl<T: Send + Sync> Channel for CustomChannel<T> {
432 fn kind(&self) -> &'static str {
433 self.label
434 }
435}
436
437impl<T: std::fmt::Debug> std::fmt::Debug for CustomChannel<T> {
438 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439 f.debug_struct("CustomChannel")
440 .field("label", &self.label)
441 .field("value", &self.value)
442 .finish()
443 }
444}
445
446pub trait Channel: Send + Sync {
453 fn kind(&self) -> &'static str;
455}
456
457impl<T: Send + Sync> Channel for AnyValue<T> {
458 fn kind(&self) -> &'static str {
459 "AnyValue"
460 }
461}
462impl<T: Send + Sync> Channel for Topic<T> {
463 fn kind(&self) -> &'static str {
464 "Topic"
465 }
466}
467impl<T: Send + Sync> Channel for BinaryOp<T> {
468 fn kind(&self) -> &'static str {
469 "BinaryOp"
470 }
471}
472impl<T: Send + Sync> Channel for Broadcast<T> {
473 fn kind(&self) -> &'static str {
474 "Broadcast"
475 }
476}
477impl<T: Send + Sync> Channel for Untracked<T> {
478 fn kind(&self) -> &'static str {
479 "Untracked"
480 }
481}
482
483pub type ChannelRef = Arc<dyn Channel>;
491
492#[doc(hidden)]
493pub struct _ChannelTag<T>(PhantomData<fn() -> T>);
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498
499 #[test]
500 fn any_value_set_and_take() {
501 let mut a: AnyValue<i32> = AnyValue::new();
502 assert!(a.is_empty());
503 a.set(1);
504 a.set(2);
505 assert_eq!(a.get(), Some(&2));
506 assert_eq!(a.take(), Some(2));
507 assert!(a.is_empty());
508 }
509
510 #[test]
511 fn topic_send_drain_round_trip() {
512 let mut t: Topic<&'static str> = Topic::new();
513 t.send("a");
514 t.send("b");
515 t.extend(["c", "d"]);
516 assert_eq!(t.len(), 4);
517 let drained = t.drain();
518 assert_eq!(drained, vec!["a", "b", "c", "d"]);
519 assert!(t.is_empty());
520 }
521
522 #[test]
523 fn binary_op_folds_associatively() {
524 let mut b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
525 b.write(1).unwrap();
526 b.write(2).unwrap();
527 b.write(3).unwrap();
528 assert_eq!(b.get(), Some(&6));
529 }
530
531 #[test]
532 fn binary_op_without_rehydrate_errors() {
533 let mut b: BinaryOp<i32> = BinaryOp {
534 value: None,
535 op: None,
536 };
537 let err = b.write(1).unwrap_err();
538 assert!(matches!(err, cognis_core::CognisError::Internal(_)));
539 }
540
541 #[test]
542 fn binary_op_rehydrate_reattaches_op() {
543 let b: BinaryOp<i32> = BinaryOp {
544 value: Some(5),
545 op: None,
546 };
547 let mut b = b.rehydrate(|a, b| a + b);
548 b.write(2).unwrap();
549 assert_eq!(b.get(), Some(&7));
550 }
551
552 #[test]
553 fn broadcast_delivers_to_all_subscribers() {
554 let mut b: Broadcast<i32> = Broadcast::new();
555 b.subscribe("a");
556 b.subscribe("b");
557 b.send(1);
558 b.send(2);
559 assert_eq!(b.read("a"), vec![1, 2]);
560 assert_eq!(b.read("b"), vec![1, 2]);
561 assert!(b.read("a").is_empty());
563 b.send(3);
564 assert_eq!(b.read("a"), vec![3]);
565 assert_eq!(b.read("b"), vec![3]);
566 }
567
568 #[test]
569 fn broadcast_gc_drops_consumed_items() {
570 let mut b: Broadcast<i32> = Broadcast::new();
571 b.subscribe("only");
572 b.send(1);
573 b.send(2);
574 let _ = b.read("only");
575 b.gc();
576 assert_eq!(b.len(), 0);
577 }
578
579 #[test]
580 fn broadcast_unknown_subscriber_reads_empty() {
581 let mut b: Broadcast<i32> = Broadcast::new();
582 b.send(1);
583 assert!(b.read("ghost").is_empty());
584 }
585
586 #[test]
587 fn untracked_round_trips_through_serde_to_default() {
588 let u = Untracked::new(42i32);
589 let json = serde_json::to_string(&u).unwrap();
590 assert_eq!(json, "null");
592 let back: Untracked<i32> = serde_json::from_str(&json).unwrap();
594 assert_eq!(back.inner, 0);
595 }
596
597 #[test]
598 fn channel_kind_strings() {
599 let a: AnyValue<i32> = AnyValue::new();
600 let t: Topic<i32> = Topic::new();
601 let b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
602 let bc: Broadcast<i32> = Broadcast::new();
603 let u: Untracked<i32> = Untracked::default();
604 assert_eq!(a.kind(), "AnyValue");
605 assert_eq!(t.kind(), "Topic");
606 assert_eq!(b.kind(), "BinaryOp");
607 assert_eq!(bc.kind(), "Broadcast");
608 assert_eq!(u.kind(), "Untracked");
609 }
610
611 #[test]
612 fn custom_channel_applies_user_merge() {
613 let mut c: CustomChannel<i32> = CustomChannel::new("Max", |slot, incoming| {
615 if incoming > *slot {
616 *slot = incoming;
617 }
618 });
619 c.write(3);
620 c.write(1);
621 c.write(7);
622 c.write(5);
623 assert_eq!(*c.get(), 7);
624 assert_eq!(c.kind(), "Max");
625 }
626
627 #[test]
628 fn custom_channel_with_initial_seeds_value() {
629 let mut c: CustomChannel<Vec<i32>> =
630 CustomChannel::with_initial("Concat", vec![1, 2], |slot, incoming| {
631 slot.extend(incoming);
632 });
633 c.write(vec![3, 4]);
634 assert_eq!(c.get(), &vec![1, 2, 3, 4]);
635 }
636}