1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use serde::{Deserialize, Serialize};
10
11use crate::ir::{WorkflowExecutionId, WorkflowExecution, ExecutionStatus};
12use kotoba_errors::WorkflowError;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct NodeInfo {
17 pub node_id: String,
18 pub address: String,
19 pub capacity: usize, pub active_workflows: usize,
21 pub last_heartbeat: chrono::DateTime<chrono::Utc>,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct DistributedTask {
27 pub task_id: String,
28 pub execution_id: WorkflowExecutionId,
29 pub node_id: Option<String>, pub status: TaskStatus,
31 pub created_at: chrono::DateTime<chrono::Utc>,
32 pub assigned_at: Option<chrono::DateTime<chrono::Utc>>,
33 pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
38pub enum TaskStatus {
39 Pending,
40 Assigned,
41 Running,
42 Completed,
43 Failed,
44}
45
46pub struct DistributedCoordinator {
48 nodes: RwLock<HashMap<String, NodeInfo>>,
49 tasks: RwLock<HashMap<String, DistributedTask>>,
50 load_balancer: Arc<dyn LoadBalancer>,
52}
53
54#[async_trait::async_trait]
56pub trait LoadBalancer: Send + Sync {
57 async fn select_node(&self, nodes: &HashMap<String, NodeInfo>) -> Option<String>;
59}
60
61pub struct RoundRobinBalancer {
63 current_index: std::sync::atomic::AtomicUsize,
64}
65
66impl RoundRobinBalancer {
67 pub fn new() -> Self {
68 Self {
69 current_index: std::sync::atomic::AtomicUsize::new(0),
70 }
71 }
72}
73
74#[async_trait::async_trait]
75impl LoadBalancer for RoundRobinBalancer {
76 async fn select_node(&self, nodes: &HashMap<String, NodeInfo>) -> Option<String> {
77 let available_nodes: Vec<_> = nodes.values()
78 .filter(|node| node.active_workflows < node.capacity)
79 .collect();
80
81 if available_nodes.is_empty() {
82 return None;
83 }
84
85 let index = self.current_index.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
86 let selected = available_nodes[index % available_nodes.len()];
87 Some(selected.node_id.clone())
88 }
89}
90
91pub struct LeastLoadedBalancer;
93
94impl LeastLoadedBalancer {
95 pub fn new() -> Self {
96 Self
97 }
98}
99
100#[async_trait::async_trait]
101impl LoadBalancer for LeastLoadedBalancer {
102 async fn select_node(&self, nodes: &HashMap<String, NodeInfo>) -> Option<String> {
103 nodes.values()
104 .filter(|node| node.active_workflows < node.capacity)
105 .min_by_key(|node| node.active_workflows)
106 .map(|node| node.node_id.clone())
107 }
108}
109
110impl DistributedCoordinator {
111 pub fn new(load_balancer: Arc<dyn LoadBalancer>) -> Self {
112 Self {
113 nodes: RwLock::new(HashMap::new()),
114 tasks: RwLock::new(HashMap::new()),
115 load_balancer,
116 }
117 }
118
119 pub async fn register_node(&self, node: NodeInfo) {
121 let mut nodes = self.nodes.write().await;
122 nodes.insert(node.node_id.clone(), node);
123 }
124
125 pub async fn unregister_node(&self, node_id: &str) {
127 let mut nodes = self.nodes.write().await;
128 nodes.remove(node_id);
129 }
130
131 pub async fn update_heartbeat(&self, node_id: &str) {
133 let mut nodes = self.nodes.write().await;
134 if let Some(node) = nodes.get_mut(node_id) {
135 node.last_heartbeat = chrono::Utc::now();
136 }
137 }
138
139 pub async fn submit_workflow(&self, execution_id: WorkflowExecutionId) -> Result<String, WorkflowError> {
141 let task = DistributedTask {
142 task_id: uuid::Uuid::new_v4().to_string(),
143 execution_id,
144 node_id: None,
145 status: TaskStatus::Pending,
146 created_at: chrono::Utc::now(),
147 assigned_at: None,
148 completed_at: None,
149 };
150
151 let mut tasks = self.tasks.write().await;
152 let task_id = task.task_id.clone();
153 tasks.insert(task_id.clone(), task);
154
155 Ok(task_id)
156 }
157
158 pub async fn assign_task(&self, task_id: &str) -> Result<Option<String>, WorkflowError> {
160 let mut tasks = self.tasks.write().await;
161 let nodes = self.nodes.read().await;
162
163 if let Some(task) = tasks.get_mut(task_id) {
164 if task.status != TaskStatus::Pending {
165 return Ok(None);
166 }
167
168 if let Some(node_id) = self.load_balancer.select_node(&nodes).await {
169 task.node_id = Some(node_id.clone());
170 task.status = TaskStatus::Running;
171 task.assigned_at = Some(chrono::Utc::now());
172
173 drop(tasks);
175 let mut nodes = self.nodes.write().await;
176 if let Some(node) = nodes.get_mut(&node_id) {
177 node.active_workflows += 1;
178 }
179
180 return Ok(Some(node_id));
181 }
182 }
183
184 Ok(None)
185 }
186
187 pub async fn complete_task(&self, task_id: &str, success: bool) -> Result<(), WorkflowError> {
189 let node_id = {
191 let mut tasks = self.tasks.write().await;
192 if let Some(task) = tasks.get_mut(task_id) {
193 task.status = if success { TaskStatus::Completed } else { TaskStatus::Failed };
194 task.completed_at = Some(chrono::Utc::now());
195 task.node_id.clone()
196 } else {
197 None
198 }
199 };
200
201 if let Some(node_id) = node_id {
203 let mut nodes = self.nodes.write().await;
204 if let Some(node) = nodes.get_mut(&node_id) {
205 node.active_workflows = node.active_workflows.saturating_sub(1);
206 }
207 }
208
209 Ok(())
210 }
211
212 pub async fn get_running_tasks(&self) -> Vec<DistributedTask> {
214 let tasks = self.tasks.read().await;
215 tasks.values()
216 .filter(|task| matches!(task.status, TaskStatus::Running | TaskStatus::Assigned))
217 .cloned()
218 .collect()
219 }
220
221 pub async fn get_node_load(&self, node_id: &str) -> Option<f64> {
223 let nodes = self.nodes.read().await;
224 nodes.get(node_id).map(|node| {
225 if node.capacity == 0 {
226 0.0
227 } else {
228 node.active_workflows as f64 / node.capacity as f64
229 }
230 })
231 }
232
233 pub async fn get_cluster_load(&self) -> f64 {
235 let nodes = self.nodes.read().await;
236 if nodes.is_empty() {
237 return 0.0;
238 }
239
240 let total_active: usize = nodes.values().map(|n| n.active_workflows).sum();
241 let total_capacity: usize = nodes.values().map(|n| n.capacity).sum();
242
243 if total_capacity == 0 {
244 0.0
245 } else {
246 total_active as f64 / total_capacity as f64
247 }
248 }
249
250 pub async fn cleanup_dead_nodes(&self, timeout: std::time::Duration) {
252 let mut nodes = self.nodes.write().await;
253 let now = chrono::Utc::now();
254
255 let dead_nodes: Vec<String> = nodes.values()
256 .filter(|node| {
257 let duration = now.signed_duration_since(node.last_heartbeat);
258 duration.to_std().unwrap_or(std::time::Duration::from_secs(0)) > timeout
259 })
260 .map(|node| node.node_id.clone())
261 .collect();
262
263 for node_id in dead_nodes {
264 println!("Removing dead node: {}", node_id);
265 nodes.remove(&node_id);
266 }
267 }
268
269 pub async fn failover_task(&self, _task_id: &str) -> Result<Option<String>, WorkflowError> {
271 Ok(None)
273 }
274}
275
276pub struct DistributedExecutionManager {
278 coordinator: Arc<DistributedCoordinator>,
279 local_node_id: String,
281}
282
283impl DistributedExecutionManager {
284 pub fn new(local_node_id: String, load_balancer: Arc<dyn LoadBalancer>) -> Self {
285 Self {
286 coordinator: Arc::new(DistributedCoordinator::new(load_balancer)),
287 local_node_id,
288 }
289 }
290
291 pub fn coordinator(&self) -> &Arc<DistributedCoordinator> {
293 &self.coordinator
294 }
295
296 pub async fn submit_execution(&self, execution_id: WorkflowExecutionId) -> Result<String, WorkflowError> {
298 self.coordinator.submit_workflow(execution_id).await
299 }
300
301 pub async fn register_local_node(&self, capacity: usize) {
303 let node = NodeInfo {
304 node_id: self.local_node_id.clone(),
305 address: "localhost:8080".to_string(), capacity,
307 active_workflows: 0,
308 last_heartbeat: chrono::Utc::now(),
309 };
310
311 self.coordinator.register_node(node).await;
312 }
313
314 pub fn start_cleanup_task(&self) -> tokio::task::JoinHandle<()> {
316 let coordinator = Arc::clone(&self.coordinator);
317
318 tokio::spawn(async move {
319 let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
320
321 loop {
322 interval.tick().await;
323 coordinator.cleanup_dead_nodes(std::time::Duration::from_secs(60)).await;
324 }
325 })
326 }
327}
328
329pub struct DistributedWorkflowExecutor {
331 pub execution_manager: Arc<DistributedExecutionManager>,
332 execution_stats: RwLock<HashMap<String, NodeExecutionStats>>,
334}
335
336#[derive(Debug, Clone)]
337pub struct NodeExecutionStats {
338 pub total_tasks: usize,
339 pub successful_tasks: usize,
340 pub failed_tasks: usize,
341 pub avg_execution_time: std::time::Duration,
342}
343
344impl DistributedWorkflowExecutor {
345 pub fn new(execution_manager: Arc<DistributedExecutionManager>) -> Self {
346 Self {
347 execution_manager,
348 execution_stats: RwLock::new(HashMap::new()),
349 }
350 }
351
352 pub async fn get_execution_stats(&self) -> HashMap<String, NodeExecutionStats> {
354 let stats = self.execution_stats.read().await;
355 stats.clone()
356 }
357
358 pub async fn cluster_health_check(&self) -> ClusterHealth {
360 let cluster_load = self.execution_manager.coordinator.get_cluster_load().await;
361 let running_tasks = self.execution_manager.coordinator.get_running_tasks().await;
362
363 ClusterHealth {
364 cluster_load,
365 active_tasks: running_tasks.len(),
366 healthy_nodes: 0, unhealthy_nodes: 0,
368 }
369 }
370}
371
372#[derive(Debug)]
373pub struct ClusterHealth {
374 pub cluster_load: f64,
375 pub active_tasks: usize,
376 pub healthy_nodes: usize,
377 pub unhealthy_nodes: usize,
378}