Skip to main content

phago_distributed/
runner.rs

1//! Async distributed tick runner.
2//!
3//! Coordinates tick execution across multiple shards with proper
4//! phase synchronization and cross-shard edge resolution.
5
6use crate::coordinator::Coordinator;
7use crate::shard::ShardedColony;
8use crate::types::*;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Configuration for the distributed runner.
13#[derive(Debug, Clone)]
14pub struct RunnerConfig {
15    /// Timeout for each phase in milliseconds.
16    pub phase_timeout_ms: u64,
17    /// Whether to resolve ghost nodes after each tick.
18    pub resolve_ghosts: bool,
19    /// Maximum parallel operations.
20    pub max_parallelism: usize,
21}
22
23impl Default for RunnerConfig {
24    fn default() -> Self {
25        Self {
26            phase_timeout_ms: 30_000,
27            resolve_ghosts: true,
28            max_parallelism: 8,
29        }
30    }
31}
32
33/// Orchestrates distributed tick execution.
34///
35/// The `DistributedRunner` coordinates tick execution across multiple shards,
36/// ensuring proper phase synchronization via barriers. Each tick consists of
37/// four phases (Sense, Act, Decay, Advance) that are executed in order across
38/// all shards before moving to the next phase.
39///
40/// # Architecture
41///
42/// The runner uses a coordinator for global synchronization and maintains
43/// references to all shard instances. During each tick:
44///
45/// 1. **Sense Phase**: All shards prepare agent decisions (read-only)
46/// 2. **Act Phase**: All shards execute agent actions (write operations)
47/// 3. **Decay Phase**: All shards decay signals, traces, and edges
48/// 4. **Advance Phase**: Coordinator advances the global tick counter
49///
50/// After the Act phase, any cross-shard edges are collected and ghost nodes
51/// are resolved if configured.
52///
53/// # Example
54///
55/// ```ignore
56/// use phago_distributed::runner::{DistributedRunner, RunnerConfig};
57///
58/// let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
59///
60/// // Run a single tick
61/// let result = runner.tick().await?;
62/// println!("Completed tick {}", result.tick);
63///
64/// // Run multiple ticks
65/// let results = runner.run(10).await?;
66/// ```
67pub struct DistributedRunner {
68    coordinator: Arc<Coordinator>,
69    shards: Vec<Arc<RwLock<ShardedColony>>>,
70    config: RunnerConfig,
71}
72
73impl DistributedRunner {
74    /// Create a new distributed runner.
75    ///
76    /// # Arguments
77    ///
78    /// * `coordinator` - The coordinator for global synchronization
79    /// * `shards` - Vector of shard instances wrapped in Arc<RwLock<_>>
80    /// * `config` - Runner configuration
81    pub fn new(
82        coordinator: Arc<Coordinator>,
83        shards: Vec<Arc<RwLock<ShardedColony>>>,
84        config: RunnerConfig,
85    ) -> Self {
86        Self {
87            coordinator,
88            shards,
89            config,
90        }
91    }
92
93    /// Run a single distributed tick.
94    ///
95    /// Executes all four phases (Sense, Act, Decay, Advance) with
96    /// barrier synchronization between each phase.
97    ///
98    /// # Returns
99    ///
100    /// A `DistributedTickResult` containing the new tick number, phase results,
101    /// and any cross-shard edges that were created.
102    ///
103    /// # Errors
104    ///
105    /// Returns a `DistributedError` if:
106    /// - Phase synchronization times out
107    /// - Cross-shard edge resolution fails
108    pub async fn tick(&self) -> DistributedResult<DistributedTickResult> {
109        let tick = self.coordinator.current_tick();
110        let mut phase_results = Vec::new();
111        let mut all_cross_edges = Vec::new();
112
113        // Phase 1: Sense
114        let sense_results = self.run_phase(TickPhase::Sense, tick).await?;
115        phase_results.extend(sense_results);
116
117        // Phase 2: Act
118        let act_results = self.run_phase(TickPhase::Act, tick).await?;
119        for result in &act_results {
120            all_cross_edges.extend(result.cross_shard_edges.clone());
121        }
122        phase_results.extend(act_results);
123
124        // Phase 3: Decay
125        let decay_results = self.run_phase(TickPhase::Decay, tick).await?;
126        phase_results.extend(decay_results);
127
128        // Phase 4: Advance
129        let new_tick = self.coordinator.advance_tick().await;
130
131        // Resolve ghost nodes if configured
132        if self.config.resolve_ghosts && !all_cross_edges.is_empty() {
133            self.resolve_cross_shard_edges(&all_cross_edges).await?;
134        }
135
136        Ok(DistributedTickResult {
137            tick: new_tick,
138            phase_results,
139            cross_shard_edges: all_cross_edges,
140        })
141    }
142
143    /// Run multiple ticks.
144    ///
145    /// # Arguments
146    ///
147    /// * `num_ticks` - Number of ticks to execute
148    ///
149    /// # Returns
150    ///
151    /// A vector of `DistributedTickResult` for each tick executed.
152    pub async fn run(&self, num_ticks: u64) -> DistributedResult<Vec<DistributedTickResult>> {
153        let mut results = Vec::with_capacity(num_ticks as usize);
154        for _ in 0..num_ticks {
155            results.push(self.tick().await?);
156        }
157        Ok(results)
158    }
159
160    /// Execute a single phase across all shards.
161    ///
162    /// Runs the specified phase on all shards in parallel, then waits
163    /// for all shards to complete before returning.
164    async fn run_phase(&self, phase: TickPhase, tick: u64) -> DistributedResult<Vec<PhaseResult>> {
165        use futures::future::join_all;
166
167        // Execute phase on all shards in parallel
168        let futures: Vec<_> = self
169            .shards
170            .iter()
171            .map(|shard| {
172                let shard = shard.clone();
173                async move {
174                    let mut s = shard.write().await;
175                    s.tick_phase(phase)
176                }
177            })
178            .collect();
179
180        let results = join_all(futures).await;
181
182        // Signal phase completion to coordinator
183        for result in &results {
184            self.coordinator
185                .phase_complete(result.shard_id, phase, tick)
186                .await?;
187        }
188
189        // Wait for all shards to complete
190        self.coordinator.wait_for_phase(phase, tick).await?;
191
192        Ok(results)
193    }
194
195    /// Resolve cross-shard edges by fetching ghost nodes.
196    ///
197    /// For each cross-shard edge, fetches the target node's data from
198    /// the owning shard and caches it as a ghost node in the requesting shard.
199    async fn resolve_cross_shard_edges(&self, edges: &[CrossShardEdge]) -> DistributedResult<()> {
200        use std::collections::HashMap;
201
202        // Group edges by target shard
203        let mut by_shard: HashMap<ShardId, Vec<&CrossShardEdge>> = HashMap::new();
204        for edge in edges {
205            by_shard.entry(edge.to_shard).or_default().push(edge);
206        }
207
208        // Fetch ghost nodes from each shard
209        for (shard_id, shard_edges) in by_shard {
210            let node_ids: Vec<_> = shard_edges.iter().map(|e| e.to_node).collect();
211
212            // Find the shard and fetch nodes
213            for shard in &self.shards {
214                let s = shard.read().await;
215                if s.shard_id() == shard_id {
216                    for node_id in &node_ids {
217                        if let Some(node_data) = s.get_node(node_id) {
218                            // Cache the ghost node in requesting shards
219                            for requesting_edge in
220                                shard_edges.iter().filter(|e| e.to_node == *node_id)
221                            {
222                                // Find requesting shard and update its ghost cache
223                                for req_shard in &self.shards {
224                                    let mut req = req_shard.write().await;
225                                    // Check if this shard has the from_node
226                                    if req.get_node(&requesting_edge.from_node).is_some() {
227                                        let ghost = GhostNode::new(
228                                            *node_id,
229                                            shard_id,
230                                            node_data.label.clone(),
231                                        );
232                                        req.ghost_cache_mut().insert(ghost);
233                                    }
234                                }
235                            }
236                        }
237                    }
238                    break;
239                }
240            }
241        }
242
243        Ok(())
244    }
245
246    /// Get the coordinator.
247    pub fn coordinator(&self) -> &Arc<Coordinator> {
248        &self.coordinator
249    }
250
251    /// Get shard count.
252    pub fn shard_count(&self) -> usize {
253        self.shards.len()
254    }
255
256    /// Get a reference to all shards.
257    pub fn shards(&self) -> &[Arc<RwLock<ShardedColony>>] {
258        &self.shards
259    }
260
261    /// Get runner configuration.
262    pub fn config(&self) -> &RunnerConfig {
263        &self.config
264    }
265}
266
267/// Result of a distributed tick.
268#[derive(Debug, Clone)]
269pub struct DistributedTickResult {
270    /// The tick number after completion.
271    pub tick: u64,
272    /// Results from each phase.
273    pub phase_results: Vec<PhaseResult>,
274    /// Cross-shard edges created this tick.
275    pub cross_shard_edges: Vec<CrossShardEdge>,
276}
277
278impl DistributedTickResult {
279    /// Get the total number of nodes across all shards after this tick.
280    pub fn total_nodes(&self) -> usize {
281        // Get the node count from the last phase result for each shard
282        let mut shard_counts: std::collections::HashMap<ShardId, usize> =
283            std::collections::HashMap::new();
284        for result in &self.phase_results {
285            shard_counts.insert(result.shard_id, result.node_count);
286        }
287        shard_counts.values().sum()
288    }
289
290    /// Get the total number of edges across all shards after this tick.
291    pub fn total_edges(&self) -> usize {
292        let mut shard_counts: std::collections::HashMap<ShardId, usize> =
293            std::collections::HashMap::new();
294        for result in &self.phase_results {
295            shard_counts.insert(result.shard_id, result.edge_count);
296        }
297        shard_counts.values().sum()
298    }
299
300    /// Check if any cross-shard communication occurred this tick.
301    pub fn has_cross_shard_activity(&self) -> bool {
302        !self.cross_shard_edges.is_empty()
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309    use crate::hashing::ConsistentHashRing;
310    use phago_runtime::colony::ColonyConfig;
311
312    fn create_test_cluster(num_shards: u32) -> (Arc<Coordinator>, Vec<Arc<RwLock<ShardedColony>>>) {
313        let coordinator = Arc::new(Coordinator::new(num_shards));
314        let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(num_shards)));
315
316        let shards: Vec<_> = (0..num_shards)
317            .map(|i| {
318                Arc::new(RwLock::new(ShardedColony::new(
319                    ShardId::new(i),
320                    ColonyConfig::default(),
321                    hash_ring.clone(),
322                )))
323            })
324            .collect();
325
326        (coordinator, shards)
327    }
328
329    #[tokio::test]
330    async fn test_runner_creation() {
331        let (coordinator, shards) = create_test_cluster(3);
332        let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
333
334        assert_eq!(runner.shard_count(), 3);
335        assert_eq!(runner.config().phase_timeout_ms, 30_000);
336        assert!(runner.config().resolve_ghosts);
337    }
338
339    #[tokio::test]
340    async fn test_single_tick() {
341        let (coordinator, shards) = create_test_cluster(3);
342        let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
343
344        let result = runner.tick().await.unwrap();
345        assert_eq!(result.tick, 1);
346        // Should have results from all phases (Sense, Act, Decay) for all 3 shards
347        // Note: Advance phase doesn't produce PhaseResult in our implementation
348        assert!(!result.phase_results.is_empty());
349    }
350
351    #[tokio::test]
352    async fn test_multiple_ticks() {
353        let (coordinator, shards) = create_test_cluster(2);
354        let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
355
356        let results = runner.run(5).await.unwrap();
357        assert_eq!(results.len(), 5);
358        assert_eq!(results.last().unwrap().tick, 5);
359    }
360
361    #[tokio::test]
362    async fn test_tick_result_methods() {
363        let (coordinator, shards) = create_test_cluster(2);
364        let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
365
366        let result = runner.tick().await.unwrap();
367
368        // These should work even with empty graphs
369        let _ = result.total_nodes();
370        let _ = result.total_edges();
371        assert!(!result.has_cross_shard_activity()); // No cross-shard edges in basic test
372    }
373
374    #[tokio::test]
375    async fn test_config_custom() {
376        let config = RunnerConfig {
377            phase_timeout_ms: 5_000,
378            resolve_ghosts: false,
379            max_parallelism: 4,
380        };
381
382        let (coordinator, shards) = create_test_cluster(2);
383        let runner = DistributedRunner::new(coordinator, shards, config);
384
385        assert_eq!(runner.config().phase_timeout_ms, 5_000);
386        assert!(!runner.config().resolve_ghosts);
387        assert_eq!(runner.config().max_parallelism, 4);
388    }
389
390    #[tokio::test]
391    async fn test_concurrent_ticks() {
392        let (coordinator, shards) = create_test_cluster(4);
393        let runner = Arc::new(DistributedRunner::new(
394            coordinator,
395            shards,
396            RunnerConfig::default(),
397        ));
398
399        // Run 10 sequential ticks (concurrent tick execution would require
400        // additional synchronization which the runner doesn't currently support)
401        let results = runner.run(10).await.unwrap();
402
403        // Verify tick numbers are sequential
404        for (i, result) in results.iter().enumerate() {
405            assert_eq!(result.tick, (i + 1) as u64);
406        }
407    }
408}