use std::ptr::null_mut;
use llama_cpp_sys::{
ggml_type, llama_context_default_params, llama_context_params, llama_pooling_type,
llama_rope_scaling_type,
};
#[derive(Clone, Copy, Debug)]
pub enum PoolingType {
Unspecified,
None,
Mean,
Cls,
}
impl From<PoolingType> for llama_pooling_type {
fn from(value: PoolingType) -> Self {
match value {
PoolingType::Unspecified => llama_pooling_type::LLAMA_POOLING_TYPE_UNSPECIFIED,
PoolingType::None => llama_pooling_type::LLAMA_POOLING_TYPE_NONE,
PoolingType::Mean => llama_pooling_type::LLAMA_POOLING_TYPE_MEAN,
PoolingType::Cls => llama_pooling_type::LLAMA_POOLING_TYPE_CLS,
}
}
}
impl From<llama_pooling_type> for PoolingType {
fn from(value: llama_pooling_type) -> Self {
#![allow(non_upper_case_globals)]
match value {
llama_pooling_type::LLAMA_POOLING_TYPE_UNSPECIFIED => PoolingType::Unspecified,
llama_pooling_type::LLAMA_POOLING_TYPE_NONE => PoolingType::None,
llama_pooling_type::LLAMA_POOLING_TYPE_MEAN => PoolingType::Mean,
llama_pooling_type::LLAMA_POOLING_TYPE_CLS => PoolingType::Cls,
_ => unimplemented!(),
}
}
}
#[derive(Clone, Copy)]
pub enum RopeScaling {
Unspecified,
None,
Linear,
Yarn,
}
impl From<RopeScaling> for llama_rope_scaling_type {
fn from(value: RopeScaling) -> Self {
match value {
RopeScaling::Unspecified => {
llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
}
RopeScaling::None => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_NONE,
RopeScaling::Linear => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_LINEAR,
RopeScaling::Yarn => llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_YARN,
}
}
}
impl From<llama_rope_scaling_type> for RopeScaling {
fn from(value: llama_rope_scaling_type) -> Self {
match value {
llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED => {
RopeScaling::Unspecified
}
llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_NONE => RopeScaling::None,
llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_LINEAR => RopeScaling::Linear,
llama_rope_scaling_type::LLAMA_ROPE_SCALING_TYPE_YARN => RopeScaling::Yarn,
_ => unimplemented!(),
}
}
}
#[derive(Clone, Copy)]
pub enum CacheType {
F32,
F16,
Q4_0,
Q4_1,
Q5_0,
Q5_1,
Q8_0,
Q8_1,
Q2K,
Q3K,
Q4K,
Q5K,
Q6K,
Q8K,
IQ2XXS,
IQ2XS,
IQ3XXS,
IQ1S,
IQ4NL,
IQ3S,
IQ2S,
IQ4XS,
I8,
I16,
I32,
I64,
F64,
Count,
}
impl From<CacheType> for ggml_type {
fn from(value: CacheType) -> Self {
match value {
CacheType::F32 => ggml_type::GGML_TYPE_F32,
CacheType::F16 => ggml_type::GGML_TYPE_F16,
CacheType::Q4_0 => ggml_type::GGML_TYPE_Q4_0,
CacheType::Q4_1 => ggml_type::GGML_TYPE_Q4_1,
CacheType::Q5_0 => ggml_type::GGML_TYPE_Q5_0,
CacheType::Q5_1 => ggml_type::GGML_TYPE_Q5_1,
CacheType::Q8_0 => ggml_type::GGML_TYPE_Q8_0,
CacheType::Q8_1 => ggml_type::GGML_TYPE_Q8_1,
CacheType::Q2K => ggml_type::GGML_TYPE_Q2_K,
CacheType::Q3K => ggml_type::GGML_TYPE_Q3_K,
CacheType::Q4K => ggml_type::GGML_TYPE_Q4_K,
CacheType::Q5K => ggml_type::GGML_TYPE_Q5_K,
CacheType::Q6K => ggml_type::GGML_TYPE_Q6_K,
CacheType::Q8K => ggml_type::GGML_TYPE_Q8_K,
CacheType::IQ2XXS => ggml_type::GGML_TYPE_IQ2_XXS,
CacheType::IQ2XS => ggml_type::GGML_TYPE_IQ2_XS,
CacheType::IQ3XXS => ggml_type::GGML_TYPE_IQ3_XXS,
CacheType::IQ1S => ggml_type::GGML_TYPE_IQ1_S,
CacheType::IQ4NL => ggml_type::GGML_TYPE_IQ4_NL,
CacheType::IQ3S => ggml_type::GGML_TYPE_IQ3_S,
CacheType::IQ2S => ggml_type::GGML_TYPE_IQ2_S,
CacheType::IQ4XS => ggml_type::GGML_TYPE_IQ4_XS,
CacheType::I8 => ggml_type::GGML_TYPE_I8,
CacheType::I16 => ggml_type::GGML_TYPE_I16,
CacheType::I32 => ggml_type::GGML_TYPE_I32,
CacheType::I64 => ggml_type::GGML_TYPE_I64,
CacheType::F64 => ggml_type::GGML_TYPE_F64,
CacheType::Count => ggml_type::GGML_TYPE_COUNT,
}
}
}
impl From<ggml_type> for CacheType {
fn from(value: ggml_type) -> Self {
match value {
ggml_type::GGML_TYPE_F32 => CacheType::F32,
ggml_type::GGML_TYPE_F16 => CacheType::F16,
ggml_type::GGML_TYPE_Q4_0 => CacheType::Q4_0,
ggml_type::GGML_TYPE_Q4_1 => CacheType::Q4_1,
ggml_type::GGML_TYPE_Q5_0 => CacheType::Q5_0,
ggml_type::GGML_TYPE_Q5_1 => CacheType::Q5_1,
ggml_type::GGML_TYPE_Q8_0 => CacheType::Q8_0,
ggml_type::GGML_TYPE_Q8_1 => CacheType::Q8_1,
ggml_type::GGML_TYPE_Q2_K => CacheType::Q2K,
ggml_type::GGML_TYPE_Q3_K => CacheType::Q3K,
ggml_type::GGML_TYPE_Q4_K => CacheType::Q4K,
ggml_type::GGML_TYPE_Q5_K => CacheType::Q5K,
ggml_type::GGML_TYPE_Q6_K => CacheType::Q6K,
ggml_type::GGML_TYPE_Q8_K => CacheType::Q8K,
ggml_type::GGML_TYPE_IQ2_XXS => CacheType::IQ2XXS,
ggml_type::GGML_TYPE_IQ2_XS => CacheType::IQ2XS,
ggml_type::GGML_TYPE_IQ3_XXS => CacheType::IQ3XXS,
ggml_type::GGML_TYPE_IQ1_S => CacheType::IQ1S,
ggml_type::GGML_TYPE_IQ4_NL => CacheType::IQ4NL,
ggml_type::GGML_TYPE_IQ3_S => CacheType::IQ3S,
ggml_type::GGML_TYPE_IQ2_S => CacheType::IQ2S,
ggml_type::GGML_TYPE_IQ4_XS => CacheType::IQ4XS,
ggml_type::GGML_TYPE_I8 => CacheType::I8,
ggml_type::GGML_TYPE_I16 => CacheType::I16,
ggml_type::GGML_TYPE_I32 => CacheType::I32,
ggml_type::GGML_TYPE_I64 => CacheType::I64,
ggml_type::GGML_TYPE_F64 => CacheType::F64,
ggml_type::GGML_TYPE_COUNT => CacheType::Count,
_ => unimplemented!(),
}
}
}
#[derive(Clone)]
pub struct SessionParams {
pub seed: u32,
pub n_ctx: u32,
pub n_batch: u32,
pub n_ubatch: u32,
pub n_seq_max: u32,
pub n_threads: u32,
pub n_threads_batch: u32,
pub rope_scaling_type: RopeScaling,
pub rope_freq_base: f32,
pub rope_freq_scale: f32,
pub yarn_ext_factor: f32,
pub yarn_attn_factor: f32,
pub yarn_beta_fast: f32,
pub yarn_beta_slow: f32,
pub yarn_orig_ctx: u32,
pub type_k: CacheType,
pub type_v: CacheType,
pub embedding: bool,
pub offload_kqv: bool,
pub pooling: PoolingType,
pub defrag_threshold: f32,
}
impl Default for SessionParams {
fn default() -> Self {
let c_defaults = unsafe {
llama_context_default_params()
};
let threads = num_cpus::get_physical() as u32 - 1;
Self {
seed: c_defaults.seed,
n_ctx: c_defaults.n_ctx,
n_batch: c_defaults.n_batch,
n_ubatch: c_defaults.n_ubatch,
n_seq_max: c_defaults.n_seq_max,
n_threads: threads,
n_threads_batch: threads,
rope_scaling_type: c_defaults.rope_scaling_type.into(),
rope_freq_base: c_defaults.rope_freq_base,
rope_freq_scale: c_defaults.rope_freq_scale,
yarn_ext_factor: c_defaults.yarn_ext_factor,
yarn_attn_factor: c_defaults.yarn_attn_factor,
yarn_beta_fast: c_defaults.yarn_beta_fast,
yarn_beta_slow: c_defaults.yarn_beta_slow,
yarn_orig_ctx: c_defaults.yarn_orig_ctx,
type_k: c_defaults.type_k.into(),
type_v: c_defaults.type_v.into(),
embedding: c_defaults.embeddings,
offload_kqv: c_defaults.offload_kqv,
pooling: c_defaults.pooling_type.into(),
defrag_threshold: c_defaults.defrag_thold,
}
}
}
impl From<SessionParams> for llama_context_params {
fn from(value: SessionParams) -> Self {
Self {
seed: value.seed,
n_ctx: value.n_ctx,
n_batch: value.n_batch,
n_ubatch: value.n_ubatch,
n_seq_max: value.n_seq_max,
n_threads: value.n_threads,
n_threads_batch: value.n_threads_batch,
rope_scaling_type: value.rope_scaling_type.into(),
rope_freq_base: value.rope_freq_base,
rope_freq_scale: value.rope_freq_scale,
yarn_ext_factor: value.yarn_ext_factor,
yarn_attn_factor: value.yarn_attn_factor,
yarn_beta_fast: value.yarn_beta_fast,
yarn_beta_slow: value.yarn_beta_slow,
yarn_orig_ctx: value.yarn_orig_ctx,
defrag_thold: value.defrag_threshold,
cb_eval: None,
cb_eval_user_data: null_mut(),
type_k: value.type_k.into(),
type_v: value.type_v.into(),
logits_all: false, embeddings: value.embedding,
offload_kqv: value.offload_kqv,
pooling_type: value.pooling.into(),
abort_callback: None,
abort_callback_data: null_mut(),
}
}
}
impl From<llama_context_params> for SessionParams {
fn from(value: llama_context_params) -> Self {
Self {
seed: value.seed,
n_ctx: value.n_ctx,
n_batch: value.n_batch,
n_ubatch: value.n_ubatch,
n_seq_max: value.n_seq_max,
n_threads: value.n_threads,
n_threads_batch: value.n_threads_batch,
rope_scaling_type: value.rope_scaling_type.into(),
rope_freq_base: value.rope_freq_base,
rope_freq_scale: value.rope_freq_scale,
yarn_ext_factor: value.yarn_ext_factor,
yarn_attn_factor: value.yarn_attn_factor,
yarn_beta_fast: value.yarn_beta_fast,
yarn_beta_slow: value.yarn_beta_slow,
yarn_orig_ctx: value.yarn_orig_ctx,
type_k: value.type_k.into(),
type_v: value.type_v.into(),
embedding: value.embeddings,
offload_kqv: value.offload_kqv,
pooling: value.pooling_type.into(),
defrag_threshold: value.defrag_thold,
}
}
}