1use std::ffi::CStr;
34use std::ptr::NonNull;
35
36pub use llama_cpp_sys_4::ggml_type;
38
39#[derive(Debug)]
41pub struct GgmlContext {
42 ctx: NonNull<llama_cpp_sys_4::ggml_context>,
43}
44
45impl GgmlContext {
46 #[must_use]
56 pub fn new(mem_size: usize, no_alloc: bool) -> Self {
57 let params = llama_cpp_sys_4::ggml_init_params {
58 mem_size,
59 mem_buffer: std::ptr::null_mut(),
60 no_alloc,
61 };
62 let ctx = unsafe { llama_cpp_sys_4::ggml_init(params) };
63 Self {
64 ctx: NonNull::new(ctx).expect("ggml_init returned null"),
65 }
66 }
67
68 pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_context {
70 self.ctx.as_ptr()
71 }
72
73 #[must_use]
77 pub fn new_tensor_1d(&self, typ: ggml_type, ne0: i64) -> GgmlTensor {
78 let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_1d(self.ctx.as_ptr(), typ, ne0) };
79 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_1d returned null"))
80 }
81
82 #[must_use]
84 pub fn new_tensor_2d(&self, typ: ggml_type, ne0: i64, ne1: i64) -> GgmlTensor {
85 let t = unsafe { llama_cpp_sys_4::ggml_new_tensor_2d(self.ctx.as_ptr(), typ, ne0, ne1) };
86 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_2d returned null"))
87 }
88
89 #[must_use]
91 pub fn new_tensor_3d(&self, typ: ggml_type, ne0: i64, ne1: i64, ne2: i64) -> GgmlTensor {
92 let t =
93 unsafe { llama_cpp_sys_4::ggml_new_tensor_3d(self.ctx.as_ptr(), typ, ne0, ne1, ne2) };
94 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_3d returned null"))
95 }
96
97 #[must_use]
99 pub fn new_tensor_4d(
100 &self,
101 typ: ggml_type,
102 ne0: i64,
103 ne1: i64,
104 ne2: i64,
105 ne3: i64,
106 ) -> GgmlTensor {
107 let t = unsafe {
108 llama_cpp_sys_4::ggml_new_tensor_4d(self.ctx.as_ptr(), typ, ne0, ne1, ne2, ne3)
109 };
110 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor_4d returned null"))
111 }
112
113 #[must_use]
115 pub fn dup_tensor(&self, src: &GgmlTensor) -> GgmlTensor {
116 let t = unsafe { llama_cpp_sys_4::ggml_dup_tensor(self.ctx.as_ptr(), src.0.as_ptr()) };
117 GgmlTensor(NonNull::new(t).expect("ggml_dup_tensor returned null"))
118 }
119
120 #[must_use]
122 pub fn new_tensor(&self, typ: ggml_type, ne: &[i64]) -> GgmlTensor {
123 let t = unsafe {
124 llama_cpp_sys_4::ggml_new_tensor(
125 self.ctx.as_ptr(),
126 typ,
127 ne.len() as i32,
128 ne.as_ptr(),
129 )
130 };
131 GgmlTensor(NonNull::new(t).expect("ggml_new_tensor returned null"))
132 }
133
134 #[must_use]
138 pub fn add(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
139 let t =
140 unsafe { llama_cpp_sys_4::ggml_add(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr()) };
141 GgmlTensor(NonNull::new(t).expect("ggml_add returned null"))
142 }
143
144 #[must_use]
146 pub fn mul_mat(&self, a: &GgmlTensor, b: &GgmlTensor) -> GgmlTensor {
147 let t = unsafe {
148 llama_cpp_sys_4::ggml_mul_mat(self.ctx.as_ptr(), a.0.as_ptr(), b.0.as_ptr())
149 };
150 GgmlTensor(NonNull::new(t).expect("ggml_mul_mat returned null"))
151 }
152
153 #[must_use]
155 pub fn scale(&self, a: &GgmlTensor, s: f32) -> GgmlTensor {
156 let t = unsafe { llama_cpp_sys_4::ggml_scale(self.ctx.as_ptr(), a.0.as_ptr(), s) };
157 GgmlTensor(NonNull::new(t).expect("ggml_scale returned null"))
158 }
159
160 #[must_use]
162 pub fn cast(&self, a: &GgmlTensor, typ: ggml_type) -> GgmlTensor {
163 let t = unsafe { llama_cpp_sys_4::ggml_cast(self.ctx.as_ptr(), a.0.as_ptr(), typ) };
164 GgmlTensor(NonNull::new(t).expect("ggml_cast returned null"))
165 }
166
167 #[must_use]
169 pub fn cont(&self, a: &GgmlTensor) -> GgmlTensor {
170 let t = unsafe { llama_cpp_sys_4::ggml_cont(self.ctx.as_ptr(), a.0.as_ptr()) };
171 GgmlTensor(NonNull::new(t).expect("ggml_cont returned null"))
172 }
173
174 #[must_use]
176 pub fn transpose(&self, a: &GgmlTensor) -> GgmlTensor {
177 let t = unsafe { llama_cpp_sys_4::ggml_transpose(self.ctx.as_ptr(), a.0.as_ptr()) };
178 GgmlTensor(NonNull::new(t).expect("ggml_transpose returned null"))
179 }
180
181 #[must_use]
183 pub fn reshape_1d(&self, a: &GgmlTensor, ne0: i64) -> GgmlTensor {
184 let t = unsafe { llama_cpp_sys_4::ggml_reshape_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0) };
185 GgmlTensor(NonNull::new(t).expect("ggml_reshape_1d returned null"))
186 }
187
188 #[must_use]
190 pub fn reshape_2d(&self, a: &GgmlTensor, ne0: i64, ne1: i64) -> GgmlTensor {
191 let t =
192 unsafe { llama_cpp_sys_4::ggml_reshape_2d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, ne1) };
193 GgmlTensor(NonNull::new(t).expect("ggml_reshape_2d returned null"))
194 }
195
196 #[must_use]
198 pub fn view_1d(&self, a: &GgmlTensor, ne0: i64, offset: usize) -> GgmlTensor {
199 let t = unsafe {
200 llama_cpp_sys_4::ggml_view_1d(self.ctx.as_ptr(), a.0.as_ptr(), ne0, offset)
201 };
202 GgmlTensor(NonNull::new(t).expect("ggml_view_1d returned null"))
203 }
204
205 #[must_use]
209 pub fn new_graph(&self) -> GgmlGraph {
210 let g = unsafe { llama_cpp_sys_4::ggml_new_graph(self.ctx.as_ptr()) };
211 GgmlGraph(NonNull::new(g).expect("ggml_new_graph returned null"))
212 }
213
214 #[must_use]
218 pub fn first_tensor(&self) -> Option<GgmlTensor> {
219 let t = unsafe { llama_cpp_sys_4::ggml_get_first_tensor(self.ctx.as_ptr()) };
220 NonNull::new(t).map(GgmlTensor)
221 }
222
223 #[must_use]
225 pub fn next_tensor(&self, tensor: &GgmlTensor) -> Option<GgmlTensor> {
226 let t =
227 unsafe { llama_cpp_sys_4::ggml_get_next_tensor(self.ctx.as_ptr(), tensor.0.as_ptr()) };
228 NonNull::new(t).map(GgmlTensor)
229 }
230}
231
232impl Drop for GgmlContext {
233 fn drop(&mut self) {
234 unsafe { llama_cpp_sys_4::ggml_free(self.ctx.as_ptr()) }
235 }
236}
237
238#[derive(Clone, Copy)]
245pub struct GgmlTensor(pub(crate) NonNull<llama_cpp_sys_4::ggml_tensor>);
246
247impl GgmlTensor {
248 pub fn as_ptr(&self) -> *mut llama_cpp_sys_4::ggml_tensor {
250 self.0.as_ptr()
251 }
252
253 pub fn set_name(&self, name: &str) {
255 let c_name = std::ffi::CString::new(name).expect("name contains null bytes");
256 unsafe { llama_cpp_sys_4::ggml_set_name(self.0.as_ptr(), c_name.as_ptr()) };
257 }
258
259 #[must_use]
261 pub fn nelements(&self) -> i64 {
262 unsafe { llama_cpp_sys_4::ggml_nelements(self.0.as_ptr()) }
263 }
264
265 #[must_use]
267 pub fn nbytes(&self) -> usize {
268 unsafe { llama_cpp_sys_4::ggml_nbytes(self.0.as_ptr()) }
269 }
270
271 #[must_use]
273 pub fn element_size(&self) -> usize {
274 unsafe { llama_cpp_sys_4::ggml_element_size(self.0.as_ptr()) }
275 }
276
277 #[must_use]
279 pub fn typ(&self) -> ggml_type {
280 unsafe { (*self.0.as_ptr()).type_ }
281 }
282
283 #[must_use]
285 pub fn ne(&self) -> [i64; 4] {
286 unsafe { (*self.0.as_ptr()).ne }
287 }
288
289 #[must_use]
291 pub fn name(&self) -> &str {
292 unsafe {
293 let ptr = (*self.0.as_ptr()).name.as_ptr();
294 CStr::from_ptr(ptr).to_str().unwrap_or("")
295 }
296 }
297}
298
299impl std::fmt::Debug for GgmlTensor {
300 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
301 let ne = self.ne();
302 write!(
303 f,
304 "GgmlTensor({:?}, [{}, {}, {}, {}], {} bytes)",
305 self.name(),
306 ne[0],
307 ne[1],
308 ne[2],
309 ne[3],
310 self.nbytes()
311 )
312 }
313}
314
315pub struct GgmlGraph(NonNull<llama_cpp_sys_4::ggml_cgraph>);
321
322impl std::fmt::Debug for GgmlGraph {
323 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324 f.debug_struct("GgmlGraph").finish()
325 }
326}
327
328impl GgmlGraph {
329 pub fn as_ptr(&mut self) -> *mut llama_cpp_sys_4::ggml_cgraph {
331 self.0.as_ptr()
332 }
333
334 pub fn build_forward(&mut self, tensor: &GgmlTensor) {
336 unsafe { llama_cpp_sys_4::ggml_build_forward_expand(self.0.as_ptr(), tensor.0.as_ptr()) }
337 }
338
339 #[must_use]
341 pub fn node(&mut self, i: i32) -> GgmlTensor {
342 let t = unsafe { llama_cpp_sys_4::ggml_graph_node(self.0.as_ptr(), i) };
343 GgmlTensor(NonNull::new(t).expect("graph_node returned null"))
344 }
345}
346
347#[derive(Debug)]
351pub struct GgmlBackend {
352 backend: llama_cpp_sys_4::ggml_backend_t,
353}
354
355impl GgmlBackend {
356 #[must_use]
358 pub fn cpu() -> Self {
359 let backend = unsafe { llama_cpp_sys_4::ggml_backend_cpu_init() };
360 assert!(!backend.is_null(), "ggml_backend_cpu_init returned null");
361 Self { backend }
362 }
363
364 pub fn set_n_threads(&self, n_threads: i32) {
366 unsafe { llama_cpp_sys_4::ggml_backend_cpu_set_n_threads(self.backend, n_threads) }
367 }
368
369 pub fn alloc_ctx_tensors(
373 &self,
374 ctx: &GgmlContext,
375 ) -> *mut llama_cpp_sys_4::ggml_backend_buffer {
376 unsafe { llama_cpp_sys_4::ggml_backend_alloc_ctx_tensors(ctx.as_ptr(), self.backend) }
377 }
378
379 pub fn graph_compute(&self, graph: &mut GgmlGraph) {
381 unsafe { llama_cpp_sys_4::ggml_backend_graph_compute(self.backend, graph.as_ptr()) };
382 }
383
384 pub fn default_buffer_type(&self) -> llama_cpp_sys_4::ggml_backend_buffer_type_t {
386 unsafe { llama_cpp_sys_4::ggml_backend_get_default_buffer_type(self.backend) }
387 }
388
389 pub fn as_ptr(&self) -> llama_cpp_sys_4::ggml_backend_t {
391 self.backend
392 }
393}
394
395impl Drop for GgmlBackend {
396 fn drop(&mut self) {
397 unsafe { llama_cpp_sys_4::ggml_backend_free(self.backend) }
398 }
399}
400
401#[derive(Debug)]
405pub struct GgmlAllocr {
406 alloc: llama_cpp_sys_4::ggml_gallocr_t,
407}
408
409impl GgmlAllocr {
410 #[must_use]
412 pub fn new(backend: &GgmlBackend) -> Self {
413 let alloc = unsafe { llama_cpp_sys_4::ggml_gallocr_new(backend.default_buffer_type()) };
414 assert!(!alloc.is_null(), "ggml_gallocr_new returned null");
415 Self { alloc }
416 }
417
418 pub fn alloc_graph(&self, graph: &mut GgmlGraph) -> bool {
420 unsafe { llama_cpp_sys_4::ggml_gallocr_alloc_graph(self.alloc, graph.as_ptr()) }
421 }
422}
423
424impl Drop for GgmlAllocr {
425 fn drop(&mut self) {
426 unsafe { llama_cpp_sys_4::ggml_gallocr_free(self.alloc) }
427 }
428}
429
430pub unsafe fn tensor_set(tensor: &GgmlTensor, data: &[u8]) {
438 llama_cpp_sys_4::ggml_backend_tensor_set(
439 tensor.0.as_ptr(),
440 data.as_ptr().cast(),
441 0,
442 data.len(),
443 );
444}
445
446pub unsafe fn tensor_get(tensor: &GgmlTensor, data: &mut [u8]) {
452 llama_cpp_sys_4::ggml_backend_tensor_get(
453 tensor.0.as_ptr(),
454 data.as_mut_ptr().cast(),
455 0,
456 data.len(),
457 );
458}
459
460pub unsafe fn buffer_free(buffer: *mut llama_cpp_sys_4::ggml_backend_buffer) {
466 llama_cpp_sys_4::ggml_backend_buffer_free(buffer);
467}
468
469#[must_use]
471pub fn tensor_overhead() -> usize {
472 unsafe { llama_cpp_sys_4::ggml_tensor_overhead() }
473}
474
475#[must_use]
477pub fn graph_overhead() -> usize {
478 unsafe { llama_cpp_sys_4::ggml_graph_overhead() }
479}
480
481#[must_use]
483pub fn is_quantized(typ: ggml_type) -> bool {
484 unsafe { llama_cpp_sys_4::ggml_is_quantized(typ) }
485}
486
487#[must_use]
489pub fn type_name(typ: ggml_type) -> &'static str {
490 unsafe {
491 let ptr = llama_cpp_sys_4::ggml_type_name(typ);
492 if ptr.is_null() {
493 "unknown"
494 } else {
495 CStr::from_ptr(ptr).to_str().unwrap_or("unknown")
496 }
497 }
498}