use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct AnyValue<T> {
inner: Option<T>,
}
impl<T> AnyValue<T> {
pub fn new() -> Self {
Self { inner: None }
}
pub fn with(value: T) -> Self {
Self { inner: Some(value) }
}
pub fn set(&mut self, value: T) {
self.inner = Some(value);
}
pub fn get(&self) -> Option<&T> {
self.inner.as_ref()
}
pub fn take(&mut self) -> Option<T> {
self.inner.take()
}
pub fn is_empty(&self) -> bool {
self.inner.is_none()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct Topic<T> {
queue: Vec<T>,
}
impl<T> Topic<T> {
pub fn new() -> Self {
Self { queue: Vec::new() }
}
pub fn send(&mut self, value: T) {
self.queue.push(value);
}
pub fn extend<I: IntoIterator<Item = T>>(&mut self, values: I) {
self.queue.extend(values);
}
pub fn drain(&mut self) -> Vec<T> {
std::mem::take(&mut self.queue)
}
pub fn peek(&self) -> &[T] {
&self.queue
}
pub fn len(&self) -> usize {
self.queue.len()
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct BinaryOp<T> {
value: Option<T>,
#[serde(skip)]
op: Option<fn(&T, &T) -> T>,
}
impl<T: PartialEq> PartialEq for BinaryOp<T> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<T: Eq> Eq for BinaryOp<T> {}
impl<T: Clone> BinaryOp<T> {
pub fn new(op: fn(&T, &T) -> T) -> Self {
Self {
value: None,
op: Some(op),
}
}
pub fn with_initial(op: fn(&T, &T) -> T, initial: T) -> Self {
Self {
value: Some(initial),
op: Some(op),
}
}
pub fn rehydrate(mut self, op: fn(&T, &T) -> T) -> Self {
self.op = Some(op);
self
}
pub fn write(&mut self, value: T) -> cognis_core::Result<()> {
let op = self.op.ok_or_else(|| {
cognis_core::CognisError::Internal(
"BinaryOp: write called before rehydrate (no op set)".into(),
)
})?;
self.value = Some(match self.value.as_ref() {
Some(existing) => op(existing, &value),
None => value,
});
Ok(())
}
pub fn get(&self) -> Option<&T> {
self.value.as_ref()
}
pub fn take(&mut self) -> Option<T> {
self.value.take()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Broadcast<T> {
items: Vec<(u64, T)>,
cursors: HashMap<String, u64>,
next_seq: u64,
}
impl<T> Default for Broadcast<T> {
fn default() -> Self {
Self {
items: Vec::new(),
cursors: HashMap::new(),
next_seq: 0,
}
}
}
impl<T: Clone> Broadcast<T> {
pub fn new() -> Self {
Self::default()
}
pub fn subscribe(&mut self, name: impl Into<String>) {
let name = name.into();
self.cursors.entry(name).or_insert(self.next_seq);
}
pub fn unsubscribe(&mut self, name: &str) {
self.cursors.remove(name);
}
pub fn send(&mut self, value: T) {
self.items.push((self.next_seq, value));
self.next_seq += 1;
}
pub fn read(&mut self, name: &str) -> Vec<T> {
let cursor = match self.cursors.get_mut(name) {
Some(c) => c,
None => return Vec::new(),
};
let out: Vec<T> = self
.items
.iter()
.filter(|(seq, _)| *seq >= *cursor)
.map(|(_, v)| v.clone())
.collect();
*cursor = self.next_seq;
out
}
pub fn gc(&mut self) {
if self.cursors.is_empty() {
self.items.clear();
return;
}
let min_cursor = self
.cursors
.values()
.copied()
.min()
.unwrap_or(self.next_seq);
self.items.retain(|(seq, _)| *seq >= min_cursor);
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct Untracked<T> {
pub inner: T,
}
impl<T: Default> Default for Untracked<T> {
fn default() -> Self {
Self {
inner: T::default(),
}
}
}
impl<T> Untracked<T> {
pub fn new(value: T) -> Self {
Self { inner: value }
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T> serde::Serialize for Untracked<T> {
fn serialize<S: serde::Serializer>(
&self,
serializer: S,
) -> std::result::Result<S::Ok, S::Error> {
serializer.serialize_unit()
}
}
impl<'de, T: Default> serde::Deserialize<'de> for Untracked<T> {
fn deserialize<D: serde::Deserializer<'de>>(
deserializer: D,
) -> std::result::Result<Self, D::Error> {
serde::de::IgnoredAny::deserialize(deserializer)?;
Ok(Self::default())
}
}
pub type CustomMergeFn<T> = Box<dyn Fn(&mut T, T) + Send + Sync>;
pub struct CustomChannel<T> {
label: &'static str,
value: T,
on_write: CustomMergeFn<T>,
}
impl<T: Default> CustomChannel<T> {
pub fn new<F>(label: &'static str, on_write: F) -> Self
where
F: Fn(&mut T, T) + Send + Sync + 'static,
{
Self {
label,
value: T::default(),
on_write: Box::new(on_write),
}
}
}
impl<T> CustomChannel<T> {
pub fn with_initial<F>(label: &'static str, initial: T, on_write: F) -> Self
where
F: Fn(&mut T, T) + Send + Sync + 'static,
{
Self {
label,
value: initial,
on_write: Box::new(on_write),
}
}
pub fn write(&mut self, value: T) {
(self.on_write)(&mut self.value, value);
}
pub fn get(&self) -> &T {
&self.value
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.value
}
pub fn replace(&mut self, new: T) -> T {
std::mem::replace(&mut self.value, new)
}
}
impl<T: Send + Sync> Channel for CustomChannel<T> {
fn kind(&self) -> &'static str {
self.label
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for CustomChannel<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomChannel")
.field("label", &self.label)
.field("value", &self.value)
.finish()
}
}
pub trait Channel: Send + Sync {
fn kind(&self) -> &'static str;
}
impl<T: Send + Sync> Channel for AnyValue<T> {
fn kind(&self) -> &'static str {
"AnyValue"
}
}
impl<T: Send + Sync> Channel for Topic<T> {
fn kind(&self) -> &'static str {
"Topic"
}
}
impl<T: Send + Sync> Channel for BinaryOp<T> {
fn kind(&self) -> &'static str {
"BinaryOp"
}
}
impl<T: Send + Sync> Channel for Broadcast<T> {
fn kind(&self) -> &'static str {
"Broadcast"
}
}
impl<T: Send + Sync> Channel for Untracked<T> {
fn kind(&self) -> &'static str {
"Untracked"
}
}
pub type ChannelRef = Arc<dyn Channel>;
#[doc(hidden)]
pub struct _ChannelTag<T>(PhantomData<fn() -> T>);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn any_value_set_and_take() {
let mut a: AnyValue<i32> = AnyValue::new();
assert!(a.is_empty());
a.set(1);
a.set(2);
assert_eq!(a.get(), Some(&2));
assert_eq!(a.take(), Some(2));
assert!(a.is_empty());
}
#[test]
fn topic_send_drain_round_trip() {
let mut t: Topic<&'static str> = Topic::new();
t.send("a");
t.send("b");
t.extend(["c", "d"]);
assert_eq!(t.len(), 4);
let drained = t.drain();
assert_eq!(drained, vec!["a", "b", "c", "d"]);
assert!(t.is_empty());
}
#[test]
fn binary_op_folds_associatively() {
let mut b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
b.write(1).unwrap();
b.write(2).unwrap();
b.write(3).unwrap();
assert_eq!(b.get(), Some(&6));
}
#[test]
fn binary_op_without_rehydrate_errors() {
let mut b: BinaryOp<i32> = BinaryOp {
value: None,
op: None,
};
let err = b.write(1).unwrap_err();
assert!(matches!(err, cognis_core::CognisError::Internal(_)));
}
#[test]
fn binary_op_rehydrate_reattaches_op() {
let b: BinaryOp<i32> = BinaryOp {
value: Some(5),
op: None,
};
let mut b = b.rehydrate(|a, b| a + b);
b.write(2).unwrap();
assert_eq!(b.get(), Some(&7));
}
#[test]
fn broadcast_delivers_to_all_subscribers() {
let mut b: Broadcast<i32> = Broadcast::new();
b.subscribe("a");
b.subscribe("b");
b.send(1);
b.send(2);
assert_eq!(b.read("a"), vec![1, 2]);
assert_eq!(b.read("b"), vec![1, 2]);
assert!(b.read("a").is_empty());
b.send(3);
assert_eq!(b.read("a"), vec![3]);
assert_eq!(b.read("b"), vec![3]);
}
#[test]
fn broadcast_gc_drops_consumed_items() {
let mut b: Broadcast<i32> = Broadcast::new();
b.subscribe("only");
b.send(1);
b.send(2);
let _ = b.read("only");
b.gc();
assert_eq!(b.len(), 0);
}
#[test]
fn broadcast_unknown_subscriber_reads_empty() {
let mut b: Broadcast<i32> = Broadcast::new();
b.send(1);
assert!(b.read("ghost").is_empty());
}
#[test]
fn untracked_round_trips_through_serde_to_default() {
let u = Untracked::new(42i32);
let json = serde_json::to_string(&u).unwrap();
assert_eq!(json, "null");
let back: Untracked<i32> = serde_json::from_str(&json).unwrap();
assert_eq!(back.inner, 0);
}
#[test]
fn channel_kind_strings() {
let a: AnyValue<i32> = AnyValue::new();
let t: Topic<i32> = Topic::new();
let b: BinaryOp<i32> = BinaryOp::new(|a, b| a + b);
let bc: Broadcast<i32> = Broadcast::new();
let u: Untracked<i32> = Untracked::default();
assert_eq!(a.kind(), "AnyValue");
assert_eq!(t.kind(), "Topic");
assert_eq!(b.kind(), "BinaryOp");
assert_eq!(bc.kind(), "Broadcast");
assert_eq!(u.kind(), "Untracked");
}
#[test]
fn custom_channel_applies_user_merge() {
let mut c: CustomChannel<i32> = CustomChannel::new("Max", |slot, incoming| {
if incoming > *slot {
*slot = incoming;
}
});
c.write(3);
c.write(1);
c.write(7);
c.write(5);
assert_eq!(*c.get(), 7);
assert_eq!(c.kind(), "Max");
}
#[test]
fn custom_channel_with_initial_seeds_value() {
let mut c: CustomChannel<Vec<i32>> =
CustomChannel::with_initial("Concat", vec![1, 2], |slot, incoming| {
slot.extend(incoming);
});
c.write(vec![3, 4]);
assert_eq!(c.get(), &vec![1, 2, 3, 4]);
}
}