enki_local/
mesh.rs

1use async_trait::async_trait;
2use dashmap::DashMap;
3use enki_core::agent::{Agent, AgentContext};
4use enki_core::error::{Error, Result};
5use enki_core::mesh::Mesh;
6use enki_core::message::Message;
7use enki_core::request_queue::{MeshRequest, MeshResult, RequestQueue};
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::sync::mpsc::{self, Sender};
11use tokio::task::JoinHandle;
12use tracing::{error, info, trace};
13
14pub struct LocalMesh {
15    name: String,
16    agents: DashMap<String, Sender<Message>>,
17    tasks: DashMap<String, JoinHandle<()>>,
18    request_queue: Arc<RequestQueue>,
19}
20
21impl LocalMesh {
22    pub fn new(name: impl Into<String>) -> Self {
23        Self {
24            name: name.into(),
25            agents: DashMap::new(),
26            tasks: DashMap::new(),
27            request_queue: Arc::new(RequestQueue::new()),
28        }
29    }
30
31    /// Get the request queue (for sharing with agent wrappers)
32    pub fn request_queue(&self) -> Arc<RequestQueue> {
33        self.request_queue.clone()
34    }
35}
36
37#[async_trait]
38impl Mesh for LocalMesh {
39    async fn start(&self) -> Result<()> {
40        // For local mesh, start might just be a signal,
41        // but agents are started when added for now.
42        Ok(())
43    }
44
45    async fn stop(&self) -> Result<()> {
46        for entry in self.tasks.iter() {
47            entry.value().abort();
48        }
49        self.tasks.clear();
50        self.agents.clear();
51        Ok(())
52    }
53
54    async fn add_agent(&self, mut agent: Box<dyn Agent + 'static>) -> Result<()> {
55        let name = agent.name();
56        if self.agents.contains_key(&name) {
57            return Err(Error::MeshError(format!("Agent {} already exists", name)));
58        }
59
60        let (tx, mut rx) = mpsc::channel(100);
61        self.agents.insert(name.clone(), tx);
62
63        let mesh_name = self.name.clone();
64        let agent_name = name.clone();
65        let request_queue = self.request_queue.clone();
66
67        // Spawn agent loop
68        let handle = tokio::spawn(async move {
69            let mut ctx = AgentContext::new(mesh_name, Some(request_queue));
70            if let Err(e) = agent.on_start(&mut ctx).await {
71                error!(agent = %agent_name, error = ?e, "Error starting agent");
72                return;
73            }
74
75            while let Some(msg) = rx.recv().await {
76                let now = chrono::Utc::now().timestamp_micros();
77                let latency = (now - msg.created_at).max(0) as u64;
78                trace!(
79                    agent = %agent_name,
80                    latency_us = latency,
81                    msg_id = %msg.id,
82                    from = %msg.from,
83                    to = %msg.to,
84                    content_type = %msg.content.type_name(),
85                    correlation_id = ?msg.correlation_id,
86                    "processing message"
87                );
88
89                let start = std::time::Instant::now();
90                if let Err(e) = agent.on_message(msg, &mut ctx).await {
91                    error!(agent = %agent_name, error = ?e, "Error processing message");
92                }
93                let duration = start.elapsed().as_micros() as u64;
94                trace!(agent = %agent_name, duration_us = duration, "message processed");
95            }
96
97            if let Err(e) = agent.on_stop(&mut ctx).await {
98                error!(agent = %agent_name, error = ?e, "Error stopping agent");
99            }
100        });
101
102        self.tasks.insert(name, handle);
103        Ok(())
104    }
105
106    async fn send(&self, message: Message) -> Result<()> {
107        // Route based on message target
108        // Clone target to avoid borrow checker issues when moving message
109        let target = message.to.clone();
110
111        match target {
112            enki_core::message::MessageTarget::Agent(target_agent) => {
113                if let Some(sender) = self.agents.get(&target_agent) {
114                    info!(
115                        mesh = %self.name,
116                        from = %message.from,
117                        to = %target_agent,
118                        content_type = %message.content.type_name(),
119                        msg_id = %message.id,
120                        correlation_id = ?message.correlation_id,
121                        "Sending message between agents"
122                    );
123
124                    sender.send(message).await.map_err(|_| {
125                        Error::MeshError(format!("Failed to send to agent {}", target_agent))
126                    })?;
127                    Ok(())
128                } else {
129                    Err(Error::AgentNotFound(target_agent))
130                }
131            }
132            enki_core::message::MessageTarget::Broadcast => {
133                // Broadcast to all except sender
134                for entry in self.agents.iter() {
135                    if entry.key() != &message.from {
136                        let _ = entry.value().send(message.clone()).await;
137                    }
138                }
139                Ok(())
140            }
141            _ => Err(Error::MeshError(
142                "Unsupported message target (Topic and Node routing not yet implemented)"
143                    .to_string(),
144            )),
145        }
146    }
147}
148
149impl LocalMesh {
150    /// Broadcast a message to all agents in the mesh
151    pub async fn broadcast(
152        &self,
153        message: Message,
154        exclude: Option<&str>,
155    ) -> Result<Vec<Result<()>>> {
156        let mut results = Vec::new();
157        let recipient_count = self.agents.len() - if exclude.is_some() { 1 } else { 0 };
158
159        info!(
160            mesh = %self.name,
161            from = %message.from,
162            recipients = recipient_count,
163            excluded = ?exclude,
164            "Broadcasting message to agents"
165        );
166
167        for entry in self.agents.iter() {
168            let agent_name = entry.key();
169
170            // Skip excluded agent if specified
171            if let Some(excluded) = exclude {
172                if agent_name == excluded {
173                    continue;
174                }
175            }
176
177            // Clone message for each agent
178            let mut msg = message.clone();
179            msg.to = enki_core::message::MessageTarget::Agent(agent_name.clone());
180
181            trace!(
182                mesh = %self.name,
183                from = %message.from,
184                to = %agent_name,
185                msg_id = %msg.id,
186                "Broadcasting to individual agent"
187            );
188
189            let result = entry.value().send(msg).await.map_err(|_| {
190                Error::MeshError(format!("Failed to broadcast to agent {}", agent_name))
191            });
192
193            results.push(result);
194        }
195
196        Ok(results)
197    }
198
199    /// Submit a request (fire-and-forget). Returns request ID.
200    pub async fn submit(&self, target: &str, payload: String) -> Result<String> {
201        let request_id = self.request_queue.submit(target, payload.clone());
202
203        info!(
204            mesh = %self.name,
205            target = %target,
206            request_id = %request_id,
207            payload_size = payload.len(),
208            "Submitting request to agent"
209        );
210
211        let msg = Message::text("mesh", payload)
212            .to(enki_core::message::MessageTarget::Agent(target.to_string()))
213            .correlation_id(&request_id)
214            .build();
215
216        self.send(msg).await?;
217        Ok(request_id)
218    }
219
220    /// Get pending requests
221    pub fn get_pending(&self) -> Vec<MeshRequest> {
222        self.request_queue.get_pending()
223    }
224
225    /// Check if there are pending requests
226    pub fn has_pending(&self) -> bool {
227        self.request_queue.has_pending()
228    }
229
230    /// Get available results (removes them from queue)
231    pub fn get_results(&self) -> Vec<MeshResult> {
232        self.request_queue.take_results()
233    }
234
235    /// Peek at results without removing
236    pub fn peek_results(&self) -> Vec<MeshResult> {
237        self.request_queue.peek_results()
238    }
239
240    /// Send reminders for stale requests
241    pub async fn send_reminders(&self, older_than_secs: f64) -> Result<Vec<String>> {
242        let stale = self
243            .request_queue
244            .get_stale(Duration::from_secs_f64(older_than_secs));
245        let mut reminded = Vec::new();
246
247        for req in stale {
248            let reminder_msg = Message::text(
249                "mesh",
250                format!("REMINDER: Please complete request {}", req.id),
251            )
252            .to(enki_core::message::MessageTarget::Agent(req.target.clone()))
253            .correlation_id(&req.id)
254            .build();
255
256            if self.send(reminder_msg).await.is_ok() {
257                self.request_queue.increment_reminder(&req.id);
258                reminded.push(req.id);
259            }
260        }
261
262        Ok(reminded)
263    }
264
265    /// Wait for a specific result with auto-reminders
266    pub async fn wait_for(
267        &self,
268        request_id: &str,
269        timeout_secs: f64,
270        reminder_interval_secs: f64,
271    ) -> Result<MeshResult> {
272        let timeout = Duration::from_secs_f64(timeout_secs);
273        let reminder_interval = Duration::from_secs_f64(reminder_interval_secs);
274        let deadline = std::time::Instant::now() + timeout;
275        let mut last_reminder = std::time::Instant::now();
276
277        loop {
278            // Check if result is available
279            if let Some(result) = self.request_queue.take_result(request_id) {
280                return Ok(result);
281            }
282
283            // Check if request still exists
284            let pending = self.request_queue.get_pending();
285            let request = pending.iter().find(|r| r.id == request_id);
286            if request.is_none() {
287                return Err(Error::MeshError(format!(
288                    "Request {} not found",
289                    request_id
290                )));
291            }
292
293            // Check timeout
294            if std::time::Instant::now() >= deadline {
295                return Err(Error::MeshError(format!(
296                    "Request {} timed out",
297                    request_id
298                )));
299            }
300
301            // Send reminder if interval passed
302            if last_reminder.elapsed() >= reminder_interval {
303                let _ = self.send_reminders(reminder_interval_secs).await;
304                last_reminder = std::time::Instant::now();
305            }
306
307            // Wait briefly
308            tokio::time::sleep(Duration::from_millis(100)).await;
309        }
310    }
311
312    /// Collect all results, blocking until all pending complete
313    pub async fn collect_results(&self, reminder_interval_secs: f64) -> Vec<MeshResult> {
314        let reminder_interval = Duration::from_secs_f64(reminder_interval_secs);
315        let mut last_reminder = std::time::Instant::now();
316        let mut all_results = Vec::new();
317
318        while self.request_queue.has_pending() {
319            // Collect any available results
320            let results = self.request_queue.take_results();
321            all_results.extend(results);
322
323            // Send reminders if interval passed
324            if last_reminder.elapsed() >= reminder_interval {
325                let _ = self.send_reminders(reminder_interval_secs).await;
326                last_reminder = std::time::Instant::now();
327            }
328
329            // Wait briefly before checking again
330            tokio::time::sleep(Duration::from_millis(100)).await;
331        }
332
333        // Get any final results
334        let results = self.request_queue.take_results();
335        all_results.extend(results);
336
337        all_results
338    }
339}