1use crate::state::GraphState;
7use crate::{RGraphError, RGraphResult};
8use async_trait::async_trait;
9use petgraph::{Directed, Graph};
10use std::collections::HashMap;
11use std::sync::Arc;
12use uuid::Uuid;
13type NodeIndex = petgraph::graph::NodeIndex;
14#[allow(dead_code)]
15type EdgeIndex = petgraph::graph::EdgeIndex;
16use parking_lot::RwLock;
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24pub struct NodeId(pub String);
25
26impl NodeId {
27 pub fn new(id: impl Into<String>) -> Self {
29 Self(id.into())
30 }
31
32 pub fn generate() -> Self {
34 Self(Uuid::new_v4().to_string())
35 }
36
37 pub fn as_str(&self) -> &str {
39 &self.0
40 }
41}
42
43impl From<String> for NodeId {
44 fn from(id: String) -> Self {
45 NodeId(id)
46 }
47}
48
49impl From<&str> for NodeId {
50 fn from(id: &str) -> Self {
51 NodeId(id.to_string())
52 }
53}
54
55#[derive(Debug, Clone, PartialEq, Eq, Hash)]
57#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
58pub struct EdgeId(pub String);
59
60impl EdgeId {
61 pub fn new(id: impl Into<String>) -> Self {
63 Self(id.into())
64 }
65
66 pub fn generate() -> Self {
68 Self(Uuid::new_v4().to_string())
69 }
70}
71
72#[async_trait]
74pub trait Node: Send + Sync {
75 async fn execute(
77 &self,
78 state: &mut GraphState,
79 context: &ExecutionContext,
80 ) -> RGraphResult<ExecutionResult>;
81
82 fn id(&self) -> &NodeId;
84
85 fn name(&self) -> &str;
87
88 fn description(&self) -> Option<&str> {
90 None
91 }
92
93 fn input_keys(&self) -> Vec<&str> {
95 vec![]
96 }
97
98 fn output_keys(&self) -> Vec<&str> {
100 vec![]
101 }
102
103 fn validate(&self, _state: &GraphState) -> RGraphResult<()> {
105 Ok(())
106 }
107
108 fn metadata(&self) -> NodeMetadata {
110 NodeMetadata {
111 id: self.id().clone(),
112 name: self.name().to_string(),
113 description: self.description().map(|s| s.to_string()),
114 input_keys: self.input_keys().iter().map(|s| s.to_string()).collect(),
115 output_keys: self.output_keys().iter().map(|s| s.to_string()).collect(),
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
123pub struct NodeMetadata {
124 pub id: NodeId,
125 pub name: String,
126 pub description: Option<String>,
127 pub input_keys: Vec<String>,
128 pub output_keys: Vec<String>,
129}
130
131#[derive(Debug, Clone)]
133#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
134pub struct Edge {
135 pub id: EdgeId,
136 pub from: NodeId,
137 pub to: NodeId,
138 pub condition: Option<EdgeCondition>,
139}
140
141#[derive(Debug, Clone)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144pub enum EdgeCondition {
145 Always,
147 Conditional(String), StateCondition {
151 key: String,
152 expected_value: serde_json::Value,
153 },
154}
155
156#[derive(Debug, Clone)]
158#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
159pub enum ExecutionResult {
160 Continue,
162 Stop,
164 JumpTo(NodeId),
166 Route(String), }
169
170#[derive(Debug, Clone)]
172pub struct ExecutionContext {
173 pub graph_id: String,
174 pub execution_id: String,
175 pub current_node: NodeId,
176 pub execution_path: Vec<NodeId>,
177 pub start_time: chrono::DateTime<chrono::Utc>,
178 pub metadata: HashMap<String, serde_json::Value>,
179}
180
181impl ExecutionContext {
182 pub fn new(graph_id: String, current_node: NodeId) -> Self {
183 Self {
184 graph_id,
185 execution_id: Uuid::new_v4().to_string(),
186 current_node,
187 execution_path: Vec::new(),
188 start_time: chrono::Utc::now(),
189 metadata: HashMap::new(),
190 }
191 }
192
193 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
194 self.metadata.insert(key, value);
195 self
196 }
197}
198
199pub struct WorkflowGraph {
201 id: String,
202 name: String,
203 description: Option<String>,
204 graph: Arc<RwLock<Graph<Arc<dyn Node>, Edge, Directed>>>,
205 node_lookup: Arc<RwLock<HashMap<NodeId, NodeIndex>>>,
206 entry_points: Arc<RwLock<Vec<NodeId>>>,
207 exit_points: Arc<RwLock<Vec<NodeId>>>,
208}
209
210impl WorkflowGraph {
211 pub fn new(name: impl Into<String>) -> Self {
213 Self {
214 id: Uuid::new_v4().to_string(),
215 name: name.into(),
216 description: None,
217 graph: Arc::new(RwLock::new(Graph::new())),
218 node_lookup: Arc::new(RwLock::new(HashMap::new())),
219 entry_points: Arc::new(RwLock::new(Vec::new())),
220 exit_points: Arc::new(RwLock::new(Vec::new())),
221 }
222 }
223
224 pub fn with_description(mut self, description: impl Into<String>) -> Self {
226 self.description = Some(description.into());
227 self
228 }
229
230 pub async fn add_node(
232 &mut self,
233 node_id: impl Into<NodeId>,
234 node: Arc<dyn Node>,
235 ) -> RGraphResult<()> {
236 let node_id = node_id.into();
237
238 let dummy_state = GraphState::new();
240 node.validate(&dummy_state)?;
241
242 let mut graph = self.graph.write();
243 let mut lookup = self.node_lookup.write();
244
245 if lookup.contains_key(&node_id) {
247 return Err(RGraphError::validation(format!(
248 "Node '{}' already exists",
249 node_id.as_str()
250 )));
251 }
252
253 let node_index = graph.add_node(node);
255 lookup.insert(node_id.clone(), node_index);
256
257 if lookup.len() == 1 {
259 self.entry_points.write().push(node_id);
260 }
261
262 Ok(())
263 }
264
265 pub fn add_edge(
267 &mut self,
268 from: impl Into<NodeId>,
269 to: impl Into<NodeId>,
270 ) -> RGraphResult<EdgeId> {
271 self.add_edge_with_condition(from, to, EdgeCondition::Always)
272 }
273
274 pub fn add_edge_with_condition(
276 &mut self,
277 from: impl Into<NodeId>,
278 to: impl Into<NodeId>,
279 condition: EdgeCondition,
280 ) -> RGraphResult<EdgeId> {
281 let from_id = from.into();
282 let to_id = to.into();
283 let edge_id = EdgeId::generate();
284
285 let graph_lock = self.graph.clone();
286 let lookup_lock = self.node_lookup.clone();
287
288 let mut graph = graph_lock.write();
289 let lookup = lookup_lock.read();
290
291 let from_index = lookup.get(&from_id).ok_or_else(|| {
293 RGraphError::validation(format!("Node '{}' not found", from_id.as_str()))
294 })?;
295 let to_index = lookup.get(&to_id).ok_or_else(|| {
296 RGraphError::validation(format!("Node '{}' not found", to_id.as_str()))
297 })?;
298
299 let edge = Edge {
301 id: edge_id.clone(),
302 from: from_id,
303 to: to_id,
304 condition: Some(condition),
305 };
306
307 graph.add_edge(*from_index, *to_index, edge);
309
310 Ok(edge_id)
311 }
312
313 pub fn add_conditional_edge<F>(
315 &mut self,
316 from: impl Into<NodeId>,
317 _condition_fn: F,
318 ) -> RGraphResult<EdgeId>
319 where
320 F: Fn(&GraphState) -> RGraphResult<String> + Send + Sync + 'static,
321 {
322 let _from_id = from.into();
325 let edge_id = EdgeId::generate();
326
327 Ok(edge_id)
330 }
331
332 pub fn set_entry_points(&mut self, entry_points: Vec<NodeId>) {
334 *self.entry_points.write() = entry_points;
335 }
336
337 pub fn set_exit_points(&mut self, exit_points: Vec<NodeId>) {
339 *self.exit_points.write() = exit_points;
340 }
341
342 pub fn id(&self) -> &str {
344 &self.id
345 }
346
347 pub fn name(&self) -> &str {
349 &self.name
350 }
351
352 pub fn description(&self) -> Option<&str> {
354 self.description.as_deref()
355 }
356
357 pub fn node_ids(&self) -> Vec<NodeId> {
359 self.node_lookup.read().keys().cloned().collect()
360 }
361
362 pub fn entry_points(&self) -> Vec<NodeId> {
364 self.entry_points.read().clone()
365 }
366
367 pub fn entry_points_owned(&self) -> Vec<NodeId> {
369 self.entry_points.read().clone()
370 }
371
372 pub fn get_node(&self, node_id: &NodeId) -> Option<Arc<dyn Node>> {
374 let lookup = self.node_lookup.read();
375 let graph = self.graph.read();
376
377 if let Some(&node_index) = lookup.get(node_id) {
378 if let Some(node_weight) = graph.node_weight(node_index) {
379 return Some(node_weight.clone());
380 }
381 }
382 None
383 }
384
385 pub fn validate(&self) -> RGraphResult<()> {
387 let lookup = self.node_lookup.read();
388 let entry_points = self.entry_points.read();
389
390 if lookup.is_empty() {
392 return Err(RGraphError::validation("Graph has no nodes"));
393 }
394
395 if entry_points.is_empty() {
397 return Err(RGraphError::validation("Graph has no entry points"));
398 }
399
400 for entry_point in entry_points.iter() {
402 if !lookup.contains_key(entry_point) {
403 return Err(RGraphError::validation(format!(
404 "Entry point '{}' does not exist",
405 entry_point.as_str()
406 )));
407 }
408 }
409
410 Ok(())
411 }
412}
413
414pub struct GraphBuilder {
416 graph: WorkflowGraph,
417}
418
419impl GraphBuilder {
420 pub fn new(name: impl Into<String>) -> Self {
422 Self {
423 graph: WorkflowGraph::new(name),
424 }
425 }
426
427 pub fn description(mut self, description: impl Into<String>) -> Self {
429 self.graph = self.graph.with_description(description);
430 self
431 }
432
433 pub async fn add_node(
435 mut self,
436 node_id: impl Into<NodeId>,
437 node: Arc<dyn Node>,
438 ) -> RGraphResult<Self> {
439 self.graph.add_node(node_id, node).await?;
440 Ok(self)
441 }
442
443 pub fn add_edge(
445 mut self,
446 from: impl Into<NodeId>,
447 to: impl Into<NodeId>,
448 ) -> RGraphResult<Self> {
449 self.graph.add_edge(from, to)?;
450 Ok(self)
451 }
452
453 pub fn entry_points(mut self, entry_points: Vec<NodeId>) -> Self {
455 self.graph.set_entry_points(entry_points);
456 self
457 }
458
459 pub fn build(self) -> RGraphResult<WorkflowGraph> {
461 self.graph.validate()?;
462 Ok(self.graph)
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469 use crate::state::StateValue;
470
471 struct TestNode {
473 id: NodeId,
474 name: String,
475 }
476
477 impl TestNode {
478 fn new(id: impl Into<NodeId>, name: impl Into<String>) -> Arc<Self> {
479 Arc::new(Self {
480 id: id.into(),
481 name: name.into(),
482 })
483 }
484 }
485
486 #[async_trait]
487 impl Node for TestNode {
488 async fn execute(
489 &self,
490 state: &mut GraphState,
491 _context: &ExecutionContext,
492 ) -> RGraphResult<ExecutionResult> {
493 state.set(
494 "executed_nodes",
495 StateValue::Array(vec![StateValue::String(self.name.clone())]),
496 );
497 Ok(ExecutionResult::Continue)
498 }
499
500 fn id(&self) -> &NodeId {
501 &self.id
502 }
503
504 fn name(&self) -> &str {
505 &self.name
506 }
507 }
508
509 #[tokio::test]
510 async fn test_graph_creation() {
511 let mut graph = WorkflowGraph::new("test_graph");
512 assert_eq!(graph.name(), "test_graph");
513
514 let node = TestNode::new("test_node", "Test Node");
515 graph.add_node("test_node", node).await.unwrap();
516
517 assert_eq!(graph.node_ids().len(), 1);
518 assert!(graph.node_ids().contains(&NodeId::new("test_node")));
519 }
520
521 #[tokio::test]
522 async fn test_graph_builder() {
523 let node1 = TestNode::new("node1", "Node 1");
524 let node2 = TestNode::new("node2", "Node 2");
525
526 let graph = GraphBuilder::new("test_graph")
527 .description("A test graph")
528 .add_node("node1", node1)
529 .await
530 .unwrap()
531 .add_node("node2", node2)
532 .await
533 .unwrap()
534 .add_edge("node1", "node2")
535 .unwrap()
536 .build()
537 .unwrap();
538
539 assert_eq!(graph.name(), "test_graph");
540 assert_eq!(graph.description(), Some("A test graph"));
541 assert_eq!(graph.node_ids().len(), 2);
542 }
543
544 #[test]
545 fn test_node_id() {
546 let id1 = NodeId::new("test");
547 let id2 = NodeId::from("test");
548 let id3: NodeId = "test".into();
549
550 assert_eq!(id1, id2);
551 assert_eq!(id2, id3);
552 assert_eq!(id1.as_str(), "test");
553 }
554
555 #[test]
556 fn test_execution_context() {
557 let context = ExecutionContext::new("graph1".to_string(), NodeId::new("node1"))
558 .with_metadata("key".to_string(), serde_json::json!("value"));
559
560 assert_eq!(context.graph_id, "graph1");
561 assert_eq!(context.current_node, NodeId::new("node1"));
562 assert!(context.metadata.contains_key("key"));
563 }
564}