rrag_graph/
core.rs

1//! # Core Graph Abstractions
2//!
3//! This module contains the fundamental types and traits that form the foundation
4//! of the RGraph system, including the workflow graph, nodes, edges, and execution context.
5
6use crate::state::GraphState;
7use crate::{RGraphError, RGraphResult};
8use async_trait::async_trait;
9use petgraph::{Directed, Graph};
10use std::collections::HashMap;
11use std::sync::Arc;
12use uuid::Uuid;
13type NodeIndex = petgraph::graph::NodeIndex;
14#[allow(dead_code)]
15type EdgeIndex = petgraph::graph::EdgeIndex;
16use parking_lot::RwLock;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// Unique identifier for a node in the workflow graph
22#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct NodeId(pub String);
25
26impl NodeId {
27    /// Create a new node ID
28    pub fn new(id: impl Into<String>) -> Self {
29        Self(id.into())
30    }
31
32    /// Generate a random node ID
33    pub fn generate() -> Self {
34        Self(Uuid::new_v4().to_string())
35    }
36
37    /// Get the string representation
38    pub fn as_str(&self) -> &str {
39        &self.0
40    }
41}
42
43impl From<String> for NodeId {
44    fn from(id: String) -> Self {
45        NodeId(id)
46    }
47}
48
49impl From<&str> for NodeId {
50    fn from(id: &str) -> Self {
51        NodeId(id.to_string())
52    }
53}
54
55/// Unique identifier for an edge in the workflow graph
56#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
58pub struct EdgeId(pub String);
59
60impl EdgeId {
61    /// Create a new edge ID
62    pub fn new(id: impl Into<String>) -> Self {
63        Self(id.into())
64    }
65
66    /// Generate a random edge ID
67    pub fn generate() -> Self {
68        Self(Uuid::new_v4().to_string())
69    }
70}
71
72/// Core trait for all executable nodes in the workflow graph
73#[async_trait]
74pub trait Node: Send + Sync {
75    /// Execute the node with the given state and context
76    async fn execute(
77        &self,
78        state: &mut GraphState,
79        context: &ExecutionContext,
80    ) -> RGraphResult<ExecutionResult>;
81
82    /// Get the node's unique identifier
83    fn id(&self) -> &NodeId;
84
85    /// Get the node's display name
86    fn name(&self) -> &str;
87
88    /// Get the node's description
89    fn description(&self) -> Option<&str> {
90        None
91    }
92
93    /// Get the expected input keys from the state
94    fn input_keys(&self) -> Vec<&str> {
95        vec![]
96    }
97
98    /// Get the output keys that this node will write to the state
99    fn output_keys(&self) -> Vec<&str> {
100        vec![]
101    }
102
103    /// Validate that the node can execute with the current state
104    fn validate(&self, _state: &GraphState) -> RGraphResult<()> {
105        Ok(())
106    }
107
108    /// Get node metadata for observability
109    fn metadata(&self) -> NodeMetadata {
110        NodeMetadata {
111            id: self.id().clone(),
112            name: self.name().to_string(),
113            description: self.description().map(|s| s.to_string()),
114            input_keys: self.input_keys().iter().map(|s| s.to_string()).collect(),
115            output_keys: self.output_keys().iter().map(|s| s.to_string()).collect(),
116        }
117    }
118}
119
120/// Metadata about a node for introspection and observability
121#[derive(Debug, Clone)]
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123pub struct NodeMetadata {
124    pub id: NodeId,
125    pub name: String,
126    pub description: Option<String>,
127    pub input_keys: Vec<String>,
128    pub output_keys: Vec<String>,
129}
130
131/// Represents an edge in the workflow graph
132#[derive(Debug, Clone)]
133#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
134pub struct Edge {
135    pub id: EdgeId,
136    pub from: NodeId,
137    pub to: NodeId,
138    pub condition: Option<EdgeCondition>,
139}
140
141/// Condition that must be met for an edge to be traversed
142#[derive(Debug, Clone)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144pub enum EdgeCondition {
145    /// Always traverse the edge
146    Always,
147    /// Traverse if the condition function returns true
148    Conditional(String), // Serialized condition function
149    /// Traverse if the state contains a specific value
150    StateCondition {
151        key: String,
152        expected_value: serde_json::Value,
153    },
154}
155
156/// Result of executing a node
157#[derive(Debug, Clone)]
158#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
159pub enum ExecutionResult {
160    /// Continue to the next node
161    Continue,
162    /// Stop execution and return the current state
163    Stop,
164    /// Jump to a specific node
165    JumpTo(NodeId),
166    /// Conditional routing based on state
167    Route(String), // Next node ID based on routing logic
168}
169
170/// Context information available during node execution
171#[derive(Debug, Clone)]
172pub struct ExecutionContext {
173    pub graph_id: String,
174    pub execution_id: String,
175    pub current_node: NodeId,
176    pub execution_path: Vec<NodeId>,
177    pub start_time: chrono::DateTime<chrono::Utc>,
178    pub metadata: HashMap<String, serde_json::Value>,
179}
180
181impl ExecutionContext {
182    pub fn new(graph_id: String, current_node: NodeId) -> Self {
183        Self {
184            graph_id,
185            execution_id: Uuid::new_v4().to_string(),
186            current_node,
187            execution_path: Vec::new(),
188            start_time: chrono::Utc::now(),
189            metadata: HashMap::new(),
190        }
191    }
192
193    pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
194        self.metadata.insert(key, value);
195        self
196    }
197}
198
199/// The main workflow graph that orchestrates node execution
200pub struct WorkflowGraph {
201    id: String,
202    name: String,
203    description: Option<String>,
204    graph: Arc<RwLock<Graph<Arc<dyn Node>, Edge, Directed>>>,
205    node_lookup: Arc<RwLock<HashMap<NodeId, NodeIndex>>>,
206    entry_points: Arc<RwLock<Vec<NodeId>>>,
207    exit_points: Arc<RwLock<Vec<NodeId>>>,
208}
209
210impl WorkflowGraph {
211    /// Create a new workflow graph
212    pub fn new(name: impl Into<String>) -> Self {
213        Self {
214            id: Uuid::new_v4().to_string(),
215            name: name.into(),
216            description: None,
217            graph: Arc::new(RwLock::new(Graph::new())),
218            node_lookup: Arc::new(RwLock::new(HashMap::new())),
219            entry_points: Arc::new(RwLock::new(Vec::new())),
220            exit_points: Arc::new(RwLock::new(Vec::new())),
221        }
222    }
223
224    /// Set the graph description
225    pub fn with_description(mut self, description: impl Into<String>) -> Self {
226        self.description = Some(description.into());
227        self
228    }
229
230    /// Add a node to the graph
231    pub async fn add_node(
232        &mut self,
233        node_id: impl Into<NodeId>,
234        node: Arc<dyn Node>,
235    ) -> RGraphResult<()> {
236        let node_id = node_id.into();
237
238        // Validate the node
239        let dummy_state = GraphState::new();
240        node.validate(&dummy_state)?;
241
242        let mut graph = self.graph.write();
243        let mut lookup = self.node_lookup.write();
244
245        // Check if node already exists
246        if lookup.contains_key(&node_id) {
247            return Err(RGraphError::validation(format!(
248                "Node '{}' already exists",
249                node_id.as_str()
250            )));
251        }
252
253        // Add node to the graph
254        let node_index = graph.add_node(node);
255        lookup.insert(node_id.clone(), node_index);
256
257        // If this is the first node, make it an entry point
258        if lookup.len() == 1 {
259            self.entry_points.write().push(node_id);
260        }
261
262        Ok(())
263    }
264
265    /// Add an edge between two nodes
266    pub fn add_edge(
267        &mut self,
268        from: impl Into<NodeId>,
269        to: impl Into<NodeId>,
270    ) -> RGraphResult<EdgeId> {
271        self.add_edge_with_condition(from, to, EdgeCondition::Always)
272    }
273
274    /// Add an edge with a condition
275    pub fn add_edge_with_condition(
276        &mut self,
277        from: impl Into<NodeId>,
278        to: impl Into<NodeId>,
279        condition: EdgeCondition,
280    ) -> RGraphResult<EdgeId> {
281        let from_id = from.into();
282        let to_id = to.into();
283        let edge_id = EdgeId::generate();
284
285        let graph_lock = self.graph.clone();
286        let lookup_lock = self.node_lookup.clone();
287
288        let mut graph = graph_lock.write();
289        let lookup = lookup_lock.read();
290
291        // Get node indices
292        let from_index = lookup.get(&from_id).ok_or_else(|| {
293            RGraphError::validation(format!("Node '{}' not found", from_id.as_str()))
294        })?;
295        let to_index = lookup.get(&to_id).ok_or_else(|| {
296            RGraphError::validation(format!("Node '{}' not found", to_id.as_str()))
297        })?;
298
299        // Create edge
300        let edge = Edge {
301            id: edge_id.clone(),
302            from: from_id,
303            to: to_id,
304            condition: Some(condition),
305        };
306
307        // Add edge to graph
308        graph.add_edge(*from_index, *to_index, edge);
309
310        Ok(edge_id)
311    }
312
313    /// Add a conditional edge with a custom routing function
314    pub fn add_conditional_edge<F>(
315        &mut self,
316        from: impl Into<NodeId>,
317        _condition_fn: F,
318    ) -> RGraphResult<EdgeId>
319    where
320        F: Fn(&GraphState) -> RGraphResult<String> + Send + Sync + 'static,
321    {
322        // In a real implementation, we'd store the condition function
323        // For now, we'll create a placeholder edge
324        let _from_id = from.into();
325        let edge_id = EdgeId::generate();
326
327        // This is a simplified implementation - in reality, we'd need to handle
328        // the conditional routing during execution
329        Ok(edge_id)
330    }
331
332    /// Set entry points for the graph
333    pub fn set_entry_points(&mut self, entry_points: Vec<NodeId>) {
334        *self.entry_points.write() = entry_points;
335    }
336
337    /// Set exit points for the graph
338    pub fn set_exit_points(&mut self, exit_points: Vec<NodeId>) {
339        *self.exit_points.write() = exit_points;
340    }
341
342    /// Get the graph ID
343    pub fn id(&self) -> &str {
344        &self.id
345    }
346
347    /// Get the graph name
348    pub fn name(&self) -> &str {
349        &self.name
350    }
351
352    /// Get the graph description
353    pub fn description(&self) -> Option<&str> {
354        self.description.as_deref()
355    }
356
357    /// Get all node IDs in the graph
358    pub fn node_ids(&self) -> Vec<NodeId> {
359        self.node_lookup.read().keys().cloned().collect()
360    }
361
362    /// Get entry points (returns owned values to avoid lifetime issues)
363    pub fn entry_points(&self) -> Vec<NodeId> {
364        self.entry_points.read().clone()
365    }
366
367    /// Get entry points as owned values
368    pub fn entry_points_owned(&self) -> Vec<NodeId> {
369        self.entry_points.read().clone()
370    }
371
372    /// Get a node by ID
373    pub fn get_node(&self, node_id: &NodeId) -> Option<Arc<dyn Node>> {
374        let lookup = self.node_lookup.read();
375        let graph = self.graph.read();
376
377        if let Some(&node_index) = lookup.get(node_id) {
378            if let Some(node_weight) = graph.node_weight(node_index) {
379                return Some(node_weight.clone());
380            }
381        }
382        None
383    }
384
385    /// Validate the graph structure
386    pub fn validate(&self) -> RGraphResult<()> {
387        let lookup = self.node_lookup.read();
388        let entry_points = self.entry_points.read();
389
390        // Check that we have nodes
391        if lookup.is_empty() {
392            return Err(RGraphError::validation("Graph has no nodes"));
393        }
394
395        // Check that we have entry points
396        if entry_points.is_empty() {
397            return Err(RGraphError::validation("Graph has no entry points"));
398        }
399
400        // Validate that all entry points exist
401        for entry_point in entry_points.iter() {
402            if !lookup.contains_key(entry_point) {
403                return Err(RGraphError::validation(format!(
404                    "Entry point '{}' does not exist",
405                    entry_point.as_str()
406                )));
407            }
408        }
409
410        Ok(())
411    }
412}
413
414/// Builder for creating workflow graphs with a fluent API
415pub struct GraphBuilder {
416    graph: WorkflowGraph,
417}
418
419impl GraphBuilder {
420    /// Create a new graph builder
421    pub fn new(name: impl Into<String>) -> Self {
422        Self {
423            graph: WorkflowGraph::new(name),
424        }
425    }
426
427    /// Set the graph description
428    pub fn description(mut self, description: impl Into<String>) -> Self {
429        self.graph = self.graph.with_description(description);
430        self
431    }
432
433    /// Add a node to the graph
434    pub async fn add_node(
435        mut self,
436        node_id: impl Into<NodeId>,
437        node: Arc<dyn Node>,
438    ) -> RGraphResult<Self> {
439        self.graph.add_node(node_id, node).await?;
440        Ok(self)
441    }
442
443    /// Add an edge between two nodes
444    pub fn add_edge(
445        mut self,
446        from: impl Into<NodeId>,
447        to: impl Into<NodeId>,
448    ) -> RGraphResult<Self> {
449        self.graph.add_edge(from, to)?;
450        Ok(self)
451    }
452
453    /// Set entry points
454    pub fn entry_points(mut self, entry_points: Vec<NodeId>) -> Self {
455        self.graph.set_entry_points(entry_points);
456        self
457    }
458
459    /// Build the workflow graph
460    pub fn build(self) -> RGraphResult<WorkflowGraph> {
461        self.graph.validate()?;
462        Ok(self.graph)
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::state::StateValue;
470
471    // Mock node for testing
472    struct TestNode {
473        id: NodeId,
474        name: String,
475    }
476
477    impl TestNode {
478        fn new(id: impl Into<NodeId>, name: impl Into<String>) -> Arc<Self> {
479            Arc::new(Self {
480                id: id.into(),
481                name: name.into(),
482            })
483        }
484    }
485
486    #[async_trait]
487    impl Node for TestNode {
488        async fn execute(
489            &self,
490            state: &mut GraphState,
491            _context: &ExecutionContext,
492        ) -> RGraphResult<ExecutionResult> {
493            state.set(
494                "executed_nodes",
495                StateValue::Array(vec![StateValue::String(self.name.clone())]),
496            );
497            Ok(ExecutionResult::Continue)
498        }
499
500        fn id(&self) -> &NodeId {
501            &self.id
502        }
503
504        fn name(&self) -> &str {
505            &self.name
506        }
507    }
508
509    #[tokio::test]
510    async fn test_graph_creation() {
511        let mut graph = WorkflowGraph::new("test_graph");
512        assert_eq!(graph.name(), "test_graph");
513
514        let node = TestNode::new("test_node", "Test Node");
515        graph.add_node("test_node", node).await.unwrap();
516
517        assert_eq!(graph.node_ids().len(), 1);
518        assert!(graph.node_ids().contains(&NodeId::new("test_node")));
519    }
520
521    #[tokio::test]
522    async fn test_graph_builder() {
523        let node1 = TestNode::new("node1", "Node 1");
524        let node2 = TestNode::new("node2", "Node 2");
525
526        let graph = GraphBuilder::new("test_graph")
527            .description("A test graph")
528            .add_node("node1", node1)
529            .await
530            .unwrap()
531            .add_node("node2", node2)
532            .await
533            .unwrap()
534            .add_edge("node1", "node2")
535            .unwrap()
536            .build()
537            .unwrap();
538
539        assert_eq!(graph.name(), "test_graph");
540        assert_eq!(graph.description(), Some("A test graph"));
541        assert_eq!(graph.node_ids().len(), 2);
542    }
543
544    #[test]
545    fn test_node_id() {
546        let id1 = NodeId::new("test");
547        let id2 = NodeId::from("test");
548        let id3: NodeId = "test".into();
549
550        assert_eq!(id1, id2);
551        assert_eq!(id2, id3);
552        assert_eq!(id1.as_str(), "test");
553    }
554
555    #[test]
556    fn test_execution_context() {
557        let context = ExecutionContext::new("graph1".to_string(), NodeId::new("node1"))
558            .with_metadata("key".to_string(), serde_json::json!("value"));
559
560        assert_eq!(context.graph_id, "graph1");
561        assert_eq!(context.current_node, NodeId::new("node1"));
562        assert!(context.metadata.contains_key("key"));
563    }
564}