Skip to main content

llama_cpp_4/
ggml.rs

1//! Safe wrappers around core ggml graph computation APIs.
2//!
3//! This module provides the building blocks for creating and executing tensor computation
4//! graphs using ggml backends. It's used for operations like LoRA merging, importance
5//! matrix computation, and control vector generation.
6//!
7//! # Example
8//!
9//! ```rust,ignore
10//! use llama_cpp_4::ggml::*;
11//!
12//! // Create a backend and context
13//! let backend = GgmlBackend::cpu()?;
14//! let mut ctx = GgmlContext::new(1024 * 1024, true)?;
15//!
16//! // Create tensors
17//! let a = ctx.new_tensor_2d(GgmlType::F32, 4, 4);
18//! let b = ctx.new_tensor_2d(GgmlType::F32, 4, 4);
19//!
20//! // Build computation graph
21//! let sum = ctx.add(&a, &b);
22//! let mut graph = ctx.new_graph();
23//! graph.build_forward(&sum);
24//!
25//! // Allocate and compute
26//! let mut alloc = GgmlAllocr::new(&backend);
27//! alloc.alloc_graph(&mut graph);
28//! // ... set tensor data ...
29//! backend.graph_compute(&mut graph);
30//! // ... get results ...
31//! ```
32
33use std::ffi::CStr;
34use std::ptr::NonNull;
35
36/// Re-export the raw ggml types for advanced usage.
37pub use llama_cpp_sys_4::ggml_type;
38
39/// A safe wrapper around `ggml_context`.
40#[derive(Debug)]
41pub struct GgmlContext {
42    ctx: NonNull<llama_cpp_sys_4::ggml_context>,
43}
44
45impl GgmlContext {
46    /// Create a new ggml context.
47    ///
48    /// # Parameters
49    /// - `mem_size`: Memory pool size in bytes for tensor metadata.
50    /// - `no_alloc`: If true, tensor data is not allocated (use with backend allocation).
51    ///
52    /// # Panics
53    ///
54    /// Panics if ggml returns a null pointer.
55    #[must_use]
56    pub fn new(mem_size: usize, no_alloc: bool) -> Self {
57        let params = llama_cpp_sys_4::ggml_init_params {
58            mem_size,
59            mem_buffer: std::ptr::null_mut(),
60            no_alloc,
61        };
62        let ctx = unsafe { llama_cpp_sys_4::ggml_init(params) };
63        Self {
64            ctx: NonNull::new(ctx).expect("ggml_init returned null"),
65        }
66    }
67
68    /// Get the raw context pointer.
69    pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_context {
70        self.ctx.as_ptr()
71    }
72
73    // ── Tensor creation ──────────────────────────────────────
74
75    /// Create a 1D tensor.
76    #[must_use]
77    pub fn new_tensor_1d(&self, typ: ggml_type, ne0: i64) -> GgmlTensor {
78        let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_1d(self.ctx.as_ptr(), typ, ne0) };
79        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_1d returned null"))
80    }
81
82    /// Create a 2D tensor.
83    #[must_use]
84    pub fn new_tensor_2d(&self, typ: ggml_type, ne0: i64, ne1: i64) -> GgmlTensor {
85        let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_2d(self.ctx.as_ptr(), typ, ne0, ne1) };
86        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_2d returned null"))
87    }
88
89    /// Create a 3D tensor.
90    #[must_use]
91    pub fn new_tensor_3d(&self, typ: ggml_type, ne0: i64, ne1: i64, ne2: i64) -> GgmlTensor {
92        let t =
93            unsafe { llama_cpp_sys_4::ggml_new_tensor_3d(self.ctx.as_ptr(), typ, ne0, ne1, ne2) };
94        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_3d returned null"))
95    }
96
97    /// Create a 4D tensor.
98    #[must_use]
99    pub fn new_tensor_4d(
100        &self,
101        typ: ggml_type,
102        ne0: i64,
103        ne1: i64,
104        ne2: i64,
105        ne3: i64,
106    ) -> GgmlTensor {
107        let t = unsafe {
108            llama_cpp_sys_4::ggml_new_tensor_4d(self.ctx.as_ptr(), typ, ne0, ne1, ne2, ne3)
109        };
110        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_4d returned null"))
111    }
112
113    /// Create a tensor with the same shape and type as another.
114    #[must_use]
115    pub fn dup_tensor(&self, src: &GgmlTensor) -> GgmlTensor {
116        let t = unsafe { llama_cpp_sys_4::ggml_dup_tensor(self.ctx.as_ptr(), src.0.as_ptr()) };
117        GgmlTensor(NonNull::new(t).expect("ggml_dup_tensor returned null"))
118    }
119
120    /// Create a new tensor with arbitrary dimensions.
121    #[must_use]
122    pub fn new_tensor(&self, typ: ggml_type, ne: &[i64]) -> GgmlTensor {
123        let t = unsafe {
124            llama_cpp_sys_4::ggml_new_tensor(self.ctx.as_ptr(), typ, ne.len() as i32, ne.as_ptr())
125        };
126        GgmlTensor(NonNull::new(t).expect("ggml_new_tensor returned null"))
127    }
128
129    // ── Tensor operations (build graph nodes) ────────────────
130
131    /// Element-wise addition: `a + b`
132    #[must_use]
133    pub fn add(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
134        let t = unsafe { llama_cpp_sys_4::ggml_add(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
135        GgmlTensor(NonNull::new(t).expect("ggml_add returned null"))
136    }
137
138    /// Matrix multiplication: `a @ b`
139    #[must_use]
140    pub fn mul_mat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
141        let t =
142            unsafe { llama_cpp_sys_4::ggml_mul_mat(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
143        GgmlTensor(NonNull::new(t).expect("ggml_mul_mat returned null"))
144    }
145
146    /// Scale tensor: `a * s`
147    #[must_use]
148    pub fn scale(&self, a: &GgmlTensor, s: f32) -> GgmlTensor {
149        let t = unsafe { llama_cpp_sys_4::ggml_scale(self.ctx.as_ptr(), a.0.as_ptr(), s) };
150        GgmlTensor(NonNull::new(t).expect("ggml_scale returned null"))
151    }
152
153    /// Cast tensor to a different type.
154    #[must_use]
155    pub fn cast(&self, a: &GgmlTensor, typ: ggml_type) -> GgmlTensor {
156        let t = unsafe { llama_cpp_sys_4::ggml_cast(self.ctx.as_ptr(), a.0.as_ptr(), typ) };
157        GgmlTensor(NonNull::new(t).expect("ggml_cast returned null"))
158    }
159
160    /// Make tensor contiguous in memory.
161    #[must_use]
162    pub fn cont(&self, a: &GgmlTensor) -> GgmlTensor {
163        let t = unsafe { llama_cpp_sys_4::ggml_cont(self.ctx.as_ptr(), a.0.as_ptr()) };
164        GgmlTensor(NonNull::new(t).expect("ggml_cont returned null"))
165    }
166
167    /// Transpose a tensor.
168    #[must_use]
169    pub fn transpose(&self, a: &GgmlTensor) -> GgmlTensor {
170        let t = unsafe { llama_cpp_sys_4::ggml_transpose(self.ctx.as_ptr(), a.0.as_ptr()) };
171        GgmlTensor(NonNull::new(t).expect("ggml_transpose returned null"))
172    }
173
174    /// Reshape to 1D.
175    #[must_use]
176    pub fn reshape_1d(&self, a: &GgmlTensor, ne0: i64) -> GgmlTensor {
177        let t = unsafe { llama_cpp_sys_4::ggml_reshape_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0) };
178        GgmlTensor(NonNull::new(t).expect("ggml_reshape_1d returned null"))
179    }
180
181    /// Reshape to 2D.
182    #[must_use]
183    pub fn reshape_2d(&self, a: &GgmlTensor, ne0: i64, ne1: i64) -> GgmlTensor {
184        let t =
185            unsafe { llama_cpp_sys_4::ggml_reshape_2d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, ne1) };
186        GgmlTensor(NonNull::new(t).expect("ggml_reshape_2d returned null"))
187    }
188
189    /// Create a 1D view of a tensor.
190    #[must_use]
191    pub fn view_1d(&self, a: &GgmlTensor, ne0: i64, offset: usize) -> GgmlTensor {
192        let t =
193            unsafe { llama_cpp_sys_4::ggml_view_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, offset) };
194        GgmlTensor(NonNull::new(t).expect("ggml_view_1d returned null"))
195    }
196
197    // ── Graph creation ───────────────────────────────────────
198
199    /// Create a new computation graph.
200    #[must_use]
201    pub fn new_graph(&self) -> GgmlGraph {
202        let g = unsafe { llama_cpp_sys_4::ggml_new_graph(self.ctx.as_ptr()) };
203        GgmlGraph(NonNull::new(g).expect("ggml_new_graph returned null"))
204    }
205
206    // ── Tensor iteration ─────────────────────────────────────
207
208    /// Get the first tensor in this context.
209    #[must_use]
210    pub fn first_tensor(&self) -> Option<GgmlTensor> {
211        let t = unsafe { llama_cpp_sys_4::ggml_get_first_tensor(self.ctx.as_ptr()) };
212        NonNull::new(t).map(GgmlTensor)
213    }
214
215    /// Get the next tensor after `tensor` in this context.
216    #[must_use]
217    pub fn next_tensor(&self, tensor: &GgmlTensor) -> Option<GgmlTensor> {
218        let t =
219            unsafe { llama_cpp_sys_4::ggml_get_next_tensor(self.ctx.as_ptr(), tensor.0.as_ptr()) };
220        NonNull::new(t).map(GgmlTensor)
221    }
222}
223
224impl Drop for GgmlContext {
225    fn drop(&mut self) {
226        unsafe { llama_cpp_sys_4::ggml_free(self.ctx.as_ptr()) }
227    }
228}
229
230// ── Tensor ──────────────────────────────────────────────────
231
232/// A wrapper around a `ggml_tensor` pointer.
233///
234/// Tensors are owned by their `GgmlContext` and must not outlive it.
235/// This wrapper does NOT free the tensor on drop.
236#[derive(Clone, Copy)]
237pub struct GgmlTensor(pub(crate) NonNull<llama_cpp_sys_4::ggml_tensor>);
238
239impl GgmlTensor {
240    /// Get the raw tensor pointer.
241    pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_tensor {
242        self.0.as_ptr()
243    }
244
245    /// Set the tensor's name.
246    pub fn set_name(&self, name: &str) {
247        let c_name = std::ffi::CString::new(name).expect("name contains null bytes");
248        unsafe { llama_cpp_sys_4::ggml_set_name(self.0.as_ptr(), c_name.as_ptr()) };
249    }
250
251    /// Get the number of elements.
252    #[must_use]
253    pub fn nelements(&self) -> i64 {
254        unsafe { llama_cpp_sys_4::ggml_nelements(self.0.as_ptr()) }
255    }
256
257    /// Get the total size in bytes.
258    #[must_use]
259    pub fn nbytes(&self) -> usize {
260        unsafe { llama_cpp_sys_4::ggml_nbytes(self.0.as_ptr()) }
261    }
262
263    /// Get the element size in bytes.
264    #[must_use]
265    pub fn element_size(&self) -> usize {
266        unsafe { llama_cpp_sys_4::ggml_element_size(self.0.as_ptr()) }
267    }
268
269    /// Get the tensor type.
270    #[must_use]
271    pub fn typ(&self) -> ggml_type {
272        unsafe { (*self.0.as_ptr()).type_ }
273    }
274
275    /// Get the tensor dimensions (ne).
276    #[must_use]
277    pub fn ne(&self) -> [i64; 4] {
278        unsafe { (*self.0.as_ptr()).ne }
279    }
280
281    /// Get the tensor name.
282    #[must_use]
283    pub fn name(&self) -> &str {
284        unsafe {
285            let ptr = (*self.0.as_ptr()).name.as_ptr();
286            CStr::from_ptr(ptr).to_str().unwrap_or("")
287        }
288    }
289}
290
291impl std::fmt::Debug for GgmlTensor {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        let ne = self.ne();
294        write!(
295            f,
296            "GgmlTensor({:?}, [{}, {}, {}, {}], {} bytes)",
297            self.name(),
298            ne[0],
299            ne[1],
300            ne[2],
301            ne[3],
302            self.nbytes()
303        )
304    }
305}
306
307// ── Graph ───────────────────────────────────────────────────
308
309/// A wrapper around `ggml_cgraph`.
310///
311/// Graphs are owned by their `GgmlContext` and must not outlive it.
312pub struct GgmlGraph(NonNull<llama_cpp_sys_4::ggml_cgraph>);
313
314impl std::fmt::Debug for GgmlGraph {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("GgmlGraph").finish()
317    }
318}
319
320impl GgmlGraph {
321    /// Get the raw graph pointer.
322    pub fn as_ptr(&mut self) -> *mut llama_cpp_sys_4::ggml_cgraph {
323        self.0.as_ptr()
324    }
325
326    /// Add a tensor and its dependencies to the forward computation graph.
327    pub fn build_forward(&mut self, tensor: &GgmlTensor) {
328        unsafe { llama_cpp_sys_4::ggml_build_forward_expand(self.0.as_ptr(), tensor.0.as_ptr()) }
329    }
330
331    /// Get a node (output tensor) by index. Use -1 for the last node.
332    #[must_use]
333    pub fn node(&mut self, i: i32) -> GgmlTensor {
334        let t = unsafe { llama_cpp_sys_4::ggml_graph_node(self.0.as_ptr(), i) };
335        GgmlTensor(NonNull::new(t).expect("graph_node returned null"))
336    }
337}
338
339// ── Backend ─────────────────────────────────────────────────
340
341/// A safe wrapper around `ggml_backend`.
342#[derive(Debug)]
343pub struct GgmlBackend {
344    backend: llama_cpp_sys_4::ggml_backend_t,
345}
346
347impl GgmlBackend {
348    /// Create a CPU backend.
349    #[must_use]
350    pub fn cpu() -> Self {
351        let backend = unsafe { llama_cpp_sys_4::ggml_backend_cpu_init() };
352        assert!(!backend.is_null(), "ggml_backend_cpu_init returned null");
353        Self { backend }
354    }
355
356    /// Set the number of threads for the CPU backend.
357    pub fn set_n_threads(&self, n_threads: i32) {
358        unsafe { llama_cpp_sys_4::ggml_backend_cpu_set_n_threads(self.backend, n_threads) }
359    }
360
361    /// Allocate all tensors in a context on this backend.
362    ///
363    /// Returns the buffer handle which must be kept alive.
364    pub fn alloc_ctx_tensors(
365        &self,
366        ctx: &GgmlContext,
367    ) -> *mut llama_cpp_sys_4::ggml_backend_buffer {
368        unsafe { llama_cpp_sys_4::ggml_backend_alloc_ctx_tensors(ctx.as_ptr(), self.backend) }
369    }
370
371    /// Compute a graph.
372    pub fn graph_compute(&self, graph: &mut GgmlGraph) {
373        unsafe { llama_cpp_sys_4::ggml_backend_graph_compute(self.backend, graph.as_ptr()) };
374    }
375
376    /// Get the default buffer type for this backend.
377    pub fn default_buffer_type(&self) -> llama_cpp_sys_4::ggml_backend_buffer_type_t {
378        unsafe { llama_cpp_sys_4::ggml_backend_get_default_buffer_type(self.backend) }
379    }
380
381    /// Get the raw backend pointer.
382    pub fn as_ptr(&self) -> llama_cpp_sys_4::ggml_backend_t {
383        self.backend
384    }
385}
386
387impl Drop for GgmlBackend {
388    fn drop(&mut self) {
389        unsafe { llama_cpp_sys_4::ggml_backend_free(self.backend) }
390    }
391}
392
393// ── Graph allocator ─────────────────────────────────────────
394
395/// A safe wrapper around `ggml_gallocr`.
396#[derive(Debug)]
397pub struct GgmlAllocr {
398    alloc: llama_cpp_sys_4::ggml_gallocr_t,
399}
400
401impl GgmlAllocr {
402    /// Create a new graph allocator for the given backend.
403    #[must_use]
404    pub fn new(backend: &GgmlBackend) -> Self {
405        let alloc = unsafe { llama_cpp_sys_4::ggml_gallocr_new(backend.default_buffer_type()) };
406        assert!(!alloc.is_null(), "ggml_gallocr_new returned null");
407        Self { alloc }
408    }
409
410    /// Allocate all tensors in a graph.
411    pub fn alloc_graph(&self, graph: &mut GgmlGraph) -> bool {
412        unsafe { llama_cpp_sys_4::ggml_gallocr_alloc_graph(self.alloc, graph.as_ptr()) }
413    }
414}
415
416impl Drop for GgmlAllocr {
417    fn drop(&mut self) {
418        unsafe { llama_cpp_sys_4::ggml_gallocr_free(self.alloc) }
419    }
420}
421
422// ── Utility functions ───────────────────────────────────────
423
424/// Set tensor data from a byte slice.
425///
426/// # Safety
427///
428/// The tensor must be allocated and the data must be the correct size.
429pub unsafe fn tensor_set(tensor: &GgmlTensor, data: &[u8]) {
430    llama_cpp_sys_4::ggml_backend_tensor_set(
431        tensor.0.as_ptr(),
432        data.as_ptr().cast(),
433        0,
434        data.len(),
435    );
436}
437
438/// Get tensor data into a byte slice.
439///
440/// # Safety
441///
442/// The tensor must be allocated and the buffer must be large enough.
443pub unsafe fn tensor_get(tensor: &GgmlTensor, data: &mut [u8]) {
444    llama_cpp_sys_4::ggml_backend_tensor_get(
445        tensor.0.as_ptr(),
446        data.as_mut_ptr().cast(),
447        0,
448        data.len(),
449    );
450}
451
452/// Free a backend buffer.
453///
454/// # Safety
455///
456/// The buffer must be valid and not already freed.
457pub unsafe fn buffer_free(buffer: *mut llama_cpp_sys_4::ggml_backend_buffer) {
458    llama_cpp_sys_4::ggml_backend_buffer_free(buffer);
459}
460
461/// Get the overhead in bytes for tensor metadata.
462#[must_use]
463pub fn tensor_overhead() -> usize {
464    unsafe { llama_cpp_sys_4::ggml_tensor_overhead() }
465}
466
467/// Get the overhead in bytes for a computation graph.
468#[must_use]
469pub fn graph_overhead() -> usize {
470    unsafe { llama_cpp_sys_4::ggml_graph_overhead() }
471}
472
473/// Check if a type is quantized.
474#[must_use]
475pub fn is_quantized(typ: ggml_type) -> bool {
476    unsafe { llama_cpp_sys_4::ggml_is_quantized(typ) }
477}
478
479/// Get the name of a ggml type.
480#[must_use]
481pub fn type_name(typ: ggml_type) -> &'static str {
482    unsafe {
483        let ptr = llama_cpp_sys_4::ggml_type_name(typ);
484        if ptr.is_null() {
485            "unknown"
486        } else {
487            CStr::from_ptr(ptr).to_str().unwrap_or("unknown")
488        }
489    }
490}