use std::path::{Path, PathBuf};
use crate::Device;
pub trait LmRunner: Send {
fn family(&self) -> &'static str;
fn vocab_size(&self) -> usize;
fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>>;
fn generate(
&mut self,
prompt_ids: &[u32],
n_new: usize,
on_token: &mut dyn FnMut(u32) -> bool,
) -> anyhow::Result<Vec<u32>> {
let mut context: Vec<u32> = prompt_ids.to_vec();
let mut produced: Vec<u32> = Vec::with_capacity(n_new);
for _ in 0..n_new {
let logits = self.predict_logits(&context)?;
let next = argmax_u32(&logits);
produced.push(next);
let cont = on_token(next);
context.push(next);
if !cont {
break;
}
}
Ok(produced)
}
fn supports_multimodal(&self) -> bool {
false
}
fn generate_multimodal(
&mut self,
_prompt: &str,
_rgb: &[u8],
_img_w: usize,
_img_h: usize,
_tokenizer: Option<&Path>,
_n_new: usize,
_on_token: &mut dyn FnMut(u32) -> bool,
) -> anyhow::Result<Vec<u32>> {
Err(anyhow::anyhow!(
"this LmRunner does not support multimodal generation"
))
}
}
fn argmax_u32(logits: &[f32]) -> u32 {
let mut best = 0usize;
let mut best_v = f32::NEG_INFINITY;
for (i, &v) in logits.iter().enumerate() {
if v > best_v {
best_v = v;
best = i;
}
}
best as u32
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightFormat {
Safetensors,
Gguf,
}
impl WeightFormat {
pub fn from_path(path: &Path) -> anyhow::Result<Self> {
match path.extension().and_then(|s| s.to_str()) {
Some("safetensors") => Ok(Self::Safetensors),
Some("gguf") => Ok(Self::Gguf),
other => Err(anyhow::anyhow!(
"cannot autodetect weight format from extension {:?} on {:?}",
other,
path
)),
}
}
pub fn parse(s: &str) -> anyhow::Result<Self> {
match s {
"safetensors" => Ok(Self::Safetensors),
"gguf" => Ok(Self::Gguf),
other => Err(anyhow::anyhow!("expected safetensors|gguf, got {other}")),
}
}
}
#[derive(Debug, Clone, Default)]
pub enum ConfigSource<T> {
#[default]
Embedded,
JsonFile(PathBuf),
Explicit(T),
}
#[derive(Debug, Clone, Copy)]
pub struct SampleOpts {
pub temperature: f32,
pub top_p: f32,
pub top_k: Option<u32>,
pub repetition_penalty: f32,
}
impl Default for SampleOpts {
fn default() -> Self {
Self::greedy()
}
}
impl SampleOpts {
pub fn greedy() -> Self {
Self {
temperature: 0.0,
top_p: 1.0,
top_k: None,
repetition_penalty: 1.0,
}
}
pub fn nucleus(temperature: f32, top_p: f32) -> Self {
Self {
temperature,
top_p,
top_k: None,
repetition_penalty: 1.0,
}
}
pub fn is_greedy(&self) -> bool {
self.temperature <= 0.0
}
}
pub const PACKED_GGUF_AUTO_THRESHOLD_BYTES: u64 = 256 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct LmRunnerBuilder<Cfg> {
pub weights: Option<PathBuf>,
pub config: ConfigSource<Cfg>,
pub device: Device,
pub max_seq: usize,
pub max_memory_gb: Option<f32>,
pub stream: bool,
pub sample: SampleOpts,
pub format: Option<WeightFormat>,
pub packed_weights: Option<bool>,
pub prefer_gguf: Option<String>,
}
impl<Cfg> Default for LmRunnerBuilder<Cfg> {
fn default() -> Self {
Self {
weights: None,
config: ConfigSource::Embedded,
device: Device::Cpu,
max_seq: 128,
max_memory_gb: None,
stream: true,
sample: SampleOpts::greedy(),
format: None,
packed_weights: None,
prefer_gguf: None,
}
}
}
impl<Cfg> LmRunnerBuilder<Cfg> {
pub fn new() -> Self {
Self::default()
}
pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
self.weights = Some(p.into());
self
}
pub fn config(mut self, src: ConfigSource<Cfg>) -> Self {
self.config = src;
self
}
pub fn config_value(self, cfg: Cfg) -> Self {
self.config(ConfigSource::Explicit(cfg))
}
pub fn device(mut self, d: Device) -> Self {
self.device = d;
self
}
pub fn max_seq(mut self, n: usize) -> Self {
self.max_seq = n;
self
}
pub fn max_memory_gb(mut self, gb: f32) -> Self {
self.max_memory_gb = Some(gb);
self
}
pub fn stream(mut self, on: bool) -> Self {
self.stream = on;
self
}
pub fn sample(mut self, s: SampleOpts) -> Self {
self.sample = s;
self
}
pub fn format(mut self, fmt: WeightFormat) -> Self {
self.format = Some(fmt);
self
}
pub fn packed_weights(mut self, on: bool) -> Self {
self.packed_weights = Some(on);
self
}
pub fn prefer_gguf<S: Into<String>>(mut self, q: S) -> Self {
self.prefer_gguf = Some(q.into());
self
}
pub fn resolved_format(&self) -> anyhow::Result<WeightFormat> {
match self.format {
Some(f) => Ok(f),
None => {
let p = self
.weights
.as_deref()
.ok_or_else(|| anyhow::anyhow!("weights path required"))?;
WeightFormat::from_path(p)
}
}
}
pub fn resolved_packed(&self, fmt: WeightFormat) -> bool {
match self.packed_weights {
Some(b) => b,
None => {
if !matches!(fmt, WeightFormat::Gguf) {
return false;
}
self.weights
.as_deref()
.and_then(|p| std::fs::metadata(p).ok())
.map(|m| m.len() >= PACKED_GGUF_AUTO_THRESHOLD_BYTES)
.unwrap_or(false)
}
}
}
}
pub struct ModelRegistration {
pub family: &'static str,
pub description: &'static str,
pub matches: fn(arch: &str, path: &Path) -> bool,
}
inventory::collect!(ModelRegistration);
pub extern crate inventory;
pub fn registered_models() -> impl Iterator<Item = &'static ModelRegistration> {
inventory::iter::<ModelRegistration>.into_iter()
}
pub fn auto_runner_name(arch: &str, path: &Path) -> Option<&'static str> {
let arch_lc = arch.to_ascii_lowercase();
registered_models()
.find(|m| (m.matches)(&arch_lc, path))
.map(|m| m.family)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_source_default_is_embedded() {
let s: ConfigSource<()> = ConfigSource::default();
assert!(matches!(s, ConfigSource::Embedded));
}
#[test]
fn builder_defaults_match_legacy_runners() {
let b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
assert_eq!(b.device, Device::Cpu);
assert_eq!(b.max_seq, 128);
assert!(b.stream);
assert!(b.sample.is_greedy());
assert!(b.packed_weights.is_none());
}
#[test]
fn packed_auto_size_threshold() {
let mut b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
b.weights = Some("/nonexistent/file.gguf".into());
assert!(!b.resolved_packed(WeightFormat::Gguf));
b.packed_weights = Some(true);
assert!(b.resolved_packed(WeightFormat::Gguf));
b.packed_weights = None;
assert!(!b.resolved_packed(WeightFormat::Safetensors));
}
}