agent_stream_kit/
flow.rs

1use std::collections::HashMap;
2use std::sync::atomic::AtomicUsize;
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use super::askit::ASKit;
8use super::config::AgentConfigs;
9use super::definition::AgentDefinition;
10use super::error::AgentError;
11
12pub type AgentFlows = HashMap<String, AgentFlow>;
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct AgentFlow {
16    name: String,
17
18    nodes: Vec<AgentFlowNode>,
19
20    edges: Vec<AgentFlowEdge>,
21
22    #[serde(flatten)]
23    pub extensions: HashMap<String, Value>,
24}
25
26impl AgentFlow {
27    pub fn new(name: String) -> Self {
28        Self {
29            name,
30            nodes: Vec::new(),
31            edges: Vec::new(),
32            extensions: HashMap::new(),
33        }
34    }
35
36    pub fn nodes(&self) -> &Vec<AgentFlowNode> {
37        &self.nodes
38    }
39
40    pub fn edges(&self) -> &Vec<AgentFlowEdge> {
41        &self.edges
42    }
43
44    pub fn name(&self) -> &str {
45        &self.name
46    }
47
48    pub fn set_name(&mut self, new_name: String) {
49        self.name = new_name;
50    }
51
52    pub fn add_node(&mut self, node: AgentFlowNode) {
53        self.nodes.push(node);
54    }
55
56    pub fn remove_node(&mut self, node_id: &str) {
57        self.nodes.retain(|node| node.id != node_id);
58    }
59
60    pub fn set_nodes(&mut self, nodes: Vec<AgentFlowNode>) {
61        self.nodes = nodes;
62    }
63
64    pub fn add_edge(&mut self, edge: AgentFlowEdge) {
65        self.edges.push(edge);
66    }
67
68    pub fn remove_edge(&mut self, edge_id: &str) -> Option<AgentFlowEdge> {
69        if let Some(edge) = self.edges.iter().find(|edge| edge.id == edge_id).cloned() {
70            self.edges.retain(|e| e.id != edge_id);
71            Some(edge)
72        } else {
73            None
74        }
75    }
76
77    pub fn set_edges(&mut self, edges: Vec<AgentFlowEdge>) {
78        self.edges = edges;
79    }
80
81    pub async fn start(&self, askit: &ASKit) -> Result<(), AgentError> {
82        for agent in self.nodes.iter() {
83            if !agent.enabled {
84                continue;
85            }
86            askit.start_agent(&agent.id).await.unwrap_or_else(|e| {
87                log::error!("Failed to start agent {}: {}", agent.id, e);
88            });
89        }
90        Ok(())
91    }
92
93    pub async fn stop(&self, askit: &ASKit) -> Result<(), AgentError> {
94        for agent in self.nodes.iter() {
95            if !agent.enabled {
96                continue;
97            }
98            askit.stop_agent(&agent.id).await.unwrap_or_else(|e| {
99                log::error!("Failed to stop agent {}: {}", agent.id, e);
100            });
101        }
102        Ok(())
103    }
104
105    pub fn disable_all_nodes(&mut self) {
106        for node in self.nodes.iter_mut() {
107            node.enabled = false;
108        }
109    }
110
111    pub fn to_json(&self) -> Result<String, AgentError> {
112        let json = serde_json::to_string_pretty(self)
113            .map_err(|e| AgentError::SerializationError(e.to_string()))?;
114        Ok(json)
115    }
116
117    pub fn from_json(json_str: &str) -> Result<Self, AgentError> {
118        let flow: AgentFlow = serde_json::from_str(json_str)
119            .map_err(|e| AgentError::SerializationError(e.to_string()))?;
120        Ok(flow)
121    }
122}
123
124pub fn copy_sub_flow(
125    nodes: &Vec<AgentFlowNode>,
126    edges: &Vec<AgentFlowEdge>,
127) -> (Vec<AgentFlowNode>, Vec<AgentFlowEdge>) {
128    let mut new_nodes = Vec::new();
129    let mut node_id_map = HashMap::new();
130    for node in nodes {
131        let new_id = new_id();
132        node_id_map.insert(node.id.clone(), new_id.clone());
133        let mut new_node = node.clone();
134        new_node.id = new_id;
135        new_nodes.push(new_node);
136    }
137
138    let mut new_edges = Vec::new();
139    for edge in edges {
140        let Some(source) = node_id_map.get(&edge.source) else {
141            continue;
142        };
143        let Some(target) = node_id_map.get(&edge.target) else {
144            continue;
145        };
146        let mut new_edge = edge.clone();
147        new_edge.id = new_id();
148        new_edge.source = source.clone();
149        new_edge.target = target.clone();
150        new_edges.push(new_edge);
151    }
152
153    (new_nodes, new_edges)
154}
155
156// AgentFlowNode
157
158#[derive(Debug, Default, Serialize, Deserialize, Clone)]
159pub struct AgentFlowNode {
160    pub id: String,
161    pub def_name: String,
162    pub enabled: bool,
163
164    #[serde(skip_serializing_if = "Option::is_none")]
165    pub configs: Option<AgentConfigs>,
166
167    #[serde(flatten)]
168    pub extensions: HashMap<String, Value>,
169}
170
171impl AgentFlowNode {
172    pub fn new(def: &AgentDefinition) -> Result<Self, AgentError> {
173        let configs = if let Some(default_configs) = &def.default_configs {
174            let mut configs = AgentConfigs::new();
175            for (key, entry) in default_configs {
176                configs.set(key.clone(), entry.value.clone());
177            }
178            Some(configs)
179        } else {
180            None
181        };
182
183        Ok(Self {
184            id: new_id(),
185            def_name: def.name.clone(),
186            enabled: false,
187            configs,
188            extensions: HashMap::new(),
189        })
190    }
191}
192
193static NODE_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
194
195fn new_id() -> String {
196    return NODE_ID_COUNTER
197        .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
198        .to_string();
199}
200
201// AgentFlowEdge
202
203#[derive(Debug, Default, Serialize, Deserialize, Clone)]
204pub struct AgentFlowEdge {
205    pub id: String,
206    pub source: String,
207    pub source_handle: String,
208    pub target: String,
209    pub target_handle: String,
210}