use thiserror::Error;
#[derive(Debug, Error)]
pub enum ModError {
#[error("hidden_dim mismatch: router expects {expected}, got {actual}")]
DimMismatch { expected: usize, actual: usize },
#[error("empty token sequence")]
EmptySequence,
#[error("capacity_factor {0} must be in (0, 1]")]
InvalidCapacity(f32),
}
#[derive(Debug, Clone)]
pub struct ModConfig {
pub capacity_factor: f32,
pub hidden_dim: usize,
pub normalize_router: bool,
}
impl ModConfig {
pub fn new(capacity_factor: f32, hidden_dim: usize) -> Self {
Self {
capacity_factor,
hidden_dim,
normalize_router: false,
}
}
pub fn with_normalize(mut self, norm: bool) -> Self {
self.normalize_router = norm;
self
}
}
impl Default for ModConfig {
fn default() -> Self {
Self::new(0.5, 128)
}
}
struct Lcg64 {
state: u64,
}
impl Lcg64 {
fn new(seed: u64) -> Self {
Self {
state: seed.wrapping_add(1),
}
}
fn next_u64(&mut self) -> u64 {
self.state = self
.state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
self.state
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 33) as f32 / (1u64 << 31) as f32
}
}
pub struct ModRouter {
config: ModConfig,
weights: Vec<f32>,
}
impl ModRouter {
pub fn new(config: ModConfig, seed: u64) -> Self {
let hidden_dim = config.hidden_dim;
let mut rng = Lcg64::new(seed);
let scale = (1.0_f32 / hidden_dim as f32).sqrt();
let weights: Vec<f32> = (0..hidden_dim)
.map(|_| (rng.next_f32() * 2.0 - 1.0) * scale)
.collect();
Self { config, weights }
}
pub fn score_tokens(&self, tokens: &[f32], seq_len: usize) -> Result<Vec<f32>, ModError> {
if seq_len == 0 {
return Err(ModError::EmptySequence);
}
let hd = self.config.hidden_dim;
if tokens.len() != seq_len * hd {
return Err(ModError::DimMismatch {
expected: seq_len * hd,
actual: tokens.len(),
});
}
let mut scores: Vec<f32> = (0..seq_len)
.map(|i| {
let row = &tokens[i * hd..(i + 1) * hd];
row.iter()
.zip(self.weights.iter())
.map(|(x, w)| x * w)
.sum()
})
.collect();
if self.config.normalize_router {
let max_s = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = scores.iter().map(|s| (s - max_s).exp()).sum();
if sum_exp > 0.0 {
for s in &mut scores {
*s = (*s - max_s).exp() / sum_exp;
}
}
}
Ok(scores)
}
pub fn select_tokens(&self, scores: &[f32], seq_len: usize) -> Vec<usize> {
let k = self.capacity(seq_len);
if k == 0 || seq_len == 0 {
return vec![];
}
let mut indexed: Vec<(f32, usize)> =
scores.iter().enumerate().map(|(i, &s)| (s, i)).collect();
for rank in 0..k {
let mut best = rank;
for j in (rank + 1)..indexed.len() {
if indexed[j].0 > indexed[best].0 {
best = j;
}
}
indexed.swap(rank, best);
}
let mut selected: Vec<usize> = indexed[..k].iter().map(|&(_, idx)| idx).collect();
selected.sort_unstable();
selected
}
pub fn capacity(&self, seq_len: usize) -> usize {
if seq_len == 0 {
return 0;
}
let k = (self.config.capacity_factor * seq_len as f32).round() as usize;
k.clamp(1, seq_len)
}
}
pub fn mixture_of_depths_forward<F>(
hidden: &[f32],
seq_len: usize,
hidden_dim: usize,
router: &ModRouter,
layer_fn: F,
) -> Result<Vec<f32>, ModError>
where
F: Fn(&[f32], usize) -> Vec<f32>,
{
if seq_len == 0 {
return Err(ModError::EmptySequence);
}
if router.config.capacity_factor <= 0.0 || router.config.capacity_factor > 1.0 {
return Err(ModError::InvalidCapacity(router.config.capacity_factor));
}
if hidden.len() != seq_len * hidden_dim {
return Err(ModError::DimMismatch {
expected: seq_len * hidden_dim,
actual: hidden.len(),
});
}
let scores = router.score_tokens(hidden, seq_len)?;
let selected_indices = router.select_tokens(&scores, seq_len);
let selected_count = selected_indices.len();
let mut selected_buf: Vec<f32> = Vec::with_capacity(selected_count * hidden_dim);
for &idx in &selected_indices {
let row = &hidden[idx * hidden_dim..(idx + 1) * hidden_dim];
selected_buf.extend_from_slice(row);
}
let processed = layer_fn(&selected_buf, selected_count);
let mut output = hidden.to_vec();
for (rank, &idx) in selected_indices.iter().enumerate() {
let src = &processed[rank * hidden_dim..(rank + 1) * hidden_dim];
let dst = &mut output[idx * hidden_dim..(idx + 1) * hidden_dim];
dst.copy_from_slice(src);
}
Ok(output)
}
#[derive(Debug, Clone)]
pub struct ModStats {
pub seq_len: usize,
pub tokens_processed: usize,
pub tokens_skipped: usize,
pub capacity_utilization: f32,
pub compute_reduction: f32,
}
impl ModStats {
pub fn compute(seq_len: usize, tokens_processed: usize) -> Self {
let tokens_skipped = seq_len.saturating_sub(tokens_processed);
let compute_reduction = if seq_len == 0 {
0.0
} else {
1.0 - tokens_processed as f32 / seq_len as f32
};
let capacity_utilization = if tokens_processed == 0 {
0.0
} else {
1.0_f32
};
Self {
seq_len,
tokens_processed,
tokens_skipped,
capacity_utilization,
compute_reduction,
}
}
pub fn summary(&self) -> String {
format!(
"MoD: seq={} processed={} skipped={} reduction={:.1}% utilization={:.1}%",
self.seq_len,
self.tokens_processed,
self.tokens_skipped,
self.compute_reduction * 100.0,
self.capacity_utilization * 100.0,
)
}
}