use crate::coordinator::Coordinator;
use crate::shard::ShardedColony;
use crate::types::*;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct RunnerConfig {
pub phase_timeout_ms: u64,
pub resolve_ghosts: bool,
pub max_parallelism: usize,
}
impl Default for RunnerConfig {
fn default() -> Self {
Self {
phase_timeout_ms: 30_000,
resolve_ghosts: true,
max_parallelism: 8,
}
}
}
pub struct DistributedRunner {
coordinator: Arc<Coordinator>,
shards: Vec<Arc<RwLock<ShardedColony>>>,
config: RunnerConfig,
}
impl DistributedRunner {
pub fn new(
coordinator: Arc<Coordinator>,
shards: Vec<Arc<RwLock<ShardedColony>>>,
config: RunnerConfig,
) -> Self {
Self {
coordinator,
shards,
config,
}
}
pub async fn tick(&self) -> DistributedResult<DistributedTickResult> {
let tick = self.coordinator.current_tick();
let mut phase_results = Vec::new();
let mut all_cross_edges = Vec::new();
let sense_results = self.run_phase(TickPhase::Sense, tick).await?;
phase_results.extend(sense_results);
let act_results = self.run_phase(TickPhase::Act, tick).await?;
for result in &act_results {
all_cross_edges.extend(result.cross_shard_edges.clone());
}
phase_results.extend(act_results);
let decay_results = self.run_phase(TickPhase::Decay, tick).await?;
phase_results.extend(decay_results);
let new_tick = self.coordinator.advance_tick().await;
if self.config.resolve_ghosts && !all_cross_edges.is_empty() {
self.resolve_cross_shard_edges(&all_cross_edges).await?;
}
Ok(DistributedTickResult {
tick: new_tick,
phase_results,
cross_shard_edges: all_cross_edges,
})
}
pub async fn run(&self, num_ticks: u64) -> DistributedResult<Vec<DistributedTickResult>> {
let mut results = Vec::with_capacity(num_ticks as usize);
for _ in 0..num_ticks {
results.push(self.tick().await?);
}
Ok(results)
}
async fn run_phase(&self, phase: TickPhase, tick: u64) -> DistributedResult<Vec<PhaseResult>> {
use futures::future::join_all;
let futures: Vec<_> = self
.shards
.iter()
.map(|shard| {
let shard = shard.clone();
async move {
let mut s = shard.write().await;
s.tick_phase(phase)
}
})
.collect();
let results = join_all(futures).await;
for result in &results {
self.coordinator
.phase_complete(result.shard_id, phase, tick)
.await?;
}
self.coordinator.wait_for_phase(phase, tick).await?;
Ok(results)
}
async fn resolve_cross_shard_edges(&self, edges: &[CrossShardEdge]) -> DistributedResult<()> {
use std::collections::HashMap;
let mut by_shard: HashMap<ShardId, Vec<&CrossShardEdge>> = HashMap::new();
for edge in edges {
by_shard.entry(edge.to_shard).or_default().push(edge);
}
for (shard_id, shard_edges) in by_shard {
let node_ids: Vec<_> = shard_edges.iter().map(|e| e.to_node).collect();
for shard in &self.shards {
let s = shard.read().await;
if s.shard_id() == shard_id {
for node_id in &node_ids {
if let Some(node_data) = s.get_node(node_id) {
for requesting_edge in
shard_edges.iter().filter(|e| e.to_node == *node_id)
{
for req_shard in &self.shards {
let mut req = req_shard.write().await;
if req.get_node(&requesting_edge.from_node).is_some() {
let ghost = GhostNode::new(
*node_id,
shard_id,
node_data.label.clone(),
);
req.ghost_cache_mut().insert(ghost);
}
}
}
}
}
break;
}
}
}
Ok(())
}
pub fn coordinator(&self) -> &Arc<Coordinator> {
&self.coordinator
}
pub fn shard_count(&self) -> usize {
self.shards.len()
}
pub fn shards(&self) -> &[Arc<RwLock<ShardedColony>>] {
&self.shards
}
pub fn config(&self) -> &RunnerConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct DistributedTickResult {
pub tick: u64,
pub phase_results: Vec<PhaseResult>,
pub cross_shard_edges: Vec<CrossShardEdge>,
}
impl DistributedTickResult {
pub fn total_nodes(&self) -> usize {
let mut shard_counts: std::collections::HashMap<ShardId, usize> =
std::collections::HashMap::new();
for result in &self.phase_results {
shard_counts.insert(result.shard_id, result.node_count);
}
shard_counts.values().sum()
}
pub fn total_edges(&self) -> usize {
let mut shard_counts: std::collections::HashMap<ShardId, usize> =
std::collections::HashMap::new();
for result in &self.phase_results {
shard_counts.insert(result.shard_id, result.edge_count);
}
shard_counts.values().sum()
}
pub fn has_cross_shard_activity(&self) -> bool {
!self.cross_shard_edges.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hashing::ConsistentHashRing;
use phago_runtime::colony::ColonyConfig;
fn create_test_cluster(num_shards: u32) -> (Arc<Coordinator>, Vec<Arc<RwLock<ShardedColony>>>) {
let coordinator = Arc::new(Coordinator::new(num_shards));
let hash_ring = Arc::new(RwLock::new(ConsistentHashRing::new(num_shards)));
let shards: Vec<_> = (0..num_shards)
.map(|i| {
Arc::new(RwLock::new(ShardedColony::new(
ShardId::new(i),
ColonyConfig::default(),
hash_ring.clone(),
)))
})
.collect();
(coordinator, shards)
}
#[tokio::test]
async fn test_runner_creation() {
let (coordinator, shards) = create_test_cluster(3);
let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
assert_eq!(runner.shard_count(), 3);
assert_eq!(runner.config().phase_timeout_ms, 30_000);
assert!(runner.config().resolve_ghosts);
}
#[tokio::test]
async fn test_single_tick() {
let (coordinator, shards) = create_test_cluster(3);
let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
let result = runner.tick().await.unwrap();
assert_eq!(result.tick, 1);
assert!(!result.phase_results.is_empty());
}
#[tokio::test]
async fn test_multiple_ticks() {
let (coordinator, shards) = create_test_cluster(2);
let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
let results = runner.run(5).await.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results.last().unwrap().tick, 5);
}
#[tokio::test]
async fn test_tick_result_methods() {
let (coordinator, shards) = create_test_cluster(2);
let runner = DistributedRunner::new(coordinator, shards, RunnerConfig::default());
let result = runner.tick().await.unwrap();
let _ = result.total_nodes();
let _ = result.total_edges();
assert!(!result.has_cross_shard_activity()); }
#[tokio::test]
async fn test_config_custom() {
let config = RunnerConfig {
phase_timeout_ms: 5_000,
resolve_ghosts: false,
max_parallelism: 4,
};
let (coordinator, shards) = create_test_cluster(2);
let runner = DistributedRunner::new(coordinator, shards, config);
assert_eq!(runner.config().phase_timeout_ms, 5_000);
assert!(!runner.config().resolve_ghosts);
assert_eq!(runner.config().max_parallelism, 4);
}
#[tokio::test]
async fn test_concurrent_ticks() {
let (coordinator, shards) = create_test_cluster(4);
let runner = Arc::new(DistributedRunner::new(
coordinator,
shards,
RunnerConfig::default(),
));
let results = runner.run(10).await.unwrap();
for (i, result) in results.iter().enumerate() {
assert_eq!(result.tick, (i + 1) as u64);
}
}
}