use std::collections::HashMap;
use rig::completion::{CompletionError, GetTokenUsage, Usage};
use rig::message::AssistantContent;
use rig::one_or_many::OneOrMany;
use rig::streaming::{RawStreamingChoice, RawStreamingToolCall};
use serde::{Deserialize, Serialize};
use tokio::sync::{mpsc, oneshot};
#[derive(Clone, Debug, Serialize, Deserialize)]
#[non_exhaustive]
pub struct RawResponse {
pub text: String,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[non_exhaustive]
pub struct StreamChunk {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_input_tokens: Option<u64>,
}
impl GetTokenUsage for StreamChunk {
fn token_usage(&self) -> Option<Usage> {
let (input, output) = self.prompt_tokens.zip(self.completion_tokens)?;
Some(Usage {
input_tokens: input,
output_tokens: output,
total_tokens: input + output,
cached_input_tokens: self.cached_input_tokens.unwrap_or(0),
cache_creation_input_tokens: 0,
})
}
}
pub(crate) type StreamSender =
mpsc::UnboundedSender<Result<RawStreamingChoice<StreamChunk>, CompletionError>>;
pub(crate) enum ResponseChannel {
Completion(oneshot::Sender<Result<InferenceResult, String>>),
Streaming(StreamSender),
}
pub(crate) enum InferenceCommand {
Request(InferenceRequest),
Reload(ReloadRequest),
Shutdown,
}
pub(crate) struct ReloadRequest {
pub model_path: String,
pub mmproj_path: Option<String>,
pub n_ctx: u32,
pub fit_params: FitParams,
pub kv_cache_params: KvCacheParams,
pub checkpoint_params: CheckpointParams,
pub result_tx: std::sync::mpsc::Sender<Result<(), crate::error::LoadError>>,
}
pub(crate) struct InferenceRequest {
pub params: InferenceParams,
pub response_channel: ResponseChannel,
}
pub(crate) struct InferenceParams {
pub prepared_request: PreparedRequest,
pub max_tokens: u32,
pub temperature: f32,
pub top_p: f32,
pub top_k: i32,
pub min_p: f32,
pub presence_penalty: f32,
pub repetition_penalty: f32,
}
pub(crate) struct InferenceResult {
pub text: String,
pub choice: OneOrMany<AssistantContent>,
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub cached_input_tokens: u64,
}
pub(crate) struct PreparedRequest {
pub messages_json: String,
pub tools_json: Option<String>,
pub tool_choice: Option<String>,
pub json_schema: Option<String>,
pub enable_thinking: bool,
#[cfg(feature = "mtmd")]
pub images: Vec<PreparedImage>,
}
#[cfg(feature = "mtmd")]
#[derive(Clone, Debug)]
pub(crate) struct PreparedImage {
pub bytes: Vec<u8>,
pub hash: u64,
}
pub(crate) struct PromptBuildResult {
pub prompt: String,
pub template_result: Option<llama_cpp_2::model::ChatTemplateResult>,
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct SamplingParams {
pub top_p: f32,
pub top_k: i32,
pub min_p: f32,
pub presence_penalty: f32,
pub repetition_penalty: f32,
}
impl Default for SamplingParams {
fn default() -> Self {
Self {
top_p: 0.95,
top_k: 40,
min_p: 0.0,
presence_penalty: 0.0,
repetition_penalty: 1.0,
}
}
}
impl SamplingParams {
#[must_use]
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = top_p;
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: i32) -> Self {
self.top_k = top_k;
self
}
#[must_use]
pub fn with_min_p(mut self, min_p: f32) -> Self {
self.min_p = min_p;
self
}
#[must_use]
pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
self.presence_penalty = presence_penalty;
self
}
#[must_use]
pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
self.repetition_penalty = repetition_penalty;
self
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct FitParams {
pub margins: Option<Vec<usize>>,
pub n_ctx_min: u32,
}
impl Default for FitParams {
fn default() -> Self {
Self {
margins: None,
n_ctx_min: 4096,
}
}
}
impl FitParams {
#[must_use]
pub fn with_margins(mut self, margins: Option<Vec<usize>>) -> Self {
self.margins = margins;
self
}
#[must_use]
pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
self.n_ctx_min = n_ctx_min;
self
}
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct CheckpointParams {
pub max_checkpoints: u32,
pub every_n_tokens: i32,
pub min_tokens: u32,
pub min_gap: u32,
}
impl Default for CheckpointParams {
fn default() -> Self {
Self {
max_checkpoints: 8,
every_n_tokens: 8192,
min_tokens: 64,
min_gap: 64,
}
}
}
impl CheckpointParams {
#[must_use]
pub fn with_max_checkpoints(mut self, max_checkpoints: u32) -> Self {
self.max_checkpoints = max_checkpoints;
self
}
#[must_use]
pub fn with_every_n_tokens(mut self, every_n_tokens: i32) -> Self {
self.every_n_tokens = every_n_tokens;
self
}
#[must_use]
pub fn with_min_tokens(mut self, min_tokens: u32) -> Self {
self.min_tokens = min_tokens;
self
}
#[must_use]
pub fn with_min_gap(mut self, min_gap: u32) -> Self {
self.min_gap = min_gap;
self
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[allow(non_camel_case_types)]
#[non_exhaustive]
pub enum KvCacheType {
F32,
F16,
BF16,
F64,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2_K,
Q3_K,
Q4_K,
Q5_K,
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ2_S,
IQ3_XXS,
IQ3_S,
IQ1_S,
IQ1_M,
IQ4_XS,
IQ4_NL,
I8,
I16,
I32,
I64,
TQ1_0,
TQ2_0,
MXFP4,
}
impl From<KvCacheType> for llama_cpp_2::context::params::KvCacheType {
fn from(value: KvCacheType) -> Self {
use llama_cpp_2::context::params::KvCacheType as Upstream;
match value {
KvCacheType::F32 => Upstream::F32,
KvCacheType::F16 => Upstream::F16,
KvCacheType::BF16 => Upstream::BF16,
KvCacheType::F64 => Upstream::F64,
KvCacheType::Q4_0 => Upstream::Q4_0,
KvCacheType::Q4_1 => Upstream::Q4_1,
KvCacheType::Q5_0 => Upstream::Q5_0,
KvCacheType::Q5_1 => Upstream::Q5_1,
KvCacheType::Q8_0 => Upstream::Q8_0,
KvCacheType::Q8_1 => Upstream::Q8_1,
KvCacheType::Q2_K => Upstream::Q2_K,
KvCacheType::Q3_K => Upstream::Q3_K,
KvCacheType::Q4_K => Upstream::Q4_K,
KvCacheType::Q5_K => Upstream::Q5_K,
KvCacheType::Q6_K => Upstream::Q6_K,
KvCacheType::Q8_K => Upstream::Q8_K,
KvCacheType::IQ2_XXS => Upstream::IQ2_XXS,
KvCacheType::IQ2_XS => Upstream::IQ2_XS,
KvCacheType::IQ2_S => Upstream::IQ2_S,
KvCacheType::IQ3_XXS => Upstream::IQ3_XXS,
KvCacheType::IQ3_S => Upstream::IQ3_S,
KvCacheType::IQ1_S => Upstream::IQ1_S,
KvCacheType::IQ1_M => Upstream::IQ1_M,
KvCacheType::IQ4_XS => Upstream::IQ4_XS,
KvCacheType::IQ4_NL => Upstream::IQ4_NL,
KvCacheType::I8 => Upstream::I8,
KvCacheType::I16 => Upstream::I16,
KvCacheType::I32 => Upstream::I32,
KvCacheType::I64 => Upstream::I64,
KvCacheType::TQ1_0 => Upstream::TQ1_0,
KvCacheType::TQ2_0 => Upstream::TQ2_0,
KvCacheType::MXFP4 => Upstream::MXFP4,
}
}
}
#[derive(Clone, Copy, Debug)]
#[non_exhaustive]
pub struct KvCacheParams {
pub type_k: KvCacheType,
pub type_v: KvCacheType,
}
impl Default for KvCacheParams {
fn default() -> Self {
Self {
type_k: KvCacheType::F16,
type_v: KvCacheType::F16,
}
}
}
impl KvCacheParams {
#[must_use]
pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
self.type_k = type_k;
self
}
#[must_use]
pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
self.type_v = type_v;
self
}
}
pub(crate) struct SamplerChain {
pub sampler: llama_cpp_2::sampling::LlamaSampler,
pub has_grammar: bool,
}
pub(crate) struct StreamDeltaState {
pub tool_calls: HashMap<u64, RawStreamingToolCall>,
}