Skip to main content

llama_cpp_2/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2mod get_set;
3
4/// A rusty wrapper around `rope_scaling_type`.
5#[repr(i8)]
6#[derive(Copy, Clone, Debug, PartialEq, Eq)]
7pub enum RopeScalingType {
8    /// The scaling type is unspecified
9    Unspecified = -1,
10    /// No scaling
11    None = 0,
12    /// Linear scaling
13    Linear = 1,
14    /// Yarn scaling
15    Yarn = 2,
16}
17
18/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
19/// the value is not recognized.
20impl From<i32> for RopeScalingType {
21    fn from(value: i32) -> Self {
22        match value {
23            0 => Self::None,
24            1 => Self::Linear,
25            2 => Self::Yarn,
26            _ => Self::Unspecified,
27        }
28    }
29}
30
31/// Create a `c_int` from a `RopeScalingType`.
32impl From<RopeScalingType> for i32 {
33    fn from(value: RopeScalingType) -> Self {
34        match value {
35            RopeScalingType::None => 0,
36            RopeScalingType::Linear => 1,
37            RopeScalingType::Yarn => 2,
38            RopeScalingType::Unspecified => -1,
39        }
40    }
41}
42
43/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
44#[repr(i8)]
45#[derive(Copy, Clone, Debug, PartialEq, Eq)]
46pub enum LlamaPoolingType {
47    /// The pooling type is unspecified
48    Unspecified = -1,
49    /// No pooling
50    None = 0,
51    /// Mean pooling
52    Mean = 1,
53    /// CLS pooling
54    Cls = 2,
55    /// Last pooling
56    Last = 3,
57    /// Rank pooling
58    Rank = 4,
59}
60
61/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
62/// the value is not recognized.
63impl From<i32> for LlamaPoolingType {
64    fn from(value: i32) -> Self {
65        match value {
66            0 => Self::None,
67            1 => Self::Mean,
68            2 => Self::Cls,
69            3 => Self::Last,
70            4 => Self::Rank,
71            _ => Self::Unspecified,
72        }
73    }
74}
75
76/// Create a `c_int` from a `LlamaPoolingType`.
77impl From<LlamaPoolingType> for i32 {
78    fn from(value: LlamaPoolingType) -> Self {
79        match value {
80            LlamaPoolingType::None => 0,
81            LlamaPoolingType::Mean => 1,
82            LlamaPoolingType::Cls => 2,
83            LlamaPoolingType::Last => 3,
84            LlamaPoolingType::Rank => 4,
85            LlamaPoolingType::Unspecified => -1,
86        }
87    }
88}
89
90/// A rusty wrapper around `LLAMA_ATTENTION_TYPE`.
91#[repr(i8)]
92#[derive(Copy, Clone, Debug, PartialEq, Eq)]
93pub enum LlamaAttentionType {
94    /// The attention type is unspecified
95    Unspecified = -1,
96    /// Causal attention
97    Causal = 0,
98    /// Non-causal attention
99    NonCausal = 1,
100}
101
102/// Create a `LlamaAttentionType` from a `c_int` - returns `LlamaAttentionType::Unspecified` if
103/// the value is not recognized.
104impl From<i32> for LlamaAttentionType {
105    fn from(value: i32) -> Self {
106        match value {
107            0 => Self::Causal,
108            1 => Self::NonCausal,
109            _ => Self::Unspecified,
110        }
111    }
112}
113
114/// Create a `c_int` from a `LlamaAttentionType`.
115impl From<LlamaAttentionType> for i32 {
116    fn from(value: LlamaAttentionType) -> Self {
117        match value {
118            LlamaAttentionType::Causal => 0,
119            LlamaAttentionType::NonCausal => 1,
120            LlamaAttentionType::Unspecified => -1,
121        }
122    }
123}
124
125/// A rusty wrapper around `ggml_type` for KV cache types.
126#[allow(non_camel_case_types, missing_docs)]
127#[derive(Copy, Clone, Debug, PartialEq, Eq)]
128pub enum KvCacheType {
129    /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
130    /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
131    /// the runtime will operate with that type).
132    /// This variant preserves API compatibility when new `ggml_type` values are
133    /// introduced in the future.
134    Unknown(llama_cpp_sys_2::ggml_type),
135    F32,
136    F16,
137    Q4_0,
138    Q4_1,
139    Q5_0,
140    Q5_1,
141    Q8_0,
142    Q8_1,
143    Q2_K,
144    Q3_K,
145    Q4_K,
146    Q5_K,
147    Q6_K,
148    Q8_K,
149    IQ2_XXS,
150    IQ2_XS,
151    IQ3_XXS,
152    IQ1_S,
153    IQ4_NL,
154    IQ3_S,
155    IQ2_S,
156    IQ4_XS,
157    I8,
158    I16,
159    I32,
160    I64,
161    F64,
162    IQ1_M,
163    BF16,
164    TQ1_0,
165    TQ2_0,
166    MXFP4,
167}
168
169impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
170    fn from(value: KvCacheType) -> Self {
171        match value {
172            KvCacheType::Unknown(raw) => raw,
173            KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
174            KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
175            KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
176            KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
177            KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
178            KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
179            KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
180            KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
181            KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
182            KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
183            KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
184            KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
185            KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
186            KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
187            KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
188            KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
189            KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
190            KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
191            KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
192            KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
193            KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
194            KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
195            KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
196            KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
197            KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
198            KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
199            KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
200            KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
201            KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
202            KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
203            KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
204            KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
205        }
206    }
207}
208
209impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
210    fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
211        match value {
212            x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
213            x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
214            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
215            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
216            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
217            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
218            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
219            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
220            x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
221            x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
222            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
223            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
224            x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
225            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
226            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
227            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
228            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
229            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
230            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
231            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
232            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
233            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
234            x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
235            x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
236            x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
237            x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
238            x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
239            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
240            x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
241            x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
242            x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
243            x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
244            _ => KvCacheType::Unknown(value),
245        }
246    }
247}
248
249/// A safe wrapper around `llama_context_params`.
250///
251/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
252///
253/// # Examples
254///
255/// ```rust
256/// # use std::num::NonZeroU32;
257/// # use llama_cpp_2::context::params::LlamaContextParams;
258///
259/// let ctx_params = LlamaContextParams::default()
260///     .with_n_ctx(NonZeroU32::new(2048));
261///
262/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
263/// ```
264#[derive(Debug, Clone)]
265#[allow(
266    missing_docs,
267    clippy::struct_excessive_bools,
268    clippy::module_name_repetitions
269)]
270pub struct LlamaContextParams {
271    pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
272}
273
274/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
275unsafe impl Send for LlamaContextParams {}
276unsafe impl Sync for LlamaContextParams {}
277
278/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
279/// ```
280/// # use std::num::NonZeroU32;
281/// # use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
282/// let params = LlamaContextParams::default();
283/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
284/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
285/// ```
286impl Default for LlamaContextParams {
287    fn default() -> Self {
288        let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
289        Self { context_params }
290    }
291}