use std::fmt::Debug;
use std::num::NonZeroU32;
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RopeScalingType {
Unspecified = -1,
None = 0,
Linear = 1,
Yarn = 2,
}
impl From<i32> for RopeScalingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Linear,
2 => Self::Yarn,
_ => Self::Unspecified,
}
}
}
impl From<RopeScalingType> for i32 {
fn from(value: RopeScalingType) -> Self {
match value {
RopeScalingType::None => 0,
RopeScalingType::Linear => 1,
RopeScalingType::Yarn => 2,
RopeScalingType::Unspecified => -1,
}
}
}
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaPoolingType {
Unspecified = -1,
None = 0,
Mean = 1,
Cls = 2,
Last = 3,
Rank = 4,
}
impl From<i32> for LlamaPoolingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Mean,
2 => Self::Cls,
3 => Self::Last,
4 => Self::Rank,
_ => Self::Unspecified,
}
}
}
impl From<LlamaPoolingType> for i32 {
fn from(value: LlamaPoolingType) -> Self {
match value {
LlamaPoolingType::None => 0,
LlamaPoolingType::Mean => 1,
LlamaPoolingType::Cls => 2,
LlamaPoolingType::Last => 3,
LlamaPoolingType::Rank => 4,
LlamaPoolingType::Unspecified => -1,
}
}
}
#[allow(non_camel_case_types, missing_docs)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum KvCacheType {
Unknown(llama_cpp_sys_2::ggml_type),
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2_K,
Q3_K,
Q4_K,
Q5_K,
Q6_K,
Q8_K,
IQ2_XXS,
IQ2_XS,
IQ3_XXS,
IQ1_S,
IQ4_NL,
IQ3_S,
IQ2_S,
IQ4_XS,
I8,
I16,
I32,
I64,
F64,
IQ1_M,
BF16,
TQ1_0,
TQ2_0,
MXFP4,
}
impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
fn from(value: KvCacheType) -> Self {
match value {
KvCacheType::Unknown(raw) => raw,
KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
}
}
}
impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
match value {
x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
_ => KvCacheType::Unknown(value),
}
}
}
#[derive(Debug, Clone)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
clippy::module_name_repetitions
)]
pub struct LlamaContextParams {
pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
}
unsafe impl Send for LlamaContextParams {}
unsafe impl Sync for LlamaContextParams {}
impl LlamaContextParams {
#[must_use]
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
self
}
#[must_use]
pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}
#[must_use]
pub fn with_n_batch(mut self, n_batch: u32) -> Self {
self.context_params.n_batch = n_batch;
self
}
#[must_use]
pub fn n_batch(&self) -> u32 {
self.context_params.n_batch
}
#[must_use]
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
self.context_params.n_ubatch = n_ubatch;
self
}
#[must_use]
pub fn n_ubatch(&self) -> u32 {
self.context_params.n_ubatch
}
#[must_use]
pub fn with_flash_attention_policy(
mut self,
policy: llama_cpp_sys_2::llama_flash_attn_type,
) -> Self {
self.context_params.flash_attn_type = policy;
self
}
#[must_use]
pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
self.context_params.flash_attn_type
}
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}
#[must_use]
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
self
}
#[must_use]
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}
#[must_use]
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
self.context_params.rope_freq_base = rope_freq_base;
self
}
#[must_use]
pub fn rope_freq_base(&self) -> f32 {
self.context_params.rope_freq_base
}
#[must_use]
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
self.context_params.rope_freq_scale = rope_freq_scale;
self
}
#[must_use]
pub fn rope_freq_scale(&self) -> f32 {
self.context_params.rope_freq_scale
}
#[must_use]
pub fn n_threads(&self) -> i32 {
self.context_params.n_threads
}
#[must_use]
pub fn n_threads_batch(&self) -> i32 {
self.context_params.n_threads_batch
}
#[must_use]
pub fn with_n_threads(mut self, n_threads: i32) -> Self {
self.context_params.n_threads = n_threads;
self
}
#[must_use]
pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
self.context_params.n_threads_batch = n_threads;
self
}
#[must_use]
pub fn embeddings(&self) -> bool {
self.context_params.embeddings
}
#[must_use]
pub fn with_embeddings(mut self, embedding: bool) -> Self {
self.context_params.embeddings = embedding;
self
}
#[must_use]
pub fn with_cb_eval(
mut self,
cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
) -> Self {
self.context_params.cb_eval = cb_eval;
self
}
#[must_use]
pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
self.context_params.cb_eval_user_data = cb_eval_user_data;
self
}
#[must_use]
pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
self.context_params.pooling_type = i32::from(pooling_type);
self
}
#[must_use]
pub fn pooling_type(&self) -> LlamaPoolingType {
LlamaPoolingType::from(self.context_params.pooling_type)
}
#[must_use]
pub fn with_swa_full(mut self, enabled: bool) -> Self {
self.context_params.swa_full = enabled;
self
}
#[must_use]
pub fn swa_full(&self) -> bool {
self.context_params.swa_full
}
#[must_use]
pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
self.context_params.n_seq_max = n_seq_max;
self
}
#[must_use]
pub fn n_seq_max(&self) -> u32 {
self.context_params.n_seq_max
}
#[must_use]
pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
self.context_params.type_k = type_k.into();
self
}
#[must_use]
pub fn type_k(&self) -> KvCacheType {
KvCacheType::from(self.context_params.type_k)
}
#[must_use]
pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
self.context_params.type_v = type_v.into();
self
}
#[must_use]
pub fn type_v(&self) -> KvCacheType {
KvCacheType::from(self.context_params.type_v)
}
}
impl Default for LlamaContextParams {
fn default() -> Self {
let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
Self { context_params }
}
}