mod types;
pub use types::{
Barrier, BinaryAggregate, Channel, ChannelSet, ChannelState, ChannelUpdate, Delta, Ephemeral,
LastValue, Messages, NamedBarrier, Topic, Untracked,
};
use std::collections::{BTreeMap, HashSet};
use std::sync::Arc;
use serde_json::Value;
use crate::graph::reducer::StateReducer;
use crate::{Result, TinyAgentsError};
impl Channel for LastValue {
fn kind(&self) -> &'static str {
"last_value"
}
fn merge(&self, _current: Option<&Value>, incoming: Value) -> Result<Value> {
Ok(incoming)
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Channel for Topic {
fn kind(&self) -> &'static str {
"topic"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
let mut list = match current {
Some(Value::Array(items)) => items.clone(),
Some(other) => vec![other.clone()],
None => Vec::new(),
};
match incoming {
Value::Array(items) => list.extend(items),
other => list.push(other),
}
Ok(Value::Array(list))
}
fn allows_concurrent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Channel for Delta {
fn kind(&self) -> &'static str {
"delta"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
let add_err =
|| TinyAgentsError::Graph("Delta channel only accepts numeric writes".to_string());
let incoming_num = incoming.as_f64().ok_or_else(add_err)?;
let Some(current) = current else {
return Ok(incoming);
};
let current_num = current.as_f64().ok_or_else(add_err)?;
if current.is_i64() && incoming.is_i64() {
let sum = current.as_i64().unwrap() + incoming.as_i64().unwrap();
return Ok(Value::from(sum));
}
Ok(Value::from(current_num + incoming_num))
}
fn allows_concurrent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Channel for Messages {
fn kind(&self) -> &'static str {
"messages"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
let mut list = match current {
Some(Value::Array(items)) => items.clone(),
Some(_) => {
return Err(TinyAgentsError::Graph(
"Messages channel value must be a JSON array".to_string(),
));
}
None => Vec::new(),
};
let incoming = match incoming {
Value::Array(items) => items,
other => vec![other],
};
for msg in incoming {
let id = msg.get("id").and_then(Value::as_str).map(str::to_string);
match id.and_then(|id| {
list.iter_mut()
.find(|existing| existing.get("id").and_then(Value::as_str) == Some(&id))
}) {
Some(existing) => *existing = msg,
None => list.push(msg),
}
}
Ok(Value::Array(list))
}
fn allows_concurrent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Channel for Ephemeral {
fn kind(&self) -> &'static str {
"ephemeral"
}
fn merge(&self, _current: Option<&Value>, incoming: Value) -> Result<Value> {
Ok(incoming)
}
fn is_ephemeral(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Channel for Untracked {
fn kind(&self) -> &'static str {
"untracked"
}
fn merge(&self, _current: Option<&Value>, incoming: Value) -> Result<Value> {
Ok(incoming)
}
fn is_tracked(&self) -> bool {
false
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl Barrier {
pub fn new(expected: usize) -> Self {
Self { expected }
}
}
impl Channel for Barrier {
fn kind(&self) -> &'static str {
"barrier"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
let mut list = match current {
Some(Value::Array(items)) => items.clone(),
Some(other) => vec![other.clone()],
None => Vec::new(),
};
match incoming {
Value::Array(items) => list.extend(items),
other => list.push(other),
}
Ok(Value::Array(list))
}
fn allows_concurrent(&self) -> bool {
true
}
fn is_ready(&self, current: Option<&Value>) -> bool {
current
.and_then(Value::as_array)
.map(|items| items.len() >= self.expected)
.unwrap_or(self.expected == 0)
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(*self)
}
}
impl NamedBarrier {
pub fn new(expected: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
expected: expected.into_iter().map(Into::into).collect(),
}
}
}
impl Channel for NamedBarrier {
fn kind(&self) -> &'static str {
"named_barrier"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
let mut map = match current {
Some(Value::Object(map)) => map.clone(),
Some(_) => {
return Err(TinyAgentsError::Graph(
"NamedBarrier channel value must be a JSON object".to_string(),
));
}
None => serde_json::Map::new(),
};
let Value::Object(incoming) = incoming else {
return Err(TinyAgentsError::Graph(
"NamedBarrier writes must be JSON objects of named arrivals".to_string(),
));
};
for (key, value) in incoming {
map.insert(key, value);
}
Ok(Value::Object(map))
}
fn allows_concurrent(&self) -> bool {
true
}
fn is_ready(&self, current: Option<&Value>) -> bool {
let Some(Value::Object(map)) = current else {
return self.expected.is_empty();
};
self.expected.iter().all(|name| map.contains_key(name))
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(self.clone())
}
}
impl BinaryAggregate {
pub fn new<F>(fold: F) -> Self
where
F: Fn(Value, Value) -> Result<Value> + Send + Sync + 'static,
{
Self {
fold: Arc::new(fold),
}
}
pub fn from_reducer<R>(reducer: R) -> Self
where
R: crate::graph::Reducer<Value> + 'static,
{
Self::new(move |current, incoming| reducer.reduce(current, incoming))
}
}
impl Channel for BinaryAggregate {
fn kind(&self) -> &'static str {
"binary_aggregate"
}
fn merge(&self, current: Option<&Value>, incoming: Value) -> Result<Value> {
match current {
Some(current) => (self.fold)(current.clone(), incoming),
None => Ok(incoming),
}
}
fn allows_concurrent(&self) -> bool {
true
}
fn clone_box(&self) -> Box<dyn Channel> {
Box::new(self.clone())
}
}
impl ChannelSet {
pub fn new() -> Self {
Self::default()
}
pub fn with_channel(
mut self,
name: impl Into<String>,
channel: impl Channel + 'static,
) -> Self {
self.add_channel(name, channel);
self
}
pub fn add_channel(&mut self, name: impl Into<String>, channel: impl Channel + 'static) {
self.channels.insert(name.into(), Box::new(channel));
}
pub fn get(&self, name: &str) -> Option<&Value> {
self.values.get(name)
}
pub fn contains(&self, name: &str) -> bool {
self.channels.contains_key(name)
}
pub fn allows_concurrent(&self, name: &str) -> Result<bool> {
self.channel(name).map(|c| c.allows_concurrent())
}
pub fn is_ready(&self, name: &str) -> Result<bool> {
let channel = self.channel(name)?;
Ok(channel.is_ready(self.values.get(name)))
}
pub fn apply_update(&mut self, name: &str, value: Value) -> Result<()> {
let channel = self.channel(name)?;
let merged = channel.merge(self.values.get(name), value)?;
self.values.insert(name.to_string(), merged);
Ok(())
}
pub fn snapshot(&self) -> BTreeMap<String, Value> {
self.values
.iter()
.filter(|(name, _)| {
self.channels
.get(*name)
.map(|c| c.is_tracked())
.unwrap_or(true)
})
.map(|(name, value)| (name.clone(), value.clone()))
.collect()
}
pub(crate) fn clear_ephemeral(&mut self) {
let ephemeral: Vec<String> = self
.channels
.iter()
.filter(|(_, c)| c.is_ephemeral())
.map(|(name, _)| name.clone())
.collect();
for name in ephemeral {
self.values.remove(&name);
}
}
fn channel(&self, name: &str) -> Result<&dyn Channel> {
self.channels
.get(name)
.map(AsRef::as_ref)
.ok_or_else(|| TinyAgentsError::Graph(format!("unknown channel `{name}`")))
}
}
impl ChannelUpdate {
pub fn new() -> Self {
Self::default()
}
pub fn set(mut self, name: impl Into<String>, value: impl Into<Value>) -> Self {
self.writes.push((name.into(), value.into()));
self
}
pub fn at_step(mut self, step: usize) -> Self {
self.step = Some(step);
self
}
pub fn is_empty(&self) -> bool {
self.writes.is_empty()
}
}
impl ChannelState {
pub fn new() -> Self {
Self::default()
}
pub fn with_channel(
mut self,
name: impl Into<String>,
channel: impl Channel + 'static,
) -> Self {
self.set.add_channel(name, channel);
self
}
pub fn channels(&self) -> &ChannelSet {
&self.set
}
pub fn get(&self, name: &str) -> Option<&Value> {
self.set.get(name)
}
pub fn snapshot(&self) -> BTreeMap<String, Value> {
self.set.snapshot()
}
pub fn is_ready(&self, name: &str) -> Result<bool> {
self.set.is_ready(name)
}
pub fn merge(mut self, update: ChannelUpdate) -> Result<Self> {
match update.step {
Some(step) if step != self.current_step => {
self.current_step = step;
self.step_writes.clear();
self.set.clear_ephemeral();
}
Some(_) => {}
None => {
self.step_writes.clear();
}
}
let mut distinct: Vec<&str> = Vec::new();
for (name, _) in &update.writes {
if !distinct.contains(&name.as_str()) {
distinct.push(name.as_str());
}
}
for name in &distinct {
let allows = self.set.allows_concurrent(name)?;
let count = self.step_writes.get(*name).copied().unwrap_or(0) + 1;
if count > 1 && !allows {
return Err(TinyAgentsError::InvalidConcurrentUpdate(format!(
"channel `{name}` received {count} concurrent writes in one step but is not an aggregate channel"
)));
}
}
let touched: HashSet<String> = distinct.iter().map(|n| n.to_string()).collect();
for name in touched {
*self.step_writes.entry(name).or_insert(0) += 1;
}
for (name, value) in update.writes {
self.set.apply_update(&name, value)?;
}
Ok(self)
}
}
impl StateReducer<ChannelState, ChannelUpdate> for ChannelState {
fn apply(&self, state: ChannelState, update: ChannelUpdate) -> Result<ChannelState> {
state.merge(update)
}
}
#[cfg(test)]
mod test;