use std::collections::BTreeMap;
use std::collections::HashMap;
use std::sync::Arc;
use safetensors::SafeTensors;
use safetensors::tensor::Dtype;
use wgpu::{Buffer, BufferDescriptor, BufferUsages};
use crate::backend::WgpuCtx;
use crate::error::{Result, RullamaError};
use crate::model::config::Gemma4Config;
use crate::reference::forward_chained::{GlobalLoraSlots, LayerLoraSlots, LoraSlot};
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct LoraKey {
pub layer: u32,
pub projection: String,
}
impl LoraKey {
pub fn new(layer: u32, projection: impl Into<String>) -> Self {
Self {
layer,
projection: projection.into(),
}
}
}
pub struct InferenceLoraLayer {
pub in_dim: u32,
pub rank: u32,
pub out_dim: u32,
pub scale: f32,
pub a: Buffer,
pub b: Buffer,
pub z: Buffer,
pub b_is_f16: bool,
}
pub struct InferenceAdapter {
layers: BTreeMap<LoraKey, InferenceLoraLayer>,
}
impl InferenceAdapter {
pub fn from_safetensors_bytes(
ctx: Arc<WgpuCtx>,
cfg: &Gemma4Config,
bytes: &[u8],
) -> Result<Self> {
let (_n, header) = SafeTensors::read_metadata(bytes)
.map_err(|e| RullamaError::Inference(format!("safetensors header: {e}")))?;
let meta_opt: &Option<HashMap<String, String>> = header.metadata();
let m = meta_opt
.as_ref()
.ok_or_else(|| RullamaError::Inference("adapter has no metadata sidecar".into()))?;
let rank: u32 = m
.get("rank")
.ok_or_else(|| RullamaError::Inference("metadata missing 'rank'".into()))?
.parse()
.map_err(|e| RullamaError::Inference(format!("bad 'rank': {e}")))?;
let alpha: f32 = m
.get("alpha")
.ok_or_else(|| RullamaError::Inference("metadata missing 'alpha'".into()))?
.parse()
.map_err(|e| RullamaError::Inference(format!("bad 'alpha': {e}")))?;
let targets_csv = m
.get("target_modules")
.ok_or_else(|| RullamaError::Inference("metadata missing 'target_modules'".into()))?;
let target_modules: Vec<String> = targets_csv
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect();
if target_modules.is_empty() {
return Err(RullamaError::Inference(
"target_modules metadata is empty".into(),
));
}
let st = SafeTensors::deserialize(bytes)
.map_err(|e| RullamaError::Inference(format!("safetensors parse: {e}")))?;
let mut layers: BTreeMap<LoraKey, InferenceLoraLayer> = BTreeMap::new();
let d_model = cfg.d_model;
let vocab = cfg.vocab_size;
const GLOBAL_TARGETS: &[&str] = &["lm_head", "embed_tokens"];
for li in 0..cfg.n_layers {
let head_dim = cfg.head_dim(li);
let n_heads_dim = cfg.n_heads * head_dim;
let n_kv_dim = cfg.n_kv_heads(li) * head_dim;
let ffn_n = cfg.ffn(li);
for proj in &target_modules {
if GLOBAL_TARGETS.contains(&proj.as_str()) {
continue; }
let (in_dim, out_dim) = match proj.as_str() {
"attn_q" => (d_model, n_heads_dim),
"attn_k" => (d_model, n_kv_dim),
"attn_v" => (d_model, n_kv_dim),
"attn_o" => (n_heads_dim, d_model),
"ffn_gate" => (d_model, ffn_n),
"ffn_up" => (d_model, ffn_n),
"ffn_down" => (ffn_n, d_model),
other => {
return Err(RullamaError::Inference(format!(
"unsupported LoRA target '{other}'"
)));
}
};
let layer = InferenceLoraLayer::alloc(&ctx, in_dim, rank, out_dim, alpha, false);
layers.insert(LoraKey::new(li, proj.clone()), layer);
}
}
for proj in &target_modules {
if !GLOBAL_TARGETS.contains(&proj.as_str()) {
continue;
}
let (in_dim, out_dim) = match proj.as_str() {
"lm_head" => (d_model, vocab),
"embed_tokens" => (vocab, d_model),
_ => unreachable!("filter above admits only GLOBAL_TARGETS"),
};
let b_is_f16 = proj == "lm_head" && rank.is_multiple_of(2);
let layer = InferenceLoraLayer::alloc(&ctx, in_dim, rank, out_dim, alpha, b_is_f16);
layers.insert(LoraKey::new(0, proj.clone()), layer);
}
for (name, tensor) in st.tensors() {
if !name.starts_with("lora.blk.") {
continue;
}
let suffix = &name["lora.blk.".len()..];
let (layer_str, rest) = match suffix.split_once('.') {
Some(p) => p,
None => continue,
};
let layer_idx: u32 = match layer_str.parse() {
Ok(n) => n,
Err(_) => continue,
};
let (projection, ab) = match rest.rsplit_once('.') {
Some(p) => p,
None => continue,
};
let key = LoraKey::new(layer_idx, projection.to_string());
let layer = match layers.get(&key) {
Some(l) => l,
None => continue,
};
let (buf, target_packed_f16) = match ab {
"A" => (&layer.a, false),
"B" => (&layer.b, layer.b_is_f16),
_ => continue,
};
let data = tensor.data();
let n_elems = if target_packed_f16 {
(buf.size() / 2) as usize
} else {
(buf.size() / 4) as usize
};
let upload_bytes: Vec<u8> = match tensor.dtype() {
Dtype::F32 => {
if data.len() != n_elems * 4 {
return Err(RullamaError::Inference(format!(
"tensor {name} f32 size mismatch: file={} expected={}",
data.len(),
n_elems * 4
)));
}
if target_packed_f16 {
let src: &[f32] = bytemuck::cast_slice(data);
let packed = pack_f32_to_f16_pairs(src);
bytemuck::cast_slice::<u32, u8>(&packed).to_vec()
} else {
data.to_vec()
}
}
Dtype::F16 => {
if data.len() != n_elems * 2 {
return Err(RullamaError::Inference(format!(
"tensor {name} f16 size mismatch: file={} expected={}",
data.len(),
n_elems * 2
)));
}
if target_packed_f16 {
data.to_vec()
} else {
let h: &[half::f16] = bytemuck::cast_slice(data);
let f: Vec<f32> = h.iter().map(|&x| x.to_f32()).collect();
bytemuck::cast_slice::<f32, u8>(&f).to_vec()
}
}
other => {
return Err(RullamaError::Inference(format!(
"tensor {name} unsupported dtype {other:?}"
)));
}
};
ctx.queue.write_buffer(buf, 0, &upload_bytes);
}
Ok(Self { layers })
}
pub fn layer_slots(&self, n_layers: u32) -> Vec<LayerLoraSlots<'_>> {
(0..n_layers)
.map(|li| LayerLoraSlots {
q: self.layers.get(&LoraKey::new(li, "attn_q")).map(slot_view),
k: self.layers.get(&LoraKey::new(li, "attn_k")).map(slot_view),
v: self.layers.get(&LoraKey::new(li, "attn_v")).map(slot_view),
o: self.layers.get(&LoraKey::new(li, "attn_o")).map(slot_view),
ffn_gate: self
.layers
.get(&LoraKey::new(li, "ffn_gate"))
.map(slot_view),
ffn_up: self.layers.get(&LoraKey::new(li, "ffn_up")).map(slot_view),
ffn_down: self
.layers
.get(&LoraKey::new(li, "ffn_down"))
.map(slot_view),
})
.collect()
}
pub fn global_slots(&self) -> GlobalLoraSlots<'_> {
GlobalLoraSlots {
embed_tokens: self
.layers
.get(&LoraKey::new(0, "embed_tokens"))
.map(slot_view),
lm_head: self.layers.get(&LoraKey::new(0, "lm_head")).map(slot_view),
}
}
pub fn len(&self) -> usize {
self.layers.len()
}
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl InferenceLoraLayer {
fn alloc(
ctx: &WgpuCtx,
in_dim: u32,
rank: u32,
out_dim: u32,
alpha: f32,
b_is_f16: bool,
) -> Self {
let scale = alpha / rank as f32;
let device = &ctx.device;
let a_bytes = (in_dim as usize * rank as usize * 4) as u64;
let b_elem_bytes = if b_is_f16 { 2 } else { 4 };
let b_bytes = (out_dim as usize * rank as usize * b_elem_bytes) as u64;
let usage = BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC;
let a = device.create_buffer(&BufferDescriptor {
label: Some("infer.lora.A"),
size: a_bytes,
usage,
mapped_at_creation: false,
});
let b = device.create_buffer(&BufferDescriptor {
label: if b_is_f16 {
Some("infer.lora.B.f16")
} else {
Some("infer.lora.B")
},
size: b_bytes,
usage,
mapped_at_creation: false,
});
let z = device.create_buffer(&BufferDescriptor {
label: Some("infer.lora.z"),
size: (rank as usize * 4) as u64,
usage,
mapped_at_creation: false,
});
Self {
in_dim,
rank,
out_dim,
scale,
a,
b,
z,
b_is_f16,
}
}
}
fn slot_view(l: &InferenceLoraLayer) -> LoraSlot<'_> {
LoraSlot {
a: &l.a,
b: &l.b,
z: &l.z,
rank: l.rank,
scale: l.scale,
b_is_f16: l.b_is_f16,
}
}
fn pack_f32_to_f16_pairs(src: &[f32]) -> Vec<u32> {
debug_assert!(src.len().is_multiple_of(2));
let mut out = Vec::with_capacity(src.len() / 2);
for pair in src.chunks_exact(2) {
let lo = half::f16::from_f32(pair[0]).to_bits() as u32;
let hi = half::f16::from_f32(pair[1]).to_bits() as u32;
out.push((hi << 16) | lo);
}
out
}