intuicio_nodes/
server.rs

1use crate::nodes::*;
2use intuicio_core::registry::Registry;
3use serde::{Deserialize, Serialize};
4use std::{
5    collections::{HashMap, HashSet},
6    error::Error,
7};
8use typid::ID;
9
10pub type NodeGraphId<T> = ID<NodeGraph<T>>;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RequestAdd<T: NodeDefinition> {
14    pub nodes: Vec<Node<T>>,
15    pub connections: Vec<NodeConnection<T>>,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RequestRemove<T: NodeDefinition> {
20    pub nodes: Vec<NodeId<T>>,
21    pub connections: Vec<NodeConnection<T>>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct RequestUpdate<T: NodeDefinition> {
26    pub nodes: Vec<Node<T>>,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct RequestQueryRegion {
31    pub fx: i64,
32    pub fy: i64,
33    pub tx: i64,
34    pub ty: i64,
35    pub extrude: i64,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ResponseQuery<T: NodeDefinition> {
40    pub nodes: Vec<Node<T>>,
41    pub connections: Vec<NodeConnection<T>>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub enum NodeGraphServerError {
46    NodeGraphDoesNotExists(String),
47    NodeNotFound { graph: String, node: String },
48    ValidationErrors { graph: String, errors: Vec<String> },
49}
50
51impl std::fmt::Display for NodeGraphServerError {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            NodeGraphServerError::NodeGraphDoesNotExists(id) => {
55                write!(f, "Node graph does not exists: {id}")
56            }
57            NodeGraphServerError::NodeNotFound { graph, node } => {
58                write!(f, "Node graph: {graph} does not have node: {node}")
59            }
60            NodeGraphServerError::ValidationErrors { graph, errors } => {
61                write!(f, "Node graph: {graph} validation errors:")?;
62                for error in errors {
63                    write!(f, "{error}")?;
64                }
65                Ok(())
66            }
67        }
68    }
69}
70
71impl Error for NodeGraphServerError {}
72
73pub struct NodeGraphServer<T: NodeDefinition + Clone> {
74    graphs: HashMap<NodeGraphId<T>, NodeGraph<T>>,
75}
76
77impl<T: NodeDefinition + Clone> Default for NodeGraphServer<T> {
78    fn default() -> Self {
79        Self {
80            graphs: Default::default(),
81        }
82    }
83}
84
85impl<T: NodeDefinition + Clone> NodeGraphServer<T> {
86    pub fn graph(&self, id: NodeGraphId<T>) -> Result<&NodeGraph<T>, NodeGraphServerError> {
87        self.graphs
88            .get(&id)
89            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
90    }
91
92    pub fn graph_mut(
93        &mut self,
94        id: NodeGraphId<T>,
95    ) -> Result<&mut NodeGraph<T>, NodeGraphServerError> {
96        self.graphs
97            .get_mut(&id)
98            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
99    }
100
101    pub fn create(&mut self) -> NodeGraphId<T> {
102        let id = NodeGraphId::new();
103        self.graphs.insert(id, NodeGraph::default());
104        id
105    }
106
107    pub fn destroy(&mut self, id: NodeGraphId<T>) -> Result<NodeGraph<T>, NodeGraphServerError> {
108        self.graphs
109            .remove(&id)
110            .ok_or_else(|| NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
111    }
112
113    pub fn list(&self) -> impl Iterator<Item = &NodeGraphId<T>> {
114        self.graphs.keys()
115    }
116
117    pub fn add(
118        &mut self,
119        id: NodeGraphId<T>,
120        request: RequestAdd<T>,
121        registry: &Registry,
122    ) -> Result<(), NodeGraphServerError> {
123        if let Some(graph) = self.graphs.get_mut(&id) {
124            for node in request.nodes {
125                graph.add_node(node, registry);
126            }
127            for connection in request.connections {
128                graph.connect_nodes(connection);
129            }
130            graph.refresh_spatial_cache();
131            Ok(())
132        } else {
133            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
134        }
135    }
136
137    pub fn remove(
138        &mut self,
139        id: NodeGraphId<T>,
140        request: RequestRemove<T>,
141        registry: &Registry,
142    ) -> Result<(), NodeGraphServerError> {
143        if let Some(graph) = self.graphs.get_mut(&id) {
144            for connection in request.connections {
145                graph.disconnect_nodes(
146                    connection.from_node,
147                    connection.to_node,
148                    &connection.from_pin,
149                    &connection.to_pin,
150                );
151            }
152            for id in request.nodes {
153                graph.remove_node(id, registry);
154            }
155            graph.refresh_spatial_cache();
156            Ok(())
157        } else {
158            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
159        }
160    }
161
162    pub fn update(
163        &mut self,
164        id: NodeGraphId<T>,
165        request: RequestUpdate<T>,
166    ) -> Result<(), NodeGraphServerError> {
167        if let Some(graph) = self.graphs.get_mut(&id) {
168            for source in &request.nodes {
169                if graph.node(source.id()).is_none() {
170                    return Err(NodeGraphServerError::NodeNotFound {
171                        graph: id.to_string(),
172                        node: source.id().to_string(),
173                    });
174                }
175            }
176            for source in request.nodes {
177                let id = source.id();
178                *graph.node_mut(id).unwrap() = source;
179            }
180            graph.refresh_spatial_cache();
181            Ok(())
182        } else {
183            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
184        }
185    }
186
187    pub fn clear(&mut self, id: NodeGraphId<T>) -> Result<(), NodeGraphServerError> {
188        if let Some(graph) = self.graphs.get_mut(&id) {
189            graph.clear();
190            graph.refresh_spatial_cache();
191            Ok(())
192        } else {
193            Err(NodeGraphServerError::NodeGraphDoesNotExists(id.to_string()))
194        }
195    }
196
197    pub fn query_all(
198        &self,
199        graph: NodeGraphId<T>,
200    ) -> Result<ResponseQuery<T>, NodeGraphServerError> {
201        if let Some(graph) = self.graphs.get(&graph) {
202            Ok(ResponseQuery {
203                nodes: graph.nodes().cloned().collect(),
204                connections: graph.connections().cloned().collect(),
205            })
206        } else {
207            Err(NodeGraphServerError::NodeGraphDoesNotExists(
208                graph.to_string(),
209            ))
210        }
211    }
212
213    pub fn query_region(
214        &self,
215        graph: NodeGraphId<T>,
216        request: RequestQueryRegion,
217    ) -> Result<ResponseQuery<T>, NodeGraphServerError> {
218        if let Some(graph) = self.graphs.get(&graph) {
219            let RequestQueryRegion {
220                fx,
221                fy,
222                tx,
223                ty,
224                extrude,
225            } = request;
226            let nodes = graph
227                .query_region_nodes(fx, fy, tx, ty, extrude)
228                .filter_map(|id| graph.node(id))
229                .cloned()
230                .collect::<Vec<_>>();
231            let connections = nodes
232                .iter()
233                .flat_map(|node| graph.node_connections(node.id()))
234                .cloned()
235                .collect::<HashSet<_>>()
236                .into_iter()
237                .collect();
238            Ok(ResponseQuery { nodes, connections })
239        } else {
240            Err(NodeGraphServerError::NodeGraphDoesNotExists(
241                graph.to_string(),
242            ))
243        }
244    }
245
246    pub fn suggest_all_nodes(
247        x: i64,
248        y: i64,
249        registry: &Registry,
250    ) -> Vec<ResponseSuggestionNode<T>> {
251        NodeGraph::suggest_all_nodes(x, y, registry)
252    }
253
254    pub fn validate(
255        &self,
256        graph: NodeGraphId<T>,
257        registry: &Registry,
258    ) -> Result<(), NodeGraphServerError> {
259        if let Some(item) = self.graphs.get(&graph) {
260            match item.validate(registry) {
261                Ok(_) => Ok(()),
262                Err(errors) => Err(NodeGraphServerError::ValidationErrors {
263                    graph: graph.to_string(),
264                    errors: errors.into_iter().map(|error| error.to_string()).collect(),
265                }),
266            }
267        } else {
268            Err(NodeGraphServerError::NodeGraphDoesNotExists(
269                graph.to_string(),
270            ))
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use crate::{
278        nodes::{
279            Node, NodeConnection, NodeDefinition, NodePin, NodeSuggestion, NodeTypeInfo,
280            ResponseSuggestionNode,
281        },
282        server::{NodeGraphServer, NodeGraphServerError, RequestAdd, RequestRemove, RequestUpdate},
283    };
284    use intuicio_core::{registry::Registry, types::TypeQuery};
285    use serde::{Deserialize, Serialize, de::DeserializeOwned};
286
287    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
288    struct TypeInfo;
289
290    impl std::fmt::Display for TypeInfo {
291        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292            write!(f, "")
293        }
294    }
295
296    impl NodeTypeInfo for TypeInfo {
297        fn type_query(&'_ self) -> TypeQuery<'_> {
298            Default::default()
299        }
300
301        fn are_compatible(&self, _: &Self) -> bool {
302            true
303        }
304    }
305
306    #[derive(Debug, Clone, Serialize, Deserialize)]
307    enum Nodes {
308        Start,
309        Expression(i32),
310        Result,
311        Convert(String),
312    }
313
314    impl NodeDefinition for Nodes {
315        type TypeInfo = TypeInfo;
316
317        fn node_label(&self, _: &Registry) -> String {
318            format!("{self:?}")
319        }
320
321        fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
322            match self {
323                Nodes::Start => vec![],
324                Nodes::Expression(_) => vec![NodePin::property("Value")],
325                Nodes::Result => vec![
326                    NodePin::execute("In", false),
327                    NodePin::parameter("Data", TypeInfo),
328                ],
329                Nodes::Convert(_) => vec![
330                    NodePin::execute("In", false),
331                    NodePin::property("Name"),
332                    NodePin::parameter("Data in", TypeInfo),
333                ],
334            }
335        }
336
337        fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
338            match self {
339                Nodes::Start => vec![NodePin::execute("Out", false)],
340                Nodes::Expression(_) => vec![NodePin::parameter("Data", TypeInfo)],
341                Nodes::Result => vec![],
342                Nodes::Convert(_) => vec![
343                    NodePin::execute("Out", false),
344                    NodePin::parameter("Data out", TypeInfo),
345                ],
346            }
347        }
348
349        fn node_is_start(&self, _: &Registry) -> bool {
350            matches!(self, Self::Start)
351        }
352
353        fn node_suggestions(
354            _: i64,
355            _: i64,
356            _: NodeSuggestion<Self>,
357            _: &Registry,
358        ) -> Vec<ResponseSuggestionNode<Self>> {
359            vec![]
360        }
361    }
362
363    fn mock_transfer<T: Serialize + DeserializeOwned>(value: T) -> T {
364        let content = serde_json::to_string(&value).unwrap();
365        serde_json::from_str(&content).unwrap()
366    }
367
368    #[test]
369    fn test_server() {
370        let registry = Registry::default().with_basic_types();
371        let mut server = NodeGraphServer::default();
372        let graph = server.create();
373        let start = Node::new(0, 0, Nodes::Start);
374        let expression = Node::new(0, 0, Nodes::Expression(42));
375        let convert = Node::new(0, 0, Nodes::Convert("foo".to_owned()));
376        let result = Node::new(0, 0, Nodes::Result);
377        server
378            .add(
379                graph,
380                mock_transfer(RequestAdd {
381                    connections: vec![
382                        NodeConnection::new(start.id(), convert.id(), "Out", "In"),
383                        NodeConnection::new(convert.id(), result.id(), "Out", "In"),
384                        NodeConnection::new(expression.id(), convert.id(), "Data", "Data in"),
385                    ],
386                    nodes: vec![
387                        start.clone(),
388                        expression.clone(),
389                        convert.clone(),
390                        result.clone(),
391                    ],
392                }),
393                &registry,
394            )
395            .unwrap();
396        let temp = server.query_all(graph).unwrap();
397        assert_eq!(temp.nodes.len(), 4);
398        assert_eq!(temp.connections.len(), 3);
399        server
400            .remove(
401                graph,
402                RequestRemove {
403                    nodes: vec![result.id(), convert.id()],
404                    connections: vec![],
405                },
406                &registry,
407            )
408            .unwrap();
409        let temp = server.query_all(graph).unwrap();
410        assert_eq!(temp.nodes.len(), 2);
411        assert_eq!(temp.connections.len(), 0);
412        assert!(matches!(
413            server.update(
414                graph,
415                mock_transfer(RequestUpdate {
416                    nodes: vec![expression.clone(), convert.clone()],
417                }),
418            ),
419            Err(NodeGraphServerError::NodeNotFound { .. })
420        ));
421        let temp = server.query_all(graph).unwrap();
422        assert_eq!(temp.nodes.len(), 2);
423        assert_eq!(temp.connections.len(), 0);
424        server
425            .update(
426                graph,
427                mock_transfer(RequestUpdate {
428                    nodes: vec![expression.clone(), start.clone()],
429                }),
430            )
431            .unwrap();
432        let temp = server.query_all(graph).unwrap();
433        assert_eq!(temp.nodes.len(), 2);
434        assert_eq!(temp.connections.len(), 0);
435    }
436}