Skip to main content

cognis_graph/
channels.rs

1//! Channel value types — embeddable inside a [`crate::state::GraphState`]
2//! struct to express V1-style channel semantics on top of V2's typed
3//! state model.
4//!
5//! These are plain Rust types you put in your state struct, not a separate
6//! abstraction. The state's reducer is responsible for routing updates
7//! into them.
8//!
9//! - [`AnyValue<T>`] — accept any number of writes per step; assert at
10//!   least one if `required`.
11//! - [`Topic<T>`] — append-only queue with `drain()` consume semantics.
12//! - [`BinaryOp<T>`] — fold writes via an associative binary operation.
13//! - [`Broadcast<T>`] — multi-consumer queue; each consumer has its own
14//!   cursor.
15//! - [`Untracked<T>`] — wrapper that is excluded from serialization (and
16//!   therefore from checkpoints).
17
18use std::collections::HashMap;
19use std::marker::PhantomData;
20use std::sync::Arc;
21
22use serde::{Deserialize, Serialize};
23
24// ---------------------------------------------------------------------------
25// AnyValue
26// ---------------------------------------------------------------------------
27
28/// Accepts any number of writes per superstep; the **last** write wins
29/// (mirrors V1 `AnyValue` semantics).
30///
31/// Useful when more than one node may write the same field in a step and
32/// you don't care which wins (e.g. a "done" flag set by either branch).
33#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
34pub struct AnyValue<T> {
35    inner: Option<T>,
36}
37
38impl<T> AnyValue<T> {
39    /// Empty channel.
40    pub fn new() -> Self {
41        Self { inner: None }
42    }
43
44    /// Build with an initial value.
45    pub fn with(value: T) -> Self {
46        Self { inner: Some(value) }
47    }
48
49    /// Set the channel's value (any-write-wins).
50    pub fn set(&mut self, value: T) {
51        self.inner = Some(value);
52    }
53
54    /// Borrow the current value, if any.
55    pub fn get(&self) -> Option<&T> {
56        self.inner.as_ref()
57    }
58
59    /// Consume the channel and return its value.
60    pub fn take(&mut self) -> Option<T> {
61        self.inner.take()
62    }
63
64    /// True if no value has been written.
65    pub fn is_empty(&self) -> bool {
66        self.inner.is_none()
67    }
68}
69
70// ---------------------------------------------------------------------------
71// Topic
72// ---------------------------------------------------------------------------
73
74/// Append-only queue. Producers `send()`, consumers `drain()` (single
75/// consumer) — mirrors V1 `Topic` channel.
76///
77/// Use [`Broadcast`] when more than one consumer needs to see every event.
78#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
79pub struct Topic<T> {
80    queue: Vec<T>,
81}
82
83impl<T> Topic<T> {
84    /// New empty topic.
85    pub fn new() -> Self {
86        Self { queue: Vec::new() }
87    }
88
89    /// Append `value` to the queue.
90    pub fn send(&mut self, value: T) {
91        self.queue.push(value);
92    }
93
94    /// Append every item from `values`.
95    pub fn extend<I: IntoIterator<Item = T>>(&mut self, values: I) {
96        self.queue.extend(values);
97    }
98
99    /// Drain the queue, returning everything currently buffered.
100    pub fn drain(&mut self) -> Vec<T> {
101        std::mem::take(&mut self.queue)
102    }
103
104    /// Borrow the current queue contents without consuming them.
105    pub fn peek(&self) -> &[T] {
106        &self.queue
107    }
108
109    /// Number of pending items.
110    pub fn len(&self) -> usize {
111        self.queue.len()
112    }
113
114    /// True if no items are pending.
115    pub fn is_empty(&self) -> bool {
116        self.queue.is_empty()
117    }
118}
119
120// ---------------------------------------------------------------------------
121// BinaryOp
122// ---------------------------------------------------------------------------
123
124/// Channel that folds writes via an associative binary operation
125/// (mirrors V1 `BinaryOp` channel).
126///
127/// The op is supplied as a `fn(&T, &T) -> T` so the type stays
128/// `Serialize`/`Deserialize`. Persisting only the value (not the op) is
129/// fine: the op is deterministic and re-supplied on reconstruction.
130#[derive(Debug, Clone, Default, Serialize, Deserialize)]
131pub struct BinaryOp<T> {
132    value: Option<T>,
133    #[serde(skip)]
134    op: Option<fn(&T, &T) -> T>,
135}
136
137impl<T: PartialEq> PartialEq for BinaryOp<T> {
138    fn eq(&self, other: &Self) -> bool {
139        self.value == other.value
140    }
141}
142
143impl<T: Eq> Eq for BinaryOp<T> {}
144
145impl<T: Clone> BinaryOp<T> {
146    /// Build an empty channel with operation `op`.
147    pub fn new(op: fn(&T, &T) -> T) -> Self {
148        Self {
149            value: None,
150            op: Some(op),
151        }
152    }
153
154    /// Build pre-seeded with `initial`.
155    pub fn with_initial(op: fn(&T, &T) -> T, initial: T) -> Self {
156        Self {
157            value: Some(initial),
158            op: Some(op),
159        }
160    }
161
162    /// Re-attach the binary op after deserialization (the op itself is
163    /// `#[serde(skip)]`; persistence only stores the value).
164    pub fn rehydrate(mut self, op: fn(&T, &T) -> T) -> Self {
165        self.op = Some(op);
166        self
167    }
168
169    /// Fold `value` into the channel via the configured op.
170    pub fn write(&mut self, value: T) -> cognis_core::Result<()> {
171        let op = self.op.ok_or_else(|| {
172            cognis_core::CognisError::Internal(
173                "BinaryOp: write called before rehydrate (no op set)".into(),
174            )
175        })?;
176        self.value = Some(match self.value.as_ref() {
177            Some(existing) => op(existing, &value),
178            None => value,
179        });
180        Ok(())
181    }
182
183    /// Borrow the current accumulated value.
184    pub fn get(&self) -> Option<&T> {
185        self.value.as_ref()
186    }
187
188    /// Consume the channel and return the accumulated value.
189    pub fn take(&mut self) -> Option<T> {
190        self.value.take()
191    }
192}
193
194// ---------------------------------------------------------------------------
195// Broadcast
196// ---------------------------------------------------------------------------
197
198/// Multi-consumer broadcast queue (mirrors V1 `Broadcast` channel).
199///
200/// Producers call `send()`. Each consumer first calls `subscribe()` to
201/// obtain a cursor, then calls `read(cursor)` to receive every item
202/// produced since its last read. Items are kept in the buffer until **all**
203/// active subscribers have consumed them, at which point they're garbage
204/// collected on `gc()`.
205#[derive(Debug, Clone, Serialize, Deserialize)]
206pub struct Broadcast<T> {
207    /// All buffered items + their global sequence numbers.
208    items: Vec<(u64, T)>,
209    /// Per-subscriber high-water cursor (next sequence number to deliver).
210    cursors: HashMap<String, u64>,
211    /// Next sequence number to assign.
212    next_seq: u64,
213}
214
215impl<T> Default for Broadcast<T> {
216    fn default() -> Self {
217        Self {
218            items: Vec::new(),
219            cursors: HashMap::new(),
220            next_seq: 0,
221        }
222    }
223}
224
225impl<T: Clone> Broadcast<T> {
226    /// New empty broadcast.
227    pub fn new() -> Self {
228        Self::default()
229    }
230
231    /// Register a subscriber. The cursor's name should be unique per
232    /// consumer; subsequent calls with the same name are idempotent.
233    pub fn subscribe(&mut self, name: impl Into<String>) {
234        let name = name.into();
235        self.cursors.entry(name).or_insert(self.next_seq);
236    }
237
238    /// Drop a subscriber. Once dropped its cursor no longer prevents
239    /// garbage collection of older items.
240    pub fn unsubscribe(&mut self, name: &str) {
241        self.cursors.remove(name);
242    }
243
244    /// Append an item to the buffer. All current subscribers will receive it.
245    pub fn send(&mut self, value: T) {
246        self.items.push((self.next_seq, value));
247        self.next_seq += 1;
248    }
249
250    /// Read every item the given subscriber has not yet consumed.
251    /// Advances the subscriber's cursor.
252    pub fn read(&mut self, name: &str) -> Vec<T> {
253        let cursor = match self.cursors.get_mut(name) {
254            Some(c) => c,
255            None => return Vec::new(),
256        };
257        let out: Vec<T> = self
258            .items
259            .iter()
260            .filter(|(seq, _)| *seq >= *cursor)
261            .map(|(_, v)| v.clone())
262            .collect();
263        *cursor = self.next_seq;
264        out
265    }
266
267    /// Drop items every subscriber has already consumed. Safe to call
268    /// any time; cheap if there's nothing to evict.
269    pub fn gc(&mut self) {
270        if self.cursors.is_empty() {
271            // No subscribers — buffer can be cleared.
272            self.items.clear();
273            return;
274        }
275        let min_cursor = self
276            .cursors
277            .values()
278            .copied()
279            .min()
280            .unwrap_or(self.next_seq);
281        self.items.retain(|(seq, _)| *seq >= min_cursor);
282    }
283
284    /// Total buffered items not yet GC'd.
285    pub fn len(&self) -> usize {
286        self.items.len()
287    }
288
289    /// True if buffer is empty.
290    pub fn is_empty(&self) -> bool {
291        self.items.is_empty()
292    }
293}
294
295// ---------------------------------------------------------------------------
296// Untracked
297// ---------------------------------------------------------------------------
298
299/// Transparent wrapper that is **excluded from serialization** — values
300/// inside `Untracked<T>` are not persisted to checkpoints (mirrors V1
301/// `Untracked` channel).
302///
303/// Use for in-memory caches, large compute artifacts, or non-`Serialize`
304/// types that you still want to live inside `GraphState`.
305#[derive(Debug, Clone)]
306pub struct Untracked<T> {
307    /// The wrapped value. Public so users can read/write without ceremony.
308    pub inner: T,
309}
310
311impl<T: Default> Default for Untracked<T> {
312    fn default() -> Self {
313        Self {
314            inner: T::default(),
315        }
316    }
317}
318
319impl<T> Untracked<T> {
320    /// Wrap `value`.
321    pub fn new(value: T) -> Self {
322        Self { inner: value }
323    }
324
325    /// Unwrap.
326    pub fn into_inner(self) -> T {
327        self.inner
328    }
329}
330
331// `Untracked<T>` is intentionally not `Serialize` even when `T: Serialize`.
332// Users put it inside a `GraphState` struct and either annotate the field
333// `#[serde(skip)]` or use a serializer that tolerates skipped fields.
334// We provide marker impls below so it can sit inside a `Serialize`-derived
335// struct that uses `#[serde(skip)]`.
336impl<T> serde::Serialize for Untracked<T> {
337    fn serialize<S: serde::Serializer>(
338        &self,
339        serializer: S,
340    ) -> std::result::Result<S::Ok, S::Error> {
341        // Serializing should be skipped at the field level via
342        // `#[serde(skip)]`. Calling this directly produces a unit so
343        // tooling that ignores skip directives still gets a stable shape.
344        serializer.serialize_unit()
345    }
346}
347
348impl<'de, T: Default> serde::Deserialize<'de> for Untracked<T> {
349    fn deserialize<D: serde::Deserializer<'de>>(
350        deserializer: D,
351    ) -> std::result::Result<Self, D::Error> {
352        // Consume any payload (most commonly `null` written by our
353        // `Serialize` impl) and reconstruct the inner value from
354        // `Default`. We use `IgnoredAny` so any shape — null, missing,
355        // an arbitrary blob — is accepted.
356        serde::de::IgnoredAny::deserialize(deserializer)?;
357        Ok(Self::default())
358    }
359}
360
361// ---------------------------------------------------------------------------
362// CustomChannel — pluggable inline channel without subclassing.
363// ---------------------------------------------------------------------------
364
365/// Boxed merge function used by [`CustomChannel`]. Receives `(slot, incoming)`
366/// and updates `slot` in-place.
367pub type CustomMergeFn<T> = Box<dyn Fn(&mut T, T) + Send + Sync>;
368
369/// Custom channel: hold a value of type `T` plus user-supplied write,
370/// read, and reset closures. Implements [`Channel`] (with a custom
371/// `kind` label) and is therefore registry-compatible alongside the
372/// stock channels.
373///
374/// Use this when none of the built-in channels match your semantics
375/// but you don't want a whole new struct + trait impl.
376pub struct CustomChannel<T> {
377    label: &'static str,
378    value: T,
379    on_write: CustomMergeFn<T>,
380}
381
382impl<T: Default> CustomChannel<T> {
383    /// Build with a label, initial value, and a `write(slot, incoming)`
384    /// merge function.
385    pub fn new<F>(label: &'static str, on_write: F) -> Self
386    where
387        F: Fn(&mut T, T) + Send + Sync + 'static,
388    {
389        Self {
390            label,
391            value: T::default(),
392            on_write: Box::new(on_write),
393        }
394    }
395}
396
397impl<T> CustomChannel<T> {
398    /// Build with explicit initial value.
399    pub fn with_initial<F>(label: &'static str, initial: T, on_write: F) -> Self
400    where
401        F: Fn(&mut T, T) + Send + Sync + 'static,
402    {
403        Self {
404            label,
405            value: initial,
406            on_write: Box::new(on_write),
407        }
408    }
409
410    /// Apply an incoming write through the user-defined merge fn.
411    pub fn write(&mut self, value: T) {
412        (self.on_write)(&mut self.value, value);
413    }
414
415    /// Borrow the current value.
416    pub fn get(&self) -> &T {
417        &self.value
418    }
419
420    /// Mutably borrow the current value.
421    pub fn get_mut(&mut self) -> &mut T {
422        &mut self.value
423    }
424
425    /// Replace the wrapped value, returning the previous one.
426    pub fn replace(&mut self, new: T) -> T {
427        std::mem::replace(&mut self.value, new)
428    }
429}
430
431impl<T: Send + Sync> Channel for CustomChannel<T> {
432    fn kind(&self) -> &'static str {
433        self.label
434    }
435}
436
437impl<T: std::fmt::Debug> std::fmt::Debug for CustomChannel<T> {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        f.debug_struct("CustomChannel")
440            .field("label", &self.label)
441            .field("value", &self.value)
442            .finish()
443    }
444}
445
446// ---------------------------------------------------------------------------
447// Type-erased channel reference (rarely needed, exposed for tooling).
448// ---------------------------------------------------------------------------
449
450/// Object-safe trait shared by every channel type. Lets diagnostic tools
451/// inspect channels uniformly without knowing the concrete type.
452pub trait Channel: Send + Sync {
453    /// Channel kind label, e.g. `"AnyValue"`, `"Topic"`.
454    fn kind(&self) -> &'static str;
455}
456
457impl<T: Send + Sync> Channel for AnyValue<T> {
458    fn kind(&self) -> &'static str {
459        "AnyValue"
460    }
461}
462impl<T: Send + Sync> Channel for Topic<T> {
463    fn kind(&self) -> &'static str {
464        "Topic"
465    }
466}
467impl<T: Send + Sync> Channel for BinaryOp<T> {
468    fn kind(&self) -> &'static str {
469        "BinaryOp"
470    }
471}
472impl<T: Send + Sync> Channel for Broadcast<T> {
473    fn kind(&self) -> &'static str {
474        "Broadcast"
475    }
476}
477impl<T: Send + Sync> Channel for Untracked<T> {
478    fn kind(&self) -> &'static str {
479        "Untracked"
480    }
481}
482
483// ---------------------------------------------------------------------------
484// `Arc<dyn Channel>` registry — used by tools that want a uniform handle.
485// ---------------------------------------------------------------------------
486
487/// Type-erased channel handle. Useful when multiple node implementations
488/// share the same channel and you want runtime polymorphism without
489/// generic plumbing.
490pub type ChannelRef = Arc<dyn Channel>;
491
492#[doc(hidden)]
493pub struct _ChannelTag<T>(PhantomData<fn() -> T>);
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498
499    #[test]
500    fn any_value_set_and_take() {
501        let mut a: AnyValue<i32> = AnyValue::new();
502        assert!(a.is_empty());
503        a.set(1);
504        a.set(2);
505        assert_eq!(a.get(), Some(&2));
506        assert_eq!(a.take(), Some(2));
507        assert!(a.is_empty());
508    }
509
510    #[test]
511    fn topic_send_drain_round_trip() {
512        let mut t: Topic<&'static str> = Topic::new();
513        t.send("a");
514        t.send("b");
515        t.extend(["c", "d"]);
516        assert_eq!(t.len(), 4);
517        let drained = t.drain();
518        assert_eq!(drained, vec!["a", "b", "c", "d"]);
519        assert!(t.is_empty());
520    }
521
522    #[test]
523    fn binary_op_folds_associatively() {
524        let mut b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
525        b.write(1).unwrap();
526        b.write(2).unwrap();
527        b.write(3).unwrap();
528        assert_eq!(b.get(), Some(&6));
529    }
530
531    #[test]
532    fn binary_op_without_rehydrate_errors() {
533        let mut b: BinaryOp<i32> = BinaryOp {
534            value: None,
535            op: None,
536        };
537        let err = b.write(1).unwrap_err();
538        assert!(matches!(err, cognis_core::CognisError::Internal(_)));
539    }
540
541    #[test]
542    fn binary_op_rehydrate_reattaches_op() {
543        let b: BinaryOp<i32> = BinaryOp {
544            value: Some(5),
545            op: None,
546        };
547        let mut b = b.rehydrate(|a, b| a + b);
548        b.write(2).unwrap();
549        assert_eq!(b.get(), Some(&7));
550    }
551
552    #[test]
553    fn broadcast_delivers_to_all_subscribers() {
554        let mut b: Broadcast<i32> = Broadcast::new();
555        b.subscribe("a");
556        b.subscribe("b");
557        b.send(1);
558        b.send(2);
559        assert_eq!(b.read("a"), vec![1, 2]);
560        assert_eq!(b.read("b"), vec![1, 2]);
561        // Second read returns nothing for a until new sends happen.
562        assert!(b.read("a").is_empty());
563        b.send(3);
564        assert_eq!(b.read("a"), vec![3]);
565        assert_eq!(b.read("b"), vec![3]);
566    }
567
568    #[test]
569    fn broadcast_gc_drops_consumed_items() {
570        let mut b: Broadcast<i32> = Broadcast::new();
571        b.subscribe("only");
572        b.send(1);
573        b.send(2);
574        let _ = b.read("only");
575        b.gc();
576        assert_eq!(b.len(), 0);
577    }
578
579    #[test]
580    fn broadcast_unknown_subscriber_reads_empty() {
581        let mut b: Broadcast<i32> = Broadcast::new();
582        b.send(1);
583        assert!(b.read("ghost").is_empty());
584    }
585
586    #[test]
587    fn untracked_round_trips_through_serde_to_default() {
588        let u = Untracked::new(42i32);
589        let json = serde_json::to_string(&u).unwrap();
590        // Serialized as unit so it occupies no payload space.
591        assert_eq!(json, "null");
592        // Deserializing reconstructs Default::default().
593        let back: Untracked<i32> = serde_json::from_str(&json).unwrap();
594        assert_eq!(back.inner, 0);
595    }
596
597    #[test]
598    fn channel_kind_strings() {
599        let a: AnyValue<i32> = AnyValue::new();
600        let t: Topic<i32> = Topic::new();
601        let b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
602        let bc: Broadcast<i32> = Broadcast::new();
603        let u: Untracked<i32> = Untracked::default();
604        assert_eq!(a.kind(), "AnyValue");
605        assert_eq!(t.kind(), "Topic");
606        assert_eq!(b.kind(), "BinaryOp");
607        assert_eq!(bc.kind(), "Broadcast");
608        assert_eq!(u.kind(), "Untracked");
609    }
610
611    #[test]
612    fn custom_channel_applies_user_merge() {
613        // A "running max" channel.
614        let mut c: CustomChannel<i32> = CustomChannel::new("Max", |slot, incoming| {
615            if incoming > *slot {
616                *slot = incoming;
617            }
618        });
619        c.write(3);
620        c.write(1);
621        c.write(7);
622        c.write(5);
623        assert_eq!(*c.get(), 7);
624        assert_eq!(c.kind(), "Max");
625    }
626
627    #[test]
628    fn custom_channel_with_initial_seeds_value() {
629        let mut c: CustomChannel<Vec<i32>> =
630            CustomChannel::with_initial("Concat", vec![1, 2], |slot, incoming| {
631                slot.extend(incoming);
632            });
633        c.write(vec![3, 4]);
634        assert_eq!(c.get(), &vec![1, 2, 3, 4]);
635    }
636}