1use crate::nodes::definition::{NodeDataType, TraceDataType};
2use crate::nodes::extensions::NodeHandlerExtensions;
3use crate::nodes::function::v2::function::Function;
4use crate::nodes::result::{NodeResponse, NodeResult};
5use crate::nodes::NodeError;
6use crate::ZEN_CONFIG;
7use ahash::AHasher;
8use jsonschema::ValidationError;
9use serde::Serialize;
10use serde_json::Value;
11use std::cell::RefCell;
12use std::fmt::{Display, Formatter};
13use std::hash::Hasher;
14use std::sync::atomic::Ordering;
15use std::sync::Arc;
16use thiserror::Error;
17use zen_types::variable::{ToVariable, Variable};
18
19#[derive(Clone)]
20pub struct NodeContext<NodeData, TraceData>
21where
22 NodeData: NodeDataType,
23 TraceData: TraceDataType,
24{
25 pub id: Arc<str>,
26 pub name: Arc<str>,
27 pub node: NodeData,
28 pub input: Variable,
29 pub trace: Option<RefCell<TraceData>>,
30 pub extensions: NodeHandlerExtensions,
31 pub iteration: u8,
32 pub config: NodeContextConfig,
33}
34
35impl<NodeData, TraceData> NodeContext<NodeData, TraceData>
36where
37 NodeData: NodeDataType,
38 TraceData: TraceDataType,
39{
40 pub fn from_base(base: NodeContextBase, data: NodeData) -> Self {
41 Self {
42 id: base.id,
43 name: base.name,
44 input: base.input,
45 extensions: base.extensions,
46 iteration: base.iteration,
47 trace: base.config.trace.then(|| Default::default()),
48 node: data,
49 config: base.config,
50 }
51 }
52
53 pub fn trace<Function>(&self, mutator: Function)
54 where
55 Function: FnOnce(&mut TraceData),
56 {
57 if let Some(trace) = &self.trace {
58 mutator(&mut *trace.borrow_mut());
59 }
60 }
61
62 pub fn error<Error>(&self, error: Error) -> NodeResult
63 where
64 Error: Into<Box<dyn std::error::Error>>,
65 {
66 Err(self.make_error(error))
67 }
68
69 pub fn success(&self, output: Variable) -> NodeResult {
70 Ok(NodeResponse {
71 output,
72 trace_data: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
73 })
74 }
75
76 pub(crate) fn make_error<Error>(&self, error: Error) -> NodeError
77 where
78 Error: Into<Box<dyn std::error::Error>>,
79 {
80 NodeError {
81 node_id: self.id.clone(),
82 trace: self.trace.as_ref().map(|v| (*v.borrow()).to_variable()),
83 source: error.into(),
84 }
85 }
86
87 pub(crate) async fn function_runtime(&self) -> Result<&Function, NodeError> {
88 self.extensions.function_runtime().await.node_context(self)
89 }
90
91 pub fn validate(&self, schema: &Value, value: &Value) -> Result<(), NodeError> {
92 let validator_cache = self.extensions.validator_cache();
93 let hash = self.hash_node();
94
95 let validator = validator_cache
96 .get_or_insert(hash, schema)
97 .node_context(self)?;
98
99 validator
100 .validate(value)
101 .map_err(|err| ValidationErrorJson::from(err))
102 .node_context(self)?;
103
104 Ok(())
105 }
106
107 fn hash_node(&self) -> u64 {
108 let mut hasher = AHasher::default();
109 hasher.write(self.id.as_bytes());
110 hasher.write(self.name.as_bytes());
111 hasher.finish()
112 }
113}
114
115pub trait NodeContextExt<T, Context>: Sized {
116 type Error: Into<Box<dyn std::error::Error>>;
117
118 fn with_node_context<Function, NewError>(
119 self,
120 ctx: &Context,
121 f: Function,
122 ) -> Result<T, NodeError>
123 where
124 Function: FnOnce(Self::Error) -> NewError,
125 NewError: Into<Box<dyn std::error::Error>>;
126
127 fn node_context(self, ctx: &Context) -> Result<T, NodeError> {
128 self.with_node_context(ctx, |e| e.into())
129 }
130
131 fn node_context_message(self, ctx: &Context, message: &str) -> Result<T, NodeError> {
132 self.with_node_context(ctx, |err| format!("{}: {}", message, err.into()))
133 }
134}
135
136impl<T, E, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Result<T, E>
137where
138 E: Into<Box<dyn std::error::Error>>,
139 NodeData: NodeDataType,
140 TraceData: TraceDataType,
141{
142 type Error = E;
143
144 fn with_node_context<Function, NewError>(
145 self,
146 ctx: &NodeContext<NodeData, TraceData>,
147 f: Function,
148 ) -> Result<T, NodeError>
149 where
150 Function: FnOnce(Self::Error) -> NewError,
151 NewError: Into<Box<dyn std::error::Error>>,
152 {
153 self.map_err(|err| ctx.make_error(f(err)))
154 }
155}
156
157impl<T, NodeData, TraceData> NodeContextExt<T, NodeContext<NodeData, TraceData>> for Option<T>
158where
159 NodeData: NodeDataType,
160 TraceData: TraceDataType,
161{
162 type Error = &'static str;
163
164 fn with_node_context<Function, NewError>(
165 self,
166 ctx: &NodeContext<NodeData, TraceData>,
167 f: Function,
168 ) -> Result<T, NodeError>
169 where
170 Function: FnOnce(Self::Error) -> NewError,
171 NewError: Into<Box<dyn std::error::Error>>,
172 {
173 self.ok_or_else(|| ctx.make_error(f("None")))
174 }
175
176 fn node_context_message(
177 self,
178 ctx: &NodeContext<NodeData, TraceData>,
179 message: &str,
180 ) -> Result<T, NodeError> {
181 self.with_node_context(ctx, |_| message.to_string())
182 }
183}
184
185#[derive(Clone)]
186pub struct NodeContextBase {
187 pub id: Arc<str>,
188 pub name: Arc<str>,
189 pub input: Variable,
190 pub iteration: u8,
191 pub extensions: NodeHandlerExtensions,
192 pub config: NodeContextConfig,
193 pub trace: Option<RefCell<Variable>>,
194}
195
196impl NodeContextBase {
197 pub fn error<Error>(&self, error: Error) -> NodeResult
198 where
199 Error: Into<Box<dyn std::error::Error>>,
200 {
201 Err(self.make_error(error))
202 }
203
204 pub fn success(&self, output: Variable) -> NodeResult {
205 Ok(NodeResponse {
206 output,
207 trace_data: self.trace.as_ref().map(|v| v.borrow().to_variable()),
208 })
209 }
210
211 fn make_error<Error>(&self, error: Error) -> NodeError
212 where
213 Error: Into<Box<dyn std::error::Error>>,
214 {
215 NodeError {
216 node_id: self.id.clone(),
217 trace: self.trace.as_ref().map(|t| t.borrow().to_variable()),
218 source: error.into(),
219 }
220 }
221
222 pub fn trace<Function>(&self, mutator: Function)
223 where
224 Function: FnOnce(&mut Variable),
225 {
226 if let Some(trace) = &self.trace {
227 mutator(&mut *trace.borrow_mut());
228 }
229 }
230}
231
232impl<NodeData, TraceData> From<NodeContext<NodeData, TraceData>> for NodeContextBase
233where
234 NodeData: NodeDataType,
235 TraceData: TraceDataType,
236{
237 fn from(value: NodeContext<NodeData, TraceData>) -> Self {
238 let trace = match value.config.trace {
239 true => Some(RefCell::new(Variable::Null)),
240 false => None,
241 };
242
243 Self {
244 id: value.id,
245 name: value.name,
246 input: value.input,
247 extensions: value.extensions,
248 iteration: value.iteration,
249 config: value.config,
250 trace,
251 }
252 }
253}
254
255impl<T, E> NodeContextExt<T, NodeContextBase> for Result<T, E>
256where
257 E: Into<Box<dyn std::error::Error>>,
258{
259 type Error = E;
260
261 fn with_node_context<Function, NewError>(
262 self,
263 ctx: &NodeContextBase,
264 f: Function,
265 ) -> Result<T, NodeError>
266 where
267 Function: FnOnce(Self::Error) -> NewError,
268 NewError: Into<Box<dyn std::error::Error>>,
269 {
270 self.map_err(|err| ctx.make_error(f(err)))
271 }
272}
273
274impl<T> NodeContextExt<T, NodeContextBase> for Option<T> {
275 type Error = &'static str;
276
277 fn with_node_context<Function, NewError>(
278 self,
279 ctx: &NodeContextBase,
280 f: Function,
281 ) -> Result<T, NodeError>
282 where
283 Function: FnOnce(Self::Error) -> NewError,
284 NewError: Into<Box<dyn std::error::Error>>,
285 {
286 self.ok_or_else(|| ctx.make_error(f("None")))
287 }
288
289 fn node_context_message(self, ctx: &NodeContextBase, message: &str) -> Result<T, NodeError> {
290 self.with_node_context(ctx, |_| message.to_string())
291 }
292}
293
294#[derive(Debug, Serialize, Error)]
295#[serde(rename_all = "camelCase")]
296struct ValidationErrorJson {
297 path: String,
298 message: String,
299}
300
301impl Display for ValidationErrorJson {
302 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
303 write!(f, "{}: {}", self.path, self.message)
304 }
305}
306
307impl<'a> From<ValidationError<'a>> for ValidationErrorJson {
308 fn from(value: ValidationError<'a>) -> Self {
309 ValidationErrorJson {
310 path: value.instance_path.to_string(),
311 message: format!("{}", value),
312 }
313 }
314}
315
316#[derive(Clone)]
317pub struct NodeContextConfig {
318 pub trace: bool,
319 pub nodes_in_context: bool,
320 pub max_depth: u8,
321 pub function_timeout_millis: u64,
322 pub http_auth: bool,
323}
324
325impl Default for NodeContextConfig {
326 fn default() -> Self {
327 Self {
328 trace: false,
329 nodes_in_context: ZEN_CONFIG.nodes_in_context.load(Ordering::Relaxed),
330 function_timeout_millis: ZEN_CONFIG.function_timeout_millis.load(Ordering::Relaxed),
331 http_auth: ZEN_CONFIG.http_auth.load(Ordering::Relaxed),
332 max_depth: 5,
333 }
334 }
335}