use std::fmt::Debug;
use std::num::NonZeroU32;
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RopeScalingType {
Unspecified = -1,
None = 0,
Linear = 1,
Yarn = 2,
}
impl From<i32> for RopeScalingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Linear,
2 => Self::Yarn,
_ => Self::Unspecified,
}
}
}
impl From<RopeScalingType> for i32 {
fn from(value: RopeScalingType) -> Self {
match value {
RopeScalingType::None => 0,
RopeScalingType::Linear => 1,
RopeScalingType::Yarn => 2,
RopeScalingType::Unspecified => -1,
}
}
}
#[repr(i8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum LlamaPoolingType {
Unspecified = -1,
None = 0,
Mean = 1,
Cls = 2,
Last = 3,
}
impl From<i32> for LlamaPoolingType {
fn from(value: i32) -> Self {
match value {
0 => Self::None,
1 => Self::Mean,
2 => Self::Cls,
3 => Self::Last,
_ => Self::Unspecified,
}
}
}
impl From<LlamaPoolingType> for i32 {
fn from(value: LlamaPoolingType) -> Self {
match value {
LlamaPoolingType::None => 0,
LlamaPoolingType::Mean => 1,
LlamaPoolingType::Cls => 2,
LlamaPoolingType::Last => 3,
LlamaPoolingType::Unspecified => -1,
}
}
}
#[derive(Debug, Clone)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
clippy::module_name_repetitions
)]
pub struct LlamaContextParams {
pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
pub(crate) attn_rot_disabled: bool,
}
unsafe impl Send for LlamaContextParams {}
unsafe impl Sync for LlamaContextParams {}
impl LlamaContextParams {
#[must_use]
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
self
}
#[must_use]
pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}
#[must_use]
pub fn with_n_batch(mut self, n_batch: u32) -> Self {
self.context_params.n_batch = n_batch;
self
}
#[must_use]
pub fn n_batch(&self) -> u32 {
self.context_params.n_batch
}
#[must_use]
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
self.context_params.n_ubatch = n_ubatch;
self
}
#[must_use]
pub fn n_ubatch(&self) -> u32 {
self.context_params.n_ubatch
}
#[must_use]
pub fn with_flash_attention(mut self, enabled: bool) -> Self {
self.context_params.flash_attn_type = if enabled {
llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
} else {
llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
};
self
}
#[must_use]
pub fn flash_attention(&self) -> bool {
self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
}
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}
#[must_use]
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
self
}
#[must_use]
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}
#[must_use]
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
self.context_params.rope_freq_base = rope_freq_base;
self
}
#[must_use]
pub fn rope_freq_base(&self) -> f32 {
self.context_params.rope_freq_base
}
#[must_use]
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
self.context_params.rope_freq_scale = rope_freq_scale;
self
}
#[must_use]
pub fn rope_freq_scale(&self) -> f32 {
self.context_params.rope_freq_scale
}
#[must_use]
pub fn n_threads(&self) -> i32 {
self.context_params.n_threads
}
#[must_use]
pub fn n_threads_batch(&self) -> i32 {
self.context_params.n_threads_batch
}
#[must_use]
pub fn with_n_threads(mut self, n_threads: i32) -> Self {
self.context_params.n_threads = n_threads;
self
}
#[must_use]
pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
self.context_params.n_threads_batch = n_threads;
self
}
#[must_use]
pub fn embeddings(&self) -> bool {
self.context_params.embeddings
}
#[must_use]
pub fn with_embeddings(mut self, embedding: bool) -> Self {
self.context_params.embeddings = embedding;
self
}
#[must_use]
pub fn with_cb_eval(
mut self,
cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
) -> Self {
self.context_params.cb_eval = cb_eval;
self
}
#[must_use]
pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
self.context_params.cb_eval_user_data = cb_eval_user_data;
self
}
#[must_use]
pub fn with_tensor_capture(
self,
capture: &mut super::tensor_capture::TensorCapture,
) -> Self {
self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
.with_cb_eval_user_data(
capture as *mut super::tensor_capture::TensorCapture as *mut std::ffi::c_void,
)
}
#[must_use]
pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
self
}
#[must_use]
pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
self.context_params.type_k
}
#[must_use]
pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
self
}
#[must_use]
pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
self.context_params.type_v
}
#[must_use]
pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
self.attn_rot_disabled = disabled;
self
}
#[must_use]
pub fn attn_rot_disabled(&self) -> bool {
self.attn_rot_disabled
}
#[must_use]
pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
self.context_params.pooling_type = i32::from(pooling_type);
self
}
#[must_use]
pub fn pooling_type(&self) -> LlamaPoolingType {
LlamaPoolingType::from(self.context_params.pooling_type)
}
}
impl Default for LlamaContextParams {
fn default() -> Self {
let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
Self {
context_params,
attn_rot_disabled: false,
}
}
}