use std::collections::HashMap;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LoraConfig {
pub rank: usize,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
}
impl Default for LoraConfig {
fn default() -> Self {
Self {
rank: 8,
alpha: 8.0,
dropout: 0.0,
target_modules: vec![
"q_proj".to_string(),
"k_proj".to_string(),
"v_proj".to_string(),
"o_proj".to_string(),
"gate_proj".to_string(),
"up_proj".to_string(),
"down_proj".to_string(),
],
}
}
}
pub struct LoraAdapter {
pub config: LoraConfig,
pub a_matrix: Vec<f32>,
pub b_matrix: Vec<f32>,
pub d_in: usize,
pub d_out: usize,
pub scaling: f32,
}
impl LoraAdapter {
pub fn new(d_in: usize, d_out: usize, config: LoraConfig) -> Self {
let rank = config.rank;
let scaling = if rank > 0 {
config.alpha / rank as f32
} else {
0.0
};
let a_size = rank * d_in;
let b_size = d_out * rank;
let bound = if d_in > 0 {
(2.0_f32 / d_in as f32).sqrt()
} else {
0.0
};
let mut rng_state: u64 = (d_in as u64)
.wrapping_mul(6364136223846793005)
.wrapping_add(d_out as u64)
.wrapping_add(rank as u64)
.wrapping_add(1);
let mut next_f32 = move || -> f32 {
rng_state ^= rng_state << 13;
rng_state ^= rng_state >> 7;
rng_state ^= rng_state << 17;
let t = (rng_state as f32) / (u64::MAX as f32);
t * 2.0 * bound - bound
};
let a_matrix: Vec<f32> = (0..a_size).map(|_| next_f32()).collect();
let b_matrix = vec![0.0f32; b_size];
Self {
config,
a_matrix,
b_matrix,
d_in,
d_out,
scaling,
}
}
pub fn apply(&self, x: &[f32]) -> Vec<f32> {
let rank = self.config.rank;
let mut intermediate = vec![0.0f32; rank];
for (r, slot) in intermediate.iter_mut().enumerate().take(rank) {
let row_offset = r * self.d_in;
let mut sum = 0.0f32;
for (j, &xj) in x.iter().enumerate().take(self.d_in) {
sum += self.a_matrix[row_offset + j] * xj;
}
*slot = sum;
}
let mut output = vec![0.0f32; self.d_out];
for (i, slot) in output.iter_mut().enumerate().take(self.d_out) {
let row_offset = i * rank;
let mut sum = 0.0f32;
for (r, &inter_r) in intermediate.iter().enumerate().take(rank) {
sum += self.b_matrix[row_offset + r] * inter_r;
}
*slot = sum * self.scaling;
}
output
}
pub fn param_count(&self) -> usize {
self.d_in * self.config.rank + self.config.rank * self.d_out
}
pub fn memory_bytes(&self) -> usize {
self.param_count() * std::mem::size_of::<f32>()
}
pub fn merge_into_weights(&self, weights: &mut [f32]) {
let rank = self.config.rank;
for i in 0..self.d_out {
for j in 0..self.d_in {
let mut delta = 0.0f32;
for r in 0..rank {
delta += self.b_matrix[i * rank + r] * self.a_matrix[r * self.d_in + j];
}
weights[i * self.d_in + j] += self.scaling * delta;
}
}
}
}
pub struct LoraRegistry {
adapters: HashMap<String, LoraAdapter>,
config: LoraConfig,
}
impl LoraRegistry {
pub fn new(config: LoraConfig) -> Self {
Self {
adapters: HashMap::new(),
config,
}
}
pub fn add(&mut self, module_name: &str, adapter: LoraAdapter) {
self.adapters.insert(module_name.to_string(), adapter);
}
pub fn get(&self, module_name: &str) -> Option<&LoraAdapter> {
self.adapters.get(module_name)
}
pub fn apply_adapter(&self, module_name: &str, x: &[f32]) -> Option<Vec<f32>> {
self.adapters.get(module_name).map(|a| a.apply(x))
}
pub fn adapter_count(&self) -> usize {
self.adapters.len()
}
pub fn total_param_count(&self) -> usize {
self.adapters.values().map(|a| a.param_count()).sum()
}
pub fn total_memory_bytes(&self) -> usize {
self.adapters.values().map(|a| a.memory_bytes()).sum()
}
pub fn is_empty(&self) -> bool {
self.adapters.is_empty()
}
pub fn module_names(&self) -> Vec<&str> {
self.adapters.keys().map(|s| s.as_str()).collect()
}
pub fn config(&self) -> &LoraConfig {
&self.config
}
pub fn estimate_adapters_for_budget(
d_in: usize,
d_out: usize,
rank: usize,
budget_bytes: usize,
) -> usize {
let params_per_adapter = d_in * rank + rank * d_out;
let bytes_per_adapter = params_per_adapter * std::mem::size_of::<f32>();
if bytes_per_adapter == 0 {
return 0;
}
budget_bytes / bytes_per_adapter
}
}
pub struct BonsaiLoraSet {
pub registry: LoraRegistry,
pub base_model_name: String,
pub adapter_name: String,
}
impl BonsaiLoraSet {
pub fn new(config: LoraConfig, model_name: &str, adapter_name: &str) -> Self {
Self {
registry: LoraRegistry::new(config),
base_model_name: model_name.to_string(),
adapter_name: adapter_name.to_string(),
}
}
pub fn add_attention_adapters(&mut self, hidden_size: usize, num_heads: usize) {
let head_dim = hidden_size.checked_div(num_heads).unwrap_or(hidden_size);
let kv_dim = head_dim * num_heads;
let config = self.registry.config.clone();
self.registry.add(
"q_proj",
LoraAdapter::new(hidden_size, hidden_size, config.clone()),
);
self.registry.add(
"k_proj",
LoraAdapter::new(hidden_size, kv_dim, config.clone()),
);
self.registry.add(
"v_proj",
LoraAdapter::new(hidden_size, kv_dim, config.clone()),
);
self.registry
.add("o_proj", LoraAdapter::new(hidden_size, hidden_size, config));
}
pub fn add_mlp_adapters(&mut self, hidden_size: usize, intermediate_size: usize) {
let config = self.registry.config.clone();
self.registry.add(
"gate_proj",
LoraAdapter::new(hidden_size, intermediate_size, config.clone()),
);
self.registry.add(
"up_proj",
LoraAdapter::new(hidden_size, intermediate_size, config.clone()),
);
self.registry.add(
"down_proj",
LoraAdapter::new(intermediate_size, hidden_size, config),
);
}
pub fn total_trainable_params(&self) -> usize {
self.registry.total_param_count()
}
pub fn efficiency_ratio(&self, base_param_count: u64) -> f32 {
if base_param_count == 0 {
return 0.0;
}
self.total_trainable_params() as f32 / base_param_count as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_config(rank: usize) -> LoraConfig {
LoraConfig {
rank,
alpha: rank as f32,
dropout: 0.0,
target_modules: vec![],
}
}
#[test]
fn test_lora_adapter_apply_zero_b_matrix() {
let config = make_config(4);
let adapter = LoraAdapter::new(8, 16, config);
let x = vec![1.0f32; 8];
let out = adapter.apply(&x);
assert_eq!(out.len(), 16, "output must have d_out elements");
for (i, v) in out.iter().enumerate() {
assert!(
v.abs() < 1e-6,
"output[{i}] = {v} but B is zero so output must be zero"
);
}
}
#[test]
fn test_lora_adapter_param_count() {
let adapter = LoraAdapter::new(64, 128, make_config(8));
assert_eq!(adapter.param_count(), 64 * 8 + 8 * 128);
}
#[test]
fn test_lora_adapter_memory_bytes() {
let adapter = LoraAdapter::new(64, 128, make_config(8));
assert_eq!(adapter.memory_bytes(), adapter.param_count() * 4);
}
#[test]
fn test_lora_registry_add_and_get() {
let mut registry = LoraRegistry::new(make_config(4));
let adapter = LoraAdapter::new(32, 64, make_config(4));
registry.add("q_proj", adapter);
assert_eq!(
registry.adapter_count(),
1,
"one adapter should be registered"
);
assert!(
registry.get("q_proj").is_some(),
"q_proj should be retrievable"
);
assert!(registry.get("v_proj").is_none(), "v_proj should not exist");
assert!(!registry.is_empty());
let names = registry.module_names();
assert!(names.contains(&"q_proj"));
}
#[test]
fn test_lora_registry_apply_adapter() {
let mut registry = LoraRegistry::new(make_config(4));
let adapter = LoraAdapter::new(8, 16, make_config(4));
registry.add("q_proj", adapter);
let x = vec![0.5f32; 8];
let out = registry.apply_adapter("q_proj", &x);
assert!(
out.is_some(),
"apply should return Some for registered module"
);
assert_eq!(out.expect("output must be Some").len(), 16);
assert!(registry.apply_adapter("missing", &x).is_none());
}
#[test]
fn test_bonsai_lora_set_creates_adapters() {
let config = make_config(8);
let mut lora_set = BonsaiLoraSet::new(config, "bonsai-8b", "my-adapter");
lora_set.add_attention_adapters(256, 8);
lora_set.add_mlp_adapters(256, 512);
assert_eq!(
lora_set.registry.adapter_count(),
7,
"should have 7 adapters (4 attn + 3 mlp)"
);
assert_eq!(lora_set.base_model_name, "bonsai-8b");
assert_eq!(lora_set.adapter_name, "my-adapter");
}
#[test]
fn test_lora_merge_into_weights() {
let config = make_config(4);
let adapter = LoraAdapter::new(4, 4, config);
let mut weights = vec![1.0f32; 16]; adapter.merge_into_weights(&mut weights);
for (i, w) in weights.iter().enumerate() {
assert!(
(w - 1.0).abs() < 1e-6,
"weights[{i}] = {w}, expected 1.0 (B is zero so no change)"
);
}
}
#[test]
fn test_lora_efficiency_ratio() {
let config = make_config(8);
let mut lora_set = BonsaiLoraSet::new(config, "bonsai-8b", "test");
lora_set.add_attention_adapters(256, 8);
let base_params = 8_000_000_000u64; let ratio = lora_set.efficiency_ratio(base_params);
assert!(ratio > 0.0, "ratio must be positive");
assert!(
ratio < 0.01,
"LoRA should be <1% of base params, got ratio={ratio}"
);
assert_eq!(lora_set.efficiency_ratio(0), 0.0);
}
#[test]
fn test_lora_estimate_adapters_for_budget() {
let bytes_per = (64 * 8 + 8 * 128) * 4;
let budget = bytes_per * 5;
let n = LoraRegistry::estimate_adapters_for_budget(64, 128, 8, budget);
assert_eq!(
n, 5,
"should fit exactly 5 adapters in budget of 5×per-adapter"
);
assert_eq!(LoraRegistry::estimate_adapters_for_budget(64, 128, 8, 0), 0);
assert_eq!(
LoraRegistry::estimate_adapters_for_budget(0, 0, 0, 1_000_000),
0
);
}
}