use std::{
io::IsTerminal,
marker::PhantomData,
num::NonZeroU32,
path::{Path, PathBuf},
str::FromStr,
sync::LazyLock,
};
use encoding_rs::{Decoder, UTF_8};
use hf_hub::api::tokio::ApiBuilder;
use llama_cpp_2::{
LlamaContextLoadError,
context::{LlamaContext, params::LlamaContextParams},
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{AddBos, LlamaChatTemplate, LlamaModel},
mtmd::{self, MtmdBitmap, MtmdContext, MtmdInputText},
sampling::LlamaSampler,
token::LlamaToken,
};
use strum::Display;
use crate::{
error::{CreateLlamaCppRunnerError, RunnerError},
sample::{LlguidanceSamplingParams, SimpleSamplingParams},
template::{ChatTemplate, ModelChatTemplate},
};
pub const QWEN_3D5_4B_GUFF_MODEL_ID: &str = "unsloth/Qwen3.5-4B-GGUF";
pub const QWEN_3D5_4B_GUFF_MODDEL_FILENAME: &str = "Qwen3.5-4B-Q4_K_M.gguf";
pub const QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME: &str = "mmproj-F16.gguf";
pub const GEMMA_3_1B_GUFF_MODEL_ID: &str = "google/gemma-3-1b-it-qat-q4_0-gguf";
pub const GEMMA_3_1B_GUFF_MODEL_FILENAME: &str = "gemma-3-1b-it-q4_0.gguf";
pub trait TextLmRunner<'s, 'req> {
type Response: Iterator<
Item = Result<String, RunnerError<<Self::Template as ChatTemplate>::Error>>,
>;
type Template: ChatTemplate;
fn stream_lm_response(
&'s self,
request: GenericTextLmRequest<'req, Self::Template>,
) -> Self::Response;
}
pub trait VisionLmRunner<'s, 'req> {
type Response: Iterator<
Item = Result<String, RunnerError<<Self::Template as ChatTemplate>::Error>>,
>;
type Template: ChatTemplate;
fn stream_vlm_response(
&'s self,
request: GenericVisionLmRequest<'req, Self::Template>,
) -> Self::Response;
}
#[derive(Debug, Clone)]
pub struct RunnerRequest<MsgCt, Tmpl> {
pub messages: Vec<(MessageRole, MsgCt)>,
pub sampling: SimpleSamplingParams,
pub llguidance: Option<LlguidanceSamplingParams>,
pub max_seq: usize,
pub prefill: Option<String>,
pub tmpl: Tmpl,
}
impl<M, T> Default for RunnerRequest<M, T>
where
T: Default,
{
fn default() -> Self {
Self {
messages: vec![],
sampling: Default::default(),
llguidance: None,
max_seq: usize::MAX,
prefill: None,
tmpl: Default::default(),
}
}
}
pub type GenericTextLmRequest<'a, Tmpl> = RunnerRequest<&'a str, Tmpl>;
pub type GenericVisionLmRequest<'a, Tmpl> = RunnerRequest<ImageOrText<'a>, Tmpl>;
pub type TextLmRequest<'a> = RunnerRequest<&'a str, ModelChatTemplate>;
pub type VisionLmRequest<'a> = RunnerRequest<ImageOrText<'a>, ModelChatTemplate>;
pub trait TextLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
fn get_lm_response(
&'s self,
request: GenericTextLmRequest<'req, Tmpl>,
) -> Result<String, RunnerError<TmplErr>>;
}
pub trait VisionLmRunnerExt<'s, 'req, Tmpl, TmplErr> {
fn get_vlm_response(
&'s self,
request: GenericVisionLmRequest<'req, Tmpl>,
) -> Result<String, RunnerError<TmplErr>>;
}
impl<'s, 'req, TextRunner, Tmpl> TextLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for TextRunner
where
Tmpl: ChatTemplate,
TextRunner: TextLmRunner<'s, 'req, Template = Tmpl>,
{
fn get_lm_response(
&'s self,
request: GenericTextLmRequest<'req, Tmpl>,
) -> Result<String, RunnerError<Tmpl::Error>> {
self.stream_lm_response(request)
.collect::<Result<String, _>>()
}
}
impl<'s, 'req, VisionRunner, Tmpl> VisionLmRunnerExt<'s, 'req, Tmpl, Tmpl::Error> for VisionRunner
where
Tmpl: ChatTemplate,
VisionRunner: VisionLmRunner<'s, 'req, Template = Tmpl>,
{
fn get_vlm_response(
&'s self,
request: GenericVisionLmRequest<'req, Tmpl>,
) -> Result<String, RunnerError<Tmpl::Error>> {
self.stream_vlm_response(request)
.collect::<Result<String, _>>()
}
}
#[derive(Debug, Clone, Display, PartialEq, Eq)]
pub enum MessageRole {
#[strum(to_string = "assistant")]
Assistant,
#[strum(to_string = "user")]
User,
#[strum(to_string = "system")]
System,
#[strum(to_string = "{0}")]
Custom(&'static str),
}
#[derive(Debug, Clone)]
pub enum ImageOrText<'a> {
Text(&'a str),
Image(&'a image::DynamicImage),
}
pub struct Gemma3TextRunner<Tmpl> {
model: LlamaModel,
llama_template: LlamaChatTemplate,
ctx_size: NonZeroU32,
_tmpl: PhantomData<Tmpl>,
}
impl<Tmpl> Gemma3TextRunner<Tmpl> {
pub async fn new(
model_id: impl ToString,
model_file: impl AsRef<str>,
ctx_size: NonZeroU32,
) -> Result<Self, CreateLlamaCppRunnerError> {
let repo = build_hf_api()?.model(model_id.to_string());
Self::from_file(repo.get(model_file.as_ref()).await?, ctx_size)
}
pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
{
let inner = Self::new(
GEMMA_3_1B_GUFF_MODEL_ID,
GEMMA_3_1B_GUFF_MODEL_FILENAME,
32_000.try_into().unwrap(),
)
.await?;
Ok(RunnerWithRecommendedSampling {
inner,
default_sampling: Self::recommend_sampling(),
})
}
pub fn recommend_sampling() -> SimpleSamplingParams {
SimpleSamplingParams {
top_p: Some(0.95f32),
top_k: Some(64),
temperature: Some(1f32),
..Default::default()
}
}
pub fn from_file(
model_file: impl AsRef<Path>,
ctx_size: NonZeroU32,
) -> Result<Self, CreateLlamaCppRunnerError> {
let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
let chat_template = model.chat_template(None)?;
Ok(Self {
model,
llama_template: chat_template,
ctx_size,
_tmpl: PhantomData,
})
}
}
impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3TextRunner<Tmpl>
where
Tmpl: ChatTemplate + 's,
{
type Response = Gemma3Stream<'s, &'req str, Self, Tmpl>;
type Template = Tmpl;
fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
let ctx = self
.model
.new_context(
&LLAMA_BACKEND,
LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
)
.map_err(|err| RunnerError::from(err));
Gemma3Stream::new(ctx, request, self, &self.model)
}
}
pub struct Gemma3VisionRunner<Tmpl> {
model: LlamaModel,
chat_template: LlamaChatTemplate,
mtmd_ctx: MtmdContext,
ctx_size: NonZeroU32,
_tmpl: PhantomData<Tmpl>,
}
static LLAMA_BACKEND: LazyLock<LlamaBackend> = LazyLock::new(|| {
llama_cpp_2::send_logs_to_tracing(llama_cpp_2::LogOptions::default());
LlamaBackend::init().unwrap()
});
impl<Tmpl> Gemma3VisionRunner<Tmpl> {
pub async fn new(
repo_id: impl ToString,
model_file: impl AsRef<str>,
multimodel_file: impl AsRef<str>,
ctx_size: NonZeroU32,
) -> Result<Self, CreateLlamaCppRunnerError> {
let repo = build_hf_api()?.model(repo_id.to_string());
let model = LlamaModel::load_from_file(
&LLAMA_BACKEND,
repo.get(model_file.as_ref()).await?,
&Default::default(),
)?;
let mtmd_ctx = MtmdContext::init_from_file(
repo.get(multimodel_file.as_ref()).await?.to_str().unwrap(),
&model,
&Default::default(),
)?;
let chat_template = model.chat_template(None)?;
Ok(Self {
model,
mtmd_ctx,
chat_template,
ctx_size,
_tmpl: PhantomData,
})
}
pub async fn default() -> Result<RunnerWithRecommendedSampling<Self>, CreateLlamaCppRunnerError>
{
let inner = Self::new(
QWEN_3D5_4B_GUFF_MODEL_ID,
QWEN_3D5_4B_GUFF_MODDEL_FILENAME,
QWEN_3D5_4B_GUFF_MULTIMODEL_FILENAME,
16384u32.try_into().unwrap(),
)
.await?;
Ok(RunnerWithRecommendedSampling {
inner: inner,
default_sampling: SimpleSamplingParams {
top_p: Some(0.8f32),
top_k: Some(20),
temperature: Some(0.7f32),
presence_penalty: Some(1.5),
repetition_penalty: Some(1.0),
seed: None,
},
})
}
pub fn from_files(
model_file: impl AsRef<Path>,
multimodel_file: impl AsRef<Path>,
ctx_size: NonZeroU32,
) -> Result<Self, CreateLlamaCppRunnerError> {
let model = LlamaModel::load_from_file(&LLAMA_BACKEND, model_file, &Default::default())?;
let mtmd_ctx = MtmdContext::init_from_file(
multimodel_file.as_ref().as_os_str().to_str().unwrap(),
&model,
&Default::default(),
)?;
let chat_template = model.chat_template(None)?;
Ok(Self {
model,
mtmd_ctx,
chat_template,
ctx_size,
_tmpl: PhantomData,
})
}
fn new_context_window(&self) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
self.model.new_context(
&LLAMA_BACKEND,
LlamaContextParams::default().with_n_ctx(Some(self.ctx_size)),
)
}
}
impl<'s, 'req, Tmpl> VisionLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
where
Tmpl: ChatTemplate + 's,
{
type Response = Gemma3Stream<'s, ImageOrText<'req>, Self, Tmpl>;
type Template = Tmpl;
fn stream_vlm_response(
&'s self,
request: GenericVisionLmRequest<'req, Tmpl>,
) -> Self::Response {
let ctx = self
.new_context_window()
.map_err(|err| RunnerError::from(err));
Gemma3Stream::new(ctx, request, self, &self.model)
}
}
impl<'s, 'req, Tmpl> TextLmRunner<'s, 'req> for Gemma3VisionRunner<Tmpl>
where
Tmpl: ChatTemplate + 's,
{
type Response = <Self as VisionLmRunner<'s, 'req>>::Response;
type Template = <Self as VisionLmRunner<'s, 'req>>::Template;
fn stream_lm_response(&'s self, request: GenericTextLmRequest<'req, Tmpl>) -> Self::Response {
self.stream_vlm_response(request.into())
}
}
impl<'a, Tmpl> From<GenericTextLmRequest<'a, Tmpl>> for GenericVisionLmRequest<'a, Tmpl> {
fn from(value: GenericTextLmRequest<'a, Tmpl>) -> Self {
Self {
messages: value
.messages
.into_iter()
.map(|(role, text)| (role, ImageOrText::Text(text)))
.collect(),
sampling: value.sampling,
llguidance: value.llguidance,
max_seq: value.max_seq,
prefill: value.prefill,
tmpl: value.tmpl,
}
}
}
pub struct Gemma3Stream<'a, Message, Runner, Tmpl: ChatTemplate> {
ctx_source: Option<Result<LlamaContext<'a>, RunnerError<Tmpl::Error>>>,
ctx: Option<LlamaContext<'a>>,
req: RunnerRequest<Message, Tmpl>,
runner: &'a Runner,
model: &'a LlamaModel,
runtime: Option<Runtime<'a>>,
done: bool,
}
struct Runtime<'a> {
sampler: LlamaSampler,
decoder: Decoder,
batch: LlamaBatch<'a>,
n_past: i32,
step: usize,
}
trait PrepareRun<TmplErr> {
fn prepare(&mut self) -> Result<(), RunnerError<TmplErr>>;
}
impl<Tmpl> PrepareRun<Tmpl::Error>
for Gemma3Stream<'_, ImageOrText<'_>, Gemma3VisionRunner<Tmpl>, Tmpl>
where
Tmpl: ChatTemplate,
{
fn prepare(&mut self) -> Result<(), RunnerError<Tmpl::Error>> {
let media_marker = mtmd::mtmd_default_marker();
let messages = self
.req
.messages
.iter()
.fold(
Vec::<(MessageRole, String)>::new(),
|mut acc, (role, message)| {
let text = match message {
ImageOrText::Text(text) => text,
ImageOrText::Image(_) => media_marker,
};
if let Some(last) = acc.last()
&& last.0 == *role
{
let (_, adj) = acc.remove(acc.len() - 1);
acc.push((role.clone(), format!("{0}\n{text}", adj)));
acc
} else {
acc.push((role.clone(), text.to_string()));
acc
}
},
)
.into_iter()
.collect::<Vec<_>>();
log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
let formatted_prompt = self
.req
.tmpl
.apply_template(self.model, &self.runner.chat_template, &messages)
.map_err(RunnerError::ApplyChatTemplate)?;
let bitmaps = self
.req
.messages
.iter()
.filter_map(|msg| match &msg.1 {
ImageOrText::Image(image) => Some(image),
_ => None,
})
.enumerate()
.map(|(idx, im)| {
MtmdBitmap::from_image_data(
im.width(),
im.height(),
im.to_rgb8().to_vec().as_slice(),
)
.expect(format!("image#{} has corrupted RGB data", idx).as_str())
})
.collect::<Vec<_>>();
let bitmap_refs = bitmaps.iter().collect::<Vec<_>>();
let chunks = self.runner.mtmd_ctx.tokenize(
MtmdInputText {
text: formatted_prompt,
add_special: true,
parse_special: true,
},
&bitmap_refs,
)?;
log::debug!(target: "gemma", "tokenization resulted in {} chunks", chunks.len());
let n_past = chunks.eval_chunks(
&self.runner.mtmd_ctx,
self.ctx.as_ref().unwrap(),
0,
0,
1,
true,
)?;
let mut preparation = Runtime {
sampler: self.req.sampling.to_llama(),
decoder: UTF_8.new_decoder(),
batch: LlamaBatch::new(self.runner.ctx_size.get() as usize, 1),
n_past,
step: 0,
};
if let Some(llguidance) = &self.req.llguidance {
let llg_sampler = llguidance.to_llama(&self.runner.model)?;
preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
}
self.runtime = Some(preparation);
Ok(())
}
}
impl<S: AsRef<str>, Tmpl> PrepareRun<Tmpl::Error>
for Gemma3Stream<'_, S, Gemma3TextRunner<Tmpl>, Tmpl>
where
Tmpl: ChatTemplate,
{
fn prepare(&mut self) -> Result<(), RunnerError<Tmpl::Error>> {
let messages = self
.req
.messages
.iter()
.fold(
Vec::<(MessageRole, String)>::new(),
|mut acc, (role, message)| {
if let Some(last) = acc.last()
&& last.0 == *role
{
let (_, adj) = acc.remove(acc.len() - 1);
acc.push((role.clone(), format!("{0}\n{1}", adj, message.as_ref())));
acc
} else {
acc.push((role.clone(), message.as_ref().to_string()));
acc
}
},
)
.into_iter()
.collect::<Vec<_>>();
log::debug!(target: "gemma", "preprocessed messages: {messages:?}");
let formatted_prompt = self
.req
.tmpl
.apply_template(self.model, &self.runner.llama_template, &messages)
.map_err(RunnerError::ApplyChatTemplate)?;
let token_list = self.model.str_to_token(&formatted_prompt, AddBos::Always)?;
let mut batch = LlamaBatch::new(self.runner.ctx_size.get() as usize, 1);
let token_list_len = token_list.len();
for (i, token) in token_list.into_iter().enumerate() {
batch.add(token, i as i32, &[0], i == token_list_len - 1)?;
}
self.ctx.as_mut().unwrap().decode(&mut batch)?;
let mut preparation = Runtime {
sampler: self.req.sampling.to_llama(),
decoder: UTF_8.new_decoder(),
batch,
n_past: token_list_len as i32,
step: 0,
};
if let Some(llguidance) = &self.req.llguidance {
let llg_sampler = llguidance.to_llama(&self.runner.model)?;
preparation.sampler = LlamaSampler::chain_simple([llg_sampler, preparation.sampler]);
}
self.runtime = Some(preparation);
Ok(())
}
}
impl<'a, Message, Runner, Tmpl> Iterator for Gemma3Stream<'a, Message, Runner, Tmpl>
where
Tmpl: ChatTemplate,
Self: PrepareRun<Tmpl::Error>,
{
type Item = Result<String, RunnerError<Tmpl::Error>>;
fn next(&mut self) -> Option<Self::Item> {
if self.done {
return None;
}
if let Some(result) = self.ctx_source.take() {
match result {
Ok(ctx) => self.ctx = Some(ctx),
Err(err) => {
self.done = true;
return Some(Err(err));
}
}
}
if self.runtime.is_none()
&& let Err(err) = self.prepare()
{
self.done = true;
return Some(Err(err));
}
let Runtime {
sampler,
decoder,
batch,
n_past,
step,
} = self.runtime.as_mut().unwrap();
if *step >= self.req.max_seq {
self.done = true;
return None;
}
let ctx = self.ctx.as_mut().unwrap();
let model = self.model;
let sample_idx = batch.n_tokens() - 1;
let mut sample = |token: LlamaToken,
sampler: &mut LlamaSampler,
ctx: &mut LlamaContext<'a>,
step: usize|
-> Result<Option<String>, RunnerError<Tmpl::Error>> {
sampler.accept(token);
if model.is_eog_token(token) {
return Ok(None);
}
batch.clear();
batch.add(token, *n_past + (step as i32), &[0], true)?;
ctx.decode(batch)?;
let piece = model.token_to_piece(token, decoder, true, None)?;
Ok(Some(piece))
};
if let Some(prefill) = self.req.prefill.take() {
log::debug!(target: "gemma", "prefill: {}", prefill);
let tokens = match model.str_to_token(&prefill, AddBos::Never) {
Ok(tokens) => tokens,
Err(err) => {
return Some(Err(err.into()));
}
};
log::debug!(target: "gemma", "prefill tokens: {:?}", tokens.iter().map(|t| t.0).collect::<Vec<_>>());
for token in tokens {
match sample(token, sampler, ctx, *step) {
Ok(_) => {}
Err(err) => return Some(Err(err.into())),
}
*step += 1;
}
Some(Ok(prefill))
} else {
let token = sampler.sample(ctx, sample_idx);
match sample(token, sampler, ctx, *step) {
Ok(Some(piece)) => {
*step += 1;
return Some(Ok(piece));
}
Ok(None) => {
self.done = true;
return None;
}
Err(err) => {
self.done = true;
return Some(Err(err));
}
}
}
}
}
impl<'s, Message, Runner, Tmpl> Gemma3Stream<'s, Message, Runner, Tmpl>
where
Tmpl: ChatTemplate,
{
fn new(
source: Result<LlamaContext<'s>, RunnerError<Tmpl::Error>>,
req: RunnerRequest<Message, Tmpl>,
runner: &'s Runner,
model: &'s LlamaModel,
) -> Self {
Self {
ctx_source: Some(source),
ctx: None,
req,
runner,
model,
runtime: None,
done: false,
}
}
}
pub struct RunnerWithRecommendedSampling<Inner> {
pub inner: Inner,
pub default_sampling: SimpleSamplingParams,
}
impl<'a, Inner> RunnerWithRecommendedSampling<Inner> {
fn get_preprocessed_simple_sampling(
&self,
sampling: SimpleSamplingParams,
) -> SimpleSamplingParams {
let mut sampling = sampling;
if sampling.top_k.is_none() {
sampling.top_k = self.default_sampling.top_k;
}
if sampling.top_p.is_none() {
sampling.top_p = self.default_sampling.top_p;
}
if sampling.temperature.is_none() {
sampling.temperature = self.default_sampling.temperature;
}
sampling
}
}
impl<'s, 'req, Inner, Tmpl> VisionLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
where
Tmpl: ChatTemplate,
Inner: VisionLmRunner<'s, 'req, Template = Tmpl>,
{
type Response = <Inner as VisionLmRunner<'s, 'req>>::Response;
type Template = Tmpl;
fn stream_vlm_response(
&'s self,
mut request: GenericVisionLmRequest<'req, Tmpl>,
) -> Self::Response {
request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
self.inner.stream_vlm_response(request)
}
}
impl<'s, 'req, Inner, Tmpl> TextLmRunner<'s, 'req> for RunnerWithRecommendedSampling<Inner>
where
Tmpl: ChatTemplate,
Inner: TextLmRunner<'s, 'req, Template = Tmpl>,
{
type Response = <Inner as TextLmRunner<'s, 'req>>::Response;
type Template = Tmpl;
fn stream_lm_response(
&'s self,
mut request: GenericTextLmRequest<'req, Tmpl>,
) -> Self::Response {
request.sampling = self.get_preprocessed_simple_sampling(request.sampling);
self.inner.stream_lm_response(request)
}
}
impl<Inner> From<Inner> for RunnerWithRecommendedSampling<Inner> {
fn from(value: Inner) -> Self {
Self {
inner: value,
default_sampling: SimpleSamplingParams::default(),
}
}
}
fn build_hf_api() -> Result<hf_hub::api::tokio::Api, hf_hub::api::tokio::ApiError> {
let mut api = ApiBuilder::new()
.with_progress(std::io::stdin().is_terminal())
.with_token(std::env::var("HF_TOKEN").ok())
.with_chunk_size(Some(2 << 28));
if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
api = api.with_endpoint(endpoint);
}
if let Ok(cache) = std::env::var("HF_HOME") {
api = api.with_cache_dir(
PathBuf::from_str(&cache).expect("HF_HOME env var is not a valid path"),
);
}
api.build()
}