1use intuicio_core::{registry::Registry, types::TypeQuery};
2use rstar::{AABB, Envelope, Point, PointDistance, RTree, RTreeObject};
3use serde::{Deserialize, Serialize};
4use serde_intermediate::{
5 Intermediate, de::intermediate::DeserializeMode, error::Result as IntermediateResult,
6};
7use std::{
8 collections::{HashMap, HashSet},
9 error::Error,
10 fmt::Display,
11 hash::{Hash, Hasher},
12};
13use typid::ID;
14
15pub type NodeId<T> = ID<Node<T>>;
16pub type PropertyCastMode = DeserializeMode;
17
18#[derive(Debug, Default, Clone, PartialEq)]
19pub struct PropertyValue {
20 value: Intermediate,
21}
22
23impl PropertyValue {
24 pub fn new<T: Serialize>(value: &T) -> IntermediateResult<Self> {
25 Ok(Self {
26 value: serde_intermediate::to_intermediate(value)?,
27 })
28 }
29
30 pub fn get<'a, T: Deserialize<'a>>(&'a self, mode: PropertyCastMode) -> IntermediateResult<T> {
31 serde_intermediate::from_intermediate_as(&self.value, mode)
32 }
33
34 pub fn get_exact<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
35 self.get(PropertyCastMode::Exact)
36 }
37
38 pub fn get_interpret<'a, T: Deserialize<'a>>(&'a self) -> IntermediateResult<T> {
39 self.get(PropertyCastMode::Interpret)
40 }
41
42 pub fn into_inner(self) -> Intermediate {
43 self.value
44 }
45}
46
47pub trait NodeTypeInfo:
48 Clone + std::fmt::Debug + Display + PartialEq + Serialize + for<'de> Deserialize<'de>
49{
50 fn type_query(&'_ self) -> TypeQuery<'_>;
51 fn are_compatible(&self, other: &Self) -> bool;
52}
53
54pub trait NodeDefinition: Sized {
55 type TypeInfo: NodeTypeInfo;
56
57 fn node_label(&self, registry: &Registry) -> String;
58 fn node_pins_in(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
59 fn node_pins_out(&self, registry: &Registry) -> Vec<NodePin<Self::TypeInfo>>;
60 fn node_is_start(&self, registry: &Registry) -> bool;
61 fn node_suggestions(
62 x: i64,
63 y: i64,
64 suggestion: NodeSuggestion<Self>,
65 registry: &Registry,
66 ) -> Vec<ResponseSuggestionNode<Self>>;
67
68 #[allow(unused_variables)]
69 fn validate_connection(
70 &self,
71 source: &Self,
72 registry: &Registry,
73 ) -> Result<(), Box<dyn Error>> {
74 Ok(())
75 }
76
77 #[allow(unused_variables)]
78 fn get_property(&self, name: &str) -> Option<PropertyValue> {
79 None
80 }
81
82 #[allow(unused_variables)]
83 fn set_property(&mut self, name: &str, value: PropertyValue) {}
84}
85
86#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
87#[serde(bound = "TI: NodeTypeInfo")]
88pub enum NodePin<TI: NodeTypeInfo> {
89 Execute { name: String, subscope: bool },
90 Parameter { name: String, type_info: TI },
91 Property { name: String },
92}
93
94impl<TI: NodeTypeInfo> NodePin<TI> {
95 pub fn execute(name: impl ToString, subscope: bool) -> Self {
96 Self::Execute {
97 name: name.to_string(),
98 subscope,
99 }
100 }
101
102 pub fn parameter(name: impl ToString, type_info: TI) -> Self {
103 Self::Parameter {
104 name: name.to_string(),
105 type_info,
106 }
107 }
108
109 pub fn property(name: impl ToString) -> Self {
110 Self::Property {
111 name: name.to_string(),
112 }
113 }
114
115 pub fn is_execute(&self) -> bool {
116 matches!(self, Self::Execute { .. })
117 }
118
119 pub fn is_parameter(&self) -> bool {
120 matches!(self, Self::Parameter { .. })
121 }
122
123 pub fn is_property(&self) -> bool {
124 matches!(self, Self::Property { .. })
125 }
126
127 pub fn name(&self) -> &str {
128 match self {
129 Self::Execute { name, .. }
130 | Self::Parameter { name, .. }
131 | Self::Property { name, .. } => name,
132 }
133 }
134
135 pub fn has_subscope(&self) -> bool {
136 matches!(self, Self::Execute { subscope: true, .. })
137 }
138
139 pub fn type_info(&self) -> Option<&TI> {
140 match self {
141 Self::Parameter { type_info, .. } => Some(type_info),
142 _ => None,
143 }
144 }
145}
146
147pub enum NodeSuggestion<'a, T: NodeDefinition> {
148 All,
149 NodeInputPin(&'a T, &'a NodePin<T::TypeInfo>),
150 NodeOutputPin(&'a T, &'a NodePin<T::TypeInfo>),
151}
152
153#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct ResponseSuggestionNode<T: NodeDefinition> {
155 pub category: String,
156 pub label: String,
157 pub node: Node<T>,
158}
159
160impl<T: NodeDefinition> ResponseSuggestionNode<T> {
161 pub fn new(category: impl ToString, node: Node<T>, registry: &Registry) -> Self {
162 Self {
163 category: category.to_string(),
164 label: node.data.node_label(registry),
165 node,
166 }
167 }
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct Node<T: NodeDefinition> {
172 id: NodeId<T>,
173 pub x: i64,
174 pub y: i64,
175 pub data: T,
176}
177
178impl<T: NodeDefinition> Node<T> {
179 pub fn new(x: i64, y: i64, data: T) -> Self {
180 Self {
181 id: Default::default(),
182 x,
183 y,
184 data,
185 }
186 }
187
188 pub fn id(&self) -> NodeId<T> {
189 self.id
190 }
191}
192
193#[derive(Clone, Serialize, Deserialize)]
194pub struct NodeConnection<T: NodeDefinition> {
195 pub from_node: NodeId<T>,
196 pub to_node: NodeId<T>,
197 pub from_pin: String,
198 pub to_pin: String,
199}
200
201impl<T: NodeDefinition> NodeConnection<T> {
202 pub fn new(from_node: NodeId<T>, to_node: NodeId<T>, from_pin: &str, to_pin: &str) -> Self {
203 Self {
204 from_node,
205 to_node,
206 from_pin: from_pin.to_owned(),
207 to_pin: to_pin.to_owned(),
208 }
209 }
210}
211
212impl<T: NodeDefinition> std::fmt::Debug for NodeConnection<T> {
213 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214 f.debug_struct("NodeConnection")
215 .field("from_node", &self.from_node)
216 .field("to_node", &self.to_node)
217 .field("from_pin", &self.from_pin)
218 .field("to_pin", &self.to_pin)
219 .finish()
220 }
221}
222
223impl<T: NodeDefinition> PartialEq for NodeConnection<T> {
224 fn eq(&self, other: &Self) -> bool {
225 self.from_node == other.from_node
226 && self.to_node == other.to_node
227 && self.from_pin == other.from_pin
228 && self.to_pin == other.to_pin
229 }
230}
231
232impl<T: NodeDefinition> Eq for NodeConnection<T> {}
233
234impl<T: NodeDefinition> Hash for NodeConnection<T> {
235 fn hash<H: Hasher>(&self, state: &mut H) {
236 self.from_node.hash(state);
237 self.to_node.hash(state);
238 self.from_pin.hash(state);
239 self.to_pin.hash(state);
240 }
241}
242
243#[derive(Debug)]
244pub enum ConnectionError {
245 InternalConnection(String),
246 SourceNodeNotFound(String),
247 TargetNodeNotFound(String),
248 NodesNotFound {
249 from: String,
250 to: String,
251 },
252 SourcePinNotFound {
253 node: String,
254 pin: String,
255 },
256 TargetPinNotFound {
257 node: String,
258 pin: String,
259 },
260 MismatchTypes {
261 from_node: String,
262 from_pin: String,
263 from_type_info: String,
264 to_node: String,
265 to_pin: String,
266 to_type_info: String,
267 },
268 MismatchPins {
269 from_node: String,
270 from_pin: String,
271 to_node: String,
272 to_pin: String,
273 },
274 CycleNodeFound(String),
275 Custom(Box<dyn Error>),
276}
277
278impl std::fmt::Display for ConnectionError {
279 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
280 match self {
281 Self::InternalConnection(node) => {
282 write!(f, "Trying to connect node: {node} to itself")
283 }
284 Self::SourceNodeNotFound(node) => write!(f, "Source node: {node} not found"),
285 Self::TargetNodeNotFound(node) => write!(f, "Target node: {node} not found"),
286 Self::NodesNotFound { from, to } => {
287 write!(f, "Source: {from} and target: {to} nodes not found")
288 }
289 Self::SourcePinNotFound { node, pin } => {
290 write!(f, "Source pin: {pin} for node: {node} not found")
291 }
292 Self::TargetPinNotFound { node, pin } => {
293 write!(f, "Target pin: {pin} for node: {node} not found")
294 }
295 Self::MismatchTypes {
296 from_node,
297 from_pin,
298 from_type_info,
299 to_node,
300 to_pin,
301 to_type_info,
302 } => {
303 write!(
304 f,
305 "Source type: {from_type_info} of pin: {from_pin} for node: {from_node} does not match target type: {to_type_info} of pin: {to_pin} for node: {to_node}"
306 )
307 }
308 Self::MismatchPins {
309 from_node,
310 from_pin,
311 to_node,
312 to_pin,
313 } => {
314 write!(
315 f,
316 "Source pin: {from_pin} kind for node: {from_node} does not match target pin: {to_pin} kind for node: {to_node}"
317 )
318 }
319 Self::CycleNodeFound(node) => write!(f, "Found cycle node: {node}"),
320 Self::Custom(error) => error.fmt(f),
321 }
322 }
323}
324
325impl Error for ConnectionError {}
326
327#[derive(Debug)]
328pub enum NodeGraphError {
329 Connection(ConnectionError),
330 DuplicateFunctionInputNames(String),
331 DuplicateFunctionOutputNames(String),
332}
333
334impl std::fmt::Display for NodeGraphError {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 match self {
337 Self::Connection(connection) => connection.fmt(f),
338 Self::DuplicateFunctionInputNames(name) => {
339 write!(
340 f,
341 "Found duplicate `{name}` function input with different types"
342 )
343 }
344 Self::DuplicateFunctionOutputNames(name) => {
345 write!(
346 f,
347 "Found duplicate `{name}` function output with different types"
348 )
349 }
350 }
351 }
352}
353
354impl Error for NodeGraphError {}
355
356#[derive(Clone)]
357struct SpatialNode<T: NodeDefinition> {
358 id: NodeId<T>,
359 x: i64,
360 y: i64,
361}
362
363impl<T: NodeDefinition> RTreeObject for SpatialNode<T> {
364 type Envelope = AABB<[i64; 2]>;
365
366 fn envelope(&self) -> Self::Envelope {
367 AABB::from_point([self.x, self.y])
368 }
369}
370
371impl<T: NodeDefinition> PointDistance for SpatialNode<T> {
372 fn distance_2(
373 &self,
374 point: &<Self::Envelope as Envelope>::Point,
375 ) -> <<Self::Envelope as Envelope>::Point as Point>::Scalar {
376 let dx = self.x - point[0];
377 let dy = self.y - point[1];
378 dx * dx + dy * dy
379 }
380}
381
382#[derive(Clone, Serialize, Deserialize)]
383pub struct NodeGraph<T: NodeDefinition> {
384 nodes: Vec<Node<T>>,
385 connections: Vec<NodeConnection<T>>,
386 #[serde(skip, default)]
387 rtree: RTree<SpatialNode<T>>,
388}
389
390impl<T: NodeDefinition> Default for NodeGraph<T> {
391 fn default() -> Self {
392 Self {
393 nodes: vec![],
394 connections: vec![],
395 rtree: Default::default(),
396 }
397 }
398}
399
400impl<T: NodeDefinition> NodeGraph<T> {
401 pub fn clear(&mut self) {
402 self.nodes.clear();
403 self.connections.clear();
404 }
405
406 pub fn refresh_spatial_cache(&mut self) {
407 self.rtree = RTree::bulk_load(
408 self.nodes
409 .iter()
410 .map(|node| SpatialNode {
411 id: node.id,
412 x: node.x,
413 y: node.y,
414 })
415 .collect(),
416 );
417 }
418
419 pub fn query_nearest_nodes(&self, x: i64, y: i64) -> impl Iterator<Item = NodeId<T>> + '_ {
420 self.rtree
421 .nearest_neighbor_iter(&[x, y])
422 .map(|node| node.id)
423 }
424
425 pub fn query_region_nodes(
426 &self,
427 fx: i64,
428 fy: i64,
429 tx: i64,
430 ty: i64,
431 extrude: i64,
432 ) -> impl Iterator<Item = NodeId<T>> + '_ {
433 self.rtree
434 .locate_in_envelope(&AABB::from_corners(
435 [fx - extrude, fy - extrude],
436 [tx - extrude, ty - extrude],
437 ))
438 .map(|node| node.id)
439 }
440
441 pub fn suggest_all_nodes(
442 x: i64,
443 y: i64,
444 registry: &Registry,
445 ) -> Vec<ResponseSuggestionNode<T>> {
446 T::node_suggestions(x, y, NodeSuggestion::All, registry)
447 }
448
449 pub fn suggest_node_input_pin(
450 &self,
451 x: i64,
452 y: i64,
453 id: NodeId<T>,
454 name: &str,
455 registry: &Registry,
456 ) -> Vec<ResponseSuggestionNode<T>> {
457 if let Some(node) = self.node(id)
458 && let Some(pin) = node
459 .data
460 .node_pins_in(registry)
461 .into_iter()
462 .find(|pin| pin.name() == name)
463 {
464 return T::node_suggestions(
465 x,
466 y,
467 NodeSuggestion::NodeInputPin(&node.data, &pin),
468 registry,
469 );
470 }
471 vec![]
472 }
473
474 pub fn suggest_node_output_pin(
475 &self,
476 x: i64,
477 y: i64,
478 id: NodeId<T>,
479 name: &str,
480 registry: &Registry,
481 ) -> Vec<ResponseSuggestionNode<T>> {
482 if let Some(node) = self.node(id)
483 && let Some(pin) = node
484 .data
485 .node_pins_out(registry)
486 .into_iter()
487 .find(|pin| pin.name() == name)
488 {
489 return T::node_suggestions(
490 x,
491 y,
492 NodeSuggestion::NodeOutputPin(&node.data, &pin),
493 registry,
494 );
495 }
496 vec![]
497 }
498
499 pub fn node(&self, id: NodeId<T>) -> Option<&Node<T>> {
500 self.nodes.iter().find(|node| node.id == id)
501 }
502
503 pub fn node_mut(&mut self, id: NodeId<T>) -> Option<&mut Node<T>> {
504 self.nodes.iter_mut().find(|node| node.id == id)
505 }
506
507 pub fn nodes(&self) -> impl Iterator<Item = &Node<T>> {
508 self.nodes.iter()
509 }
510
511 pub fn nodes_mut(&mut self) -> impl Iterator<Item = &mut Node<T>> {
512 self.nodes.iter_mut()
513 }
514
515 pub fn add_node(&mut self, node: Node<T>, registry: &Registry) -> Option<NodeId<T>> {
516 if node.data.node_is_start(registry)
517 && self
518 .nodes
519 .iter()
520 .any(|node| node.data.node_is_start(registry))
521 {
522 return None;
523 }
524 let id = node.id;
525 if let Some(index) = self.nodes.iter().position(|node| node.id == id) {
526 self.nodes.swap_remove(index);
527 }
528 self.nodes.push(node);
529 Some(id)
530 }
531
532 pub fn remove_node(&mut self, id: NodeId<T>, registry: &Registry) -> Option<Node<T>> {
533 if let Some(index) = self
534 .nodes
535 .iter()
536 .position(|node| node.id == id && !node.data.node_is_start(registry))
537 {
538 self.disconnect_node(id, None);
539 Some(self.nodes.swap_remove(index))
540 } else {
541 None
542 }
543 }
544
545 pub fn connect_nodes(&mut self, connection: NodeConnection<T>) {
546 if !self.connections.iter().any(|other| &connection == other) {
547 self.disconnect_node(connection.from_node, Some(&connection.from_pin));
548 self.disconnect_node(connection.to_node, Some(&connection.to_pin));
549 self.connections.push(connection);
550 }
551 }
552
553 pub fn disconnect_nodes(
554 &mut self,
555 from_node: NodeId<T>,
556 to_node: NodeId<T>,
557 from_pin: &str,
558 to_pin: &str,
559 ) {
560 if let Some(index) = self.connections.iter().position(|connection| {
561 connection.from_node == from_node
562 && connection.to_node == to_node
563 && connection.from_pin == from_pin
564 && connection.to_pin == to_pin
565 }) {
566 self.connections.swap_remove(index);
567 }
568 }
569
570 pub fn disconnect_node(&mut self, node: NodeId<T>, pin: Option<&str>) {
571 let to_remove = self
572 .connections
573 .iter()
574 .enumerate()
575 .filter_map(|(index, connection)| {
576 if let Some(pin) = pin {
577 if connection.from_node == node && connection.from_pin == pin {
578 return Some(index);
579 }
580 if connection.to_node == node && connection.to_pin == pin {
581 return Some(index);
582 }
583 } else if connection.from_node == node || connection.to_node == node {
584 return Some(index);
585 }
586 None
587 })
588 .collect::<Vec<_>>();
589 for index in to_remove.into_iter().rev() {
590 self.connections.swap_remove(index);
591 }
592 }
593
594 pub fn connections(&self) -> impl Iterator<Item = &NodeConnection<T>> {
595 self.connections.iter()
596 }
597
598 pub fn node_connections(&self, id: NodeId<T>) -> impl Iterator<Item = &NodeConnection<T>> {
599 self.connections
600 .iter()
601 .filter(move |connection| connection.from_node == id || connection.to_node == id)
602 }
603
604 pub fn node_connections_in<'a>(
605 &'a self,
606 id: NodeId<T>,
607 pin: Option<&'a str>,
608 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
609 self.connections.iter().filter(move |connection| {
610 connection.to_node == id && pin.map(|pin| connection.to_pin == pin).unwrap_or(true)
611 })
612 }
613
614 pub fn node_connections_out<'a>(
615 &'a self,
616 id: NodeId<T>,
617 pin: Option<&'a str>,
618 ) -> impl Iterator<Item = &'a NodeConnection<T>> + 'a {
619 self.connections.iter().filter(move |connection| {
620 connection.from_node == id && pin.map(|pin| connection.from_pin == pin).unwrap_or(true)
621 })
622 }
623
624 pub fn node_neighbors_in<'a>(
625 &'a self,
626 id: NodeId<T>,
627 pin: Option<&'a str>,
628 ) -> impl Iterator<Item = NodeId<T>> + 'a {
629 self.node_connections_in(id, pin)
630 .map(move |connection| connection.from_node)
631 }
632
633 pub fn node_neighbors_out<'a>(
634 &'a self,
635 id: NodeId<T>,
636 pin: Option<&'a str>,
637 ) -> impl Iterator<Item = NodeId<T>> + 'a {
638 self.node_connections_out(id, pin)
639 .map(move |connection| connection.to_node)
640 }
641
642 pub fn validate(&self, registry: &Registry) -> Result<(), Vec<NodeGraphError>> {
643 let mut errors = self
644 .connections
645 .iter()
646 .filter_map(|connection| self.validate_connection(connection, registry))
647 .map(NodeGraphError::Connection)
648 .collect::<Vec<_>>();
649 if let Some(error) = self.detect_cycles() {
650 errors.push(NodeGraphError::Connection(error));
651 }
652 if errors.is_empty() {
653 Ok(())
654 } else {
655 Err(errors)
656 }
657 }
658
659 fn validate_connection(
660 &self,
661 connection: &NodeConnection<T>,
662 registry: &Registry,
663 ) -> Option<ConnectionError> {
664 if connection.from_node == connection.to_node {
665 return Some(ConnectionError::InternalConnection(
666 connection.from_node.to_string(),
667 ));
668 }
669 let from = self
670 .nodes
671 .iter()
672 .find(|node| node.id == connection.from_node);
673 let to = self.nodes.iter().find(|node| node.id == connection.to_node);
674 let (from_node, to_node) = match (from, to) {
675 (Some(from), Some(to)) => (from, to),
676 (Some(_), None) => {
677 return Some(ConnectionError::TargetNodeNotFound(
678 connection.to_node.to_string(),
679 ));
680 }
681 (None, Some(_)) => {
682 return Some(ConnectionError::SourceNodeNotFound(
683 connection.from_node.to_string(),
684 ));
685 }
686 (None, None) => {
687 return Some(ConnectionError::NodesNotFound {
688 from: connection.from_node.to_string(),
689 to: connection.to_node.to_string(),
690 });
691 }
692 };
693 let from_pins_out = from_node.data.node_pins_out(registry);
694 let from_pin = match from_pins_out
695 .iter()
696 .find(|pin| pin.name() == connection.from_pin)
697 {
698 Some(pin) => pin,
699 None => {
700 return Some(ConnectionError::SourcePinNotFound {
701 node: connection.from_node.to_string(),
702 pin: connection.from_pin.to_owned(),
703 });
704 }
705 };
706 let to_pins_in = to_node.data.node_pins_in(registry);
707 let to_pin = match to_pins_in
708 .iter()
709 .find(|pin| pin.name() == connection.to_pin)
710 {
711 Some(pin) => pin,
712 None => {
713 return Some(ConnectionError::TargetPinNotFound {
714 node: connection.to_node.to_string(),
715 pin: connection.to_pin.to_owned(),
716 });
717 }
718 };
719 match (from_pin, to_pin) {
720 (NodePin::Execute { .. }, NodePin::Execute { .. }) => {}
721 (NodePin::Parameter { type_info: a, .. }, NodePin::Parameter { type_info: b, .. }) => {
722 if !a.are_compatible(b) {
723 return Some(ConnectionError::MismatchTypes {
724 from_node: connection.from_node.to_string(),
725 from_pin: connection.from_pin.to_owned(),
726 to_node: connection.to_node.to_string(),
727 to_pin: connection.to_pin.to_owned(),
728 from_type_info: a.to_string(),
729 to_type_info: b.to_string(),
730 });
731 }
732 }
733 (NodePin::Property { .. }, NodePin::Property { .. }) => {}
734 _ => {
735 return Some(ConnectionError::MismatchPins {
736 from_node: connection.from_node.to_string(),
737 from_pin: connection.from_pin.to_owned(),
738 to_node: connection.to_node.to_string(),
739 to_pin: connection.to_pin.to_owned(),
740 });
741 }
742 }
743 if let Err(error) = to_node.data.validate_connection(&from_node.data, registry) {
744 return Some(ConnectionError::Custom(error));
745 }
746 None
747 }
748
749 fn detect_cycles(&self) -> Option<ConnectionError> {
750 let mut visited = HashSet::with_capacity(self.nodes.len());
751 let mut available = self.nodes.iter().map(|node| node.id).collect::<Vec<_>>();
752 while let Some(id) = available.first() {
753 if let Some(error) = self.detect_cycle(*id, &mut available, &mut visited) {
754 return Some(error);
755 }
756 available.swap_remove(0);
757 }
758 None
759 }
760
761 fn detect_cycle(
762 &self,
763 id: NodeId<T>,
764 available: &mut Vec<NodeId<T>>,
765 visited: &mut HashSet<NodeId<T>>,
766 ) -> Option<ConnectionError> {
767 if visited.contains(&id) {
768 return Some(ConnectionError::CycleNodeFound(id.to_string()));
769 }
770 visited.insert(id);
771 for id in self.node_neighbors_out(id, None) {
772 if let Some(index) = available.iter().position(|item| item == &id) {
773 available.swap_remove(index);
774 if let Some(error) = self.detect_cycle(id, available, visited) {
775 return Some(error);
776 }
777 }
778 }
779 None
780 }
781
782 pub fn visit<V: NodeGraphVisitor<T>>(
783 &self,
784 visitor: &mut V,
785 registry: &Registry,
786 ) -> Vec<V::Output> {
787 let starts = self
788 .nodes
789 .iter()
790 .filter(|node| node.data.node_is_start(registry))
791 .map(|node| node.id)
792 .collect::<HashSet<_>>();
793 let mut result = Vec::with_capacity(self.nodes.len());
794 for id in starts {
795 self.visit_statement(id, &mut result, visitor, registry);
796 }
797 result
798 }
799
800 fn visit_statement<V: NodeGraphVisitor<T>>(
801 &self,
802 id: NodeId<T>,
803 result: &mut Vec<V::Output>,
804 visitor: &mut V,
805 registry: &Registry,
806 ) {
807 if let Some(node) = self.node(id) {
808 let inputs = node
809 .data
810 .node_pins_in(registry)
811 .into_iter()
812 .filter(|pin| pin.is_parameter())
813 .filter_map(|pin| {
814 self.node_neighbors_in(id, Some(pin.name()))
815 .next()
816 .map(|id| (pin.name().to_owned(), id))
817 })
818 .filter_map(|(name, id)| {
819 self.visit_expression(id, visitor, registry)
820 .map(|input| (name, input))
821 })
822 .collect();
823 let pins_out = node.data.node_pins_out(registry);
824 let scopes = pins_out
825 .iter()
826 .filter(|pin| pin.has_subscope())
827 .filter_map(|pin| {
828 let id = self.node_neighbors_out(id, Some(pin.name())).next()?;
829 Some((id, pin.name().to_owned()))
830 })
831 .map(|(id, name)| {
832 let mut result = Vec::with_capacity(self.nodes.len());
833 self.visit_statement(id, &mut result, visitor, registry);
834 (name, result)
835 })
836 .collect();
837 if visitor.visit_statement(node, inputs, scopes, result) {
838 for pin in pins_out {
839 if pin.is_execute() && !pin.has_subscope() {
840 for id in self.node_neighbors_out(id, Some(pin.name())) {
841 self.visit_statement(id, result, visitor, registry);
842 }
843 }
844 }
845 }
846 }
847 }
848
849 fn visit_expression<V: NodeGraphVisitor<T>>(
850 &self,
851 id: NodeId<T>,
852 visitor: &mut V,
853 registry: &Registry,
854 ) -> Option<V::Input> {
855 if let Some(node) = self.node(id) {
856 let inputs = node
857 .data
858 .node_pins_in(registry)
859 .into_iter()
860 .filter(|pin| pin.is_parameter())
861 .filter_map(|pin| {
862 self.node_neighbors_in(id, Some(pin.name()))
863 .next()
864 .map(|id| (pin.name().to_owned(), id))
865 })
866 .filter_map(|(name, id)| {
867 self.visit_expression(id, visitor, registry)
868 .map(|input| (name, input))
869 })
870 .collect();
871 return visitor.visit_expression(node, inputs);
872 }
873 None
874 }
875}
876
877impl<T: NodeDefinition + std::fmt::Debug> std::fmt::Debug for NodeGraph<T> {
878 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
879 f.debug_struct("NodeGraph")
880 .field("nodes", &self.nodes)
881 .field("connections", &self.connections)
882 .finish()
883 }
884}
885
886pub trait NodeGraphVisitor<T: NodeDefinition> {
887 type Input;
888 type Output;
889
890 fn visit_statement(
891 &mut self,
892 node: &Node<T>,
893 inputs: HashMap<String, Self::Input>,
894 scopes: HashMap<String, Vec<Self::Output>>,
895 result: &mut Vec<Self::Output>,
896 ) -> bool;
897
898 fn visit_expression(
899 &mut self,
900 node: &Node<T>,
901 inputs: HashMap<String, Self::Input>,
902 ) -> Option<Self::Input>;
903}
904
905#[cfg(test)]
906mod tests {
907 use crate::nodes::{
908 Node, NodeConnection, NodeDefinition, NodeGraph, NodeGraphVisitor, NodePin, NodeSuggestion,
909 NodeTypeInfo, PropertyValue, ResponseSuggestionNode,
910 };
911 use intuicio_core::{registry::Registry, types::TypeQuery};
912 use std::collections::HashMap;
913
914 #[derive(Debug, Clone, PartialEq)]
915 enum Script {
916 Literal(i32),
917 Return,
918 Call(String),
919 Scope(Vec<Script>),
920 }
921
922 impl NodeTypeInfo for String {
923 fn type_query(&'_ self) -> TypeQuery<'_> {
924 TypeQuery {
925 name: Some(self.into()),
926 ..Default::default()
927 }
928 }
929
930 fn are_compatible(&self, other: &Self) -> bool {
931 self == other
932 }
933 }
934
935 #[derive(Debug, Clone)]
936 enum Nodes {
937 Start,
938 Expression(i32),
939 Result,
940 Convert(String),
941 Child,
942 }
943
944 impl NodeDefinition for Nodes {
945 type TypeInfo = String;
946
947 fn node_label(&self, _: &Registry) -> String {
948 format!("{self:?}")
949 }
950
951 fn node_pins_in(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
952 match self {
953 Nodes::Start => vec![],
954 Nodes::Expression(_) => {
955 vec![NodePin::execute("In", false), NodePin::property("Value")]
956 }
957 Nodes::Result => vec![
958 NodePin::execute("In", false),
959 NodePin::parameter("Data", "i32".to_owned()),
960 ],
961 Nodes::Convert(_) => vec![
962 NodePin::execute("In", false),
963 NodePin::property("Name"),
964 NodePin::parameter("Data in", "i32".to_owned()),
965 ],
966 Nodes::Child => vec![NodePin::execute("In", false)],
967 }
968 }
969
970 fn node_pins_out(&self, _: &Registry) -> Vec<NodePin<Self::TypeInfo>> {
971 match self {
972 Nodes::Start => vec![NodePin::execute("Out", false)],
973 Nodes::Expression(_) => vec![
974 NodePin::execute("Out", false),
975 NodePin::parameter("Data", "i32".to_owned()),
976 ],
977 Nodes::Result => vec![],
978 Nodes::Convert(_) => vec![
979 NodePin::execute("Out", false),
980 NodePin::parameter("Data out", "i32".to_owned()),
981 ],
982 Nodes::Child => vec![
983 NodePin::execute("Out", false),
984 NodePin::execute("Body", true),
985 ],
986 }
987 }
988
989 fn node_is_start(&self, _: &Registry) -> bool {
990 matches!(self, Self::Start)
991 }
992
993 fn node_suggestions(
994 _: i64,
995 _: i64,
996 _: NodeSuggestion<Self>,
997 _: &Registry,
998 ) -> Vec<ResponseSuggestionNode<Self>> {
999 vec![]
1000 }
1001
1002 fn get_property(&self, property_name: &str) -> Option<PropertyValue> {
1003 match self {
1004 Nodes::Expression(value) => match property_name {
1005 "Value" => PropertyValue::new(value).ok(),
1006 _ => None,
1007 },
1008 Nodes::Convert(name) => match property_name {
1009 "Name" => PropertyValue::new(name).ok(),
1010 _ => None,
1011 },
1012 _ => None,
1013 }
1014 }
1015
1016 fn set_property(&mut self, property_name: &str, property_value: PropertyValue) {
1017 #[allow(clippy::single_match)]
1018 match self {
1019 Nodes::Expression(value) => match property_name {
1020 "Value" => {
1021 if let Ok(v) = property_value.get_exact::<i32>() {
1022 *value = v;
1023 }
1024 }
1025 _ => {}
1026 },
1027 Nodes::Convert(name) => {
1028 if let Ok(v) = property_value.get_exact::<String>() {
1029 *name = v;
1030 }
1031 }
1032 _ => {}
1033 }
1034 }
1035 }
1036
1037 struct CompileNodesToScript;
1038
1039 impl NodeGraphVisitor<Nodes> for CompileNodesToScript {
1040 type Input = ();
1041 type Output = Script;
1042
1043 fn visit_statement(
1044 &mut self,
1045 node: &Node<Nodes>,
1046 _: HashMap<String, Self::Input>,
1047 mut scopes: HashMap<String, Vec<Self::Output>>,
1048 result: &mut Vec<Self::Output>,
1049 ) -> bool {
1050 match &node.data {
1051 Nodes::Result => result.push(Script::Return),
1052 Nodes::Convert(name) => result.push(Script::Call(name.to_owned())),
1053 Nodes::Child => {
1054 if let Some(body) = scopes.remove("Body") {
1055 result.push(Script::Scope(body));
1056 }
1057 }
1058 Nodes::Expression(value) => result.push(Script::Literal(*value)),
1059 _ => {}
1060 }
1061 true
1062 }
1063
1064 fn visit_expression(
1065 &mut self,
1066 _: &Node<Nodes>,
1067 _: HashMap<String, Self::Input>,
1068 ) -> Option<Self::Input> {
1069 None
1070 }
1071 }
1072
1073 #[test]
1074 fn test_nodes() {
1075 let registry = Registry::default().with_basic_types();
1076 let mut graph = NodeGraph::default();
1077 let start = graph
1078 .add_node(Node::new(0, 0, Nodes::Start), ®istry)
1079 .unwrap();
1080 let expression_child = graph
1081 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1082 .unwrap();
1083 let convert_child = graph
1084 .add_node(Node::new(0, 0, Nodes::Convert("foo".to_owned())), ®istry)
1085 .unwrap();
1086 let result_child = graph
1087 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1088 .unwrap();
1089 let child = graph
1090 .add_node(Node::new(0, 0, Nodes::Child), ®istry)
1091 .unwrap();
1092 let expression = graph
1093 .add_node(Node::new(0, 0, Nodes::Expression(42)), ®istry)
1094 .unwrap();
1095 let convert = graph
1096 .add_node(Node::new(0, 0, Nodes::Convert("bar".to_owned())), ®istry)
1097 .unwrap();
1098 let result = graph
1099 .add_node(Node::new(0, 0, Nodes::Result), ®istry)
1100 .unwrap();
1101 graph.connect_nodes(NodeConnection::new(start, child, "Out", "In"));
1102 graph.connect_nodes(NodeConnection::new(child, expression_child, "Body", "In"));
1103 graph.connect_nodes(NodeConnection::new(
1104 expression_child,
1105 convert_child,
1106 "Out",
1107 "In",
1108 ));
1109 graph.connect_nodes(NodeConnection::new(
1110 expression_child,
1111 convert_child,
1112 "Data",
1113 "Data in",
1114 ));
1115 graph.connect_nodes(NodeConnection::new(
1116 convert_child,
1117 result_child,
1118 "Out",
1119 "In",
1120 ));
1121 graph.connect_nodes(NodeConnection::new(
1122 convert_child,
1123 result_child,
1124 "Data out",
1125 "Data",
1126 ));
1127 graph.connect_nodes(NodeConnection::new(child, expression, "Out", "In"));
1128 graph.connect_nodes(NodeConnection::new(expression, convert, "Out", "In"));
1129 graph.connect_nodes(NodeConnection::new(expression, convert, "Data", "Data in"));
1130 graph.connect_nodes(NodeConnection::new(convert, result, "Out", "In"));
1131 graph.connect_nodes(NodeConnection::new(convert, result, "Data out", "Data"));
1132 graph.validate(®istry).unwrap();
1133 assert_eq!(
1134 graph.visit(&mut CompileNodesToScript, ®istry),
1135 vec![
1136 Script::Scope(vec![
1137 Script::Literal(42),
1138 Script::Call("foo".to_owned()),
1139 Script::Return
1140 ]),
1141 Script::Literal(42),
1142 Script::Call("bar".to_owned()),
1143 Script::Return
1144 ]
1145 );
1146 assert_eq!(
1147 graph
1148 .node(expression)
1149 .unwrap()
1150 .data
1151 .get_property("Value")
1152 .unwrap(),
1153 PropertyValue::new(&42i32).unwrap(),
1154 );
1155 graph
1156 .node_mut(expression)
1157 .unwrap()
1158 .data
1159 .set_property("Value", PropertyValue::new(&10i32).unwrap());
1160 assert_eq!(
1161 graph
1162 .node(expression)
1163 .unwrap()
1164 .data
1165 .get_property("Value")
1166 .unwrap(),
1167 PropertyValue::new(&10i32).unwrap(),
1168 );
1169 }
1170}