use std::ops::Range;
#[derive(Debug, Clone)]
pub struct ShardConfig {
pub shard_size: usize,
pub num_devices: usize,
pub max_tokens: usize,
}
impl Default for ShardConfig {
fn default() -> Self {
Self {
shard_size: 100_000, num_devices: 1,
max_tokens: 2_000_000, }
}
}
impl ShardConfig {
pub fn single_gpu(max_tokens: usize) -> Self {
Self {
shard_size: max_tokens,
num_devices: 1,
max_tokens,
}
}
pub fn multi_gpu(num_devices: usize, max_tokens: usize) -> Self {
let shard_size = (max_tokens / num_devices).max(10_000);
Self {
shard_size,
num_devices,
max_tokens,
}
}
pub fn num_shards(&self, num_tokens: usize) -> usize {
(num_tokens + self.shard_size - 1) / self.shard_size
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShardLocation {
pub token_start: usize,
pub token_end: usize,
pub device_id: usize,
pub local_shard_id: usize,
}
impl ShardLocation {
pub fn contains(&self, token_idx: usize) -> bool {
token_idx >= self.token_start && token_idx < self.token_end
}
pub fn range(&self) -> Range<usize> {
self.token_start..self.token_end
}
pub fn len(&self) -> usize {
self.token_end - self.token_start
}
pub fn is_empty(&self) -> bool {
self.token_start >= self.token_end
}
}
#[derive(Debug)]
pub struct ShardManager {
config: ShardConfig,
shards: Vec<ShardLocation>,
current_tokens: usize,
}
impl ShardManager {
pub fn new(config: ShardConfig) -> Self {
Self {
config,
shards: Vec::new(),
current_tokens: 0,
}
}
pub fn config(&self) -> &ShardConfig {
&self.config
}
pub fn allocate(&mut self, num_new_tokens: usize) -> Vec<ShardLocation> {
let mut new_shards = Vec::new();
let mut remaining = num_new_tokens;
let mut token_pos = self.current_tokens;
while remaining > 0 {
let device_id = (token_pos / self.config.shard_size) % self.config.num_devices;
let local_shard_id = token_pos / (self.config.shard_size * self.config.num_devices);
let shard_start = (token_pos / self.config.shard_size) * self.config.shard_size;
let shard_end = (shard_start + self.config.shard_size).min(self.config.max_tokens);
let space_in_shard = shard_end - token_pos;
let tokens_to_add = remaining.min(space_in_shard);
let needs_new_shard = self.shards.is_empty()
|| self.shards.last().map(|s| s.token_end).unwrap_or(0) <= token_pos;
if needs_new_shard {
let location = ShardLocation {
token_start: token_pos,
token_end: token_pos + tokens_to_add,
device_id,
local_shard_id,
};
new_shards.push(location);
self.shards.push(location);
} else if let Some(last) = self.shards.last_mut() {
last.token_end = token_pos + tokens_to_add;
new_shards.push(*last);
}
token_pos += tokens_to_add;
remaining -= tokens_to_add;
}
self.current_tokens = token_pos;
new_shards
}
pub fn find_shard(&self, token_idx: usize) -> Option<&ShardLocation> {
let idx = self.shards.partition_point(|s| s.token_end <= token_idx);
self.shards.get(idx).filter(|s| s.contains(token_idx))
}
pub fn shards(&self) -> &[ShardLocation] {
&self.shards
}
pub fn shards_for_range(&self, range: Range<usize>) -> Vec<&ShardLocation> {
self.shards
.iter()
.filter(|s| s.token_start < range.end && s.token_end > range.start)
.collect()
}
pub fn current_tokens(&self) -> usize {
self.current_tokens
}
pub fn num_shards(&self) -> usize {
self.shards.len()
}
pub fn has_capacity(&self, additional: usize) -> bool {
self.current_tokens + additional <= self.config.max_tokens
}
pub fn reset(&mut self) {
self.shards.clear();
self.current_tokens = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shard_config_defaults() {
let config = ShardConfig::default();
assert_eq!(config.shard_size, 100_000);
assert_eq!(config.num_devices, 1);
assert_eq!(config.max_tokens, 2_000_000);
}
#[test]
fn test_shard_manager_allocate() {
let config = ShardConfig {
shard_size: 100,
num_devices: 2,
max_tokens: 1000,
};
let mut manager = ShardManager::new(config);
let shards = manager.allocate(50);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].token_start, 0);
assert_eq!(shards[0].token_end, 50);
assert_eq!(shards[0].device_id, 0);
let shards = manager.allocate(50);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].token_end, 100);
let shards = manager.allocate(50);
assert_eq!(shards.len(), 1);
assert_eq!(shards[0].token_start, 100);
assert_eq!(shards[0].device_id, 1);
}
#[test]
fn test_shard_manager_find() {
let config = ShardConfig {
shard_size: 100,
num_devices: 1,
max_tokens: 1000,
};
let mut manager = ShardManager::new(config);
manager.allocate(250);
let shard = manager.find_shard(50);
assert!(shard.is_some());
assert_eq!(shard.unwrap().token_start, 0);
let shard = manager.find_shard(150);
assert!(shard.is_some());
assert_eq!(shard.unwrap().token_start, 100);
let shard = manager.find_shard(500);
assert!(shard.is_none());
}
#[test]
fn test_multi_device_distribution() {
let config = ShardConfig {
shard_size: 100,
num_devices: 4,
max_tokens: 1000,
};
let mut manager = ShardManager::new(config);
manager.allocate(400);
assert_eq!(manager.num_shards(), 4);
let devices: Vec<_> = manager.shards().iter().map(|s| s.device_id).collect();
assert_eq!(devices, vec![0, 1, 2, 3]);
}
}