Skip to main content

llama_cpp_4/
ggml.rs

1//! Safe wrappers around core ggml graph computation APIs.
2//!
3//! This module provides the building blocks for creating and executing tensor computation
4//! graphs using ggml backends. It's used for operations like LoRA merging, importance
5//! matrix computation, and control vector generation.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use llama_cpp_4::ggml::*;
11//!
12//! // Create a backend and context
13//! let backend = GgmlBackend::cpu()?;
14//! let mut ctx = GgmlContext::new(1024 * 1024, true)?;
15//!
16//! // Create tensors
17//! let a = ctx.new_tensor_2d(GgmlType::F32, 4, 4);
18//! let b = ctx.new_tensor_2d(GgmlType::F32, 4, 4);
19//!
20//! // Build computation graph
21//! let sum = ctx.add(&a, &b);
22//! let mut graph = ctx.new_graph();
23//! graph.build_forward(&sum);
24//!
25//! // Allocate and compute
26//! let mut alloc = GgmlAllocr::new(&backend);
27//! alloc.alloc_graph(&mut graph);
28//! // ... set tensor data ...
29//! backend.graph_compute(&mut graph);
30//! // ... get results ...
31//! ```
32
33use std::ffi::CStr;
34use std::ptr::NonNull;
35
36/// Re-export the raw ggml types for advanced usage.
37pub use llama_cpp_sys_4::ggml_type;
38
39/// A safe wrapper around `ggml_context`.
40#[derive(Debug)]
41pub struct GgmlContext {
42    ctx: NonNull<llama_cpp_sys_4::ggml_context>,
43}
44
45impl GgmlContext {
46    /// Create a new ggml context.
47    ///
48    /// # Parameters
49    /// - `mem_size`: Memory pool size in bytes for tensor metadata.
50    /// - `no_alloc`: If true, tensor data is not allocated (use with backend allocation).
51    ///
52    /// # Panics
53    ///
54    /// Panics if ggml returns a null pointer.
55    #[must_use]
56    pub fn new(mem_size: usize, no_alloc: bool) -> Self {
57        let params = llama_cpp_sys_4::ggml_init_params {
58            mem_size,
59            mem_buffer: std::ptr::null_mut(),
60            no_alloc,
61        };
62        let ctx = unsafe { llama_cpp_sys_4::ggml_init(params) };
63        Self {
64            ctx: NonNull::new(ctx).expect("ggml_init returned null"),
65        }
66    }
67
68    /// Get the raw context pointer.
69    pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_context {
70        self.ctx.as_ptr()
71    }
72
73    // ── Tensor creation ──────────────────────────────────────
74
75    /// Create a 1D tensor.
76    #[must_use]
77    pub fn new_tensor_1d(&self, typ: ggml_type, ne0: i64) -> GgmlTensor {
78        let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_1d(self.ctx.as_ptr(), typ, ne0) };
79        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_1d returned null"))
80    }
81
82    /// Create a 2D tensor.
83    #[must_use]
84    pub fn new_tensor_2d(&self, typ: ggml_type, ne0: i64, ne1: i64) -> GgmlTensor {
85        let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_2d(self.ctx.as_ptr(), typ, ne0, ne1) };
86        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_2d returned null"))
87    }
88
89    /// Create a 3D tensor.
90    #[must_use]
91    pub fn new_tensor_3d(&self, typ: ggml_type, ne0: i64, ne1: i64, ne2: i64) -> GgmlTensor {
92        let t =
93            unsafe { llama_cpp_sys_4::ggml_new_tensor_3d(self.ctx.as_ptr(), typ, ne0, ne1, ne2) };
94        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_3d returned null"))
95    }
96
97    /// Create a 4D tensor.
98    #[must_use]
99    pub fn new_tensor_4d(
100        &self,
101        typ: ggml_type,
102        ne0: i64,
103        ne1: i64,
104        ne2: i64,
105        ne3: i64,
106    ) -> GgmlTensor {
107        let t = unsafe {
108            llama_cpp_sys_4::ggml_new_tensor_4d(self.ctx.as_ptr(), typ, ne0, ne1, ne2, ne3)
109        };
110        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_4d returned null"))
111    }
112
113    /// Create a tensor with the same shape and type as another.
114    #[must_use]
115    pub fn dup_tensor(&self, src: &GgmlTensor) -> GgmlTensor {
116        let t = unsafe { llama_cpp_sys_4::ggml_dup_tensor(self.ctx.as_ptr(), src.0.as_ptr()) };
117        GgmlTensor(NonNull::new(t).expect("ggml_dup_tensor returned null"))
118    }
119
120    /// Create a new tensor with arbitrary dimensions.
121    #[must_use]
122    pub fn new_tensor(&self, typ: ggml_type, ne: &[i64]) -> GgmlTensor {
123        let t = unsafe {
124            llama_cpp_sys_4::ggml_new_tensor(
125                self.ctx.as_ptr(),
126                typ,
127                ne.len() as i32,
128                ne.as_ptr(),
129            )
130        };
131        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor returned null"))
132    }
133
134    // ── Tensor operations (build graph nodes) ────────────────
135
136    /// Element-wise addition: `a + b`
137    #[must_use]
138    pub fn add(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
139        let t =
140            unsafe { llama_cpp_sys_4::ggml_add(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
141        GgmlTensor(NonNull::new(t).expect("ggml_add returned null"))
142    }
143
144    /// Matrix multiplication: `a @ b`
145    #[must_use]
146    pub fn mul_mat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
147        let t = unsafe {
148            llama_cpp_sys_4::ggml_mul_mat(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr())
149        };
150        GgmlTensor(NonNull::new(t).expect("ggml_mul_mat returned null"))
151    }
152
153    /// Scale tensor: `a * s`
154    #[must_use]
155    pub fn scale(&self, a: &GgmlTensor, s: f32) -> GgmlTensor {
156        let t = unsafe { llama_cpp_sys_4::ggml_scale(self.ctx.as_ptr(), a.0.as_ptr(), s) };
157        GgmlTensor(NonNull::new(t).expect("ggml_scale returned null"))
158    }
159
160    /// Cast tensor to a different type.
161    #[must_use]
162    pub fn cast(&self, a: &GgmlTensor, typ: ggml_type) -> GgmlTensor {
163        let t = unsafe { llama_cpp_sys_4::ggml_cast(self.ctx.as_ptr(), a.0.as_ptr(), typ) };
164        GgmlTensor(NonNull::new(t).expect("ggml_cast returned null"))
165    }
166
167    /// Make tensor contiguous in memory.
168    #[must_use]
169    pub fn cont(&self, a: &GgmlTensor) -> GgmlTensor {
170        let t = unsafe { llama_cpp_sys_4::ggml_cont(self.ctx.as_ptr(), a.0.as_ptr()) };
171        GgmlTensor(NonNull::new(t).expect("ggml_cont returned null"))
172    }
173
174    /// Transpose a tensor.
175    #[must_use]
176    pub fn transpose(&self, a: &GgmlTensor) -> GgmlTensor {
177        let t = unsafe { llama_cpp_sys_4::ggml_transpose(self.ctx.as_ptr(), a.0.as_ptr()) };
178        GgmlTensor(NonNull::new(t).expect("ggml_transpose returned null"))
179    }
180
181    /// Reshape to 1D.
182    #[must_use]
183    pub fn reshape_1d(&self, a: &GgmlTensor, ne0: i64) -> GgmlTensor {
184        let t = unsafe { llama_cpp_sys_4::ggml_reshape_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0) };
185        GgmlTensor(NonNull::new(t).expect("ggml_reshape_1d returned null"))
186    }
187
188    /// Reshape to 2D.
189    #[must_use]
190    pub fn reshape_2d(&self, a: &GgmlTensor, ne0: i64, ne1: i64) -> GgmlTensor {
191        let t =
192            unsafe { llama_cpp_sys_4::ggml_reshape_2d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, ne1) };
193        GgmlTensor(NonNull::new(t).expect("ggml_reshape_2d returned null"))
194    }
195
196    /// Create a 1D view of a tensor.
197    #[must_use]
198    pub fn view_1d(&self, a: &GgmlTensor, ne0: i64, offset: usize) -> GgmlTensor {
199        let t = unsafe {
200            llama_cpp_sys_4::ggml_view_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, offset)
201        };
202        GgmlTensor(NonNull::new(t).expect("ggml_view_1d returned null"))
203    }
204
205    // ── Graph creation ───────────────────────────────────────
206
207    /// Create a new computation graph.
208    #[must_use]
209    pub fn new_graph(&self) -> GgmlGraph {
210        let g = unsafe { llama_cpp_sys_4::ggml_new_graph(self.ctx.as_ptr()) };
211        GgmlGraph(NonNull::new(g).expect("ggml_new_graph returned null"))
212    }
213
214    // ── Tensor iteration ─────────────────────────────────────
215
216    /// Get the first tensor in this context.
217    #[must_use]
218    pub fn first_tensor(&self) -> Option<GgmlTensor> {
219        let t = unsafe { llama_cpp_sys_4::ggml_get_first_tensor(self.ctx.as_ptr()) };
220        NonNull::new(t).map(GgmlTensor)
221    }
222
223    /// Get the next tensor after `tensor` in this context.
224    #[must_use]
225    pub fn next_tensor(&self, tensor: &GgmlTensor) -> Option<GgmlTensor> {
226        let t =
227            unsafe { llama_cpp_sys_4::ggml_get_next_tensor(self.ctx.as_ptr(), tensor.0.as_ptr()) };
228        NonNull::new(t).map(GgmlTensor)
229    }
230}
231
232impl Drop for GgmlContext {
233    fn drop(&mut self) {
234        unsafe { llama_cpp_sys_4::ggml_free(self.ctx.as_ptr()) }
235    }
236}
237
238// ── Tensor ──────────────────────────────────────────────────
239
240/// A wrapper around a `ggml_tensor` pointer.
241///
242/// Tensors are owned by their `GgmlContext` and must not outlive it.
243/// This wrapper does NOT free the tensor on drop.
244#[derive(Clone, Copy)]
245pub struct GgmlTensor(pub(crate) NonNull<llama_cpp_sys_4::ggml_tensor>);
246
247impl GgmlTensor {
248    /// Get the raw tensor pointer.
249    pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_tensor {
250        self.0.as_ptr()
251    }
252
253    /// Set the tensor's name.
254    pub fn set_name(&self, name: &str) {
255        let c_name = std::ffi::CString::new(name).expect("name contains null bytes");
256        unsafe { llama_cpp_sys_4::ggml_set_name(self.0.as_ptr(), c_name.as_ptr()) };
257    }
258
259    /// Get the number of elements.
260    #[must_use]
261    pub fn nelements(&self) -> i64 {
262        unsafe { llama_cpp_sys_4::ggml_nelements(self.0.as_ptr()) }
263    }
264
265    /// Get the total size in bytes.
266    #[must_use]
267    pub fn nbytes(&self) -> usize {
268        unsafe { llama_cpp_sys_4::ggml_nbytes(self.0.as_ptr()) }
269    }
270
271    /// Get the element size in bytes.
272    #[must_use]
273    pub fn element_size(&self) -> usize {
274        unsafe { llama_cpp_sys_4::ggml_element_size(self.0.as_ptr()) }
275    }
276
277    /// Get the tensor type.
278    #[must_use]
279    pub fn typ(&self) -> ggml_type {
280        unsafe { (*self.0.as_ptr()).type_ }
281    }
282
283    /// Get the tensor dimensions (ne).
284    #[must_use]
285    pub fn ne(&self) -> [i64; 4] {
286        unsafe { (*self.0.as_ptr()).ne }
287    }
288
289    /// Get the tensor name.
290    #[must_use]
291    pub fn name(&self) -> &str {
292        unsafe {
293            let ptr = (*self.0.as_ptr()).name.as_ptr();
294            CStr::from_ptr(ptr).to_str().unwrap_or("")
295        }
296    }
297}
298
299impl std::fmt::Debug for GgmlTensor {
300    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301        let ne = self.ne();
302        write!(
303            f,
304            "GgmlTensor({:?}, [{}, {}, {}, {}], {} bytes)",
305            self.name(),
306            ne[0],
307            ne[1],
308            ne[2],
309            ne[3],
310            self.nbytes()
311        )
312    }
313}
314
315// ── Graph ───────────────────────────────────────────────────
316
317/// A wrapper around `ggml_cgraph`.
318///
319/// Graphs are owned by their `GgmlContext` and must not outlive it.
320pub struct GgmlGraph(NonNull<llama_cpp_sys_4::ggml_cgraph>);
321
322impl std::fmt::Debug for GgmlGraph {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        f.debug_struct("GgmlGraph").finish()
325    }
326}
327
328impl GgmlGraph {
329    /// Get the raw graph pointer.
330    pub fn as_ptr(&mut self) -> *mut llama_cpp_sys_4::ggml_cgraph {
331        self.0.as_ptr()
332    }
333
334    /// Add a tensor and its dependencies to the forward computation graph.
335    pub fn build_forward(&mut self, tensor: &GgmlTensor) {
336        unsafe { llama_cpp_sys_4::ggml_build_forward_expand(self.0.as_ptr(), tensor.0.as_ptr()) }
337    }
338
339    /// Get a node (output tensor) by index. Use -1 for the last node.
340    #[must_use]
341    pub fn node(&mut self, i: i32) -> GgmlTensor {
342        let t = unsafe { llama_cpp_sys_4::ggml_graph_node(self.0.as_ptr(), i) };
343        GgmlTensor(NonNull::new(t).expect("graph_node returned null"))
344    }
345}
346
347// ── Backend ─────────────────────────────────────────────────
348
349/// A safe wrapper around `ggml_backend`.
350#[derive(Debug)]
351pub struct GgmlBackend {
352    backend: llama_cpp_sys_4::ggml_backend_t,
353}
354
355impl GgmlBackend {
356    /// Create a CPU backend.
357    #[must_use]
358    pub fn cpu() -> Self {
359        let backend = unsafe { llama_cpp_sys_4::ggml_backend_cpu_init() };
360        assert!(!backend.is_null(), "ggml_backend_cpu_init returned null");
361        Self { backend }
362    }
363
364    /// Set the number of threads for the CPU backend.
365    pub fn set_n_threads(&self, n_threads: i32) {
366        unsafe { llama_cpp_sys_4::ggml_backend_cpu_set_n_threads(self.backend, n_threads) }
367    }
368
369    /// Allocate all tensors in a context on this backend.
370    ///
371    /// Returns the buffer handle which must be kept alive.
372    pub fn alloc_ctx_tensors(
373        &self,
374        ctx: &GgmlContext,
375    ) -> *mut llama_cpp_sys_4::ggml_backend_buffer {
376        unsafe { llama_cpp_sys_4::ggml_backend_alloc_ctx_tensors(ctx.as_ptr(), self.backend) }
377    }
378
379    /// Compute a graph.
380    pub fn graph_compute(&self, graph: &mut GgmlGraph) {
381        unsafe { llama_cpp_sys_4::ggml_backend_graph_compute(self.backend, graph.as_ptr()) };
382    }
383
384    /// Get the default buffer type for this backend.
385    pub fn default_buffer_type(&self) -> llama_cpp_sys_4::ggml_backend_buffer_type_t {
386        unsafe { llama_cpp_sys_4::ggml_backend_get_default_buffer_type(self.backend) }
387    }
388
389    /// Get the raw backend pointer.
390    pub fn as_ptr(&self) -> llama_cpp_sys_4::ggml_backend_t {
391        self.backend
392    }
393}
394
395impl Drop for GgmlBackend {
396    fn drop(&mut self) {
397        unsafe { llama_cpp_sys_4::ggml_backend_free(self.backend) }
398    }
399}
400
401// ── Graph allocator ─────────────────────────────────────────
402
403/// A safe wrapper around `ggml_gallocr`.
404#[derive(Debug)]
405pub struct GgmlAllocr {
406    alloc: llama_cpp_sys_4::ggml_gallocr_t,
407}
408
409impl GgmlAllocr {
410    /// Create a new graph allocator for the given backend.
411    #[must_use]
412    pub fn new(backend: &GgmlBackend) -> Self {
413        let alloc = unsafe { llama_cpp_sys_4::ggml_gallocr_new(backend.default_buffer_type()) };
414        assert!(!alloc.is_null(), "ggml_gallocr_new returned null");
415        Self { alloc }
416    }
417
418    /// Allocate all tensors in a graph.
419    pub fn alloc_graph(&self, graph: &mut GgmlGraph) -> bool {
420        unsafe { llama_cpp_sys_4::ggml_gallocr_alloc_graph(self.alloc, graph.as_ptr()) }
421    }
422}
423
424impl Drop for GgmlAllocr {
425    fn drop(&mut self) {
426        unsafe { llama_cpp_sys_4::ggml_gallocr_free(self.alloc) }
427    }
428}
429
430// ── Utility functions ───────────────────────────────────────
431
432/// Set tensor data from a byte slice.
433///
434/// # Safety
435///
436/// The tensor must be allocated and the data must be the correct size.
437pub unsafe fn tensor_set(tensor: &GgmlTensor, data: &[u8]) {
438    llama_cpp_sys_4::ggml_backend_tensor_set(
439        tensor.0.as_ptr(),
440        data.as_ptr().cast(),
441        0,
442        data.len(),
443    );
444}
445
446/// Get tensor data into a byte slice.
447///
448/// # Safety
449///
450/// The tensor must be allocated and the buffer must be large enough.
451pub unsafe fn tensor_get(tensor: &GgmlTensor, data: &mut [u8]) {
452    llama_cpp_sys_4::ggml_backend_tensor_get(
453        tensor.0.as_ptr(),
454        data.as_mut_ptr().cast(),
455        0,
456        data.len(),
457    );
458}
459
460/// Free a backend buffer.
461///
462/// # Safety
463///
464/// The buffer must be valid and not already freed.
465pub unsafe fn buffer_free(buffer: *mut llama_cpp_sys_4::ggml_backend_buffer) {
466    llama_cpp_sys_4::ggml_backend_buffer_free(buffer);
467}
468
469/// Get the overhead in bytes for tensor metadata.
470#[must_use]
471pub fn tensor_overhead() -> usize {
472    unsafe { llama_cpp_sys_4::ggml_tensor_overhead() }
473}
474
475/// Get the overhead in bytes for a computation graph.
476#[must_use]
477pub fn graph_overhead() -> usize {
478    unsafe { llama_cpp_sys_4::ggml_graph_overhead() }
479}
480
481/// Check if a type is quantized.
482#[must_use]
483pub fn is_quantized(typ: ggml_type) -> bool {
484    unsafe { llama_cpp_sys_4::ggml_is_quantized(typ) }
485}
486
487/// Get the name of a ggml type.
488#[must_use]
489pub fn type_name(typ: ggml_type) -> &'static str {
490    unsafe {
491        let ptr = llama_cpp_sys_4::ggml_type_name(typ);
492        if ptr.is_null() {
493            "unknown"
494        } else {
495            CStr::from_ptr(ptr).to_str().unwrap_or("unknown")
496        }
497    }
498}