1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{AABB, Envelope, Point, PointDistance, RTree, RTreeObject};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5 Intermediate, de::intermediate::DeserializeMode, error::Result as IntermediateResult,
6};
7use std::{
8 collections::{HashMap, HashSet},
9 error::Error,
10 fmt::Display,
11 hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20 value: Intermediate,
21}
22
23impl PropertyValue {
24 pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25 Ok(Self {
26 value: serde_intermediate::to_intermediate(value)?,
27 })
28 }
29
30 pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31 serde_intermediate::from_intermediate_as(&self.value, mode)
32 }
33
34 pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35 self.get(PropertyCastMode::Exact)
36 }
37
38 pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39 self.get(PropertyCastMode::Interpret)
40 }
41
42 pub fn into_inner(self) -> Intermediate {
43 self.value
44 }
45}
46
47pub trait NodeTypeInfo:
48 Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50 fn type_query(&self) -> TypeQuery;
51 fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55 type TypeInfo: NodeTypeInfo;
56
57 fn node_label(&self, registry: &Registry) -> String;
58 fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59 fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60 fn node_is_start(&self, registry: &Registry) -> bool;
61 fn node_suggestions(
62 x: i64,
63 y: i64,
64 suggestion: NodeSuggestion<Self>,
65 registry: &Registry,
66 ) -> Vec<ResponseSuggestionNode<Self>>;
67
68 #[allow(unused_variables)]
69 fn validate_connection(
70 &self,
71 source: &Self,
72 registry: &Registry,
73 ) -> Result<(), Box<dyn Error>> {
74 Ok(())
75 }
76
77 #[allow(unused_variables)]
78 fn get_property(&self, name: &str) -> Option<PropertyValue> {
79 None
80 }
81
82 #[allow(unused_variables)]
83 fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89 Execute { name: String, subscope: bool },
90 Parameter { name: String, type_info: TI },
91 Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95 pub fn execute(name: impl ToString, subscope: bool) -> Self {
96 Self::Execute {
97 name: name.to_string(),
98 subscope,
99 }
100 }
101
102 pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103 Self::Parameter {
104 name: name.to_string(),
105 type_info,
106 }
107 }
108
109 pub fn property(name: impl ToString) -> Self {
110 Self::Property {
111 name: name.to_string(),
112 }
113 }
114
115 pub fn is_execute(&self) -> bool {
116 matches!(self, Self::Execute { .. })
117 }
118
119 pub fn is_parameter(&self) -> bool {
120 matches!(self, Self::Parameter { .. })
121 }
122
123 pub fn is_property(&self) -> bool {
124 matches!(self, Self::Property { .. })
125 }
126
127 pub fn name(&self) -> &str {
128 match self {
129 Self::Execute { name, .. }
130 | Self::Parameter { name, .. }
131 | Self::Property { name, .. } => name,
132 }
133 }
134
135 pub fn has_subscope(&self) -> bool {
136 matches!(self, Self::Execute { subscope: true, .. })
137 }
138
139 pub fn type_info(&self) -> Option<&TI> {
140 match self {
141 Self::Parameter { type_info, .. } => Some(type_info),
142 _ => None,
143 }
144 }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148 All,
149 NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150 NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155 pub category: String,
156 pub label: String,
157 pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161 pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162 Self {
163 category: category.to_string(),
164 label: node.data.node_label(registry),
165 node,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172 id: NodeId<T>,
173 pub x: i64,
174 pub y: i64,
175 pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179 pub fn new(x: i64, y: i64, data: T) -> Self {
180 Self {
181 id: Default::default(),
182 x,
183 y,
184 data,
185 }
186 }
187
188 pub fn id(&self) -> NodeId<T> {
189 self.id
190 }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195 pub from_node: NodeId<T>,
196 pub to_node: NodeId<T>,
197 pub from_pin: String,
198 pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202 pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203 Self {
204 from_node,
205 to_node,
206 from_pin: from_pin.to_owned(),
207 to_pin: to_pin.to_owned(),
208 }
209 }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("NodeConnection")
215 .field("from_node", &self.from_node)
216 .field("to_node", &self.to_node)
217 .field("from_pin", &self.from_pin)
218 .field("to_pin", &self.to_pin)
219 .finish()
220 }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224 fn eq(&self, other: &Self) -> bool {
225 self.from_node == other.from_node
226 && self.to_node == other.to_node
227 && self.from_pin == other.from_pin
228 && self.to_pin == other.to_pin
229 }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235 fn hash<H: Hasher>(&self, state: &mut H) {
236 self.from_node.hash(state);
237 self.to_node.hash(state);
238 self.from_pin.hash(state);
239 self.to_pin.hash(state);
240 }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245 InternalConnection(String),
246 SourceNodeNotFound(String),
247 TargetNodeNotFound(String),
248 NodesNotFound {
249 from: String,
250 to: String,
251 },
252 SourcePinNotFound {
253 node: String,
254 pin: String,
255 },
256 TargetPinNotFound {
257 node: String,
258 pin: String,
259 },
260 MismatchTypes {
261 from_node: String,
262 from_pin: String,
263 from_type_info: String,
264 to_node: String,
265 to_pin: String,
266 to_type_info: String,
267 },
268 MismatchPins {
269 from_node: String,
270 from_pin: String,
271 to_node: String,
272 to_pin: String,
273 },
274 CycleNodeFound(String),
275 Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 match self {
281 Self::InternalConnection(node) => {
282 write!(f, "Trying to connect node: {node} to itself")
283 }
284 Self::SourceNodeNotFound(node) => write!(f, "Source node: {node} not found"),
285 Self::TargetNodeNotFound(node) => write!(f, "Target node: {node} not found"),
286 Self::NodesNotFound { from, to } => {
287 write!(f, "Source: {from} and target: {to} nodes not found")
288 }
289 Self::SourcePinNotFound { node, pin } => {
290 write!(f, "Source pin: {pin} for node: {node} not found")
291 }
292 Self::TargetPinNotFound { node, pin } => {
293 write!(f, "Target pin: {pin} for node: {node} not found")
294 }
295 Self::MismatchTypes {
296 from_node,
297 from_pin,
298 from_type_info,
299 to_node,
300 to_pin,
301 to_type_info,
302 } => {
303 write!(
304 f,
305 "Source type: {from_type_info} of pin: {from_pin} for node: {from_node} does not match target type: {to_type_info} of pin: {to_pin} for node: {to_node}"
306 )
307 }
308 Self::MismatchPins {
309 from_node,
310 from_pin,
311 to_node,
312 to_pin,
313 } => {
314 write!(
315 f,
316 "Source pin: {from_pin} kind for node: {from_node} does not match target pin: {to_pin} kind for node: {to_node}"
317 )
318 }
319 Self::CycleNodeFound(node) => write!(f, "Found cycle node: {node}"),
320 Self::Custom(error) => error.fmt(f),
321 }
322 }
323}
324
325impl Error for ConnectionError {}
326
327#[derive(Debug)]
328pub enum NodeGraphError {
329 Connection(ConnectionError),
330 DuplicateFunctionInputNames(String),
331 DuplicateFunctionOutputNames(String),
332}
333
334impl std::fmt::Display for NodeGraphError {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 match self {
337 Self::Connection(connection) => connection.fmt(f),
338 Self::DuplicateFunctionInputNames(name) => {
339 write!(
340 f,
341 "Found duplicate `{name}` function input with different types"
342 )
343 }
344 Self::DuplicateFunctionOutputNames(name) => {
345 write!(
346 f,
347 "Found duplicate `{name}` function output with different types"
348 )
349 }
350 }
351 }
352}
353
354impl Error for NodeGraphError {}
355
356#[derive(Clone)]
357struct SpatialNode<T: NodeDefinition> {
358 id: NodeId<T>,
359 x: i64,
360 y: i64,
361}
362
363impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
364 type Envelope = AABB<[i64; 2]>;
365
366 fn envelope(&self) -> Self::Envelope {
367 AABB::from_point([self.x, self.y])
368 }
369}
370
371impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
372 fn distance_2(
373 &self,
374 point: &<Self::Envelope as Envelope>::Point,
375 ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
376 let dx = self.x - point[0];
377 let dy = self.y - point[1];
378 dx * dx + dy * dy
379 }
380}
381
382#[derive(Clone, Serialize, Deserialize)]
383pub struct NodeGraph<T: NodeDefinition> {
384 nodes: Vec<Node<T>>,
385 connections: Vec<NodeConnection<T>>,
386 #[serde(skip, default)]
387 rtree: RTree<SpatialNode<T>>,
388}
389
390impl<T: NodeDefinition> Default for NodeGraph<T> {
391 fn default() -> Self {
392 Self {
393 nodes: vec![],
394 connections: vec![],
395 rtree: Default::default(),
396 }
397 }
398}
399
400impl<T: NodeDefinition> NodeGraph<T> {
401 pub fn clear(&mut self) {
402 self.nodes.clear();
403 self.connections.clear();
404 }
405
406 pub fn refresh_spatial_cache(&mut self) {
407 self.rtree = RTree::bulk_load(
408 self.nodes
409 .iter()
410 .map(|node| SpatialNode {
411 id: node.id,
412 x: node.x,
413 y: node.y,
414 })
415 .collect(),
416 );
417 }
418
419 pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
420 self.rtree
421 .nearest_neighbor_iter(&[x, y])
422 .map(|node| node.id)
423 }
424
425 pub fn query_region_nodes(
426 &self,
427 fx: i64,
428 fy: i64,
429 tx: i64,
430 ty: i64,
431 extrude: i64,
432 ) -> impl Iterator<Item = NodeId<T>> + '_ {
433 self.rtree
434 .locate_in_envelope(&AABB::from_corners(
435 [fx - extrude, fy - extrude],
436 [tx - extrude, ty - extrude],
437 ))
438 .map(|node| node.id)
439 }
440
441 pub fn suggest_all_nodes(
442 x: i64,
443 y: i64,
444 registry: &Registry,
445 ) -> Vec<ResponseSuggestionNode<T>> {
446 T::node_suggestions(x, y, NodeSuggestion::All, registry)
447 }
448
449 pub fn suggest_node_input_pin(
450 &self,
451 x: i64,
452 y: i64,
453 id: NodeId<T>,
454 name: &str,
455 registry: &Registry,
456 ) -> Vec<ResponseSuggestionNode<T>> {
457 if let Some(node) = self.node(id) {
458 if let Some(pin) = node
459 .data
460 .node_pins_in(registry)
461 .into_iter()
462 .find(|pin| pin.name() == name)
463 {
464 return T::node_suggestions(
465 x,
466 y,
467 NodeSuggestion::NodeInputPin(&node.data, &pin),
468 registry,
469 );
470 }
471 }
472 vec![]
473 }
474
475 pub fn suggest_node_output_pin(
476 &self,
477 x: i64,
478 y: i64,
479 id: NodeId<T>,
480 name: &str,
481 registry: &Registry,
482 ) -> Vec<ResponseSuggestionNode<T>> {
483 if let Some(node) = self.node(id) {
484 if let Some(pin) = node
485 .data
486 .node_pins_out(registry)
487 .into_iter()
488 .find(|pin| pin.name() == name)
489 {
490 return T::node_suggestions(
491 x,
492 y,
493 NodeSuggestion::NodeOutputPin(&node.data, &pin),
494 registry,
495 );
496 }
497 }
498 vec![]
499 }
500
501 pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
502 self.nodes.iter().find(|node| node.id == id)
503 }
504
505 pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
506 self.nodes.iter_mut().find(|node| node.id == id)
507 }
508
509 pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
510 self.nodes.iter()
511 }
512
513 pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
514 self.nodes.iter_mut()
515 }
516
517 pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
518 if node.data.node_is_start(registry)
519 && self
520 .nodes
521 .iter()
522 .any(|node| node.data.node_is_start(registry))
523 {
524 return None;
525 }
526 let id = node.id;
527 if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
528 self.nodes.swap_remove(index);
529 }
530 self.nodes.push(node);
531 Some(id)
532 }
533
534 pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
535 if let Some(index) = self
536 .nodes
537 .iter()
538 .position(|node| node.id == id && !node.data.node_is_start(registry))
539 {
540 self.disconnect_node(id, None);
541 Some(self.nodes.swap_remove(index))
542 } else {
543 None
544 }
545 }
546
547 pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
548 if !self.connections.iter().any(|other| &connection == other) {
549 self.disconnect_node(connection.from_node, Some(&connection.from_pin));
550 self.disconnect_node(connection.to_node, Some(&connection.to_pin));
551 self.connections.push(connection);
552 }
553 }
554
555 pub fn disconnect_nodes(
556 &mut self,
557 from_node: NodeId<T>,
558 to_node: NodeId<T>,
559 from_pin: &str,
560 to_pin: &str,
561 ) {
562 if let Some(index) = self.connections.iter().position(|connection| {
563 connection.from_node == from_node
564 && connection.to_node == to_node
565 && connection.from_pin == from_pin
566 && connection.to_pin == to_pin
567 }) {
568 self.connections.swap_remove(index);
569 }
570 }
571
572 pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
573 let to_remove = self
574 .connections
575 .iter()
576 .enumerate()
577 .filter_map(|(index, connection)| {
578 if let Some(pin) = pin {
579 if connection.from_node == node && connection.from_pin == pin {
580 return Some(index);
581 }
582 if connection.to_node == node && connection.to_pin == pin {
583 return Some(index);
584 }
585 } else if connection.from_node == node || connection.to_node == node {
586 return Some(index);
587 }
588 None
589 })
590 .collect::<Vec<_>>();
591 for index in to_remove.into_iter().rev() {
592 self.connections.swap_remove(index);
593 }
594 }
595
596 pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
597 self.connections.iter()
598 }
599
600 pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
601 self.connections
602 .iter()
603 .filter(move |connection| connection.from_node == id || connection.to_node == id)
604 }
605
606 pub fn node_connections_in<'a>(
607 &'a self,
608 id: NodeId<T>,
609 pin: Option<&'a str>,
610 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
611 self.connections.iter().filter(move |connection| {
612 connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
613 })
614 }
615
616 pub fn node_connections_out<'a>(
617 &'a self,
618 id: NodeId<T>,
619 pin: Option<&'a str>,
620 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
621 self.connections.iter().filter(move |connection| {
622 connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
623 })
624 }
625
626 pub fn node_neighbors_in<'a>(
627 &'a self,
628 id: NodeId<T>,
629 pin: Option<&'a str>,
630 ) -> impl Iterator<Item = NodeId<T>> + 'a {
631 self.node_connections_in(id, pin)
632 .map(move |connection| connection.from_node)
633 }
634
635 pub fn node_neighbors_out<'a>(
636 &'a self,
637 id: NodeId<T>,
638 pin: Option<&'a str>,
639 ) -> impl Iterator<Item = NodeId<T>> + 'a {
640 self.node_connections_out(id, pin)
641 .map(move |connection| connection.to_node)
642 }
643
644 pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
645 let mut errors = self
646 .connections
647 .iter()
648 .filter_map(|connection| self.validate_connection(connection, registry))
649 .map(NodeGraphError::Connection)
650 .collect::<Vec<_>>();
651 if let Some(error) = self.detect_cycles() {
652 errors.push(NodeGraphError::Connection(error));
653 }
654 if errors.is_empty() {
655 Ok(())
656 } else {
657 Err(errors)
658 }
659 }
660
661 fn validate_connection(
662 &self,
663 connection: &NodeConnection<T>,
664 registry: &Registry,
665 ) -> Option<ConnectionError> {
666 if connection.from_node == connection.to_node {
667 return Some(ConnectionError::InternalConnection(
668 connection.from_node.to_string(),
669 ));
670 }
671 let from = self
672 .nodes
673 .iter()
674 .find(|node| node.id == connection.from_node);
675 let to = self.nodes.iter().find(|node| node.id == connection.to_node);
676 let (from_node, to_node) = match (from, to) {
677 (Some(from), Some(to)) => (from, to),
678 (Some(_), None) => {
679 return Some(ConnectionError::TargetNodeNotFound(
680 connection.to_node.to_string(),
681 ));
682 }
683 (None, Some(_)) => {
684 return Some(ConnectionError::SourceNodeNotFound(
685 connection.from_node.to_string(),
686 ));
687 }
688 (None, None) => {
689 return Some(ConnectionError::NodesNotFound {
690 from: connection.from_node.to_string(),
691 to: connection.to_node.to_string(),
692 });
693 }
694 };
695 let from_pins_out = from_node.data.node_pins_out(registry);
696 let from_pin = match from_pins_out
697 .iter()
698 .find(|pin| pin.name() == connection.from_pin)
699 {
700 Some(pin) => pin,
701 None => {
702 return Some(ConnectionError::SourcePinNotFound {
703 node: connection.from_node.to_string(),
704 pin: connection.from_pin.to_owned(),
705 });
706 }
707 };
708 let to_pins_in = to_node.data.node_pins_in(registry);
709 let to_pin = match to_pins_in
710 .iter()
711 .find(|pin| pin.name() == connection.to_pin)
712 {
713 Some(pin) => pin,
714 None => {
715 return Some(ConnectionError::TargetPinNotFound {
716 node: connection.to_node.to_string(),
717 pin: connection.to_pin.to_owned(),
718 });
719 }
720 };
721 match (from_pin, to_pin) {
722 (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
723 (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
724 if !a.are_compatible(b) {
725 return Some(ConnectionError::MismatchTypes {
726 from_node: connection.from_node.to_string(),
727 from_pin: connection.from_pin.to_owned(),
728 to_node: connection.to_node.to_string(),
729 to_pin: connection.to_pin.to_owned(),
730 from_type_info: a.to_string(),
731 to_type_info: b.to_string(),
732 });
733 }
734 }
735 (NodePin::Property { .. }, NodePin::Property { .. }) => {}
736 _ => {
737 return Some(ConnectionError::MismatchPins {
738 from_node: connection.from_node.to_string(),
739 from_pin: connection.from_pin.to_owned(),
740 to_node: connection.to_node.to_string(),
741 to_pin: connection.to_pin.to_owned(),
742 });
743 }
744 }
745 if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
746 return Some(ConnectionError::Custom(error));
747 }
748 None
749 }
750
751 fn detect_cycles(&self) -> Option<ConnectionError> {
752 let mut visited = HashSet::with_capacity(self.nodes.len());
753 let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
754 while let Some(id) = available.first() {
755 if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
756 return Some(error);
757 }
758 available.swap_remove(0);
759 }
760 None
761 }
762
763 fn detect_cycle(
764 &self,
765 id: NodeId<T>,
766 available: &mut Vec<NodeId<T>>,
767 visited: &mut HashSet<NodeId<T>>,
768 ) -> Option<ConnectionError> {
769 if visited.contains(&id) {
770 return Some(ConnectionError::CycleNodeFound(id.to_string()));
771 }
772 visited.insert(id);
773 for id in self.node_neighbors_out(id, None) {
774 if let Some(index) = available.iter().position(|item| item == &id) {
775 available.swap_remove(index);
776 if let Some(error) = self.detect_cycle(id, available, visited) {
777 return Some(error);
778 }
779 }
780 }
781 None
782 }
783
784 pub fn visit<V: NodeGraphVisitor<T>>(
785 &self,
786 visitor: &mut V,
787 registry: &Registry,
788 ) -> Vec<V::Output> {
789 let starts = self
790 .nodes
791 .iter()
792 .filter(|node| node.data.node_is_start(registry))
793 .map(|node| node.id)
794 .collect::<HashSet<_>>();
795 let mut result = Vec::with_capacity(self.nodes.len());
796 for id in starts {
797 self.visit_statement(id, &mut result, visitor, registry);
798 }
799 result
800 }
801
802 fn visit_statement<V: NodeGraphVisitor<T>>(
803 &self,
804 id: NodeId<T>,
805 result: &mut Vec<V::Output>,
806 visitor: &mut V,
807 registry: &Registry,
808 ) {
809 if let Some(node) = self.node(id) {
810 let inputs = node
811 .data
812 .node_pins_in(registry)
813 .into_iter()
814 .filter(|pin| pin.is_parameter())
815 .filter_map(|pin| {
816 self.node_neighbors_in(id, Some(pin.name()))
817 .next()
818 .map(|id| (pin.name().to_owned(), id))
819 })
820 .filter_map(|(name, id)| {
821 self.visit_expression(id, visitor, registry)
822 .map(|input| (name, input))
823 })
824 .collect();
825 let pins_out = node.data.node_pins_out(registry);
826 let scopes = pins_out
827 .iter()
828 .filter(|pin| pin.has_subscope())
829 .filter_map(|pin| {
830 let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
831 Some((id, pin.name().to_owned()))
832 })
833 .map(|(id, name)| {
834 let mut result = Vec::with_capacity(self.nodes.len());
835 self.visit_statement(id, &mut result, visitor, registry);
836 (name, result)
837 })
838 .collect();
839 if visitor.visit_statement(node, inputs, scopes, result) {
840 for pin in pins_out {
841 if pin.is_execute() && !pin.has_subscope() {
842 for id in self.node_neighbors_out(id, Some(pin.name())) {
843 self.visit_statement(id, result, visitor, registry);
844 }
845 }
846 }
847 }
848 }
849 }
850
851 fn visit_expression<V: NodeGraphVisitor<T>>(
852 &self,
853 id: NodeId<T>,
854 visitor: &mut V,
855 registry: &Registry,
856 ) -> Option<V::Input> {
857 if let Some(node) = self.node(id) {
858 let inputs = node
859 .data
860 .node_pins_in(registry)
861 .into_iter()
862 .filter(|pin| pin.is_parameter())
863 .filter_map(|pin| {
864 self.node_neighbors_in(id, Some(pin.name()))
865 .next()
866 .map(|id| (pin.name().to_owned(), id))
867 })
868 .filter_map(|(name, id)| {
869 self.visit_expression(id, visitor, registry)
870 .map(|input| (name, input))
871 })
872 .collect();
873 return visitor.visit_expression(node, inputs);
874 }
875 None
876 }
877}
878
879impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
880 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
881 f.debug_struct("NodeGraph")
882 .field("nodes", &self.nodes)
883 .field("connections", &self.connections)
884 .finish()
885 }
886}
887
888pub trait NodeGraphVisitor<T: NodeDefinition> {
889 type Input;
890 type Output;
891
892 fn visit_statement(
893 &mut self,
894 node: &Node<T>,
895 inputs: HashMap<String, Self::Input>,
896 scopes: HashMap<String, Vec<Self::Output>>,
897 result: &mut Vec<Self::Output>,
898 ) -> bool;
899
900 fn visit_expression(
901 &mut self,
902 node: &Node<T>,
903 inputs: HashMap<String, Self::Input>,
904 ) -> Option<Self::Input>;
905}
906
907#[cfg(test)]
908mod tests {
909 use crate::prelude::*;
910 use intuicio_core::prelude::*;
911 use std::collections::HashMap;
912
913 #[derive(Debug, Clone, PartialEq)]
914 enum Script {
915 Literal(i32),
916 Return,
917 Call(String),
918 Scope(Vec<Script>),
919 }
920
921 impl NodeTypeInfo for String {
922 fn type_query(&self) -> TypeQuery {
923 TypeQuery {
924 name: Some(self.into()),
925 ..Default::default()
926 }
927 }
928
929 fn are_compatible(&self, other: &Self) -> bool {
930 self == other
931 }
932 }
933
934 #[derive(Debug, Clone)]
935 enum Nodes {
936 Start,
937 Expression(i32),
938 Result,
939 Convert(String),
940 Child,
941 }
942
943 impl NodeDefinition for Nodes {
944 type TypeInfo = String;
945
946 fn node_label(&self, _: &Registry) -> String {
947 format!("{self:?}")
948 }
949
950 fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
951 match self {
952 Nodes::Start => vec![],
953 Nodes::Expression(_) => {
954 vec![NodePin::execute("In", false), NodePin::property("Value")]
955 }
956 Nodes::Result => vec![
957 NodePin::execute("In", false),
958 NodePin::parameter("Data", "i32".to_owned()),
959 ],
960 Nodes::Convert(_) => vec![
961 NodePin::execute("In", false),
962 NodePin::property("Name"),
963 NodePin::parameter("Data in", "i32".to_owned()),
964 ],
965 Nodes::Child => vec![NodePin::execute("In", false)],
966 }
967 }
968
969 fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
970 match self {
971 Nodes::Start => vec![NodePin::execute("Out", false)],
972 Nodes::Expression(_) => vec![
973 NodePin::execute("Out", false),
974 NodePin::parameter("Data", "i32".to_owned()),
975 ],
976 Nodes::Result => vec![],
977 Nodes::Convert(_) => vec![
978 NodePin::execute("Out", false),
979 NodePin::parameter("Data out", "i32".to_owned()),
980 ],
981 Nodes::Child => vec![
982 NodePin::execute("Out", false),
983 NodePin::execute("Body", true),
984 ],
985 }
986 }
987
988 fn node_is_start(&self, _: &Registry) -> bool {
989 matches!(self, Self::Start)
990 }
991
992 fn node_suggestions(
993 _: i64,
994 _: i64,
995 _: NodeSuggestion<Self>,
996 _: &Registry,
997 ) -> Vec<ResponseSuggestionNode<Self>> {
998 vec![]
999 }
1000
1001 fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1002 match self {
1003 Nodes::Expression(value) => match property_name {
1004 "Value" => PropertyValue::new(value).ok(),
1005 _ => None,
1006 },
1007 Nodes::Convert(name) => match property_name {
1008 "Name" => PropertyValue::new(name).ok(),
1009 _ => None,
1010 },
1011 _ => None,
1012 }
1013 }
1014
1015 fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1016 #[allow(clippy::single_match)]
1017 match self {
1018 Nodes::Expression(value) => match property_name {
1019 "Value" => {
1020 if let Ok(v) = property_value.get_exact::<i32>() {
1021 *value = v;
1022 }
1023 }
1024 _ => {}
1025 },
1026 Nodes::Convert(name) => {
1027 if let Ok(v) = property_value.get_exact::<String>() {
1028 *name = v;
1029 }
1030 }
1031 _ => {}
1032 }
1033 }
1034 }
1035
1036 struct CompileNodesToScript;
1037
1038 impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1039 type Input = ();
1040 type Output = Script;
1041
1042 fn visit_statement(
1043 &mut self,
1044 node: &Node<Nodes>,
1045 _: HashMap<String, Self::Input>,
1046 mut scopes: HashMap<String, Vec<Self::Output>>,
1047 result: &mut Vec<Self::Output>,
1048 ) -> bool {
1049 match &node.data {
1050 Nodes::Result => result.push(Script::Return),
1051 Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1052 Nodes::Child => {
1053 if let Some(body) = scopes.remove("Body") {
1054 result.push(Script::Scope(body));
1055 }
1056 }
1057 Nodes::Expression(value) => result.push(Script::Literal(*value)),
1058 _ => {}
1059 }
1060 true
1061 }
1062
1063 fn visit_expression(
1064 &mut self,
1065 _: &Node<Nodes>,
1066 _: HashMap<String, Self::Input>,
1067 ) -> Option<Self::Input> {
1068 None
1069 }
1070 }
1071
1072 #[test]
1073 fn test_nodes() {
1074 let registry = Registry::default().with_basic_types();
1075 let mut graph = NodeGraph::default();
1076 let start = graph
1077 .add_node(Node::new(0, 0, Nodes::Start), ®istry)
1078 .unwrap();
1079 let expression_child = graph
1080 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1081 .unwrap();
1082 let convert_child = graph
1083 .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), ®istry)
1084 .unwrap();
1085 let result_child = graph
1086 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1087 .unwrap();
1088 let child = graph
1089 .add_node(Node::new(0, 0, Nodes::Child), ®istry)
1090 .unwrap();
1091 let expression = graph
1092 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1093 .unwrap();
1094 let convert = graph
1095 .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), ®istry)
1096 .unwrap();
1097 let result = graph
1098 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1099 .unwrap();
1100 graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1101 graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1102 graph.connect_nodes(NodeConnection::new(
1103 expression_child,
1104 convert_child,
1105 "Out",
1106 "In",
1107 ));
1108 graph.connect_nodes(NodeConnection::new(
1109 expression_child,
1110 convert_child,
1111 "Data",
1112 "Data in",
1113 ));
1114 graph.connect_nodes(NodeConnection::new(
1115 convert_child,
1116 result_child,
1117 "Out",
1118 "In",
1119 ));
1120 graph.connect_nodes(NodeConnection::new(
1121 convert_child,
1122 result_child,
1123 "Data out",
1124 "Data",
1125 ));
1126 graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1127 graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1128 graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1129 graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1130 graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1131 graph.validate(®istry).unwrap();
1132 assert_eq!(
1133 graph.visit(&mut CompileNodesToScript, ®istry),
1134 vec![
1135 Script::Scope(vec![
1136 Script::Literal(42),
1137 Script::Call("foo".to_owned()),
1138 Script::Return
1139 ]),
1140 Script::Literal(42),
1141 Script::Call("bar".to_owned()),
1142 Script::Return
1143 ]
1144 );
1145 assert_eq!(
1146 graph
1147 .node(expression)
1148 .unwrap()
1149 .data
1150 .get_property("Value")
1151 .unwrap(),
1152 PropertyValue::new(&42i32).unwrap(),
1153 );
1154 graph
1155 .node_mut(expression)
1156 .unwrap()
1157 .data
1158 .set_property("Value", PropertyValue::new(&10i32).unwrap());
1159 assert_eq!(
1160 graph
1161 .node(expression)
1162 .unwrap()
1163 .data
1164 .get_property("Value")
1165 .unwrap(),
1166 PropertyValue::new(&10i32).unwrap(),
1167 );
1168 }
1169}