flow_rs_core/graph/
validation.rs

1//! Graph validation and constraint checking
2
3use crate::error::{FlowError, Result};
4use crate::handle::HandleId;
5use crate::types::NodeId;
6
7use super::{Edge, Graph};
8
9impl<N, E> Graph<N, E> {
10    /// Add an edge with handle validation
11    pub fn add_handle_edge(&mut self, edge: Edge<E>) -> Result<()> {
12        // Validate that source and target nodes exist
13        let source_node = self
14            .get_node(&edge.source)
15            .ok_or_else(|| FlowError::node_not_found(edge.source.as_str()))?;
16        let target_node = self
17            .get_node(&edge.target)
18            .ok_or_else(|| FlowError::node_not_found(edge.target.as_str()))?;
19
20        // Validate handle references if specified
21        if let Some(source_handle_id) = &edge.source_handle {
22            let source_handle_id = HandleId::new(source_handle_id.clone());
23            let source_handle = source_node
24                .get_handle(&source_handle_id)
25                .ok_or_else(|| FlowError::handle_not_found(source_handle_id.as_str()))?;
26
27            // Check connection limit
28            if !self.can_handle_accept_connection(&edge.source, &source_handle_id) {
29                let current_count =
30                    self.get_handle_connection_count(&edge.source, &source_handle_id);
31                let limit = source_handle.connection_limit.unwrap_or(usize::MAX);
32                return Err(FlowError::connection_limit_exceeded(
33                    source_handle_id.as_str(),
34                    current_count,
35                    limit,
36                ));
37            }
38
39            if let Some(target_handle_id) = &edge.target_handle {
40                let target_handle_id = HandleId::new(target_handle_id.clone());
41                let target_handle = target_node
42                    .get_handle(&target_handle_id)
43                    .ok_or_else(|| FlowError::handle_not_found(target_handle_id.as_str()))?;
44
45                // Check handle compatibility
46                if !source_handle.can_connect_to(target_handle) {
47                    return Err(FlowError::invalid_connection(
48                        "Handle types or connection types are incompatible",
49                    ));
50                }
51            }
52        }
53
54        // If validation passes, add the edge normally
55        self.add_edge(edge)
56    }
57
58    /// Get connection count for a specific handle
59    fn get_handle_connection_count(&self, node_id: &NodeId, handle_id: &HandleId) -> usize {
60        let handle_id_str = handle_id.as_str();
61        self.edges
62            .values()
63            .filter(|edge| {
64                (&edge.source == node_id && edge.source_handle.as_deref() == Some(handle_id_str))
65                    || (&edge.target == node_id
66                        && edge.target_handle.as_deref() == Some(handle_id_str))
67            })
68            .count()
69    }
70
71    /// Get all edges connected to a specific handle
72    ///
73    /// This method provides accurate connection counting by examining all edges
74    /// in the graph that reference the specified handle.
75    pub fn get_handle_connections(&self, node_id: &NodeId, handle_id: &HandleId) -> Vec<&Edge<E>> {
76        let handle_id_str = handle_id.as_str();
77        self.edges
78            .values()
79            .filter(|edge| {
80                (&edge.source == node_id && edge.source_handle.as_deref() == Some(handle_id_str))
81                    || (&edge.target == node_id
82                        && edge.target_handle.as_deref() == Some(handle_id_str))
83            })
84            .collect()
85    }
86
87    /// Check if a handle can accept new connections (respects connection limits)
88    ///
89    /// This method provides accurate connection limit validation by counting
90    /// current connections and comparing against the handle's limit.
91    pub fn can_handle_accept_connection(&self, node_id: &NodeId, handle_id: &HandleId) -> bool {
92        if let Some(node) = self.get_node(node_id) {
93            if let Some(handle) = node.get_handle(handle_id) {
94                if let Some(limit) = handle.connection_limit {
95                    let current_connections = self.get_handle_connections(node_id, handle_id).len();
96                    return current_connections < limit;
97                }
98            }
99        }
100        true // No limit or handle doesn't exist - allow connection
101    }
102}