ai_agents/
pipeline_net.rs

1use std::{collections::HashMap, sync::Arc};
2
3use crate::{
4    error::Error,
5    sync::RwLock,
6    traits::{Adapter, UnitProcess},
7    ModuleParam,
8};
9
10struct Edge {
11    to: String,
12    adapter: Option<Arc<dyn Adapter>>,
13}
14
15pub struct PipelineNet {
16    nodes: HashMap<String, Arc<RwLock<dyn UnitProcess + Send + Sync>>>,
17    edges: HashMap<String, Vec<Edge>>,
18    groups: HashMap<String, String>, // Maps group names to input node names for each group.
19}
20
21impl PipelineNet {
22    pub fn new() -> Self {
23        Self {
24            nodes: HashMap::new(),
25            edges: HashMap::new(),
26            groups: HashMap::new(),
27        }
28    }
29
30    // Add a node that implements `UnitProcess`
31    pub fn add_node(&mut self, name: &str, node: Arc<RwLock<dyn UnitProcess + Send + Sync>>) {
32        self.nodes.insert(name.into(), node);
33    }
34    // Add an edge between nodes
35    pub fn add_edge(&mut self, from: &str, to: &str) {
36        let edge = Edge {
37            to: to.to_string(),
38            adapter: None,
39        };
40        self.edges.entry(from.to_string()).or_default().push(edge);
41    }
42
43    // Add an edge between nodes with an adapter
44    pub fn add_edge_with_adapter<A: Adapter + 'static>(
45        &mut self,
46        from: &str,
47        to: &str,
48        adapter: A,
49    ) {
50        let edge = Edge {
51            to: to.to_string(),
52            adapter: Some(Arc::new(adapter)),
53        };
54        self.edges.entry(from.to_string()).or_default().push(edge);
55    }
56
57    // Set group with input node.
58    pub fn set_group_input(&mut self, group_name: &str, input_node_name: &str) {
59        self.groups
60            .insert(group_name.into(), input_node_name.into());
61    }
62
63    // Process a group starting from the group's input node, collecting the each results.
64    pub async fn process_group(
65        &self,
66        group_name: &str,
67        initial_input: ModuleParam,
68    ) -> Result<HashMap<String, ModuleParam>, Error> {
69        let input_node_name = self
70            .groups
71            .get(group_name)
72            .ok_or_else(|| Error::NotFound(group_name.to_string()))?;
73
74        let mut results = HashMap::new();
75        let mut stack = vec![(input_node_name.as_str(), initial_input)];
76
77        // bfs
78        while let Some((current_node_name, input)) = stack.pop() {
79            if results.contains_key(current_node_name) {
80                continue; // Skip if visited
81            }
82
83            let node = self
84                .nodes
85                .get(current_node_name)
86                .ok_or_else(|| Error::NotFound(current_node_name.to_string()))?;
87
88            let processed_input = node.read().await.process(input).await?;
89
90            // let processed_input = node.process(input).await?;
91
92            results.insert(current_node_name.to_string(), processed_input.clone());
93
94            if let Some(edges) = self.edges.get(current_node_name) {
95                for edge in edges {
96                    let adapted_input = edge
97                        .adapter
98                        .as_ref()
99                        .map(|adapter| adapter.adapt(processed_input.clone()))
100                        .unwrap_or_else(|| processed_input.clone());
101                    stack.push((&edge.to, adapted_input));
102                }
103            }
104        }
105
106        Ok(results)
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    //
113    //
114    use super::*;
115    use crate::sync::block_on;
116    use async_trait::async_trait;
117    use std::sync::Arc;
118
119    // Mock implementation
120    #[derive(Default)]
121    struct MockUnitProcess;
122
123    #[async_trait]
124    impl UnitProcess for MockUnitProcess {
125        fn get_name(&self) -> &str {
126            "MockUnit"
127        }
128        async fn process(&self, input: ModuleParam) -> Result<ModuleParam, Error> {
129            Ok(input)
130        }
131    }
132
133    struct MockAdapter;
134
135    impl Adapter for MockAdapter {
136        fn adapt(&self, input: ModuleParam) -> ModuleParam {
137            input
138        }
139    }
140
141    #[test]
142    fn test_pipeline_net() {
143        let mut pipeline = PipelineNet::new();
144
145        // Mock input for processing
146        let mock_input: &str = "TestInput";
147        let initial_input = ModuleParam::Str(mock_input.into());
148
149        // Add nodes
150        let node1 = Arc::new(RwLock::new(MockUnitProcess::default()));
151        let node2 = Arc::new(RwLock::new(MockUnitProcess::default()));
152
153        pipeline.add_node("node1", node1);
154        pipeline.add_node("node2", node2);
155
156        pipeline.add_edge_with_adapter("node1", "node2", |v: ModuleParam| {
157            if let ModuleParam::Str(param) = v.clone() {
158                assert_eq!(param, "TestInput");
159            }
160            v
161        });
162
163        // Set group input
164        pipeline.set_group_input("group1", "node1");
165
166        block_on(async move {
167            let results = pipeline
168                .process_group("group1", initial_input)
169                .await
170                .expect("Failed to process group");
171
172            assert!(results.contains_key("node1"));
173            assert!(results.contains_key("node2"));
174            assert_eq!(
175                results.get("node1").unwrap().as_string().unwrap(),
176                mock_input
177            );
178            assert_eq!(
179                results.get("node2").unwrap().as_string().unwrap(),
180                mock_input
181            );
182        });
183    }
184}