1use std::ffi::CStr;
34use std::ptr::NonNull;
35
36pub use llama_cpp_sys_4::ggml_type;
38
39#[derive(Debug)]
41pub struct GgmlContext {
42 ctx: NonNull<llama_cpp_sys_4::ggml_context>,
43}
44
45impl GgmlContext {
46 #[must_use]
56 pub fn new(mem_size: usize, no_alloc: bool) -> Self {
57 let params = llama_cpp_sys_4::ggml_init_params {
58 mem_size,
59 mem_buffer: std::ptr::null_mut(),
60 no_alloc,
61 };
62 let ctx = unsafe { llama_cpp_sys_4::ggml_init(params) };
63 Self {
64 ctx: NonNull::new(ctx).expect("ggml_init returned null"),
65 }
66 }
67
68 pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_context {
70 self.ctx.as_ptr()
71 }
72
73 #[must_use]
77 pub fn new_tensor_1d(&self, typ: ggml_type, ne0: i64) -> GgmlTensor {
78 let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_1d(self.ctx.as_ptr(), typ, ne0) };
79 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_1d returned null"))
80 }
81
82 #[must_use]
84 pub fn new_tensor_2d(&self, typ: ggml_type, ne0: i64, ne1: i64) -> GgmlTensor {
85 let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_2d(self.ctx.as_ptr(), typ, ne0, ne1) };
86 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_2d returned null"))
87 }
88
89 #[must_use]
91 pub fn new_tensor_3d(&self, typ: ggml_type, ne0: i64, ne1: i64, ne2: i64) -> GgmlTensor {
92 let t =
93 unsafe { llama_cpp_sys_4::ggml_new_tensor_3d(self.ctx.as_ptr(), typ, ne0, ne1, ne2) };
94 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_3d returned null"))
95 }
96
97 #[must_use]
99 pub fn new_tensor_4d(
100 &self,
101 typ: ggml_type,
102 ne0: i64,
103 ne1: i64,
104 ne2: i64,
105 ne3: i64,
106 ) -> GgmlTensor {
107 let t = unsafe {
108 llama_cpp_sys_4::ggml_new_tensor_4d(self.ctx.as_ptr(), typ, ne0, ne1, ne2, ne3)
109 };
110 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_4d returned null"))
111 }
112
113 #[must_use]
115 pub fn dup_tensor(&self, src: &GgmlTensor) -> GgmlTensor {
116 let t = unsafe { llama_cpp_sys_4::ggml_dup_tensor(self.ctx.as_ptr(), src.0.as_ptr()) };
117 GgmlTensor(NonNull::new(t).expect("ggml_dup_tensor returned null"))
118 }
119
120 #[must_use]
122 pub fn new_tensor(&self, typ: ggml_type, ne: &[i64]) -> GgmlTensor {
123 let t = unsafe {
124 llama_cpp_sys_4::ggml_new_tensor(self.ctx.as_ptr(), typ, ne.len() as i32, ne.as_ptr())
125 };
126 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor returned null"))
127 }
128
129 #[must_use]
133 pub fn add(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
134 let t = unsafe { llama_cpp_sys_4::ggml_add(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
135 GgmlTensor(NonNull::new(t).expect("ggml_add returned null"))
136 }
137
138 #[must_use]
140 pub fn mul_mat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
141 let t =
142 unsafe { llama_cpp_sys_4::ggml_mul_mat(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
143 GgmlTensor(NonNull::new(t).expect("ggml_mul_mat returned null"))
144 }
145
146 #[must_use]
148 pub fn scale(&self, a: &GgmlTensor, s: f32) -> GgmlTensor {
149 let t = unsafe { llama_cpp_sys_4::ggml_scale(self.ctx.as_ptr(), a.0.as_ptr(), s) };
150 GgmlTensor(NonNull::new(t).expect("ggml_scale returned null"))
151 }
152
153 #[must_use]
155 pub fn cast(&self, a: &GgmlTensor, typ: ggml_type) -> GgmlTensor {
156 let t = unsafe { llama_cpp_sys_4::ggml_cast(self.ctx.as_ptr(), a.0.as_ptr(), typ) };
157 GgmlTensor(NonNull::new(t).expect("ggml_cast returned null"))
158 }
159
160 #[must_use]
162 pub fn cont(&self, a: &GgmlTensor) -> GgmlTensor {
163 let t = unsafe { llama_cpp_sys_4::ggml_cont(self.ctx.as_ptr(), a.0.as_ptr()) };
164 GgmlTensor(NonNull::new(t).expect("ggml_cont returned null"))
165 }
166
167 #[must_use]
169 pub fn transpose(&self, a: &GgmlTensor) -> GgmlTensor {
170 let t = unsafe { llama_cpp_sys_4::ggml_transpose(self.ctx.as_ptr(), a.0.as_ptr()) };
171 GgmlTensor(NonNull::new(t).expect("ggml_transpose returned null"))
172 }
173
174 #[must_use]
176 pub fn reshape_1d(&self, a: &GgmlTensor, ne0: i64) -> GgmlTensor {
177 let t = unsafe { llama_cpp_sys_4::ggml_reshape_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0) };
178 GgmlTensor(NonNull::new(t).expect("ggml_reshape_1d returned null"))
179 }
180
181 #[must_use]
183 pub fn reshape_2d(&self, a: &GgmlTensor, ne0: i64, ne1: i64) -> GgmlTensor {
184 let t =
185 unsafe { llama_cpp_sys_4::ggml_reshape_2d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, ne1) };
186 GgmlTensor(NonNull::new(t).expect("ggml_reshape_2d returned null"))
187 }
188
189 #[must_use]
191 pub fn view_1d(&self, a: &GgmlTensor, ne0: i64, offset: usize) -> GgmlTensor {
192 let t =
193 unsafe { llama_cpp_sys_4::ggml_view_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, offset) };
194 GgmlTensor(NonNull::new(t).expect("ggml_view_1d returned null"))
195 }
196
197 #[must_use]
201 pub fn new_graph(&self) -> GgmlGraph {
202 let g = unsafe { llama_cpp_sys_4::ggml_new_graph(self.ctx.as_ptr()) };
203 GgmlGraph(NonNull::new(g).expect("ggml_new_graph returned null"))
204 }
205
206 #[must_use]
210 pub fn first_tensor(&self) -> Option<GgmlTensor> {
211 let t = unsafe { llama_cpp_sys_4::ggml_get_first_tensor(self.ctx.as_ptr()) };
212 NonNull::new(t).map(GgmlTensor)
213 }
214
215 #[must_use]
217 pub fn next_tensor(&self, tensor: &GgmlTensor) -> Option<GgmlTensor> {
218 let t =
219 unsafe { llama_cpp_sys_4::ggml_get_next_tensor(self.ctx.as_ptr(), tensor.0.as_ptr()) };
220 NonNull::new(t).map(GgmlTensor)
221 }
222}
223
224impl Drop for GgmlContext {
225 fn drop(&mut self) {
226 unsafe { llama_cpp_sys_4::ggml_free(self.ctx.as_ptr()) }
227 }
228}
229
230#[derive(Clone, Copy)]
237pub struct GgmlTensor(pub(crate) NonNull<llama_cpp_sys_4::ggml_tensor>);
238
239impl GgmlTensor {
240 pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_tensor {
242 self.0.as_ptr()
243 }
244
245 pub fn set_name(&self, name: &str) {
247 let c_name = std::ffi::CString::new(name).expect("name contains null bytes");
248 unsafe { llama_cpp_sys_4::ggml_set_name(self.0.as_ptr(), c_name.as_ptr()) };
249 }
250
251 #[must_use]
253 pub fn nelements(&self) -> i64 {
254 unsafe { llama_cpp_sys_4::ggml_nelements(self.0.as_ptr()) }
255 }
256
257 #[must_use]
259 pub fn nbytes(&self) -> usize {
260 unsafe { llama_cpp_sys_4::ggml_nbytes(self.0.as_ptr()) }
261 }
262
263 #[must_use]
265 pub fn element_size(&self) -> usize {
266 unsafe { llama_cpp_sys_4::ggml_element_size(self.0.as_ptr()) }
267 }
268
269 #[must_use]
271 pub fn typ(&self) -> ggml_type {
272 unsafe { (*self.0.as_ptr()).type_ }
273 }
274
275 #[must_use]
277 pub fn ne(&self) -> [i64; 4] {
278 unsafe { (*self.0.as_ptr()).ne }
279 }
280
281 #[must_use]
283 pub fn name(&self) -> &str {
284 unsafe {
285 let ptr = (*self.0.as_ptr()).name.as_ptr();
286 CStr::from_ptr(ptr).to_str().unwrap_or("")
287 }
288 }
289}
290
291impl std::fmt::Debug for GgmlTensor {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 let ne = self.ne();
294 write!(
295 f,
296 "GgmlTensor({:?}, [{}, {}, {}, {}], {} bytes)",
297 self.name(),
298 ne[0],
299 ne[1],
300 ne[2],
301 ne[3],
302 self.nbytes()
303 )
304 }
305}
306
307pub struct GgmlGraph(NonNull<llama_cpp_sys_4::ggml_cgraph>);
313
314impl std::fmt::Debug for GgmlGraph {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 f.debug_struct("GgmlGraph").finish()
317 }
318}
319
320impl GgmlGraph {
321 pub fn as_ptr(&mut self) -> *mut llama_cpp_sys_4::ggml_cgraph {
323 self.0.as_ptr()
324 }
325
326 pub fn build_forward(&mut self, tensor: &GgmlTensor) {
328 unsafe { llama_cpp_sys_4::ggml_build_forward_expand(self.0.as_ptr(), tensor.0.as_ptr()) }
329 }
330
331 #[must_use]
333 pub fn node(&mut self, i: i32) -> GgmlTensor {
334 let t = unsafe { llama_cpp_sys_4::ggml_graph_node(self.0.as_ptr(), i) };
335 GgmlTensor(NonNull::new(t).expect("graph_node returned null"))
336 }
337}
338
339#[derive(Debug)]
343pub struct GgmlBackend {
344 backend: llama_cpp_sys_4::ggml_backend_t,
345}
346
347impl GgmlBackend {
348 #[must_use]
350 pub fn cpu() -> Self {
351 let backend = unsafe { llama_cpp_sys_4::ggml_backend_cpu_init() };
352 assert!(!backend.is_null(), "ggml_backend_cpu_init returned null");
353 Self { backend }
354 }
355
356 pub fn set_n_threads(&self, n_threads: i32) {
358 unsafe { llama_cpp_sys_4::ggml_backend_cpu_set_n_threads(self.backend, n_threads) }
359 }
360
361 pub fn alloc_ctx_tensors(
365 &self,
366 ctx: &GgmlContext,
367 ) -> *mut llama_cpp_sys_4::ggml_backend_buffer {
368 unsafe { llama_cpp_sys_4::ggml_backend_alloc_ctx_tensors(ctx.as_ptr(), self.backend) }
369 }
370
371 pub fn graph_compute(&self, graph: &mut GgmlGraph) {
373 unsafe { llama_cpp_sys_4::ggml_backend_graph_compute(self.backend, graph.as_ptr()) };
374 }
375
376 pub fn default_buffer_type(&self) -> llama_cpp_sys_4::ggml_backend_buffer_type_t {
378 unsafe { llama_cpp_sys_4::ggml_backend_get_default_buffer_type(self.backend) }
379 }
380
381 pub fn as_ptr(&self) -> llama_cpp_sys_4::ggml_backend_t {
383 self.backend
384 }
385}
386
387impl Drop for GgmlBackend {
388 fn drop(&mut self) {
389 unsafe { llama_cpp_sys_4::ggml_backend_free(self.backend) }
390 }
391}
392
393#[derive(Debug)]
397pub struct GgmlAllocr {
398 alloc: llama_cpp_sys_4::ggml_gallocr_t,
399}
400
401impl GgmlAllocr {
402 #[must_use]
404 pub fn new(backend: &GgmlBackend) -> Self {
405 let alloc = unsafe { llama_cpp_sys_4::ggml_gallocr_new(backend.default_buffer_type()) };
406 assert!(!alloc.is_null(), "ggml_gallocr_new returned null");
407 Self { alloc }
408 }
409
410 pub fn alloc_graph(&self, graph: &mut GgmlGraph) -> bool {
412 unsafe { llama_cpp_sys_4::ggml_gallocr_alloc_graph(self.alloc, graph.as_ptr()) }
413 }
414}
415
416impl Drop for GgmlAllocr {
417 fn drop(&mut self) {
418 unsafe { llama_cpp_sys_4::ggml_gallocr_free(self.alloc) }
419 }
420}
421
422pub unsafe fn tensor_set(tensor: &GgmlTensor, data: &[u8]) {
430 llama_cpp_sys_4::ggml_backend_tensor_set(
431 tensor.0.as_ptr(),
432 data.as_ptr().cast(),
433 0,
434 data.len(),
435 );
436}
437
438pub unsafe fn tensor_get(tensor: &GgmlTensor, data: &mut [u8]) {
444 llama_cpp_sys_4::ggml_backend_tensor_get(
445 tensor.0.as_ptr(),
446 data.as_mut_ptr().cast(),
447 0,
448 data.len(),
449 );
450}
451
452pub unsafe fn buffer_free(buffer: *mut llama_cpp_sys_4::ggml_backend_buffer) {
458 llama_cpp_sys_4::ggml_backend_buffer_free(buffer);
459}
460
461#[must_use]
463pub fn tensor_overhead() -> usize {
464 unsafe { llama_cpp_sys_4::ggml_tensor_overhead() }
465}
466
467#[must_use]
469pub fn graph_overhead() -> usize {
470 unsafe { llama_cpp_sys_4::ggml_graph_overhead() }
471}
472
473#[must_use]
475pub fn is_quantized(typ: ggml_type) -> bool {
476 unsafe { llama_cpp_sys_4::ggml_is_quantized(typ) }
477}
478
479#[must_use]
481pub fn type_name(typ: ggml_type) -> &'static str {
482 unsafe {
483 let ptr = llama_cpp_sys_4::ggml_type_name(typ);
484 if ptr.is_null() {
485 "unknown"
486 } else {
487 CStr::from_ptr(ptr).to_str().unwrap_or("unknown")
488 }
489 }
490}