use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SwarmRole {
Leader,
Worker,
}
#[derive(Debug, Clone)]
pub struct SwarmConfig {
pub role: SwarmRole,
pub token: String,
pub listen_addr: String,
pub leader_addr: Option<String>,
pub mdns_discovery: bool,
}
impl Default for SwarmConfig {
fn default() -> Self {
Self {
role: SwarmRole::Leader,
token: String::new(),
listen_addr: "0.0.0.0:4001".to_string(),
leader_addr: None,
mdns_discovery: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SwarmNode {
pub node_id: String,
pub address: String,
pub role: SwarmRole,
pub gpu_count: usize,
pub vram_per_gpu: Vec<u64>,
pub assigned_layers: Option<(usize, usize)>,
}
pub struct SwarmManager {
config: SwarmConfig,
nodes: HashMap<String, SwarmNode>,
}
impl SwarmManager {
pub fn new(config: SwarmConfig) -> Self {
Self {
config,
nodes: HashMap::new(),
}
}
pub fn register_node(&mut self, node: SwarmNode) {
tracing::info!(
node_id = %node.node_id,
address = %node.address,
gpus = node.gpu_count,
role = ?node.role,
"Node registered in swarm"
);
self.nodes.insert(node.node_id.clone(), node);
}
pub fn remove_node(&mut self, node_id: &str) {
if self.nodes.remove(node_id).is_some() {
tracing::info!(node_id = %node_id, "Node removed from swarm");
}
}
pub fn compute_layer_assignment(&self, total_layers: usize) -> HashMap<String, (usize, usize)> {
let mut nodes: Vec<(&String, &SwarmNode)> = self.nodes.iter().collect();
if nodes.is_empty() {
return HashMap::new();
}
nodes.sort_by(|a, b| {
let vram_a: u64 = a.1.vram_per_gpu.iter().sum();
let vram_b: u64 = b.1.vram_per_gpu.iter().sum();
vram_b.cmp(&vram_a)
});
let total_vram: u64 = nodes
.iter()
.map(|(_, n)| n.vram_per_gpu.iter().sum::<u64>())
.sum();
let mut assignments = HashMap::new();
let mut current_layer = 0;
for (i, (node_id, node)) in nodes.iter().enumerate() {
let node_vram: u64 = node.vram_per_gpu.iter().sum();
let num_layers = if i == nodes.len() - 1 {
total_layers - current_layer
} else {
((node_vram as f64 / total_vram as f64) * total_layers as f64).round() as usize
};
let end_layer = (current_layer + num_layers).min(total_layers);
assignments.insert((*node_id).clone(), (current_layer, end_layer));
current_layer = end_layer;
}
assignments
}
pub fn config(&self) -> &SwarmConfig {
&self.config
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn total_gpus(&self) -> usize {
self.nodes.values().map(|n| n.gpu_count).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_assignment_single_node() {
let mut mgr = SwarmManager::new(SwarmConfig::default());
mgr.register_node(SwarmNode {
node_id: "node-0".into(),
address: "127.0.0.1:4001".into(),
role: SwarmRole::Leader,
gpu_count: 1,
vram_per_gpu: vec![24 * 1024 * 1024 * 1024],
assigned_layers: None,
});
let assignments = mgr.compute_layer_assignment(32);
assert_eq!(assignments.get("node-0"), Some(&(0, 32)));
}
#[test]
fn test_layer_assignment_two_nodes() {
let mut mgr = SwarmManager::new(SwarmConfig::default());
mgr.register_node(SwarmNode {
node_id: "node-0".into(),
address: "10.0.0.1:4001".into(),
role: SwarmRole::Leader,
gpu_count: 1,
vram_per_gpu: vec![24_000_000_000],
assigned_layers: None,
});
mgr.register_node(SwarmNode {
node_id: "node-1".into(),
address: "10.0.0.2:4001".into(),
role: SwarmRole::Worker,
gpu_count: 1,
vram_per_gpu: vec![24_000_000_000],
assigned_layers: None,
});
let assignments = mgr.compute_layer_assignment(32);
let total_assigned: usize = assignments.values().map(|(s, e)| e - s).sum();
assert_eq!(total_assigned, 32);
assert_eq!(assignments.len(), 2);
}
}