intuicio_nodes/
nodes.rs

1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{AABB, Envelope, Point, PointDistance, RTree, RTreeObject};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5    Intermediate, de::intermediate::DeserializeMode, error::Result as IntermediateResult,
6};
7use std::{
8    collections::{HashMap, HashSet},
9    error::Error,
10    fmt::Display,
11    hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20    value: Intermediate,
21}
22
23impl PropertyValue {
24    pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25        Ok(Self {
26            value: serde_intermediate::to_intermediate(value)?,
27        })
28    }
29
30    pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31        serde_intermediate::from_intermediate_as(&self.value, mode)
32    }
33
34    pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35        self.get(PropertyCastMode::Exact)
36    }
37
38    pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39        self.get(PropertyCastMode::Interpret)
40    }
41
42    pub fn into_inner(self) -> Intermediate {
43        self.value
44    }
45}
46
47pub trait NodeTypeInfo:
48    Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50    fn type_query(&'_ self) -> TypeQuery<'_>;
51    fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55    type TypeInfo: NodeTypeInfo;
56
57    fn node_label(&self, registry: &Registry) -> String;
58    fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59    fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60    fn node_is_start(&self, registry: &Registry) -> bool;
61    fn node_suggestions(
62        x: i64,
63        y: i64,
64        suggestion: NodeSuggestion<Self>,
65        registry: &Registry,
66    ) -> Vec<ResponseSuggestionNode<Self>>;
67
68    #[allow(unused_variables)]
69    fn validate_connection(
70        &self,
71        source: &Self,
72        registry: &Registry,
73    ) -> Result<(), Box<dyn Error>> {
74        Ok(())
75    }
76
77    #[allow(unused_variables)]
78    fn get_property(&self, name: &str) -> Option<PropertyValue> {
79        None
80    }
81
82    #[allow(unused_variables)]
83    fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89    Execute { name: String, subscope: bool },
90    Parameter { name: String, type_info: TI },
91    Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95    pub fn execute(name: impl ToString, subscope: bool) -> Self {
96        Self::Execute {
97            name: name.to_string(),
98            subscope,
99        }
100    }
101
102    pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103        Self::Parameter {
104            name: name.to_string(),
105            type_info,
106        }
107    }
108
109    pub fn property(name: impl ToString) -> Self {
110        Self::Property {
111            name: name.to_string(),
112        }
113    }
114
115    pub fn is_execute(&self) -> bool {
116        matches!(self, Self::Execute { .. })
117    }
118
119    pub fn is_parameter(&self) -> bool {
120        matches!(self, Self::Parameter { .. })
121    }
122
123    pub fn is_property(&self) -> bool {
124        matches!(self, Self::Property { .. })
125    }
126
127    pub fn name(&self) -> &str {
128        match self {
129            Self::Execute { name, .. }
130            | Self::Parameter { name, .. }
131            | Self::Property { name, .. } => name,
132        }
133    }
134
135    pub fn has_subscope(&self) -> bool {
136        matches!(self, Self::Execute { subscope: true, .. })
137    }
138
139    pub fn type_info(&self) -> Option<&TI> {
140        match self {
141            Self::Parameter { type_info, .. } => Some(type_info),
142            _ => None,
143        }
144    }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148    All,
149    NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150    NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155    pub category: String,
156    pub label: String,
157    pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161    pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162        Self {
163            category: category.to_string(),
164            label: node.data.node_label(registry),
165            node,
166        }
167    }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172    id: NodeId<T>,
173    pub x: i64,
174    pub y: i64,
175    pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179    pub fn new(x: i64, y: i64, data: T) -> Self {
180        Self {
181            id: Default::default(),
182            x,
183            y,
184            data,
185        }
186    }
187
188    pub fn id(&self) -> NodeId<T> {
189        self.id
190    }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195    pub from_node: NodeId<T>,
196    pub to_node: NodeId<T>,
197    pub from_pin: String,
198    pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202    pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203        Self {
204            from_node,
205            to_node,
206            from_pin: from_pin.to_owned(),
207            to_pin: to_pin.to_owned(),
208        }
209    }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("NodeConnection")
215            .field("from_node", &self.from_node)
216            .field("to_node", &self.to_node)
217            .field("from_pin", &self.from_pin)
218            .field("to_pin", &self.to_pin)
219            .finish()
220    }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224    fn eq(&self, other: &Self) -> bool {
225        self.from_node == other.from_node
226            && self.to_node == other.to_node
227            && self.from_pin == other.from_pin
228            && self.to_pin == other.to_pin
229    }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235    fn hash<H: Hasher>(&self, state: &mut H) {
236        self.from_node.hash(state);
237        self.to_node.hash(state);
238        self.from_pin.hash(state);
239        self.to_pin.hash(state);
240    }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245    InternalConnection(String),
246    SourceNodeNotFound(String),
247    TargetNodeNotFound(String),
248    NodesNotFound {
249        from: String,
250        to: String,
251    },
252    SourcePinNotFound {
253        node: String,
254        pin: String,
255    },
256    TargetPinNotFound {
257        node: String,
258        pin: String,
259    },
260    MismatchTypes {
261        from_node: String,
262        from_pin: String,
263        from_type_info: String,
264        to_node: String,
265        to_pin: String,
266        to_type_info: String,
267    },
268    MismatchPins {
269        from_node: String,
270        from_pin: String,
271        to_node: String,
272        to_pin: String,
273    },
274    CycleNodeFound(String),
275    Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            Self::InternalConnection(node) => {
282                write!(f, "Trying to connect node: {node} to itself")
283            }
284            Self::SourceNodeNotFound(node) => write!(f, "Source node: {node} not found"),
285            Self::TargetNodeNotFound(node) => write!(f, "Target node: {node} not found"),
286            Self::NodesNotFound { from, to } => {
287                write!(f, "Source: {from} and target: {to} nodes not found")
288            }
289            Self::SourcePinNotFound { node, pin } => {
290                write!(f, "Source pin: {pin} for node: {node} not found")
291            }
292            Self::TargetPinNotFound { node, pin } => {
293                write!(f, "Target pin: {pin} for node: {node} not found")
294            }
295            Self::MismatchTypes {
296                from_node,
297                from_pin,
298                from_type_info,
299                to_node,
300                to_pin,
301                to_type_info,
302            } => {
303                write!(
304                    f,
305                    "Source type: {from_type_info} of pin: {from_pin} for node: {from_node} does not match target type: {to_type_info} of pin: {to_pin} for node: {to_node}"
306                )
307            }
308            Self::MismatchPins {
309                from_node,
310                from_pin,
311                to_node,
312                to_pin,
313            } => {
314                write!(
315                    f,
316                    "Source pin: {from_pin} kind for node: {from_node} does not match target pin: {to_pin} kind for node: {to_node}"
317                )
318            }
319            Self::CycleNodeFound(node) => write!(f, "Found cycle node: {node}"),
320            Self::Custom(error) => error.fmt(f),
321        }
322    }
323}
324
325impl Error for ConnectionError {}
326
327#[derive(Debug)]
328pub enum NodeGraphError {
329    Connection(ConnectionError),
330    DuplicateFunctionInputNames(String),
331    DuplicateFunctionOutputNames(String),
332}
333
334impl std::fmt::Display for NodeGraphError {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        match self {
337            Self::Connection(connection) => connection.fmt(f),
338            Self::DuplicateFunctionInputNames(name) => {
339                write!(
340                    f,
341                    "Found duplicate `{name}` function input with different types"
342                )
343            }
344            Self::DuplicateFunctionOutputNames(name) => {
345                write!(
346                    f,
347                    "Found duplicate `{name}` function output with different types"
348                )
349            }
350        }
351    }
352}
353
354impl Error for NodeGraphError {}
355
356#[derive(Clone)]
357struct SpatialNode<T: NodeDefinition> {
358    id: NodeId<T>,
359    x: i64,
360    y: i64,
361}
362
363impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
364    type Envelope = AABB<[i64; 2]>;
365
366    fn envelope(&self) -> Self::Envelope {
367        AABB::from_point([self.x, self.y])
368    }
369}
370
371impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
372    fn distance_2(
373        &self,
374        point: &<Self::Envelope as Envelope>::Point,
375    ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
376        let dx = self.x - point[0];
377        let dy = self.y - point[1];
378        dx * dx + dy * dy
379    }
380}
381
382#[derive(Clone, Serialize, Deserialize)]
383pub struct NodeGraph<T: NodeDefinition> {
384    nodes: Vec<Node<T>>,
385    connections: Vec<NodeConnection<T>>,
386    #[serde(skip, default)]
387    rtree: RTree<SpatialNode<T>>,
388}
389
390impl<T: NodeDefinition> Default for NodeGraph<T> {
391    fn default() -> Self {
392        Self {
393            nodes: vec![],
394            connections: vec![],
395            rtree: Default::default(),
396        }
397    }
398}
399
400impl<T: NodeDefinition> NodeGraph<T> {
401    pub fn clear(&mut self) {
402        self.nodes.clear();
403        self.connections.clear();
404    }
405
406    pub fn refresh_spatial_cache(&mut self) {
407        self.rtree = RTree::bulk_load(
408            self.nodes
409                .iter()
410                .map(|node| SpatialNode {
411                    id: node.id,
412                    x: node.x,
413                    y: node.y,
414                })
415                .collect(),
416        );
417    }
418
419    pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
420        self.rtree
421            .nearest_neighbor_iter(&[x, y])
422            .map(|node| node.id)
423    }
424
425    pub fn query_region_nodes(
426        &self,
427        fx: i64,
428        fy: i64,
429        tx: i64,
430        ty: i64,
431        extrude: i64,
432    ) -> impl Iterator<Item = NodeId<T>> + '_ {
433        self.rtree
434            .locate_in_envelope(&AABB::from_corners(
435                [fx - extrude, fy - extrude],
436                [tx - extrude, ty - extrude],
437            ))
438            .map(|node| node.id)
439    }
440
441    pub fn suggest_all_nodes(
442        x: i64,
443        y: i64,
444        registry: &Registry,
445    ) -> Vec<ResponseSuggestionNode<T>> {
446        T::node_suggestions(x, y, NodeSuggestion::All, registry)
447    }
448
449    pub fn suggest_node_input_pin(
450        &self,
451        x: i64,
452        y: i64,
453        id: NodeId<T>,
454        name: &str,
455        registry: &Registry,
456    ) -> Vec<ResponseSuggestionNode<T>> {
457        if let Some(node) = self.node(id)
458            && let Some(pin) = node
459                .data
460                .node_pins_in(registry)
461                .into_iter()
462                .find(|pin| pin.name() == name)
463        {
464            return T::node_suggestions(
465                x,
466                y,
467                NodeSuggestion::NodeInputPin(&node.data, &pin),
468                registry,
469            );
470        }
471        vec![]
472    }
473
474    pub fn suggest_node_output_pin(
475        &self,
476        x: i64,
477        y: i64,
478        id: NodeId<T>,
479        name: &str,
480        registry: &Registry,
481    ) -> Vec<ResponseSuggestionNode<T>> {
482        if let Some(node) = self.node(id)
483            && let Some(pin) = node
484                .data
485                .node_pins_out(registry)
486                .into_iter()
487                .find(|pin| pin.name() == name)
488        {
489            return T::node_suggestions(
490                x,
491                y,
492                NodeSuggestion::NodeOutputPin(&node.data, &pin),
493                registry,
494            );
495        }
496        vec![]
497    }
498
499    pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
500        self.nodes.iter().find(|node| node.id == id)
501    }
502
503    pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
504        self.nodes.iter_mut().find(|node| node.id == id)
505    }
506
507    pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
508        self.nodes.iter()
509    }
510
511    pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
512        self.nodes.iter_mut()
513    }
514
515    pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
516        if node.data.node_is_start(registry)
517            && self
518                .nodes
519                .iter()
520                .any(|node| node.data.node_is_start(registry))
521        {
522            return None;
523        }
524        let id = node.id;
525        if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
526            self.nodes.swap_remove(index);
527        }
528        self.nodes.push(node);
529        Some(id)
530    }
531
532    pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
533        if let Some(index) = self
534            .nodes
535            .iter()
536            .position(|node| node.id == id && !node.data.node_is_start(registry))
537        {
538            self.disconnect_node(id, None);
539            Some(self.nodes.swap_remove(index))
540        } else {
541            None
542        }
543    }
544
545    pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
546        if !self.connections.iter().any(|other| &connection == other) {
547            self.disconnect_node(connection.from_node, Some(&connection.from_pin));
548            self.disconnect_node(connection.to_node, Some(&connection.to_pin));
549            self.connections.push(connection);
550        }
551    }
552
553    pub fn disconnect_nodes(
554        &mut self,
555        from_node: NodeId<T>,
556        to_node: NodeId<T>,
557        from_pin: &str,
558        to_pin: &str,
559    ) {
560        if let Some(index) = self.connections.iter().position(|connection| {
561            connection.from_node == from_node
562                && connection.to_node == to_node
563                && connection.from_pin == from_pin
564                && connection.to_pin == to_pin
565        }) {
566            self.connections.swap_remove(index);
567        }
568    }
569
570    pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
571        let to_remove = self
572            .connections
573            .iter()
574            .enumerate()
575            .filter_map(|(index, connection)| {
576                if let Some(pin) = pin {
577                    if connection.from_node == node && connection.from_pin == pin {
578                        return Some(index);
579                    }
580                    if connection.to_node == node && connection.to_pin == pin {
581                        return Some(index);
582                    }
583                } else if connection.from_node == node || connection.to_node == node {
584                    return Some(index);
585                }
586                None
587            })
588            .collect::<Vec<_>>();
589        for index in to_remove.into_iter().rev() {
590            self.connections.swap_remove(index);
591        }
592    }
593
594    pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
595        self.connections.iter()
596    }
597
598    pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
599        self.connections
600            .iter()
601            .filter(move |connection| connection.from_node == id || connection.to_node == id)
602    }
603
604    pub fn node_connections_in<'a>(
605        &'a self,
606        id: NodeId<T>,
607        pin: Option<&'a str>,
608    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
609        self.connections.iter().filter(move |connection| {
610            connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
611        })
612    }
613
614    pub fn node_connections_out<'a>(
615        &'a self,
616        id: NodeId<T>,
617        pin: Option<&'a str>,
618    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
619        self.connections.iter().filter(move |connection| {
620            connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
621        })
622    }
623
624    pub fn node_neighbors_in<'a>(
625        &'a self,
626        id: NodeId<T>,
627        pin: Option<&'a str>,
628    ) -> impl Iterator<Item = NodeId<T>> + 'a {
629        self.node_connections_in(id, pin)
630            .map(move |connection| connection.from_node)
631    }
632
633    pub fn node_neighbors_out<'a>(
634        &'a self,
635        id: NodeId<T>,
636        pin: Option<&'a str>,
637    ) -> impl Iterator<Item = NodeId<T>> + 'a {
638        self.node_connections_out(id, pin)
639            .map(move |connection| connection.to_node)
640    }
641
642    pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
643        let mut errors = self
644            .connections
645            .iter()
646            .filter_map(|connection| self.validate_connection(connection, registry))
647            .map(NodeGraphError::Connection)
648            .collect::<Vec<_>>();
649        if let Some(error) = self.detect_cycles() {
650            errors.push(NodeGraphError::Connection(error));
651        }
652        if errors.is_empty() {
653            Ok(())
654        } else {
655            Err(errors)
656        }
657    }
658
659    fn validate_connection(
660        &self,
661        connection: &NodeConnection<T>,
662        registry: &Registry,
663    ) -> Option<ConnectionError> {
664        if connection.from_node == connection.to_node {
665            return Some(ConnectionError::InternalConnection(
666                connection.from_node.to_string(),
667            ));
668        }
669        let from = self
670            .nodes
671            .iter()
672            .find(|node| node.id == connection.from_node);
673        let to = self.nodes.iter().find(|node| node.id == connection.to_node);
674        let (from_node, to_node) = match (from, to) {
675            (Some(from), Some(to)) => (from, to),
676            (Some(_), None) => {
677                return Some(ConnectionError::TargetNodeNotFound(
678                    connection.to_node.to_string(),
679                ));
680            }
681            (None, Some(_)) => {
682                return Some(ConnectionError::SourceNodeNotFound(
683                    connection.from_node.to_string(),
684                ));
685            }
686            (None, None) => {
687                return Some(ConnectionError::NodesNotFound {
688                    from: connection.from_node.to_string(),
689                    to: connection.to_node.to_string(),
690                });
691            }
692        };
693        let from_pins_out = from_node.data.node_pins_out(registry);
694        let from_pin = match from_pins_out
695            .iter()
696            .find(|pin| pin.name() == connection.from_pin)
697        {
698            Some(pin) => pin,
699            None => {
700                return Some(ConnectionError::SourcePinNotFound {
701                    node: connection.from_node.to_string(),
702                    pin: connection.from_pin.to_owned(),
703                });
704            }
705        };
706        let to_pins_in = to_node.data.node_pins_in(registry);
707        let to_pin = match to_pins_in
708            .iter()
709            .find(|pin| pin.name() == connection.to_pin)
710        {
711            Some(pin) => pin,
712            None => {
713                return Some(ConnectionError::TargetPinNotFound {
714                    node: connection.to_node.to_string(),
715                    pin: connection.to_pin.to_owned(),
716                });
717            }
718        };
719        match (from_pin, to_pin) {
720            (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
721            (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
722                if !a.are_compatible(b) {
723                    return Some(ConnectionError::MismatchTypes {
724                        from_node: connection.from_node.to_string(),
725                        from_pin: connection.from_pin.to_owned(),
726                        to_node: connection.to_node.to_string(),
727                        to_pin: connection.to_pin.to_owned(),
728                        from_type_info: a.to_string(),
729                        to_type_info: b.to_string(),
730                    });
731                }
732            }
733            (NodePin::Property { .. }, NodePin::Property { .. }) => {}
734            _ => {
735                return Some(ConnectionError::MismatchPins {
736                    from_node: connection.from_node.to_string(),
737                    from_pin: connection.from_pin.to_owned(),
738                    to_node: connection.to_node.to_string(),
739                    to_pin: connection.to_pin.to_owned(),
740                });
741            }
742        }
743        if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
744            return Some(ConnectionError::Custom(error));
745        }
746        None
747    }
748
749    fn detect_cycles(&self) -> Option<ConnectionError> {
750        let mut visited = HashSet::with_capacity(self.nodes.len());
751        let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
752        while let Some(id) = available.first() {
753            if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
754                return Some(error);
755            }
756            available.swap_remove(0);
757        }
758        None
759    }
760
761    fn detect_cycle(
762        &self,
763        id: NodeId<T>,
764        available: &mut Vec<NodeId<T>>,
765        visited: &mut HashSet<NodeId<T>>,
766    ) -> Option<ConnectionError> {
767        if visited.contains(&id) {
768            return Some(ConnectionError::CycleNodeFound(id.to_string()));
769        }
770        visited.insert(id);
771        for id in self.node_neighbors_out(id, None) {
772            if let Some(index) = available.iter().position(|item| item == &id) {
773                available.swap_remove(index);
774                if let Some(error) = self.detect_cycle(id, available, visited) {
775                    return Some(error);
776                }
777            }
778        }
779        None
780    }
781
782    pub fn visit<V: NodeGraphVisitor<T>>(
783        &self,
784        visitor: &mut V,
785        registry: &Registry,
786    ) -> Vec<V::Output> {
787        let starts = self
788            .nodes
789            .iter()
790            .filter(|node| node.data.node_is_start(registry))
791            .map(|node| node.id)
792            .collect::<HashSet<_>>();
793        let mut result = Vec::with_capacity(self.nodes.len());
794        for id in starts {
795            self.visit_statement(id, &mut result, visitor, registry);
796        }
797        result
798    }
799
800    fn visit_statement<V: NodeGraphVisitor<T>>(
801        &self,
802        id: NodeId<T>,
803        result: &mut Vec<V::Output>,
804        visitor: &mut V,
805        registry: &Registry,
806    ) {
807        if let Some(node) = self.node(id) {
808            let inputs = node
809                .data
810                .node_pins_in(registry)
811                .into_iter()
812                .filter(|pin| pin.is_parameter())
813                .filter_map(|pin| {
814                    self.node_neighbors_in(id, Some(pin.name()))
815                        .next()
816                        .map(|id| (pin.name().to_owned(), id))
817                })
818                .filter_map(|(name, id)| {
819                    self.visit_expression(id, visitor, registry)
820                        .map(|input| (name, input))
821                })
822                .collect();
823            let pins_out = node.data.node_pins_out(registry);
824            let scopes = pins_out
825                .iter()
826                .filter(|pin| pin.has_subscope())
827                .filter_map(|pin| {
828                    let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
829                    Some((id, pin.name().to_owned()))
830                })
831                .map(|(id, name)| {
832                    let mut result = Vec::with_capacity(self.nodes.len());
833                    self.visit_statement(id, &mut result, visitor, registry);
834                    (name, result)
835                })
836                .collect();
837            if visitor.visit_statement(node, inputs, scopes, result) {
838                for pin in pins_out {
839                    if pin.is_execute() && !pin.has_subscope() {
840                        for id in self.node_neighbors_out(id, Some(pin.name())) {
841                            self.visit_statement(id, result, visitor, registry);
842                        }
843                    }
844                }
845            }
846        }
847    }
848
849    fn visit_expression<V: NodeGraphVisitor<T>>(
850        &self,
851        id: NodeId<T>,
852        visitor: &mut V,
853        registry: &Registry,
854    ) -> Option<V::Input> {
855        if let Some(node) = self.node(id) {
856            let inputs = node
857                .data
858                .node_pins_in(registry)
859                .into_iter()
860                .filter(|pin| pin.is_parameter())
861                .filter_map(|pin| {
862                    self.node_neighbors_in(id, Some(pin.name()))
863                        .next()
864                        .map(|id| (pin.name().to_owned(), id))
865                })
866                .filter_map(|(name, id)| {
867                    self.visit_expression(id, visitor, registry)
868                        .map(|input| (name, input))
869                })
870                .collect();
871            return visitor.visit_expression(node, inputs);
872        }
873        None
874    }
875}
876
877impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
878    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879        f.debug_struct("NodeGraph")
880            .field("nodes", &self.nodes)
881            .field("connections", &self.connections)
882            .finish()
883    }
884}
885
886pub trait NodeGraphVisitor<T: NodeDefinition> {
887    type Input;
888    type Output;
889
890    fn visit_statement(
891        &mut self,
892        node: &Node<T>,
893        inputs: HashMap<String, Self::Input>,
894        scopes: HashMap<String, Vec<Self::Output>>,
895        result: &mut Vec<Self::Output>,
896    ) -> bool;
897
898    fn visit_expression(
899        &mut self,
900        node: &Node<T>,
901        inputs: HashMap<String, Self::Input>,
902    ) -> Option<Self::Input>;
903}
904
905#[cfg(test)]
906mod tests {
907    use crate::nodes::{
908        Node, NodeConnection, NodeDefinition, NodeGraph, NodeGraphVisitor, NodePin, NodeSuggestion,
909        NodeTypeInfo, PropertyValue, ResponseSuggestionNode,
910    };
911    use intuicio_core::{registry::Registry, types::TypeQuery};
912    use std::collections::HashMap;
913
914    #[derive(Debug, Clone, PartialEq)]
915    enum Script {
916        Literal(i32),
917        Return,
918        Call(String),
919        Scope(Vec<Script>),
920    }
921
922    impl NodeTypeInfo for String {
923        fn type_query(&'_ self) -> TypeQuery<'_> {
924            TypeQuery {
925                name: Some(self.into()),
926                ..Default::default()
927            }
928        }
929
930        fn are_compatible(&self, other: &Self) -> bool {
931            self == other
932        }
933    }
934
935    #[derive(Debug, Clone)]
936    enum Nodes {
937        Start,
938        Expression(i32),
939        Result,
940        Convert(String),
941        Child,
942    }
943
944    impl NodeDefinition for Nodes {
945        type TypeInfo = String;
946
947        fn node_label(&self, _: &Registry) -> String {
948            format!("{self:?}")
949        }
950
951        fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
952            match self {
953                Nodes::Start => vec![],
954                Nodes::Expression(_) => {
955                    vec![NodePin::execute("In", false), NodePin::property("Value")]
956                }
957                Nodes::Result => vec![
958                    NodePin::execute("In", false),
959                    NodePin::parameter("Data", "i32".to_owned()),
960                ],
961                Nodes::Convert(_) => vec![
962                    NodePin::execute("In", false),
963                    NodePin::property("Name"),
964                    NodePin::parameter("Data in", "i32".to_owned()),
965                ],
966                Nodes::Child => vec![NodePin::execute("In", false)],
967            }
968        }
969
970        fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
971            match self {
972                Nodes::Start => vec![NodePin::execute("Out", false)],
973                Nodes::Expression(_) => vec![
974                    NodePin::execute("Out", false),
975                    NodePin::parameter("Data", "i32".to_owned()),
976                ],
977                Nodes::Result => vec![],
978                Nodes::Convert(_) => vec![
979                    NodePin::execute("Out", false),
980                    NodePin::parameter("Data out", "i32".to_owned()),
981                ],
982                Nodes::Child => vec![
983                    NodePin::execute("Out", false),
984                    NodePin::execute("Body", true),
985                ],
986            }
987        }
988
989        fn node_is_start(&self, _: &Registry) -> bool {
990            matches!(self, Self::Start)
991        }
992
993        fn node_suggestions(
994            _: i64,
995            _: i64,
996            _: NodeSuggestion<Self>,
997            _: &Registry,
998        ) -> Vec<ResponseSuggestionNode<Self>> {
999            vec![]
1000        }
1001
1002        fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1003            match self {
1004                Nodes::Expression(value) => match property_name {
1005                    "Value" => PropertyValue::new(value).ok(),
1006                    _ => None,
1007                },
1008                Nodes::Convert(name) => match property_name {
1009                    "Name" => PropertyValue::new(name).ok(),
1010                    _ => None,
1011                },
1012                _ => None,
1013            }
1014        }
1015
1016        fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1017            #[allow(clippy::single_match)]
1018            match self {
1019                Nodes::Expression(value) => match property_name {
1020                    "Value" => {
1021                        if let Ok(v) = property_value.get_exact::<i32>() {
1022                            *value = v;
1023                        }
1024                    }
1025                    _ => {}
1026                },
1027                Nodes::Convert(name) => {
1028                    if let Ok(v) = property_value.get_exact::<String>() {
1029                        *name = v;
1030                    }
1031                }
1032                _ => {}
1033            }
1034        }
1035    }
1036
1037    struct CompileNodesToScript;
1038
1039    impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1040        type Input = ();
1041        type Output = Script;
1042
1043        fn visit_statement(
1044            &mut self,
1045            node: &Node<Nodes>,
1046            _: HashMap<String, Self::Input>,
1047            mut scopes: HashMap<String, Vec<Self::Output>>,
1048            result: &mut Vec<Self::Output>,
1049        ) -> bool {
1050            match &node.data {
1051                Nodes::Result => result.push(Script::Return),
1052                Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1053                Nodes::Child => {
1054                    if let Some(body) = scopes.remove("Body") {
1055                        result.push(Script::Scope(body));
1056                    }
1057                }
1058                Nodes::Expression(value) => result.push(Script::Literal(*value)),
1059                _ => {}
1060            }
1061            true
1062        }
1063
1064        fn visit_expression(
1065            &mut self,
1066            _: &Node<Nodes>,
1067            _: HashMap<String, Self::Input>,
1068        ) -> Option<Self::Input> {
1069            None
1070        }
1071    }
1072
1073    #[test]
1074    fn test_nodes() {
1075        let registry = Registry::default().with_basic_types();
1076        let mut graph = NodeGraph::default();
1077        let start = graph
1078            .add_node(Node::new(0, 0, Nodes::Start), &registry)
1079            .unwrap();
1080        let expression_child = graph
1081            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1082            .unwrap();
1083        let convert_child = graph
1084            .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), &registry)
1085            .unwrap();
1086        let result_child = graph
1087            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1088            .unwrap();
1089        let child = graph
1090            .add_node(Node::new(0, 0, Nodes::Child), &registry)
1091            .unwrap();
1092        let expression = graph
1093            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1094            .unwrap();
1095        let convert = graph
1096            .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), &registry)
1097            .unwrap();
1098        let result = graph
1099            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1100            .unwrap();
1101        graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1102        graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1103        graph.connect_nodes(NodeConnection::new(
1104            expression_child,
1105            convert_child,
1106            "Out",
1107            "In",
1108        ));
1109        graph.connect_nodes(NodeConnection::new(
1110            expression_child,
1111            convert_child,
1112            "Data",
1113            "Data in",
1114        ));
1115        graph.connect_nodes(NodeConnection::new(
1116            convert_child,
1117            result_child,
1118            "Out",
1119            "In",
1120        ));
1121        graph.connect_nodes(NodeConnection::new(
1122            convert_child,
1123            result_child,
1124            "Data out",
1125            "Data",
1126        ));
1127        graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1128        graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1129        graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1130        graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1131        graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1132        graph.validate(&registry).unwrap();
1133        assert_eq!(
1134            graph.visit(&mut CompileNodesToScript, &registry),
1135            vec![
1136                Script::Scope(vec![
1137                    Script::Literal(42),
1138                    Script::Call("foo".to_owned()),
1139                    Script::Return
1140                ]),
1141                Script::Literal(42),
1142                Script::Call("bar".to_owned()),
1143                Script::Return
1144            ]
1145        );
1146        assert_eq!(
1147            graph
1148                .node(expression)
1149                .unwrap()
1150                .data
1151                .get_property("Value")
1152                .unwrap(),
1153            PropertyValue::new(&42i32).unwrap(),
1154        );
1155        graph
1156            .node_mut(expression)
1157            .unwrap()
1158            .data
1159            .set_property("Value", PropertyValue::new(&10i32).unwrap());
1160        assert_eq!(
1161            graph
1162                .node(expression)
1163                .unwrap()
1164                .data
1165                .get_property("Value")
1166                .unwrap(),
1167            PropertyValue::new(&10i32).unwrap(),
1168        );
1169    }
1170}