1use std::sync::Arc;
4
5use async_trait::async_trait;
6use cognis_core::stream::Observer;
7use cognis_core::{Event, Result, RunnableConfig};
8use uuid::Uuid;
9
10use crate::goto::Goto;
11use crate::state::GraphState;
12
13pub struct NodeOut<S: GraphState> {
15 pub update: S::Update,
17 pub goto: Goto,
19}
20
21impl<S: GraphState> NodeOut<S> {
22 pub fn end_with(update: S::Update) -> Self {
24 Self {
25 update,
26 goto: Goto::End,
27 }
28 }
29
30 pub fn goto_only(goto: Goto) -> Self {
32 Self {
33 update: S::Update::default(),
34 goto,
35 }
36 }
37}
38
39pub struct NodeCtx<'a> {
43 pub run_id: Uuid,
45 pub step: u64,
47 pub config: &'a RunnableConfig,
49 payload: Option<&'a serde_json::Value>,
52 remaining_steps: Option<u32>,
56}
57
58impl<'a> NodeCtx<'a> {
59 pub fn new(run_id: Uuid, step: u64, config: &'a RunnableConfig) -> Self {
62 Self {
63 run_id,
64 step,
65 config,
66 payload: None,
67 remaining_steps: None,
68 }
69 }
70
71 pub(crate) fn with_payload(mut self, payload: &'a serde_json::Value) -> Self {
73 self.payload = Some(payload);
74 self
75 }
76
77 pub(crate) fn with_remaining_steps(mut self, remaining: u32) -> Self {
79 self.remaining_steps = Some(remaining);
80 self
81 }
82
83 pub fn payload(&self) -> Option<&serde_json::Value> {
86 self.payload
87 }
88
89 pub fn remaining_steps(&self) -> Option<u32> {
92 self.remaining_steps
93 }
94
95 pub fn is_last_step(&self) -> bool {
98 matches!(self.remaining_steps, Some(0) | Some(1))
99 }
100
101 pub fn emit(&self, event: &Event) {
103 self.config.emit(event);
104 }
105
106 pub fn write_custom(&self, kind: impl Into<String>, payload: serde_json::Value) {
110 self.config.emit(&Event::Custom {
111 kind: kind.into(),
112 payload,
113 run_id: self.run_id,
114 });
115 }
116
117 pub fn is_cancelled(&self) -> bool {
119 self.config.is_cancelled()
120 }
121
122 pub fn observers(&self) -> &[Arc<dyn Observer>] {
124 &self.config.observers
125 }
126}
127
128#[derive(Debug, Clone, Copy)]
132pub struct NodeRetryPolicy {
133 pub max_attempts: u32,
135 pub initial_delay_ms: u64,
137 pub backoff_multiplier: f64,
139 pub max_delay_ms: u64,
141}
142
143impl Default for NodeRetryPolicy {
144 fn default() -> Self {
145 Self {
146 max_attempts: 3,
147 initial_delay_ms: 100,
148 backoff_multiplier: 2.0,
149 max_delay_ms: 30_000,
150 }
151 }
152}
153
154#[async_trait]
157pub trait Node<S: GraphState>: Send + Sync {
158 async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>>;
160
161 fn name(&self) -> &str {
163 std::any::type_name::<Self>()
164 }
165
166 fn retry_policy(&self) -> Option<NodeRetryPolicy> {
170 None
171 }
172}
173
174pub struct NodeFn<S, F> {
176 name: String,
177 f: F,
178 _state: std::marker::PhantomData<fn() -> S>,
179}
180
181pub fn node_fn<S, F, Fut>(name: impl Into<String>, f: F) -> NodeFn<S, F>
184where
185 S: GraphState,
186 F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
187 Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
188{
189 NodeFn {
190 name: name.into(),
191 f,
192 _state: std::marker::PhantomData,
193 }
194}
195
196#[async_trait]
197impl<S, F, Fut> Node<S> for NodeFn<S, F>
198where
199 S: GraphState,
200 F: Fn(&S, &NodeCtx<'_>) -> Fut + Send + Sync + 'static,
201 Fut: std::future::Future<Output = Result<NodeOut<S>>> + Send,
202{
203 async fn execute(&self, state: &S, ctx: &NodeCtx<'_>) -> Result<NodeOut<S>> {
204 (self.f)(state, ctx).await
205 }
206
207 fn name(&self) -> &str {
208 &self.name
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::goto::Goto;
216 use crate::state::GraphState;
217
218 #[derive(Default, Clone, Debug, PartialEq)]
219 struct S {
220 n: u32,
221 }
222
223 #[derive(Default)]
224 struct SU {
225 n: u32,
226 }
227
228 impl GraphState for S {
229 type Update = SU;
230 fn apply(&mut self, update: Self::Update) {
231 self.n += update.n;
232 }
233 }
234
235 #[tokio::test]
236 async fn node_fn_executes() {
237 let n = node_fn::<S, _, _>("incr", |state, _ctx| {
238 let cur = state.n;
239 async move {
240 Ok(NodeOut {
241 update: SU { n: cur + 1 },
242 goto: Goto::end(),
243 })
244 }
245 });
246 let cfg = RunnableConfig::default();
247 let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
248 let s = S { n: 5 };
249 let out = n.execute(&s, &ctx).await.unwrap();
250 assert_eq!(out.update.n, 6);
251 assert!(out.goto.is_end());
252 assert_eq!(n.name(), "incr");
253 }
254
255 #[test]
256 fn node_ctx_payload_default_none() {
257 let cfg = RunnableConfig::default();
258 let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg);
259 assert!(ctx.payload().is_none());
260 }
261
262 #[test]
263 fn node_ctx_with_payload() {
264 let cfg = RunnableConfig::default();
265 let payload = serde_json::json!({"x": 42});
266 let ctx = NodeCtx::new(Uuid::nil(), 0, &cfg).with_payload(&payload);
267 assert_eq!(ctx.payload().unwrap()["x"], 42);
268 }
269
270 #[test]
271 fn nodeout_constructors() {
272 let upd = SU { n: 10 };
273 let no: NodeOut<S> = NodeOut::end_with(upd);
274 assert!(no.goto.is_end());
275
276 let no2: NodeOut<S> = NodeOut::goto_only(Goto::node("next"));
277 assert_eq!(no2.update.n, 0);
278 assert_eq!(no2.goto, Goto::node("next"));
279 }
280}