#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use clap::{Parser, ValueEnum};
use std::io::Write;
use std::path::{Path, PathBuf};
use tokenizers::Tokenizer;
use candle::quantized::gguf_file;
use candle::Tensor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_transformers::models::quantized_lfm2::ModelWeights;
const DEFAULT_PROMPT: &str = "Explain how Rotary Position Embeddings work in transformers.";
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum Which {
#[value(name = "lfm2-350m-q4_k_m")]
Lfm2_350MQ4KM,
#[value(name = "lfm2-350m-q8_0")]
Lfm2_350MQ8_0,
#[value(name = "lfm2-2.6b-q4_k_m")]
Lfm2_2_6BQ4KM,
#[value(name = "lfm2-2.6b-q8_0")]
Lfm2_2_6BQ8_0,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long, default_value = "lfm2-2.6b-q4_k_m")]
which: Which,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
prompt: Option<String>,
#[arg(short = 'n', long, default_value_t = 512)]
sample_len: usize,
#[arg(long, default_value_t = 0.8)]
temperature: f64,
#[arg(long)]
top_p: Option<f64>,
#[arg(long)]
top_k: Option<usize>,
#[arg(long, default_value_t = 299792458)]
seed: u64,
#[arg(long)]
tracing: bool,
#[arg(long)]
split_prompt: bool,
#[arg(long)]
cpu: bool,
#[arg(long, default_value_t = 1.1)]
repeat_penalty: f32,
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
impl Args {
fn model_path(&self) -> Result<PathBuf> {
if let Some(model) = &self.model {
return Ok(PathBuf::from(model));
}
let (repo, filename) = match self.which {
Which::Lfm2_350MQ4KM => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q4_K_M.gguf"),
Which::Lfm2_350MQ8_0 => ("LiquidAI/LFM2-350M-GGUF", "LFM2-350M-Q8_0.gguf"),
Which::Lfm2_2_6BQ4KM => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q4_K_M.gguf"),
Which::Lfm2_2_6BQ8_0 => ("LiquidAI/LFM2-2.6B-GGUF", "LFM2-2.6B-Q8_0.gguf"),
};
let api = hf_hub::api::sync::Api::new()?;
api.repo(hf_hub::Repo::with_revision(
repo.to_string(),
hf_hub::RepoType::Model,
self.revision.clone(),
))
.get(filename)
.map_err(Into::into)
}
fn tokenizer(&self, model_path: &Path) -> Result<Tokenizer> {
if let Some(path) = &self.tokenizer {
return Tokenizer::from_file(path).map_err(anyhow::Error::msg);
}
if let Some(dir) = model_path.parent() {
let candidate = dir.join("tokenizer.json");
if candidate.exists() {
return Tokenizer::from_file(candidate).map_err(anyhow::Error::msg);
}
}
let tokenizer_repo = match self.which {
Which::Lfm2_350MQ4KM | Which::Lfm2_350MQ8_0 => "LiquidAI/LFM2-350M",
Which::Lfm2_2_6BQ4KM | Which::Lfm2_2_6BQ8_0 => "LiquidAI/LFM2-2.6B",
};
let api = hf_hub::api::sync::Api::new()?;
let tokenizer_path = api
.repo(hf_hub::Repo::with_revision(
tokenizer_repo.to_string(),
hf_hub::RepoType::Model,
self.revision.clone(),
))
.get("tokenizer.json")?;
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
}
}
fn format_size(size_in_bytes: usize) -> String {
if size_in_bytes < 1_000 {
format!("{size_in_bytes}B")
} else if size_in_bytes < 1_000_000 {
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
} else if size_in_bytes < 1_000_000_000 {
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
} else {
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
}
}
fn guess_eos_id(tokenizer: &Tokenizer) -> Option<u32> {
let vocab = tokenizer.get_vocab(true);
let candidates = [
"</s>",
"<|im_end|>",
"<|eot_id|>",
"<|end|>",
"<|end_of_text|>",
"<|endoftext|>",
];
candidates
.iter()
.find_map(|token| vocab.get(*token).copied())
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature, args.repeat_penalty, args.repeat_last_n
);
let model_path = args.model_path()?;
let mut file = std::fs::File::open(&model_path)?;
let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?;
let gguf = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path.clone()))?;
let mut total_size_in_bytes = 0;
for (_, tensor) in gguf.tensor_infos.iter() {
let elem_count = tensor.shape.elem_count();
total_size_in_bytes +=
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
}
let context_length = gguf
.metadata
.get("lfm2.context_length")
.and_then(|v| v.to_u32().ok().map(|v| v as usize));
println!(
"loaded {:?} tensors ({}) in {:.2}s",
gguf.tensor_infos.len(),
format_size(total_size_in_bytes),
start.elapsed().as_secs_f32()
);
let mut model = ModelWeights::from_gguf(gguf, &mut file, &device)?;
println!("model ready");
let tokenizer = args.tokenizer(&model_path)?;
let mut tos = TokenOutputStream::new(tokenizer);
let mut tokens = tos
.tokenizer()
.encode(args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
if let Some(max_ctx) = context_length {
if tokens.len() >= max_ctx {
let trim = tokens.len() - max_ctx + 1;
tokens.drain(0..trim);
println!("prompt trimmed to last {max_ctx} tokens to fit context");
}
}
let mut all_tokens = tokens.clone();
let to_sample = args.sample_len.saturating_sub(1);
let mut logits_processor = {
let temperature = args.temperature;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
println!("Starting the inference loop:");
let prompt_str = args.prompt.as_deref().unwrap_or(DEFAULT_PROMPT);
print!("{prompt_str}");
std::io::stdout().flush()?;
let start_prompt_processing = std::time::Instant::now();
let mut next_token = if !args.split_prompt {
let input = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
logits_processor.sample(&logits)?
} else {
let mut next_token = 0;
for (pos, token) in tokens.iter().enumerate() {
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, pos)?;
let logits = logits.squeeze(0)?;
next_token = logits_processor.sample(&logits)?
}
next_token
};
let mut index_pos = tokens.len();
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
let eos_token = guess_eos_id(tos.tokenizer());
let mut sampled = 0;
let start_post_prompt = std::time::Instant::now();
for _ in 0..to_sample {
if let Some(max_ctx) = context_length {
if index_pos + 1 > max_ctx {
println!("\n\ncontext window of {max_ctx} reached, stopping generation");
break;
}
}
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, index_pos)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
logits
} else {
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
args.repeat_penalty,
&all_tokens[start_at..],
)?
};
next_token = logits_processor.sample(&logits)?;
index_pos += 1;
all_tokens.push(next_token);
if let Some(t) = tos.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
sampled += 1;
if let Some(eos) = eos_token {
if next_token == eos {
break;
}
}
}
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
print!("{rest}");
}
std::io::stdout().flush()?;
let dt = start_post_prompt.elapsed();
println!(
"\n\n{:4} prompt tokens processed: {:.2} token/s",
tokens.len(),
tokens.len() as f64 / prompt_dt.as_secs_f64(),
);
println!(
"{sampled:4} tokens generated: {:.2} token/s",
sampled as f64 / dt.as_secs_f64(),
);
Ok(())
}