ai_agents/
pipeline_net.rs1use std::{collections::HashMap, sync::Arc};
2
3use crate::{
4 error::Error,
5 sync::RwLock,
6 traits::{Adapter, UnitProcess},
7 ModuleParam,
8};
9
10struct Edge {
11 to: String,
12 adapter: Option<Arc<dyn Adapter>>,
13}
14
15pub struct PipelineNet {
16 nodes: HashMap<String, Arc<RwLock<dyn UnitProcess + Send + Sync>>>,
17 edges: HashMap<String, Vec<Edge>>,
18 groups: HashMap<String, String>, }
20
21impl PipelineNet {
22 pub fn new() -> Self {
23 Self {
24 nodes: HashMap::new(),
25 edges: HashMap::new(),
26 groups: HashMap::new(),
27 }
28 }
29
30 pub fn add_node(&mut self, name: &str, node: Arc<RwLock<dyn UnitProcess + Send + Sync>>) {
32 self.nodes.insert(name.into(), node);
33 }
34 pub fn add_edge(&mut self, from: &str, to: &str) {
36 let edge = Edge {
37 to: to.to_string(),
38 adapter: None,
39 };
40 self.edges.entry(from.to_string()).or_default().push(edge);
41 }
42
43 pub fn add_edge_with_adapter<A: Adapter + 'static>(
45 &mut self,
46 from: &str,
47 to: &str,
48 adapter: A,
49 ) {
50 let edge = Edge {
51 to: to.to_string(),
52 adapter: Some(Arc::new(adapter)),
53 };
54 self.edges.entry(from.to_string()).or_default().push(edge);
55 }
56
57 pub fn set_group_input(&mut self, group_name: &str, input_node_name: &str) {
59 self.groups
60 .insert(group_name.into(), input_node_name.into());
61 }
62
63 pub async fn process_group(
65 &self,
66 group_name: &str,
67 initial_input: ModuleParam,
68 ) -> Result<HashMap<String, ModuleParam>, Error> {
69 let input_node_name = self
70 .groups
71 .get(group_name)
72 .ok_or_else(|| Error::NotFound(group_name.to_string()))?;
73
74 let mut results = HashMap::new();
75 let mut stack = vec![(input_node_name.as_str(), initial_input)];
76
77 while let Some((current_node_name, input)) = stack.pop() {
79 if results.contains_key(current_node_name) {
80 continue; }
82
83 let node = self
84 .nodes
85 .get(current_node_name)
86 .ok_or_else(|| Error::NotFound(current_node_name.to_string()))?;
87
88 let processed_input = node.read().await.process(input).await?;
89
90 results.insert(current_node_name.to_string(), processed_input.clone());
93
94 if let Some(edges) = self.edges.get(current_node_name) {
95 for edge in edges {
96 let adapted_input = edge
97 .adapter
98 .as_ref()
99 .map(|adapter| adapter.adapt(processed_input.clone()))
100 .unwrap_or_else(|| processed_input.clone());
101 stack.push((&edge.to, adapted_input));
102 }
103 }
104 }
105
106 Ok(results)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
115 use crate::sync::block_on;
116 use async_trait::async_trait;
117 use std::sync::Arc;
118
119 #[derive(Default)]
121 struct MockUnitProcess;
122
123 #[async_trait]
124 impl UnitProcess for MockUnitProcess {
125 fn get_name(&self) -> &str {
126 "MockUnit"
127 }
128 async fn process(&self, input: ModuleParam) -> Result<ModuleParam, Error> {
129 Ok(input)
130 }
131 }
132
133 struct MockAdapter;
134
135 impl Adapter for MockAdapter {
136 fn adapt(&self, input: ModuleParam) -> ModuleParam {
137 input
138 }
139 }
140
141 #[test]
142 fn test_pipeline_net() {
143 let mut pipeline = PipelineNet::new();
144
145 let mock_input: &str = "TestInput";
147 let initial_input = ModuleParam::Str(mock_input.into());
148
149 let node1 = Arc::new(RwLock::new(MockUnitProcess::default()));
151 let node2 = Arc::new(RwLock::new(MockUnitProcess::default()));
152
153 pipeline.add_node("node1", node1);
154 pipeline.add_node("node2", node2);
155
156 pipeline.add_edge_with_adapter("node1", "node2", |v: ModuleParam| {
157 if let ModuleParam::Str(param) = v.clone() {
158 assert_eq!(param, "TestInput");
159 }
160 v
161 });
162
163 pipeline.set_group_input("group1", "node1");
165
166 block_on(async move {
167 let results = pipeline
168 .process_group("group1", initial_input)
169 .await
170 .expect("Failed to process group");
171
172 assert!(results.contains_key("node1"));
173 assert!(results.contains_key("node2"));
174 assert_eq!(
175 results.get("node1").unwrap().as_string().unwrap(),
176 mock_input
177 );
178 assert_eq!(
179 results.get("node2").unwrap().as_string().unwrap(),
180 mock_input
181 );
182 });
183 }
184}