intuicio_nodes/
nodes.rs

1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{AABB, Envelope, Point, PointDistance, RTree, RTreeObject};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5    Intermediate, de::intermediate::DeserializeMode, error::Result as IntermediateResult,
6};
7use std::{
8    collections::{HashMap, HashSet},
9    error::Error,
10    fmt::Display,
11    hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20    value: Intermediate,
21}
22
23impl PropertyValue {
24    pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25        Ok(Self {
26            value: serde_intermediate::to_intermediate(value)?,
27        })
28    }
29
30    pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31        serde_intermediate::from_intermediate_as(&self.value, mode)
32    }
33
34    pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35        self.get(PropertyCastMode::Exact)
36    }
37
38    pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39        self.get(PropertyCastMode::Interpret)
40    }
41
42    pub fn into_inner(self) -> Intermediate {
43        self.value
44    }
45}
46
47pub trait NodeTypeInfo:
48    Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50    fn type_query(&self) -> TypeQuery;
51    fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55    type TypeInfo: NodeTypeInfo;
56
57    fn node_label(&self, registry: &Registry) -> String;
58    fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59    fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60    fn node_is_start(&self, registry: &Registry) -> bool;
61    fn node_suggestions(
62        x: i64,
63        y: i64,
64        suggestion: NodeSuggestion<Self>,
65        registry: &Registry,
66    ) -> Vec<ResponseSuggestionNode<Self>>;
67
68    #[allow(unused_variables)]
69    fn validate_connection(
70        &self,
71        source: &Self,
72        registry: &Registry,
73    ) -> Result<(), Box<dyn Error>> {
74        Ok(())
75    }
76
77    #[allow(unused_variables)]
78    fn get_property(&self, name: &str) -> Option<PropertyValue> {
79        None
80    }
81
82    #[allow(unused_variables)]
83    fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89    Execute { name: String, subscope: bool },
90    Parameter { name: String, type_info: TI },
91    Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95    pub fn execute(name: impl ToString, subscope: bool) -> Self {
96        Self::Execute {
97            name: name.to_string(),
98            subscope,
99        }
100    }
101
102    pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103        Self::Parameter {
104            name: name.to_string(),
105            type_info,
106        }
107    }
108
109    pub fn property(name: impl ToString) -> Self {
110        Self::Property {
111            name: name.to_string(),
112        }
113    }
114
115    pub fn is_execute(&self) -> bool {
116        matches!(self, Self::Execute { .. })
117    }
118
119    pub fn is_parameter(&self) -> bool {
120        matches!(self, Self::Parameter { .. })
121    }
122
123    pub fn is_property(&self) -> bool {
124        matches!(self, Self::Property { .. })
125    }
126
127    pub fn name(&self) -> &str {
128        match self {
129            Self::Execute { name, .. }
130            | Self::Parameter { name, .. }
131            | Self::Property { name, .. } => name,
132        }
133    }
134
135    pub fn has_subscope(&self) -> bool {
136        matches!(self, Self::Execute { subscope: true, .. })
137    }
138
139    pub fn type_info(&self) -> Option<&TI> {
140        match self {
141            Self::Parameter { type_info, .. } => Some(type_info),
142            _ => None,
143        }
144    }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148    All,
149    NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150    NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155    pub category: String,
156    pub label: String,
157    pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161    pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162        Self {
163            category: category.to_string(),
164            label: node.data.node_label(registry),
165            node,
166        }
167    }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172    id: NodeId<T>,
173    pub x: i64,
174    pub y: i64,
175    pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179    pub fn new(x: i64, y: i64, data: T) -> Self {
180        Self {
181            id: Default::default(),
182            x,
183            y,
184            data,
185        }
186    }
187
188    pub fn id(&self) -> NodeId<T> {
189        self.id
190    }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195    pub from_node: NodeId<T>,
196    pub to_node: NodeId<T>,
197    pub from_pin: String,
198    pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202    pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203        Self {
204            from_node,
205            to_node,
206            from_pin: from_pin.to_owned(),
207            to_pin: to_pin.to_owned(),
208        }
209    }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        f.debug_struct("NodeConnection")
215            .field("from_node", &self.from_node)
216            .field("to_node", &self.to_node)
217            .field("from_pin", &self.from_pin)
218            .field("to_pin", &self.to_pin)
219            .finish()
220    }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224    fn eq(&self, other: &Self) -> bool {
225        self.from_node == other.from_node
226            && self.to_node == other.to_node
227            && self.from_pin == other.from_pin
228            && self.to_pin == other.to_pin
229    }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235    fn hash<H: Hasher>(&self, state: &mut H) {
236        self.from_node.hash(state);
237        self.to_node.hash(state);
238        self.from_pin.hash(state);
239        self.to_pin.hash(state);
240    }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245    InternalConnection(String),
246    SourceNodeNotFound(String),
247    TargetNodeNotFound(String),
248    NodesNotFound {
249        from: String,
250        to: String,
251    },
252    SourcePinNotFound {
253        node: String,
254        pin: String,
255    },
256    TargetPinNotFound {
257        node: String,
258        pin: String,
259    },
260    MismatchTypes {
261        from_node: String,
262        from_pin: String,
263        from_type_info: String,
264        to_node: String,
265        to_pin: String,
266        to_type_info: String,
267    },
268    MismatchPins {
269        from_node: String,
270        from_pin: String,
271        to_node: String,
272        to_pin: String,
273    },
274    CycleNodeFound(String),
275    Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280        match self {
281            Self::InternalConnection(node) => {
282                write!(f, "Trying to connect node: {node} to itself")
283            }
284            Self::SourceNodeNotFound(node) => write!(f, "Source node: {node} not found"),
285            Self::TargetNodeNotFound(node) => write!(f, "Target node: {node} not found"),
286            Self::NodesNotFound { from, to } => {
287                write!(f, "Source: {from} and target: {to} nodes not found")
288            }
289            Self::SourcePinNotFound { node, pin } => {
290                write!(f, "Source pin: {pin} for node: {node} not found")
291            }
292            Self::TargetPinNotFound { node, pin } => {
293                write!(f, "Target pin: {pin} for node: {node} not found")
294            }
295            Self::MismatchTypes {
296                from_node,
297                from_pin,
298                from_type_info,
299                to_node,
300                to_pin,
301                to_type_info,
302            } => {
303                write!(
304                    f,
305                    "Source type: {from_type_info} of pin: {from_pin} for node: {from_node} does not match target type: {to_type_info} of pin: {to_pin} for node: {to_node}"
306                )
307            }
308            Self::MismatchPins {
309                from_node,
310                from_pin,
311                to_node,
312                to_pin,
313            } => {
314                write!(
315                    f,
316                    "Source pin: {from_pin} kind for node: {from_node} does not match target pin: {to_pin} kind for node: {to_node}"
317                )
318            }
319            Self::CycleNodeFound(node) => write!(f, "Found cycle node: {node}"),
320            Self::Custom(error) => error.fmt(f),
321        }
322    }
323}
324
325impl Error for ConnectionError {}
326
327#[derive(Debug)]
328pub enum NodeGraphError {
329    Connection(ConnectionError),
330    DuplicateFunctionInputNames(String),
331    DuplicateFunctionOutputNames(String),
332}
333
334impl std::fmt::Display for NodeGraphError {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        match self {
337            Self::Connection(connection) => connection.fmt(f),
338            Self::DuplicateFunctionInputNames(name) => {
339                write!(
340                    f,
341                    "Found duplicate `{name}` function input with different types"
342                )
343            }
344            Self::DuplicateFunctionOutputNames(name) => {
345                write!(
346                    f,
347                    "Found duplicate `{name}` function output with different types"
348                )
349            }
350        }
351    }
352}
353
354impl Error for NodeGraphError {}
355
356#[derive(Clone)]
357struct SpatialNode<T: NodeDefinition> {
358    id: NodeId<T>,
359    x: i64,
360    y: i64,
361}
362
363impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
364    type Envelope = AABB<[i64; 2]>;
365
366    fn envelope(&self) -> Self::Envelope {
367        AABB::from_point([self.x, self.y])
368    }
369}
370
371impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
372    fn distance_2(
373        &self,
374        point: &<Self::Envelope as Envelope>::Point,
375    ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
376        let dx = self.x - point[0];
377        let dy = self.y - point[1];
378        dx * dx + dy * dy
379    }
380}
381
382#[derive(Clone, Serialize, Deserialize)]
383pub struct NodeGraph<T: NodeDefinition> {
384    nodes: Vec<Node<T>>,
385    connections: Vec<NodeConnection<T>>,
386    #[serde(skip, default)]
387    rtree: RTree<SpatialNode<T>>,
388}
389
390impl<T: NodeDefinition> Default for NodeGraph<T> {
391    fn default() -> Self {
392        Self {
393            nodes: vec![],
394            connections: vec![],
395            rtree: Default::default(),
396        }
397    }
398}
399
400impl<T: NodeDefinition> NodeGraph<T> {
401    pub fn clear(&mut self) {
402        self.nodes.clear();
403        self.connections.clear();
404    }
405
406    pub fn refresh_spatial_cache(&mut self) {
407        self.rtree = RTree::bulk_load(
408            self.nodes
409                .iter()
410                .map(|node| SpatialNode {
411                    id: node.id,
412                    x: node.x,
413                    y: node.y,
414                })
415                .collect(),
416        );
417    }
418
419    pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
420        self.rtree
421            .nearest_neighbor_iter(&[x, y])
422            .map(|node| node.id)
423    }
424
425    pub fn query_region_nodes(
426        &self,
427        fx: i64,
428        fy: i64,
429        tx: i64,
430        ty: i64,
431        extrude: i64,
432    ) -> impl Iterator<Item = NodeId<T>> + '_ {
433        self.rtree
434            .locate_in_envelope(&AABB::from_corners(
435                [fx - extrude, fy - extrude],
436                [tx - extrude, ty - extrude],
437            ))
438            .map(|node| node.id)
439    }
440
441    pub fn suggest_all_nodes(
442        x: i64,
443        y: i64,
444        registry: &Registry,
445    ) -> Vec<ResponseSuggestionNode<T>> {
446        T::node_suggestions(x, y, NodeSuggestion::All, registry)
447    }
448
449    pub fn suggest_node_input_pin(
450        &self,
451        x: i64,
452        y: i64,
453        id: NodeId<T>,
454        name: &str,
455        registry: &Registry,
456    ) -> Vec<ResponseSuggestionNode<T>> {
457        if let Some(node) = self.node(id) {
458            if let Some(pin) = node
459                .data
460                .node_pins_in(registry)
461                .into_iter()
462                .find(|pin| pin.name() == name)
463            {
464                return T::node_suggestions(
465                    x,
466                    y,
467                    NodeSuggestion::NodeInputPin(&node.data, &pin),
468                    registry,
469                );
470            }
471        }
472        vec![]
473    }
474
475    pub fn suggest_node_output_pin(
476        &self,
477        x: i64,
478        y: i64,
479        id: NodeId<T>,
480        name: &str,
481        registry: &Registry,
482    ) -> Vec<ResponseSuggestionNode<T>> {
483        if let Some(node) = self.node(id) {
484            if let Some(pin) = node
485                .data
486                .node_pins_out(registry)
487                .into_iter()
488                .find(|pin| pin.name() == name)
489            {
490                return T::node_suggestions(
491                    x,
492                    y,
493                    NodeSuggestion::NodeOutputPin(&node.data, &pin),
494                    registry,
495                );
496            }
497        }
498        vec![]
499    }
500
501    pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
502        self.nodes.iter().find(|node| node.id == id)
503    }
504
505    pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
506        self.nodes.iter_mut().find(|node| node.id == id)
507    }
508
509    pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
510        self.nodes.iter()
511    }
512
513    pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
514        self.nodes.iter_mut()
515    }
516
517    pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
518        if node.data.node_is_start(registry)
519            && self
520                .nodes
521                .iter()
522                .any(|node| node.data.node_is_start(registry))
523        {
524            return None;
525        }
526        let id = node.id;
527        if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
528            self.nodes.swap_remove(index);
529        }
530        self.nodes.push(node);
531        Some(id)
532    }
533
534    pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
535        if let Some(index) = self
536            .nodes
537            .iter()
538            .position(|node| node.id == id && !node.data.node_is_start(registry))
539        {
540            self.disconnect_node(id, None);
541            Some(self.nodes.swap_remove(index))
542        } else {
543            None
544        }
545    }
546
547    pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
548        if !self.connections.iter().any(|other| &connection == other) {
549            self.disconnect_node(connection.from_node, Some(&connection.from_pin));
550            self.disconnect_node(connection.to_node, Some(&connection.to_pin));
551            self.connections.push(connection);
552        }
553    }
554
555    pub fn disconnect_nodes(
556        &mut self,
557        from_node: NodeId<T>,
558        to_node: NodeId<T>,
559        from_pin: &str,
560        to_pin: &str,
561    ) {
562        if let Some(index) = self.connections.iter().position(|connection| {
563            connection.from_node == from_node
564                && connection.to_node == to_node
565                && connection.from_pin == from_pin
566                && connection.to_pin == to_pin
567        }) {
568            self.connections.swap_remove(index);
569        }
570    }
571
572    pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
573        let to_remove = self
574            .connections
575            .iter()
576            .enumerate()
577            .filter_map(|(index, connection)| {
578                if let Some(pin) = pin {
579                    if connection.from_node == node && connection.from_pin == pin {
580                        return Some(index);
581                    }
582                    if connection.to_node == node && connection.to_pin == pin {
583                        return Some(index);
584                    }
585                } else if connection.from_node == node || connection.to_node == node {
586                    return Some(index);
587                }
588                None
589            })
590            .collect::<Vec<_>>();
591        for index in to_remove.into_iter().rev() {
592            self.connections.swap_remove(index);
593        }
594    }
595
596    pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
597        self.connections.iter()
598    }
599
600    pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
601        self.connections
602            .iter()
603            .filter(move |connection| connection.from_node == id || connection.to_node == id)
604    }
605
606    pub fn node_connections_in<'a>(
607        &'a self,
608        id: NodeId<T>,
609        pin: Option<&'a str>,
610    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
611        self.connections.iter().filter(move |connection| {
612            connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
613        })
614    }
615
616    pub fn node_connections_out<'a>(
617        &'a self,
618        id: NodeId<T>,
619        pin: Option<&'a str>,
620    ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
621        self.connections.iter().filter(move |connection| {
622            connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
623        })
624    }
625
626    pub fn node_neighbors_in<'a>(
627        &'a self,
628        id: NodeId<T>,
629        pin: Option<&'a str>,
630    ) -> impl Iterator<Item = NodeId<T>> + 'a {
631        self.node_connections_in(id, pin)
632            .map(move |connection| connection.from_node)
633    }
634
635    pub fn node_neighbors_out<'a>(
636        &'a self,
637        id: NodeId<T>,
638        pin: Option<&'a str>,
639    ) -> impl Iterator<Item = NodeId<T>> + 'a {
640        self.node_connections_out(id, pin)
641            .map(move |connection| connection.to_node)
642    }
643
644    pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
645        let mut errors = self
646            .connections
647            .iter()
648            .filter_map(|connection| self.validate_connection(connection, registry))
649            .map(NodeGraphError::Connection)
650            .collect::<Vec<_>>();
651        if let Some(error) = self.detect_cycles() {
652            errors.push(NodeGraphError::Connection(error));
653        }
654        if errors.is_empty() {
655            Ok(())
656        } else {
657            Err(errors)
658        }
659    }
660
661    fn validate_connection(
662        &self,
663        connection: &NodeConnection<T>,
664        registry: &Registry,
665    ) -> Option<ConnectionError> {
666        if connection.from_node == connection.to_node {
667            return Some(ConnectionError::InternalConnection(
668                connection.from_node.to_string(),
669            ));
670        }
671        let from = self
672            .nodes
673            .iter()
674            .find(|node| node.id == connection.from_node);
675        let to = self.nodes.iter().find(|node| node.id == connection.to_node);
676        let (from_node, to_node) = match (from, to) {
677            (Some(from), Some(to)) => (from, to),
678            (Some(_), None) => {
679                return Some(ConnectionError::TargetNodeNotFound(
680                    connection.to_node.to_string(),
681                ));
682            }
683            (None, Some(_)) => {
684                return Some(ConnectionError::SourceNodeNotFound(
685                    connection.from_node.to_string(),
686                ));
687            }
688            (None, None) => {
689                return Some(ConnectionError::NodesNotFound {
690                    from: connection.from_node.to_string(),
691                    to: connection.to_node.to_string(),
692                });
693            }
694        };
695        let from_pins_out = from_node.data.node_pins_out(registry);
696        let from_pin = match from_pins_out
697            .iter()
698            .find(|pin| pin.name() == connection.from_pin)
699        {
700            Some(pin) => pin,
701            None => {
702                return Some(ConnectionError::SourcePinNotFound {
703                    node: connection.from_node.to_string(),
704                    pin: connection.from_pin.to_owned(),
705                });
706            }
707        };
708        let to_pins_in = to_node.data.node_pins_in(registry);
709        let to_pin = match to_pins_in
710            .iter()
711            .find(|pin| pin.name() == connection.to_pin)
712        {
713            Some(pin) => pin,
714            None => {
715                return Some(ConnectionError::TargetPinNotFound {
716                    node: connection.to_node.to_string(),
717                    pin: connection.to_pin.to_owned(),
718                });
719            }
720        };
721        match (from_pin, to_pin) {
722            (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
723            (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
724                if !a.are_compatible(b) {
725                    return Some(ConnectionError::MismatchTypes {
726                        from_node: connection.from_node.to_string(),
727                        from_pin: connection.from_pin.to_owned(),
728                        to_node: connection.to_node.to_string(),
729                        to_pin: connection.to_pin.to_owned(),
730                        from_type_info: a.to_string(),
731                        to_type_info: b.to_string(),
732                    });
733                }
734            }
735            (NodePin::Property { .. }, NodePin::Property { .. }) => {}
736            _ => {
737                return Some(ConnectionError::MismatchPins {
738                    from_node: connection.from_node.to_string(),
739                    from_pin: connection.from_pin.to_owned(),
740                    to_node: connection.to_node.to_string(),
741                    to_pin: connection.to_pin.to_owned(),
742                });
743            }
744        }
745        if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
746            return Some(ConnectionError::Custom(error));
747        }
748        None
749    }
750
751    fn detect_cycles(&self) -> Option<ConnectionError> {
752        let mut visited = HashSet::with_capacity(self.nodes.len());
753        let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
754        while let Some(id) = available.first() {
755            if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
756                return Some(error);
757            }
758            available.swap_remove(0);
759        }
760        None
761    }
762
763    fn detect_cycle(
764        &self,
765        id: NodeId<T>,
766        available: &mut Vec<NodeId<T>>,
767        visited: &mut HashSet<NodeId<T>>,
768    ) -> Option<ConnectionError> {
769        if visited.contains(&id) {
770            return Some(ConnectionError::CycleNodeFound(id.to_string()));
771        }
772        visited.insert(id);
773        for id in self.node_neighbors_out(id, None) {
774            if let Some(index) = available.iter().position(|item| item == &id) {
775                available.swap_remove(index);
776                if let Some(error) = self.detect_cycle(id, available, visited) {
777                    return Some(error);
778                }
779            }
780        }
781        None
782    }
783
784    pub fn visit<V: NodeGraphVisitor<T>>(
785        &self,
786        visitor: &mut V,
787        registry: &Registry,
788    ) -> Vec<V::Output> {
789        let starts = self
790            .nodes
791            .iter()
792            .filter(|node| node.data.node_is_start(registry))
793            .map(|node| node.id)
794            .collect::<HashSet<_>>();
795        let mut result = Vec::with_capacity(self.nodes.len());
796        for id in starts {
797            self.visit_statement(id, &mut result, visitor, registry);
798        }
799        result
800    }
801
802    fn visit_statement<V: NodeGraphVisitor<T>>(
803        &self,
804        id: NodeId<T>,
805        result: &mut Vec<V::Output>,
806        visitor: &mut V,
807        registry: &Registry,
808    ) {
809        if let Some(node) = self.node(id) {
810            let inputs = node
811                .data
812                .node_pins_in(registry)
813                .into_iter()
814                .filter(|pin| pin.is_parameter())
815                .filter_map(|pin| {
816                    self.node_neighbors_in(id, Some(pin.name()))
817                        .next()
818                        .map(|id| (pin.name().to_owned(), id))
819                })
820                .filter_map(|(name, id)| {
821                    self.visit_expression(id, visitor, registry)
822                        .map(|input| (name, input))
823                })
824                .collect();
825            let pins_out = node.data.node_pins_out(registry);
826            let scopes = pins_out
827                .iter()
828                .filter(|pin| pin.has_subscope())
829                .filter_map(|pin| {
830                    let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
831                    Some((id, pin.name().to_owned()))
832                })
833                .map(|(id, name)| {
834                    let mut result = Vec::with_capacity(self.nodes.len());
835                    self.visit_statement(id, &mut result, visitor, registry);
836                    (name, result)
837                })
838                .collect();
839            if visitor.visit_statement(node, inputs, scopes, result) {
840                for pin in pins_out {
841                    if pin.is_execute() && !pin.has_subscope() {
842                        for id in self.node_neighbors_out(id, Some(pin.name())) {
843                            self.visit_statement(id, result, visitor, registry);
844                        }
845                    }
846                }
847            }
848        }
849    }
850
851    fn visit_expression<V: NodeGraphVisitor<T>>(
852        &self,
853        id: NodeId<T>,
854        visitor: &mut V,
855        registry: &Registry,
856    ) -> Option<V::Input> {
857        if let Some(node) = self.node(id) {
858            let inputs = node
859                .data
860                .node_pins_in(registry)
861                .into_iter()
862                .filter(|pin| pin.is_parameter())
863                .filter_map(|pin| {
864                    self.node_neighbors_in(id, Some(pin.name()))
865                        .next()
866                        .map(|id| (pin.name().to_owned(), id))
867                })
868                .filter_map(|(name, id)| {
869                    self.visit_expression(id, visitor, registry)
870                        .map(|input| (name, input))
871                })
872                .collect();
873            return visitor.visit_expression(node, inputs);
874        }
875        None
876    }
877}
878
879impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
880    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881        f.debug_struct("NodeGraph")
882            .field("nodes", &self.nodes)
883            .field("connections", &self.connections)
884            .finish()
885    }
886}
887
888pub trait NodeGraphVisitor<T: NodeDefinition> {
889    type Input;
890    type Output;
891
892    fn visit_statement(
893        &mut self,
894        node: &Node<T>,
895        inputs: HashMap<String, Self::Input>,
896        scopes: HashMap<String, Vec<Self::Output>>,
897        result: &mut Vec<Self::Output>,
898    ) -> bool;
899
900    fn visit_expression(
901        &mut self,
902        node: &Node<T>,
903        inputs: HashMap<String, Self::Input>,
904    ) -> Option<Self::Input>;
905}
906
907#[cfg(test)]
908mod tests {
909    use crate::prelude::*;
910    use intuicio_core::prelude::*;
911    use std::collections::HashMap;
912
913    #[derive(Debug, Clone, PartialEq)]
914    enum Script {
915        Literal(i32),
916        Return,
917        Call(String),
918        Scope(Vec<Script>),
919    }
920
921    impl NodeTypeInfo for String {
922        fn type_query(&self) -> TypeQuery {
923            TypeQuery {
924                name: Some(self.into()),
925                ..Default::default()
926            }
927        }
928
929        fn are_compatible(&self, other: &Self) -> bool {
930            self == other
931        }
932    }
933
934    #[derive(Debug, Clone)]
935    enum Nodes {
936        Start,
937        Expression(i32),
938        Result,
939        Convert(String),
940        Child,
941    }
942
943    impl NodeDefinition for Nodes {
944        type TypeInfo = String;
945
946        fn node_label(&self, _: &Registry) -> String {
947            format!("{self:?}")
948        }
949
950        fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
951            match self {
952                Nodes::Start => vec![],
953                Nodes::Expression(_) => {
954                    vec![NodePin::execute("In", false), NodePin::property("Value")]
955                }
956                Nodes::Result => vec![
957                    NodePin::execute("In", false),
958                    NodePin::parameter("Data", "i32".to_owned()),
959                ],
960                Nodes::Convert(_) => vec![
961                    NodePin::execute("In", false),
962                    NodePin::property("Name"),
963                    NodePin::parameter("Data in", "i32".to_owned()),
964                ],
965                Nodes::Child => vec![NodePin::execute("In", false)],
966            }
967        }
968
969        fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
970            match self {
971                Nodes::Start => vec![NodePin::execute("Out", false)],
972                Nodes::Expression(_) => vec![
973                    NodePin::execute("Out", false),
974                    NodePin::parameter("Data", "i32".to_owned()),
975                ],
976                Nodes::Result => vec![],
977                Nodes::Convert(_) => vec![
978                    NodePin::execute("Out", false),
979                    NodePin::parameter("Data out", "i32".to_owned()),
980                ],
981                Nodes::Child => vec![
982                    NodePin::execute("Out", false),
983                    NodePin::execute("Body", true),
984                ],
985            }
986        }
987
988        fn node_is_start(&self, _: &Registry) -> bool {
989            matches!(self, Self::Start)
990        }
991
992        fn node_suggestions(
993            _: i64,
994            _: i64,
995            _: NodeSuggestion<Self>,
996            _: &Registry,
997        ) -> Vec<ResponseSuggestionNode<Self>> {
998            vec![]
999        }
1000
1001        fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1002            match self {
1003                Nodes::Expression(value) => match property_name {
1004                    "Value" => PropertyValue::new(value).ok(),
1005                    _ => None,
1006                },
1007                Nodes::Convert(name) => match property_name {
1008                    "Name" => PropertyValue::new(name).ok(),
1009                    _ => None,
1010                },
1011                _ => None,
1012            }
1013        }
1014
1015        fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1016            #[allow(clippy::single_match)]
1017            match self {
1018                Nodes::Expression(value) => match property_name {
1019                    "Value" => {
1020                        if let Ok(v) = property_value.get_exact::<i32>() {
1021                            *value = v;
1022                        }
1023                    }
1024                    _ => {}
1025                },
1026                Nodes::Convert(name) => {
1027                    if let Ok(v) = property_value.get_exact::<String>() {
1028                        *name = v;
1029                    }
1030                }
1031                _ => {}
1032            }
1033        }
1034    }
1035
1036    struct CompileNodesToScript;
1037
1038    impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1039        type Input = ();
1040        type Output = Script;
1041
1042        fn visit_statement(
1043            &mut self,
1044            node: &Node<Nodes>,
1045            _: HashMap<String, Self::Input>,
1046            mut scopes: HashMap<String, Vec<Self::Output>>,
1047            result: &mut Vec<Self::Output>,
1048        ) -> bool {
1049            match &node.data {
1050                Nodes::Result => result.push(Script::Return),
1051                Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1052                Nodes::Child => {
1053                    if let Some(body) = scopes.remove("Body") {
1054                        result.push(Script::Scope(body));
1055                    }
1056                }
1057                Nodes::Expression(value) => result.push(Script::Literal(*value)),
1058                _ => {}
1059            }
1060            true
1061        }
1062
1063        fn visit_expression(
1064            &mut self,
1065            _: &Node<Nodes>,
1066            _: HashMap<String, Self::Input>,
1067        ) -> Option<Self::Input> {
1068            None
1069        }
1070    }
1071
1072    #[test]
1073    fn test_nodes() {
1074        let registry = Registry::default().with_basic_types();
1075        let mut graph = NodeGraph::default();
1076        let start = graph
1077            .add_node(Node::new(0, 0, Nodes::Start), &registry)
1078            .unwrap();
1079        let expression_child = graph
1080            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1081            .unwrap();
1082        let convert_child = graph
1083            .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), &registry)
1084            .unwrap();
1085        let result_child = graph
1086            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1087            .unwrap();
1088        let child = graph
1089            .add_node(Node::new(0, 0, Nodes::Child), &registry)
1090            .unwrap();
1091        let expression = graph
1092            .add_node(Node::new(0, 0, Nodes::Expression(42)), &registry)
1093            .unwrap();
1094        let convert = graph
1095            .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), &registry)
1096            .unwrap();
1097        let result = graph
1098            .add_node(Node::new(0, 0, Nodes::Result), &registry)
1099            .unwrap();
1100        graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1101        graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1102        graph.connect_nodes(NodeConnection::new(
1103            expression_child,
1104            convert_child,
1105            "Out",
1106            "In",
1107        ));
1108        graph.connect_nodes(NodeConnection::new(
1109            expression_child,
1110            convert_child,
1111            "Data",
1112            "Data in",
1113        ));
1114        graph.connect_nodes(NodeConnection::new(
1115            convert_child,
1116            result_child,
1117            "Out",
1118            "In",
1119        ));
1120        graph.connect_nodes(NodeConnection::new(
1121            convert_child,
1122            result_child,
1123            "Data out",
1124            "Data",
1125        ));
1126        graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1127        graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1128        graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1129        graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1130        graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1131        graph.validate(&registry).unwrap();
1132        assert_eq!(
1133            graph.visit(&mut CompileNodesToScript, &registry),
1134            vec![
1135                Script::Scope(vec![
1136                    Script::Literal(42),
1137                    Script::Call("foo".to_owned()),
1138                    Script::Return
1139                ]),
1140                Script::Literal(42),
1141                Script::Call("bar".to_owned()),
1142                Script::Return
1143            ]
1144        );
1145        assert_eq!(
1146            graph
1147                .node(expression)
1148                .unwrap()
1149                .data
1150                .get_property("Value")
1151                .unwrap(),
1152            PropertyValue::new(&42i32).unwrap(),
1153        );
1154        graph
1155            .node_mut(expression)
1156            .unwrap()
1157            .data
1158            .set_property("Value", PropertyValue::new(&10i32).unwrap());
1159        assert_eq!(
1160            graph
1161                .node(expression)
1162                .unwrap()
1163                .data
1164                .get_property("Value")
1165                .unwrap(),
1166            PropertyValue::new(&10i32).unwrap(),
1167        );
1168    }
1169}