use std::sync::atomic::{AtomicUsize, Ordering};
use super::shard_manager::ShardLocation;
#[derive(Debug, Clone)]
pub struct CoordinatorConfig {
pub num_devices: usize,
pub strict_ordering: bool,
pub timeout_ms: u64,
}
impl Default for CoordinatorConfig {
fn default() -> Self {
Self {
num_devices: 1,
strict_ordering: true,
timeout_ms: 30_000,
}
}
}
#[derive(Debug, Default)]
pub struct CoordinatorStats {
pub appends: AtomicUsize,
pub attentions: AtomicUsize,
pub transfers: AtomicUsize,
pub active_sequences: AtomicUsize,
}
impl CoordinatorStats {
pub fn new() -> Self {
Self::default()
}
pub fn record_append(&self) {
self.appends.fetch_add(1, Ordering::Relaxed);
}
pub fn record_attention(&self) {
self.attentions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_transfer(&self) {
self.transfers.fetch_add(1, Ordering::Relaxed);
}
pub fn increment_sequences(&self) {
self.active_sequences.fetch_add(1, Ordering::Relaxed);
}
pub fn decrement_sequences(&self) {
self.active_sequences.fetch_sub(1, Ordering::Relaxed);
}
pub fn snapshot(&self) -> CoordinatorStatsSnapshot {
CoordinatorStatsSnapshot {
appends: self.appends.load(Ordering::Relaxed),
attentions: self.attentions.load(Ordering::Relaxed),
transfers: self.transfers.load(Ordering::Relaxed),
active_sequences: self.active_sequences.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct CoordinatorStatsSnapshot {
pub appends: usize,
pub attentions: usize,
pub transfers: usize,
pub active_sequences: usize,
}
pub struct Coordinator {
config: CoordinatorConfig,
stats: CoordinatorStats,
}
impl Coordinator {
pub fn new(config: CoordinatorConfig) -> Self {
Self {
config,
stats: CoordinatorStats::new(),
}
}
pub fn config(&self) -> &CoordinatorConfig {
&self.config
}
pub fn stats(&self) -> &CoordinatorStats {
&self.stats
}
pub fn route_append(&self, current_tokens: usize, shard_size: usize) -> usize {
if self.config.num_devices == 1 {
return 0;
}
let shard_id = current_tokens / shard_size;
shard_id % self.config.num_devices
}
pub fn plan_attention<'a>(&self, shards: &'a [ShardLocation]) -> Vec<&'a ShardLocation> {
if self.config.strict_ordering {
let mut ordered: Vec<_> = shards.iter().collect();
ordered.sort_by_key(|s| s.token_start);
ordered
} else {
shards.iter().collect()
}
}
pub fn needs_transfer(&self, _shard: &ShardLocation, _local_device: usize) -> bool {
false
}
pub fn register_sequence(&self) -> usize {
self.stats.increment_sequences();
0
}
pub fn release_sequence(&self, _seq_id: usize) {
self.stats.decrement_sequences();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AggregationStrategy {
Sequential,
Parallel,
Hierarchical,
}
impl Default for AggregationStrategy {
fn default() -> Self {
Self::Sequential
}
}
pub struct AggregationPlan {
shards: Vec<ShardLocation>,
strategy: AggregationStrategy,
strict_order: bool,
}
impl AggregationPlan {
pub fn new(shards: Vec<ShardLocation>) -> Self {
Self {
shards,
strategy: AggregationStrategy::default(),
strict_order: true,
}
}
pub fn with_strategy(mut self, strategy: AggregationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn with_strict_order(mut self, strict: bool) -> Self {
self.strict_order = strict;
self
}
pub fn ordered_shards(&self) -> Vec<&ShardLocation> {
let mut shards: Vec<_> = self.shards.iter().collect();
if self.strict_order {
shards.sort_by_key(|s| s.token_start);
}
shards
}
pub fn strategy(&self) -> AggregationStrategy {
self.strategy
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_coordinator_routing() {
let config = CoordinatorConfig {
num_devices: 4,
..Default::default()
};
let coord = Coordinator::new(config);
assert_eq!(coord.route_append(0, 100), 0);
assert_eq!(coord.route_append(100, 100), 1);
assert_eq!(coord.route_append(200, 100), 2);
assert_eq!(coord.route_append(400, 100), 0); }
#[test]
fn test_coordinator_stats() {
let coord = Coordinator::new(CoordinatorConfig::default());
coord.stats.record_append();
coord.stats.record_append();
coord.stats.record_attention();
let snapshot = coord.stats.snapshot();
assert_eq!(snapshot.appends, 2);
assert_eq!(snapshot.attentions, 1);
}
#[test]
fn test_aggregation_plan() {
let shards = vec![
ShardLocation {
token_start: 200,
token_end: 300,
device_id: 1,
local_shard_id: 1,
},
ShardLocation {
token_start: 0,
token_end: 100,
device_id: 0,
local_shard_id: 0,
},
ShardLocation {
token_start: 100,
token_end: 200,
device_id: 0,
local_shard_id: 1,
},
];
let plan = AggregationPlan::new(shards);
let ordered = plan.ordered_shards();
assert_eq!(ordered[0].token_start, 0);
assert_eq!(ordered[1].token_start, 100);
assert_eq!(ordered[2].token_start, 200);
}
}