#[cfg(feature = "python")]
use candle_core::{DType, IndexOp, Tensor, Var};
#[cfg(feature = "python")]
use pyo3::prelude::*;
#[cfg(feature = "python")]
use crate::model::{BitLlama, BitLlamaConfig};
#[cfg(feature = "python")]
use crate::optim::schedule_free::{ParamsScheduleFree, ScheduleFreeOptimizer};
#[cfg(feature = "python")]
use candle_nn::VarMap;
#[cfg(feature = "python")]
use tokenizers::Tokenizer;
#[cfg(feature = "python")]
#[pyclass(name = "BitLlama")]
pub struct PyBitLlama {
inner: BitLlama,
w_states: Vec<Tensor>,
tokenizer: Option<Tokenizer>,
}
#[cfg(feature = "python")]
#[pymethods]
impl PyBitLlama {
#[new]
#[pyo3(signature = (config, checkpoint_path, device=None, tokenizer_path=None))]
pub fn new(
config: BitLlamaConfig,
checkpoint_path: &str,
device: Option<&str>,
tokenizer_path: Option<&str>,
) -> PyResult<Self> {
let _device = match device {
Some("cuda") => candle_core::Device::new_cuda(0).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("CUDA error: {}", e))
})?,
Some("cpu") | None => candle_core::Device::Cpu,
Some(unknown) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Unsupported device: {}. Use 'cpu' or 'cuda'",
unknown
)))
}
};
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[checkpoint_path],
DType::F32,
&candle_core::Device::Cpu,
)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))?
};
let mut model = BitLlama::load(config.clone(), vb)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
model
.precompute_packed()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let d_small = config.hidden_dim / 4;
let mut w_states = Vec::new();
for layer in &model.layers {
let layer_device = layer.device();
let w = Tensor::zeros((d_small, d_small), DType::F32, layer_device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
w_states.push(w);
}
let tokenizer = if let Some(tok_path) = tokenizer_path {
match Tokenizer::from_file(tok_path) {
Ok(tok) => Some(tok),
Err(e) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Failed to load tokenizer: {}",
e
)))
}
}
} else {
None
};
Ok(Self {
inner: model,
w_states,
tokenizer,
})
}
pub fn forward(&mut self, token_id: u32) -> PyResult<Vec<f32>> {
let device = self.inner.embedding.embeddings().device();
let input = Tensor::new(&[token_id], device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let logits = self
.inner
.forward_one(&input, &mut self.w_states)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let logits_vec = logits
.flatten_all()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.to_vec1::<f32>()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(logits_vec)
}
#[pyo3(signature = (prompt, max_tokens))]
pub fn generate(&mut self, py: Python, prompt: &str, max_tokens: usize) -> PyResult<String> {
let prompt_tokens = {
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"Tokenizer not available. Please provide tokenizer_path during initialization.",
)
})?;
let encoding = tokenizer.encode(prompt, true).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to encode prompt: {}", e))
})?;
let tokens: Vec<u32> = encoding.get_ids().to_vec();
if tokens.is_empty() {
return Err(pyo3::exceptions::PyValueError::new_err(
"Prompt encoded to empty token list",
));
}
tokens
};
let generated_tokens = self.generate_tokens(py, prompt_tokens.clone(), max_tokens)?;
let tokenizer = self.tokenizer.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"Tokenizer not available. Please provide tokenizer_path during initialization.",
)
})?;
let new_tokens: Vec<u32> = generated_tokens
.iter()
.skip(prompt_tokens.len())
.copied()
.collect();
let decoded = tokenizer.decode(&new_tokens, true).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to decode tokens: {}", e))
})?;
Ok(decoded)
}
pub fn generate_tokens(
&mut self,
py: Python,
start_tokens: Vec<u32>,
max_new_tokens: usize,
) -> PyResult<Vec<u32>> {
py.allow_threads(move || {
let device = self.inner.embedding.embeddings().device().clone();
let mut current_tokens = start_tokens.clone();
let input = Tensor::new(start_tokens.as_slice(), &device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.unsqueeze(0) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let logits = self
.inner
.forward(&input, &mut self.w_states)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let (_b, seq_len, _v) = logits
.dims3()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let last_logits = logits
.i((0, seq_len - 1))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let next_token = last_logits
.argmax(0)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.to_scalar::<u32>()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
current_tokens.push(next_token);
for _ in 1..max_new_tokens {
let last_token = *current_tokens.last().unwrap();
let input = Tensor::new(&[last_token], &device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let logits = self
.inner
.forward_one(&input, &mut self.w_states)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let logits_v = logits
.flatten_all()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let next_token = logits_v
.argmax(0)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.to_scalar::<u32>()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
current_tokens.push(next_token);
}
Ok(current_tokens)
})
}
}
#[cfg(feature = "python")]
#[pyclass(name = "PyTrainer")]
pub struct PyTrainer {
model: BitLlama,
varmap: VarMap,
optimizer: ScheduleFreeOptimizer,
sorted_vars: Vec<Var>, }
#[cfg(feature = "python")]
#[pymethods]
impl PyTrainer {
#[new]
#[pyo3(signature = (config, checkpoint_path=None, device=None))]
pub fn new(
config: BitLlamaConfig,
checkpoint_path: Option<&str>,
device: Option<&str>,
) -> PyResult<Self> {
let device = match device {
Some("cuda") => candle_core::Device::new_cuda(0).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("CUDA error: {}", e))
})?,
Some("cpu") | None => candle_core::Device::Cpu,
Some(unknown) => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Unsupported device: {}. Use 'cpu' or 'cuda'",
unknown
)))
}
};
let mut varmap = VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let model = BitLlama::load(config, vb)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
if let Some(path) = checkpoint_path {
varmap.load(path).map_err(|e| {
pyo3::exceptions::PyRuntimeError::new_err(format!("Failed to load params: {}", e))
})?;
}
let data = varmap.data().lock().unwrap();
let mut named_vars: Vec<_> = data.iter().map(|(n, v)| (n.clone(), v.clone())).collect();
drop(data);
named_vars.sort_by(|a, b| a.0.cmp(&b.0));
let vars: Vec<Var> = named_vars.iter().map(|(_, v)| v.clone()).collect();
let sorted_vars = vars.clone();
let params = ParamsScheduleFree {
lr: 0.002,
..Default::default()
};
let optimizer = ScheduleFreeOptimizer::new(vars, params)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(Self {
model,
varmap,
optimizer,
sorted_vars,
})
}
pub fn set_learning_rate(&mut self, lr: f64) {
self.optimizer.set_learning_rate(lr);
}
#[pyo3(signature = (py_input_ids, py_targets))]
pub fn train_step(
&mut self,
py: Python,
py_input_ids: Vec<u32>,
py_targets: Vec<u32>,
) -> PyResult<f64> {
py.allow_threads(move || {
let device = self.model.embedding.embeddings().device();
let input_tensor = Tensor::new(py_input_ids.as_slice(), device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.unsqueeze(0) .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let target_tensor = Tensor::new(py_targets.as_slice(), device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
.unsqueeze(0)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
self.optimizer
.pre_step()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let d_small = self.model.config.hidden_dim / 4;
let mut w_states = Vec::new();
for _ in 0..self.model.config.num_layers {
let w = Tensor::zeros((d_small, d_small), DType::F32, device)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
w_states.push(w);
}
let seq_len = py_input_ids.len();
let logits = self
.model
.forward_chunkwise(&input_tensor, &mut w_states, seq_len)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let b_sz = 1;
let logits = logits
.reshape((b_sz * seq_len, logits.dim(2).unwrap()))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let targets = target_tensor
.reshape((b_sz * seq_len,))
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let loss = candle_nn::loss::cross_entropy(&logits, &targets)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let grads_store = loss
.backward()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let mut grad_tensors = Vec::new();
for var in &self.sorted_vars {
if let Some(g) = grads_store.get(var) {
grad_tensors.push(g.clone());
} else {
return Err(pyo3::exceptions::PyRuntimeError::new_err(
"Missing gradient for a variable. Graph disconnected?",
));
}
}
self.optimizer
.step(&grad_tensors)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let loss_val = loss
.to_scalar::<f32>()
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?
as f64;
Ok(loss_val)
})
}
#[pyo3(signature = (path))]
pub fn save_checkpoint(&self, path: &str) -> PyResult<()> {
self.varmap
.save(path)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
let mut z_map = std::collections::HashMap::new();
let data = self.varmap.data().lock().unwrap();
let mut named_vars: Vec<_> = data.iter().collect();
named_vars.sort_by(|a, b| a.0.cmp(b.0));
for (i, (name, _var)) in named_vars.iter().enumerate() {
if i < self.optimizer.z.len() {
z_map.insert(format!("{}.z", name), self.optimizer.z[i].clone());
}
}
let optim_path = format!("{}.optim", path);
candle_core::safetensors::save(&z_map, &optim_path)
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
Ok(())
}
}