use crate::error::{Result, RuvLLMError};
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterConfig {
pub name: String,
pub rank: usize,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
pub merge_weights: bool,
}
impl Default for AdapterConfig {
fn default() -> Self {
Self {
name: "default".to_string(),
rank: 8,
alpha: 16.0,
dropout: 0.0,
target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
merge_weights: false,
}
}
}
#[derive(Debug, Clone)]
pub struct LoraLayerWeights {
pub lora_a: Vec<f32>,
pub lora_b: Vec<f32>,
pub in_features: usize,
pub out_features: usize,
pub rank: usize,
}
impl LoraLayerWeights {
pub fn new(in_features: usize, out_features: usize, rank: usize) -> Self {
Self {
lora_a: vec![0.0; in_features * rank],
lora_b: vec![0.0; rank * out_features],
in_features,
out_features,
rank,
}
}
pub fn apply(&self, input: &[f32], alpha: f32) -> Vec<f32> {
let scale = alpha / self.rank as f32;
let batch_size = input.len() / self.in_features;
let mut intermediate = vec![0.0; batch_size * self.rank];
for b in 0..batch_size {
for r in 0..self.rank {
let mut sum = 0.0;
for i in 0..self.in_features {
sum += input[b * self.in_features + i] * self.lora_a[i * self.rank + r];
}
intermediate[b * self.rank + r] = sum;
}
}
let mut output = vec![0.0; batch_size * self.out_features];
for b in 0..batch_size {
for o in 0..self.out_features {
let mut sum = 0.0;
for r in 0..self.rank {
sum += intermediate[b * self.rank + r] * self.lora_b[r * self.out_features + o];
}
output[b * self.out_features + o] = sum * scale;
}
}
output
}
pub fn memory_bytes(&self) -> usize {
(self.lora_a.len() + self.lora_b.len()) * std::mem::size_of::<f32>()
}
}
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub id: Uuid,
pub config: AdapterConfig,
pub layers: HashMap<String, LoraLayerWeights>,
pub version: u64,
pub created_at: chrono::DateTime<chrono::Utc>,
ref_count: Arc<std::sync::atomic::AtomicUsize>,
}
impl LoraAdapter {
pub fn new(config: AdapterConfig) -> Self {
Self {
id: Uuid::new_v4(),
config,
layers: HashMap::new(),
version: 1,
created_at: chrono::Utc::now(),
ref_count: Arc::new(std::sync::atomic::AtomicUsize::new(1)),
}
}
pub fn add_layer(&mut self, module_name: String, weights: LoraLayerWeights) {
self.layers.insert(module_name, weights);
}
pub fn memory_bytes(&self) -> usize {
self.layers.values().map(|l| l.memory_bytes()).sum()
}
pub fn apply(&self, module_name: &str, input: &[f32], base_output: &mut [f32]) -> Result<()> {
if let Some(layer) = self.layers.get(module_name) {
let delta = layer.apply(input, self.config.alpha);
if delta.len() != base_output.len() {
return Err(RuvLLMError::Adapter(format!(
"Output size mismatch: expected {}, got {}",
base_output.len(),
delta.len()
)));
}
for (out, d) in base_output.iter_mut().zip(delta.iter()) {
*out += d;
}
}
Ok(())
}
pub fn inc_ref(&self) {
self.ref_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
pub fn dec_ref(&self) -> bool {
self.ref_count
.fetch_sub(1, std::sync::atomic::Ordering::SeqCst)
== 1
}
pub fn ref_count(&self) -> usize {
self.ref_count.load(std::sync::atomic::Ordering::SeqCst)
}
}
struct CacheEntry {
adapter: Arc<LoraAdapter>,
last_accessed: chrono::DateTime<chrono::Utc>,
}
pub struct AdapterManager {
adapters: DashMap<Uuid, Arc<LoraAdapter>>,
name_to_id: DashMap<String, Uuid>,
cache: RwLock<Vec<CacheEntry>>,
max_loaded: usize,
max_memory_bytes: usize,
current_memory: std::sync::atomic::AtomicUsize,
}
impl AdapterManager {
pub fn new() -> Self {
Self {
adapters: DashMap::new(),
name_to_id: DashMap::new(),
cache: RwLock::new(Vec::new()),
max_loaded: 16,
max_memory_bytes: 512 * 1024 * 1024, current_memory: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn with_limits(max_loaded: usize, max_memory_bytes: usize) -> Self {
Self {
adapters: DashMap::new(),
name_to_id: DashMap::new(),
cache: RwLock::new(Vec::new()),
max_loaded,
max_memory_bytes,
current_memory: std::sync::atomic::AtomicUsize::new(0),
}
}
pub fn load(&self, adapter: LoraAdapter) -> Result<Uuid> {
let memory_needed = adapter.memory_bytes();
self.ensure_memory(memory_needed)?;
let id = adapter.id;
let name = adapter.config.name.clone();
let adapter = Arc::new(adapter);
self.adapters.insert(id, adapter.clone());
self.name_to_id.insert(name, id);
let mut cache = self.cache.write();
cache.push(CacheEntry {
adapter,
last_accessed: chrono::Utc::now(),
});
self.current_memory
.fetch_add(memory_needed, std::sync::atomic::Ordering::SeqCst);
Ok(id)
}
fn ensure_memory(&self, needed: usize) -> Result<()> {
let current = self
.current_memory
.load(std::sync::atomic::Ordering::SeqCst);
if current + needed <= self.max_memory_bytes {
return Ok(());
}
let mut cache = self.cache.write();
cache.sort_by(|a, b| a.last_accessed.cmp(&b.last_accessed));
let mut freed = 0;
while freed < needed && !cache.is_empty() {
if let Some(entry) = cache.first() {
if entry.adapter.ref_count() <= 1 {
let id = entry.adapter.id;
let size = entry.adapter.memory_bytes();
self.adapters.remove(&id);
self.name_to_id.remove(&entry.adapter.config.name);
cache.remove(0);
freed += size;
self.current_memory
.fetch_sub(size, std::sync::atomic::Ordering::SeqCst);
} else {
let entry = cache.remove(0);
cache.push(entry);
}
}
}
if freed < needed {
return Err(RuvLLMError::OutOfMemory(
"Cannot free enough memory for new adapter".to_string(),
));
}
Ok(())
}
pub fn get(&self, id: &Uuid) -> Option<Arc<LoraAdapter>> {
if let Some(adapter) = self.adapters.get(id) {
let mut cache = self.cache.write();
if let Some(entry) = cache.iter_mut().find(|e| e.adapter.id == *id) {
entry.last_accessed = chrono::Utc::now();
}
Some(adapter.clone())
} else {
None
}
}
pub fn get_by_name(&self, name: &str) -> Option<Arc<LoraAdapter>> {
self.name_to_id.get(name).and_then(|id| self.get(&id))
}
pub fn unload(&self, id: &Uuid) -> Result<()> {
if let Some((_, adapter)) = self.adapters.remove(id) {
self.name_to_id.remove(&adapter.config.name);
let mut cache = self.cache.write();
cache.retain(|e| e.adapter.id != *id);
self.current_memory
.fetch_sub(adapter.memory_bytes(), std::sync::atomic::Ordering::SeqCst);
}
Ok(())
}
pub fn list(&self) -> Vec<AdapterInfo> {
self.adapters
.iter()
.map(|entry| {
let adapter = entry.value();
AdapterInfo {
id: adapter.id,
name: adapter.config.name.clone(),
rank: adapter.config.rank,
version: adapter.version,
memory_bytes: adapter.memory_bytes(),
ref_count: adapter.ref_count(),
}
})
.collect()
}
pub fn memory_stats(&self) -> AdapterMemoryStats {
AdapterMemoryStats {
total_budget: self.max_memory_bytes,
used_bytes: self
.current_memory
.load(std::sync::atomic::Ordering::SeqCst),
adapter_count: self.adapters.len(),
max_adapters: self.max_loaded,
}
}
}
impl Default for AdapterManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdapterInfo {
pub id: Uuid,
pub name: String,
pub rank: usize,
pub version: u64,
pub memory_bytes: usize,
pub ref_count: usize,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct AdapterMemoryStats {
pub total_budget: usize,
pub used_bytes: usize,
pub adapter_count: usize,
pub max_adapters: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lora_layer_weights() {
let weights = LoraLayerWeights::new(4, 4, 2);
assert_eq!(weights.lora_a.len(), 8); assert_eq!(weights.lora_b.len(), 8); }
#[test]
fn test_lora_adapter() {
let config = AdapterConfig {
name: "test".to_string(),
rank: 4,
..Default::default()
};
let mut adapter = LoraAdapter::new(config);
adapter.add_layer("q_proj".to_string(), LoraLayerWeights::new(64, 64, 4));
assert_eq!(adapter.layers.len(), 1);
assert!(adapter.memory_bytes() > 0);
}
#[test]
fn test_adapter_manager() {
let manager = AdapterManager::new();
let adapter = LoraAdapter::new(AdapterConfig::default());
let id = manager.load(adapter).unwrap();
assert!(manager.get(&id).is_some());
assert!(manager.get_by_name("default").is_some());
manager.unload(&id).unwrap();
assert!(manager.get(&id).is_none());
}
}