llama_cpp_2/
llama_backend.rs1use crate::LLamaCppError;
4use llama_cpp_sys_2::ggml_log_level;
5use std::sync::atomic::AtomicBool;
6use std::sync::atomic::Ordering::SeqCst;
7
8#[derive(Eq, PartialEq, Debug)]
12pub struct LlamaBackend {}
13
14static LLAMA_BACKEND_INITIALIZED: AtomicBool = AtomicBool::new(false);
15
16impl LlamaBackend {
17 fn mark_init() -> crate::Result<()> {
19 match LLAMA_BACKEND_INITIALIZED.compare_exchange(false, true, SeqCst, SeqCst) {
20 Ok(_) => Ok(()),
21 Err(_) => Err(LLamaCppError::BackendAlreadyInitialized),
22 }
23 }
24
25 #[tracing::instrument(skip_all)]
45 pub fn init() -> crate::Result<LlamaBackend> {
46 Self::mark_init()?;
47 unsafe { llama_cpp_sys_2::llama_backend_init() }
48 Ok(LlamaBackend {})
49 }
50
51 #[tracing::instrument(skip_all)]
65 pub fn init_numa(strategy: NumaStrategy) -> crate::Result<LlamaBackend> {
66 Self::mark_init()?;
67 unsafe {
68 llama_cpp_sys_2::llama_numa_init(llama_cpp_sys_2::ggml_numa_strategy::from(strategy));
69 }
70 Ok(LlamaBackend {})
71 }
72
73 pub fn supports_gpu_offload(&self) -> bool {
75 unsafe { llama_cpp_sys_2::llama_supports_gpu_offload() }
76 }
77
78 pub fn supports_mmap(&self) -> bool {
80 unsafe { llama_cpp_sys_2::llama_supports_mmap() }
81 }
82
83 pub fn supports_mlock(&self) -> bool {
85 unsafe { llama_cpp_sys_2::llama_supports_mlock() }
86 }
87
88 pub fn void_logs(&mut self) {
90 unsafe extern "C" fn void_log(
91 _level: ggml_log_level,
92 _text: *const ::std::os::raw::c_char,
93 _user_data: *mut ::std::os::raw::c_void,
94 ) {
95 }
96
97 unsafe {
98 llama_cpp_sys_2::llama_log_set(Some(void_log), std::ptr::null_mut());
99 }
100 }
101}
102
103#[derive(Debug, Eq, PartialEq, Copy, Clone)]
105pub enum NumaStrategy {
106 DISABLED,
108 DISTRIBUTE,
110 ISOLATE,
112 NUMACTL,
114 MIRROR,
116 COUNT,
118}
119
120#[derive(Debug, Eq, PartialEq, Copy, Clone)]
122pub struct InvalidNumaStrategy(
123 pub llama_cpp_sys_2::ggml_numa_strategy,
125);
126
127impl TryFrom<llama_cpp_sys_2::ggml_numa_strategy> for NumaStrategy {
128 type Error = InvalidNumaStrategy;
129
130 fn try_from(value: llama_cpp_sys_2::ggml_numa_strategy) -> Result<Self, Self::Error> {
131 match value {
132 llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED => Ok(Self::DISABLED),
133 llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE => Ok(Self::DISTRIBUTE),
134 llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE => Ok(Self::ISOLATE),
135 llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL => Ok(Self::NUMACTL),
136 llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR => Ok(Self::MIRROR),
137 llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT => Ok(Self::COUNT),
138 value => Err(InvalidNumaStrategy(value)),
139 }
140 }
141}
142
143impl From<NumaStrategy> for llama_cpp_sys_2::ggml_numa_strategy {
144 fn from(value: NumaStrategy) -> Self {
145 match value {
146 NumaStrategy::DISABLED => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISABLED,
147 NumaStrategy::DISTRIBUTE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_DISTRIBUTE,
148 NumaStrategy::ISOLATE => llama_cpp_sys_2::GGML_NUMA_STRATEGY_ISOLATE,
149 NumaStrategy::NUMACTL => llama_cpp_sys_2::GGML_NUMA_STRATEGY_NUMACTL,
150 NumaStrategy::MIRROR => llama_cpp_sys_2::GGML_NUMA_STRATEGY_MIRROR,
151 NumaStrategy::COUNT => llama_cpp_sys_2::GGML_NUMA_STRATEGY_COUNT,
152 }
153 }
154}
155
156impl Drop for LlamaBackend {
172 fn drop(&mut self) {
173 match LLAMA_BACKEND_INITIALIZED.compare_exchange(true, false, SeqCst, SeqCst) {
174 Ok(_) => {}
175 Err(_) => {
176 unreachable!("This should not be reachable as the only ways to obtain a llama backend involve marking the backend as initialized.")
177 }
178 }
179 unsafe { llama_cpp_sys_2::llama_backend_free() }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn numa_from_and_to() {
189 let numas = [
190 NumaStrategy::DISABLED,
191 NumaStrategy::DISTRIBUTE,
192 NumaStrategy::ISOLATE,
193 NumaStrategy::NUMACTL,
194 NumaStrategy::MIRROR,
195 NumaStrategy::COUNT,
196 ];
197
198 for numa in &numas {
199 let from = llama_cpp_sys_2::ggml_numa_strategy::from(*numa);
200 let to = NumaStrategy::try_from(from).expect("Failed to convert from and to");
201 assert_eq!(*numa, to);
202 }
203 }
204
205 #[test]
206 fn check_invalid_numa() {
207 let invalid = 800;
208 let invalid = NumaStrategy::try_from(invalid);
209 assert_eq!(invalid, Err(InvalidNumaStrategy(invalid.unwrap_err().0)));
210 }
211}