use std::sync::Arc;
use async_trait::async_trait;
use cognis_core::stream::Observer;
use cognis_core::{Event, Result, RunnableConfig};
use uuid::Uuid;
use crate::goto::Goto;
use crate::state::GraphState;
pub struct NodeOut<S: GraphState> {
pub update: S::Update,
pub goto: Goto,
}
impl<S: GraphState> NodeOut<S> {
pub fn end_with(update: S::Update) -> Self {
Self {
update,
goto: Goto::End,
}
}
pub fn goto_only(goto: Goto) -> Self {
Self {
update: S::Update::default(),
goto,
}
}
}
pub struct NodeCtx<'a> {
pub run_id: Uuid,
pub step: u64,
pub config: &'a RunnableConfig,
payload: Option<&'a serde_json::Value>,
remaining_steps: Option<u32>,
}
impl<'a> NodeCtx<'a> {
pub fn new(run_id: Uuid, step: u64, config: &'a RunnableConfig) -> Self {
Self {
run_id,
step,
config,
payload: None,
remaining_steps: None,
}
}
pub(crate) fn with_payload(mut self, payload: &'a serde_json::Value) -> Self {
self.payload = Some(payload);
self
}
pub(crate) fn with_remaining_steps(mut self, remaining: u32) -> Self {
self.remaining_steps = Some(remaining);
self
}
pub fn payload(&self) -> Option<&serde_json::Value> {
self.payload
}
pub fn remaining_steps(&self) -> Option<u32> {
self.remaining_steps
}
pub fn is_last_step(&self) -> bool {
matches!(self.remaining_steps, Some(0) | Some(1))
}
pub fn emit(&self, event: &Event) {
self.config.emit(event);
}
pub fn write_custom(&self, kind: impl Into<String>, payload: serde_json::Value) {
self.config.emit(&Event::Custom {
kind: kind.into(),
payload,
run_id: self.run_id,
});
}
pub fn is_cancelled(&self) -> bool {
self.config.is_cancelled()
}
pub fn observers(&self) -> &[Arc<dyn Observer>] {
&self.config.observers
}
}
#[derive(Debug, Clone, Copy)]
pub struct NodeRetryPolicy {
pub max_attempts: u32,
pub initial_delay_ms: u64,
pub backoff_multiplier: f64,
pub max_delay_ms: u64,
}
impl Default for NodeRetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 100,
backoff_multiplier: 2.0,
max_delay_ms: 30_000,
}
}
}
#[async_trait]
pub trait Node<S: GraphState>: Send + Sync {
async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>>;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
fn retry_policy(&self) -> Option<NodeRetryPolicy> {
None
}
}
pub struct NodeFn<S, F> {
name: String,
f: F,
_state: std::marker::PhantomData<fn() -> S>,
}
pub fn node_fn<S, F, Fut>(name: impl Into<String>, f: F) -> NodeFn<S, F>
where
S: GraphState,
F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
{
NodeFn {
name: name.into(),
f,
_state: std::marker::PhantomData,
}
}
#[async_trait]
impl<S, F, Fut> Node<S> for NodeFn<S, F>
where
S: GraphState,
F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
{
async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>> {
(self.f)(state, ctx).await
}
fn name(&self) -> &str {
&self.name
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::goto::Goto;
use crate::state::GraphState;
#[derive(Default, Clone, Debug, PartialEq)]
struct S {
n: u32,
}
#[derive(Default)]
struct SU {
n: u32,
}
impl GraphState for S {
type Update = SU;
fn apply(&mut self, update: Self::Update) {
self.n += update.n;
}
}
#[tokio::test]
async fn node_fn_executes() {
let n = node_fn::<S, _, _>("incr", |state, _ctx| {
let cur = state.n;
async move {
Ok(NodeOut {
update: SU { n: cur + 1 },
goto: Goto::end(),
})
}
});
let cfg = RunnableConfig::default();
let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
let s = S { n: 5 };
let out = n.execute(&s, &ctx).await.unwrap();
assert_eq!(out.update.n, 6);
assert!(out.goto.is_end());
assert_eq!(n.name(), "incr");
}
#[test]
fn node_ctx_payload_default_none() {
let cfg = RunnableConfig::default();
let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
assert!(ctx.payload().is_none());
}
#[test]
fn node_ctx_with_payload() {
let cfg = RunnableConfig::default();
let payload = serde_json::json!({"x": 42});
let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&payload);
assert_eq!(ctx.payload().unwrap()["x"], 42);
}
#[test]
fn nodeout_constructors() {
let upd = SU { n: 10 };
let no: NodeOut<S> = NodeOut::end_with(upd);
assert!(no.goto.is_end());
let no2: NodeOut<S> = NodeOut::goto_only(Goto::node("next"));
assert_eq!(no2.update.n, 0);
assert_eq!(no2.goto, Goto::node("next"));
}
}