agent_stream_kit/
flow.rs

1use std::collections::HashMap;
2use std::sync::atomic::AtomicUsize;
3
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6
7use super::askit::ASKit;
8use super::config::AgentConfigs;
9use super::definition::AgentDefinition;
10use super::error::AgentError;
11
12pub type AgentFlows = HashMap<String, AgentFlow>;
13
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct AgentFlow {
16    #[serde(skip_serializing_if = "String::is_empty")]
17    id: String,
18
19    name: String,
20
21    nodes: Vec<AgentFlowNode>,
22
23    edges: Vec<AgentFlowEdge>,
24
25    #[serde(flatten)]
26    pub extensions: HashMap<String, Value>,
27}
28
29impl AgentFlow {
30    pub fn new(name: String) -> Self {
31        Self {
32            id: new_id(),
33            name,
34            nodes: Vec::new(),
35            edges: Vec::new(),
36            extensions: HashMap::new(),
37        }
38    }
39
40    pub fn id(&self) -> &str {
41        &self.id
42    }
43
44    pub fn name(&self) -> &str {
45        &self.name
46    }
47
48    pub fn set_name(&mut self, new_name: String) {
49        self.name = new_name;
50    }
51
52    pub fn nodes(&self) -> &Vec<AgentFlowNode> {
53        &self.nodes
54    }
55
56    pub fn add_node(&mut self, node: AgentFlowNode) {
57        self.nodes.push(node);
58    }
59
60    pub fn remove_node(&mut self, node_id: &str) {
61        self.nodes.retain(|node| node.id != node_id);
62    }
63
64    pub fn set_nodes(&mut self, nodes: Vec<AgentFlowNode>) {
65        self.nodes = nodes;
66    }
67
68    pub fn edges(&self) -> &Vec<AgentFlowEdge> {
69        &self.edges
70    }
71
72    pub fn add_edge(&mut self, edge: AgentFlowEdge) {
73        self.edges.push(edge);
74    }
75
76    pub fn remove_edge(&mut self, edge_id: &str) -> Option<AgentFlowEdge> {
77        if let Some(edge) = self.edges.iter().find(|edge| edge.id == edge_id).cloned() {
78            self.edges.retain(|e| e.id != edge_id);
79            Some(edge)
80        } else {
81            None
82        }
83    }
84
85    pub fn set_edges(&mut self, edges: Vec<AgentFlowEdge>) {
86        self.edges = edges;
87    }
88
89    pub async fn start(&self, askit: &ASKit) -> Result<(), AgentError> {
90        for agent in self.nodes.iter() {
91            if !agent.enabled {
92                continue;
93            }
94            askit.start_agent(&agent.id).await.unwrap_or_else(|e| {
95                log::error!("Failed to start agent {}: {}", agent.id, e);
96            });
97        }
98        Ok(())
99    }
100
101    pub async fn stop(&self, askit: &ASKit) -> Result<(), AgentError> {
102        for agent in self.nodes.iter() {
103            if !agent.enabled {
104                continue;
105            }
106            askit.stop_agent(&agent.id).await.unwrap_or_else(|e| {
107                log::error!("Failed to stop agent {}: {}", agent.id, e);
108            });
109        }
110        Ok(())
111    }
112
113    pub fn disable_all_nodes(&mut self) {
114        for node in self.nodes.iter_mut() {
115            node.enabled = false;
116        }
117    }
118
119    pub fn to_json(&self) -> Result<String, AgentError> {
120        let json = serde_json::to_string_pretty(self)
121            .map_err(|e| AgentError::SerializationError(e.to_string()))?;
122        Ok(json)
123    }
124
125    pub fn from_json(json_str: &str) -> Result<Self, AgentError> {
126        let mut flow: AgentFlow = serde_json::from_str(json_str)
127            .map_err(|e| AgentError::SerializationError(e.to_string()))?;
128        flow.id = new_id();
129        Ok(flow)
130    }
131}
132
133pub fn copy_sub_flow(
134    nodes: &Vec<AgentFlowNode>,
135    edges: &Vec<AgentFlowEdge>,
136) -> (Vec<AgentFlowNode>, Vec<AgentFlowEdge>) {
137    let mut new_nodes = Vec::new();
138    let mut node_id_map = HashMap::new();
139    for node in nodes {
140        let new_id = new_id();
141        node_id_map.insert(node.id.clone(), new_id.clone());
142        let mut new_node = node.clone();
143        new_node.id = new_id;
144        new_nodes.push(new_node);
145    }
146
147    let mut new_edges = Vec::new();
148    for edge in edges {
149        let Some(source) = node_id_map.get(&edge.source) else {
150            continue;
151        };
152        let Some(target) = node_id_map.get(&edge.target) else {
153            continue;
154        };
155        let mut new_edge = edge.clone();
156        new_edge.id = new_id();
157        new_edge.source = source.clone();
158        new_edge.target = target.clone();
159        new_edges.push(new_edge);
160    }
161
162    (new_nodes, new_edges)
163}
164
165// AgentFlowNode
166
167#[derive(Debug, Default, Serialize, Deserialize, Clone)]
168pub struct AgentFlowNode {
169    pub id: String,
170    pub def_name: String,
171    pub enabled: bool,
172
173    #[serde(skip_serializing_if = "Option::is_none")]
174    pub configs: Option<AgentConfigs>,
175
176    #[serde(flatten)]
177    pub extensions: HashMap<String, Value>,
178}
179
180impl AgentFlowNode {
181    pub fn new(def: &AgentDefinition) -> Result<Self, AgentError> {
182        let configs = if let Some(default_configs) = &def.default_configs {
183            let mut configs = AgentConfigs::new();
184            for (key, entry) in default_configs {
185                configs.set(key.clone(), entry.value.clone());
186            }
187            Some(configs)
188        } else {
189            None
190        };
191
192        Ok(Self {
193            id: new_id(),
194            def_name: def.name.clone(),
195            enabled: false,
196            configs,
197            extensions: HashMap::new(),
198        })
199    }
200}
201
202static NODE_ID_COUNTER: AtomicUsize = AtomicUsize::new(1);
203
204fn new_id() -> String {
205    return NODE_ID_COUNTER
206        .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
207        .to_string();
208}
209
210// AgentFlowEdge
211
212#[derive(Debug, Default, Serialize, Deserialize, Clone)]
213pub struct AgentFlowEdge {
214    pub id: String,
215    pub source: String,
216    pub source_handle: String,
217    pub target: String,
218    pub target_handle: String,
219}