mod advanced;
mod types;
pub use types::*;
use std::num::NonZeroU32;
use thiserror::Error;
use crate::sampling::LlamaSampler;
#[derive(Debug, Error, PartialEq, Eq)]
pub enum ParamsCloneError {
#[error("cannot clone params that own per-sequence sampler chains")]
SamplerChains,
}
#[derive(Debug)]
#[allow(
missing_docs,
clippy::struct_excessive_bools,
clippy::module_name_repetitions
)]
pub struct LlamaContextParams {
pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
pub(crate) attn_rot_disabled: bool,
owned_samplers: Vec<LlamaSampler>,
sampler_configs: Vec<llama_cpp_sys_4::llama_sampler_seq_config>,
}
unsafe impl Send for LlamaContextParams {}
unsafe impl Sync for LlamaContextParams {}
impl LlamaContextParams {
#[must_use]
pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
self
}
#[must_use]
pub fn n_ctx(&self) -> Option<NonZeroU32> {
NonZeroU32::new(self.context_params.n_ctx)
}
#[must_use]
pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
self.context_params.n_seq_max = n_seq_max.max(1);
self
}
#[must_use]
pub fn n_seq_max(&self) -> u32 {
self.context_params.n_seq_max
}
#[must_use]
pub fn with_n_batch(mut self, n_batch: u32) -> Self {
self.context_params.n_batch = n_batch;
self
}
#[must_use]
pub fn n_batch(&self) -> u32 {
self.context_params.n_batch
}
#[must_use]
pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
self.context_params.n_ubatch = n_ubatch;
self
}
#[must_use]
pub fn n_ubatch(&self) -> u32 {
self.context_params.n_ubatch
}
#[must_use]
pub fn with_ctx_type(mut self, ctx_type: LlamaContextType) -> Self {
self.context_params.ctx_type = ctx_type.into();
self
}
#[must_use]
pub fn ctx_type(&self) -> LlamaContextType {
self.context_params.ctx_type.into()
}
#[must_use]
pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
self.context_params.n_rs_seq = n_rs_seq;
self
}
#[must_use]
pub fn n_rs_seq(&self) -> u32 {
self.context_params.n_rs_seq
}
#[must_use]
pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
self.context_params.offload_kqv = enabled;
self
}
#[must_use]
pub fn offload_kqv(&self) -> bool {
self.context_params.offload_kqv
}
#[must_use]
pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
self
}
#[must_use]
pub fn rope_scaling_type(&self) -> RopeScalingType {
RopeScalingType::from(self.context_params.rope_scaling_type)
}
#[must_use]
pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
self.context_params.rope_freq_base = rope_freq_base;
self
}
#[must_use]
pub fn rope_freq_base(&self) -> f32 {
self.context_params.rope_freq_base
}
#[must_use]
pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
self.context_params.rope_freq_scale = rope_freq_scale;
self
}
#[must_use]
pub fn rope_freq_scale(&self) -> f32 {
self.context_params.rope_freq_scale
}
#[must_use]
pub fn n_threads(&self) -> i32 {
self.context_params.n_threads
}
#[must_use]
pub fn n_threads_batch(&self) -> i32 {
self.context_params.n_threads_batch
}
#[must_use]
pub fn with_n_threads(mut self, n_threads: i32) -> Self {
self.context_params.n_threads = n_threads;
self
}
#[must_use]
pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
self.context_params.n_threads_batch = n_threads;
self
}
#[must_use]
pub fn embeddings(&self) -> bool {
self.context_params.embeddings
}
#[must_use]
pub fn with_embeddings(mut self, embedding: bool) -> Self {
self.context_params.embeddings = embedding;
self
}
#[must_use]
pub fn with_cb_eval(
mut self,
cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
) -> Self {
self.context_params.cb_eval = cb_eval;
self
}
#[must_use]
pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
self.context_params.cb_eval_user_data = cb_eval_user_data;
self
}
#[must_use]
pub fn with_tensor_capture(self, capture: &mut super::tensor_capture::TensorCapture) -> Self {
self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
.with_cb_eval_user_data(
std::ptr::from_mut::<super::tensor_capture::TensorCapture>(capture)
.cast::<std::ffi::c_void>(),
)
}
#[must_use]
pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
self
}
#[must_use]
pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
self.context_params.type_k
}
#[must_use]
pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
self
}
#[must_use]
pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
self.context_params.type_v
}
#[must_use]
pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
self.attn_rot_disabled = disabled;
self
}
#[must_use]
pub fn attn_rot_disabled(&self) -> bool {
self.attn_rot_disabled
}
#[must_use]
pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
self.context_params.pooling_type = i32::from(pooling_type);
self
}
#[must_use]
pub fn pooling_type(&self) -> LlamaPoolingType {
LlamaPoolingType::from(self.context_params.pooling_type)
}
pub fn try_clone(&self) -> Result<Self, ParamsCloneError> {
if self.sampler_configs.is_empty() {
Ok(self.clone())
} else {
Err(ParamsCloneError::SamplerChains)
}
}
}
impl Default for LlamaContextParams {
fn default() -> Self {
let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
Self {
context_params,
attn_rot_disabled: false,
owned_samplers: Vec::new(),
sampler_configs: Vec::new(),
}
}
}
impl Clone for LlamaContextParams {
fn clone(&self) -> Self {
let mut context_params = self.context_params;
context_params.samplers = std::ptr::null_mut();
context_params.n_samplers = 0;
Self {
context_params,
attn_rot_disabled: self.attn_rot_disabled,
owned_samplers: Vec::new(),
sampler_configs: Vec::new(),
}
}
}