Skip to main content

ranvier_std/nodes/
logic.rs

1use async_trait::async_trait;
2use ranvier_core::{bus::Bus, outcome::Outcome, transition::Transition};
3use schemars::JsonSchema;
4use serde::{Deserialize, Serialize};
5use std::marker::PhantomData;
6
7#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
8pub struct RandomBranchNode<T> {
9    pub probability: f64,
10    pub jump_target: String,
11    #[serde(skip)]
12    pub _marker: PhantomData<T>,
13}
14
15impl<T> RandomBranchNode<T> {
16    pub fn new(probability: f64, jump_target: impl Into<String>) -> Self {
17        Self {
18            probability,
19            jump_target: jump_target.into(),
20            _marker: PhantomData,
21        }
22    }
23}
24
25#[async_trait]
26impl<T> Transition<T, T> for RandomBranchNode<T>
27where
28    T: Send + Sync + 'static + Clone + Serialize,
29{
30    type Error = String;
31    type Resources = ();
32
33    async fn run(
34        &self,
35        input: T,
36        _resources: &Self::Resources,
37        _bus: &mut Bus,
38    ) -> Outcome<T, Self::Error> {
39        if rand::random::<f64>() < self.probability {
40            Outcome::next(input)
41        } else {
42            let payload = serde_json::to_value(&input).ok();
43            Outcome::branch(self.jump_target.clone(), payload)
44        }
45    }
46}
47
48use std::sync::Arc;
49
50#[derive(Serialize, Deserialize, JsonSchema)]
51pub struct FilterNode<T, F> {
52    #[serde(skip)]
53    pub predicate: Arc<F>,
54    #[serde(skip)]
55    pub _marker: PhantomData<T>,
56}
57
58impl<T, F> FilterNode<T, F>
59where
60    F: Fn(&T) -> bool + Send + Sync + 'static,
61{
62    pub fn new(predicate: F) -> Self {
63        Self {
64            predicate: Arc::new(predicate),
65            _marker: PhantomData,
66        }
67    }
68}
69
70// Clone is now always derived-able if we manually impl or just rely on Arc
71impl<T, F> Clone for FilterNode<T, F> {
72    fn clone(&self) -> Self {
73        Self {
74            predicate: self.predicate.clone(),
75            _marker: PhantomData,
76        }
77    }
78}
79
80impl<T, F> std::fmt::Debug for FilterNode<T, F> {
81    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
82        f.debug_struct("FilterNode").finish()
83    }
84}
85
86#[async_trait]
87impl<T, F> Transition<T, T> for FilterNode<T, F>
88where
89    T: Send + Sync + 'static + Serialize,
90    F: Fn(&T) -> bool + Send + Sync + 'static,
91{
92    type Error = String;
93    type Resources = ();
94
95    async fn run(
96        &self,
97        input: T,
98        _resources: &Self::Resources,
99        _bus: &mut Bus,
100    ) -> Outcome<T, Self::Error> {
101        if (self.predicate)(&input) {
102            Outcome::next(input)
103        } else {
104            let payload = serde_json::to_value(&input).ok();
105            Outcome::branch("rejected".to_string(), payload)
106        }
107    }
108}
109
110#[derive(Serialize, Deserialize, JsonSchema)]
111pub struct SwitchNode<T, F> {
112    #[serde(skip)]
113    pub matcher: Arc<F>,
114    #[serde(skip)]
115    pub _marker: PhantomData<T>,
116}
117
118impl<T, F> SwitchNode<T, F>
119where
120    F: Fn(&T) -> String + Send + Sync + 'static,
121{
122    pub fn new(matcher: F) -> Self {
123        Self {
124            matcher: Arc::new(matcher),
125            _marker: PhantomData,
126        }
127    }
128}
129
130impl<T, F> Clone for SwitchNode<T, F> {
131    fn clone(&self) -> Self {
132        Self {
133            matcher: self.matcher.clone(),
134            _marker: PhantomData,
135        }
136    }
137}
138
139impl<T, F> std::fmt::Debug for SwitchNode<T, F> {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        f.debug_struct("SwitchNode").finish()
142    }
143}
144
145#[async_trait]
146impl<T, F> Transition<T, T> for SwitchNode<T, F>
147where
148    T: Send + Sync + 'static + Serialize,
149    F: Fn(&T) -> String + Send + Sync + 'static,
150{
151    type Error = String;
152    type Resources = ();
153
154    async fn run(
155        &self,
156        input: T,
157        _resources: &Self::Resources,
158        _bus: &mut Bus,
159    ) -> Outcome<T, Self::Error> {
160        let branch_id = (self.matcher)(&input);
161        let payload = serde_json::to_value(&input).ok();
162        Outcome::branch(branch_id, payload)
163    }
164}