rusty_ggml/
context.rs

1use std::{
2    ffi::c_void,
3    ops,
4    ptr::NonNull,
5    sync::{
6        atomic::{self, AtomicBool},
7        Arc, Mutex, MutexGuard,
8    },
9};
10
11use anyhow::{anyhow, bail, ensure, Result};
12use thiserror::Error;
13
14use ggml_sys_bleedingedge as gg;
15
16use crate::{dims::*, gtensor::GTensor, util::GType, validation::*};
17
18#[derive(Debug, Error, Clone)]
19pub enum GContextError {
20    #[error("Attempt to set invalid scratch buffer id {0}")]
21    InvalidScratchBufferId(usize),
22
23    #[error("Failed to lock context mutex")]
24    MutexFailure,
25
26    #[error("Context is deceased: {0}")]
27    DeadContext(Arc<anyhow::Error>),
28
29    #[error("Unknown error (likely mutex acquisition failure)")]
30    Unknown,
31
32    // FIXME: Add the other fields to the message.
33    #[error("Not enough memory for request {0:?}")]
34    InsufficientMemory(GMemoryRequest),
35
36    // FIXME: Allow including more detail about what went wrong.
37    #[error("Could not create tensor")]
38    TensorCreationFailed,
39
40    #[error("Attempt to access data in or compute with a no_alloc context")]
41    NoAlloc,
42
43    #[error("General error: {0}")]
44    General(Arc<anyhow::Error>),
45}
46
47pub(crate) struct IContext {
48    // Pointer to the GGML context.
49    pub(crate) gctx: NonNull<gg::ggml_context>,
50
51    pub(crate) context_memory: usize,
52    // Amount of context memory currently used.
53    pub(crate) context_used: usize,
54
55    // List of scratch buffers. Only dropped when the `IContext` is
56    // finally freed.
57    pub(crate) scratch_buffers: Vec<ScratchBuffer>,
58
59    // The current scratch buffer if set.
60    pub(crate) current_scratch_buffer: Option<usize>,
61
62    // Populated if an error occurred during some previous
63    // operation.
64    pub(crate) failed: Option<Arc<anyhow::Error>>,
65}
66
67// FIXME: YOLO? It's an internal struct and only lives in an Arc.
68unsafe impl Send for IContext {}
69
70impl Drop for IContext {
71    // Since `IContext` lives inside an `Arc` this will only happen
72    // when the very last instance of the `Arc` is dropped.
73    fn drop(&mut self) {
74        unsafe { gg::ggml_free(self.gctx.as_ptr()) }
75    }
76}
77
78impl ops::Deref for IContext {
79    type Target = NonNull<gg::ggml_context>;
80
81    fn deref(&self) -> &Self::Target {
82        &self.gctx
83    }
84}
85
86impl ops::DerefMut for IContext {
87    fn deref_mut(&mut self) -> &mut Self::Target {
88        &mut self.gctx
89    }
90}
91
92impl IContext {
93    pub(crate) unsafe fn gptr(&self) -> *mut gg::ggml_context {
94        self.gctx.as_ptr()
95    }
96
97    pub(crate) fn update_used_memory(&mut self, mr: &GMemoryRequest) -> Result<()> {
98        let mut mr = *mr;
99        ensure!(
100            mr.required_scratch == 0 || self.current_scratch_buffer == mr.current_scratch_buffer,
101            "Scratch buffer mismatch in IContext::use_memory. Current: {:?}, expected: {:?}",
102            self.current_scratch_buffer,
103            mr.current_scratch_buffer,
104        );
105        match mr.reqtype {
106            GMemoryRequestType::Tensor { .. } => (),
107            wut => bail!("Request type {wut:?} currently not implemented in IContext::use_memory"),
108        }
109        let new_ctx_used = self.context_used + mr.required_ctx;
110        if new_ctx_used > self.context_memory {
111            mr.fits = false;
112            bail!(GContextError::InsufficientMemory(mr));
113        }
114        if let Some(bufid) = &mr.current_scratch_buffer {
115            let buf = &mut self.scratch_buffers[*bufid];
116            let new_scratch_used = buf.used + mr.required_scratch;
117            if new_scratch_used > buf.buf.len() {
118                println!(
119                    "MEM(scratch): {new_scratch_used} > {} -- {mr:?}",
120                    buf.buf.len()
121                );
122                mr.fits = false;
123                bail!(GContextError::InsufficientMemory(mr));
124            }
125            buf.used = new_scratch_used;
126        }
127        self.context_used = new_ctx_used;
128        Ok(())
129    }
130}
131
132#[derive(Clone)]
133pub struct GContext {
134    // This is just used to validate that operations for objects containing a
135    // a context (i.e. tensors) have the same context. It is never actually
136    // used as a pointer or updated after the context is created.
137    pub(crate) ptrval: usize,
138
139    // Amount of context memory allocated (in bytes)
140    #[allow(dead_code)]
141    pub(crate) context_size: usize,
142
143    #[allow(dead_code)]
144    pub(crate) no_alloc: bool,
145
146    // This atomic is used to mark the context as dead. Ideally we could
147    // mark it in the `ictx` field, but one failure condition is failing to
148    // acquire the mutex: in that case all we can do is mark the context as
149    // dead using this field.
150    pub(crate) dead: Arc<AtomicBool>,
151
152    // The real context structure which contains a pointer to the actual
153    // GGML context.
154    pub(crate) ictx: Arc<Mutex<IContext>>,
155}
156
157/// GGML scratch buffer structure used for temporary data storage.
158pub struct ScratchBuffer {
159    pub(crate) buf: Box<[u8]>,
160    pub(crate) used: usize,
161}
162
163impl ScratchBuffer {
164    /// Create a new scratch buffer with the specified size (in bytes).
165    pub fn new(size: usize) -> Self {
166        let mut data: Vec<u8> = Vec::with_capacity(size);
167        #[allow(clippy::uninit_vec)]
168        unsafe {
169            data.set_len(size);
170        }
171        Self {
172            buf: data.into_boxed_slice(),
173            used: 0,
174        }
175    }
176}
177
178#[derive(Default)]
179/// GGML context builder structure used to build a
180/// [GContext].
181pub struct GContextBuilder {
182    mem_size: usize,
183    no_alloc: bool,
184}
185
186// FIXME: We probably should use the typestate pattern in here to make sure
187// people don't do something silly like build an alloc context with 0 memory.
188impl GContextBuilder {
189    /// Create a new [GContextBuilder].
190    pub fn new() -> Self {
191        Self::default()
192    }
193
194    /// Set the GGML context size.
195    pub fn mem_size(mut self, mem_size: usize) -> Self {
196        self.mem_size = mem_size;
197        self
198    }
199
200    /// When a context is in no_alloc mode, apparently this
201    /// means that data for tensors is not allocated at all
202    /// so all requests will succeed (althoug the context
203    /// needs to have at least enough memory for the GGML
204    /// objects).
205    ///
206    /// Naturally, you can't actually execute the graph when
207    /// this is turned on, so the main use is to probe how much
208    /// memory building a graph will require and then use that
209    /// value to create another context with the expected memory
210    /// size and then populate that with the actual tensor data.
211    pub fn no_alloc(mut self, no_alloc: bool) -> Self {
212        self.no_alloc = no_alloc;
213        self
214    }
215
216    /// Build a GGML context ([GContext]) based on the
217    /// builder's configuration.
218    pub fn build(self) -> Result<GContext> {
219        let ptr = unsafe {
220            gg::ggml_init(gg::ggml_init_params {
221                mem_size: self.mem_size,
222                mem_buffer: std::ptr::null_mut(),
223                no_alloc: self.no_alloc,
224            })
225        };
226        ensure!(!ptr.is_null(), "GGML init failed");
227        Ok(GContext {
228            context_size: self.mem_size,
229            no_alloc: self.no_alloc,
230            ptrval: ptr as usize,
231            ictx: Arc::new(Mutex::new(IContext {
232                gctx: NonNull::new(ptr).unwrap(),
233                context_used: 0,
234                context_memory: self.mem_size,
235                scratch_buffers: vec![],
236                current_scratch_buffer: None,
237                failed: None,
238            })),
239            dead: Arc::new(AtomicBool::new(false)),
240        })
241    }
242}
243
244impl GContext {
245    pub(crate) fn with_icontext<OUT, F>(&self, fun: F) -> Result<OUT>
246    where
247        F: FnOnce(&GContext, MutexGuard<IContext>) -> Result<OUT>,
248    {
249        let failed = self.dead.load(atomic::Ordering::SeqCst);
250        let ictx = self
251            .ictx
252            .lock()
253            .map_err(|_e| anyhow!(GContextError::MutexFailure))?;
254        if let Some(e) = ictx.failed.clone() {
255            bail!(GContextError::DeadContext(e));
256        }
257        if failed {
258            bail!(GContextError::Unknown)
259        } else {
260            fun(self, ictx)
261        }
262    }
263
264    // FIXME: This logic seems kind of weird. Same problem in `Tensor::with_tensor_infallible`.
265    pub(crate) fn with_icontext_infallible<OUT, F>(&self, fun: F) -> Result<OUT>
266    where
267        F: FnOnce(MutexGuard<IContext>) -> OUT,
268    {
269        let failed = self.dead.load(atomic::Ordering::SeqCst);
270        let mut ctx = self.ictx.lock().map_err(|_e| {
271            self.dead.store(true, atomic::Ordering::SeqCst);
272            GContextError::MutexFailure
273        })?;
274        if let Some(e) = ctx.failed.clone() {
275            bail!(GContextError::DeadContext(e));
276        }
277        // This might look weird but the idea is that we might have failed previously
278        // due to being unable to acquire the mutex. Since we didn't have the mutex,
279        // naturally it was impossible to set the `failed` field inside the `IContext`
280        // structure.
281        // There probably still is a race condition here but it should be very unlikely.
282        if failed {
283            let e = GContextError::Unknown;
284            ctx.failed = Some(Arc::new(anyhow::Error::new(e.clone())));
285            Err(e)?;
286        }
287        Ok(fun(ctx))
288    }
289
290    pub(crate) fn delay_failure_with_icontext<OUT, DF, F>(&self, dfun: DF, fun: F) -> OUT
291    where
292        DF: Fn() -> OUT,
293        F: FnOnce(&mut IContext) -> Result<OUT>,
294    {
295        self.with_icontext_infallible(|mut ictx| {
296            fun(&mut ictx).unwrap_or_else(|e| {
297                // We have the context mutex but the handler function returned
298                // an error condition. So store the error in the context and mark it as dead.
299                self.dead.store(true, atomic::Ordering::SeqCst);
300                ictx.failed = Some(Arc::new(e));
301                dfun()
302            })
303        })
304        .unwrap_or_else(|_e| {
305            // We couldn't get the context mutex, but we can still mark the context as dead.
306            self.dead.store(true, atomic::Ordering::SeqCst);
307            dfun()
308        })
309    }
310
311    pub fn estimate_tensor_size<const DIMS: usize>(
312        &self,
313        typ: GType,
314        shape: [usize; DIMS],
315    ) -> Result<GMemoryRequest> {
316        self.with_icontext(|ctx, ictx| {
317            Ok(GMemoryRequest::estimate_tensor_request_ictx(
318                ctx, &ictx, typ, shape,
319            ))
320        })
321    }
322
323    /// Create a new tensor with the specified [type](GType) and shape.
324    ///
325    /// This uses const generics to determine the new tensor's dimensions. The tensor dimensions
326    /// will be equal to the number of items in the `shape` array.
327    pub fn tensor<const DIMS: usize>(
328        &self,
329        typ: GType,
330        shape: [usize; DIMS],
331    ) -> Result<GTensor<DIMS>>
332    where
333        Dim<DIMS>: DimValid,
334        DimPair<DIMS, 4>: DimLt,
335    {
336        self.with_icontext(|ctx, mut ictx| {
337            let mr = GMemoryRequest::estimate_tensor_request_ictx(self, &ictx, typ, shape);
338            mr.fit_or_die()?;
339
340            unsafe {
341                let p = match DIMS {
342                    1 => gg::ggml_new_tensor_1d(ictx.gptr(), typ as u32, shape[0] as i64),
343                    2 => gg::ggml_new_tensor_2d(
344                        ictx.gptr(),
345                        typ as u32,
346                        shape[1] as i64,
347                        shape[0] as i64,
348                    ),
349                    3 => gg::ggml_new_tensor_3d(
350                        ictx.gptr(),
351                        typ as u32,
352                        shape[1] as i64,
353                        shape[0] as i64,
354                        shape[2] as i64,
355                    ),
356                    _ => unreachable!(),
357                };
358
359                if p.is_null() {
360                    Err(GContextError::TensorCreationFailed)?;
361                }
362                GTensor::new_from_ptr(ctx, &mut ictx, (mr, p))
363            }
364        })
365    }
366
367    /// Register a scratch buffer. The return value is the scratch buffer id
368    /// which can be used with [Self::set_scratch_buffer].
369    pub fn register_scratch_buffer(&mut self, buf: ScratchBuffer) -> Result<usize> {
370        self.with_icontext_infallible(|mut ictx| {
371            let bufid = ictx.scratch_buffers.len();
372            ictx.scratch_buffers.push(buf);
373            bufid
374        })
375    }
376
377    /// Set or clear the current scratch buffer. A valid id as returned by
378    /// [Self::register_scratch_buffer] must be supplied.
379    ///
380    /// **Note**: Scratch buffers cannot be removed directly and are only freed
381    /// when the [GContext] structure is dropped.
382    pub fn set_scratch_buffer(&self, maybebufid: Option<usize>) -> Result<()> {
383        self.with_icontext(|_ctx, mut ictx| {
384            let (size, data) = if let Some(bufid) = maybebufid {
385                if bufid >= ictx.scratch_buffers.len() {
386                    Err(GContextError::InvalidScratchBufferId(bufid))?;
387                }
388                ictx.current_scratch_buffer = maybebufid;
389                let buf = &mut ictx.scratch_buffers[bufid].buf;
390                (buf.len(), buf.as_mut_ptr() as *mut c_void)
391            } else {
392                (0, std::ptr::null_mut())
393            };
394            unsafe {
395                gg::ggml_set_scratch(
396                    ictx.gptr(),
397                    gg::ggml_scratch {
398                        offs: 0,
399                        size,
400                        data,
401                    },
402                );
403            }
404            Ok(())
405        })
406    }
407
408    /// Runs the supplied graph using this context.
409    pub fn compute(&self, graph: &mut GGraph) -> Result<()> {
410        ensure!(!self.no_alloc, GContextError::NoAlloc);
411        let n_threads = graph.n_threads;
412        self.with_icontext_infallible(|ictx| unsafe {
413            gg::ggml_graph_compute_with_ctx(ictx.gptr(), &mut graph.graph, n_threads as i32)
414        })
415    }
416
417    /// Returns the amount of memory GGML is currently using.
418    pub fn used_mem(&self) -> Result<usize> {
419        self.with_icontext_infallible(|ictx| unsafe { gg::ggml_used_mem(ictx.gptr()) })
420    }
421}
422
423pub struct GGraph {
424    n_threads: usize,
425    graph: gg::ggml_cgraph,
426}
427
428impl GGraph {
429    /// Create a new computation graph with the specified number of threads.
430    pub fn new(n_threads: usize) -> Self {
431        let graph = unsafe { std::mem::zeroed::<gg::ggml_cgraph>() };
432        Self { n_threads, graph }
433    }
434
435    /// Register a tensor to be processed when the graph is computed.
436    pub fn build_forward_expand<const DIMS: usize, T: AsRef<GTensor<DIMS>>>(
437        &mut self,
438        tensor: T,
439    ) -> Result<()>
440    where
441        Dim<DIMS>: DimValid,
442    {
443        // FIXME: Should we bail out here if no_alloc?
444        tensor
445            .as_ref()
446            .with_tensor_infallible(|_ctx, _ictx, tptr| unsafe {
447                gg::ggml_build_forward_expand(&mut self.graph, tptr)
448            })
449    }
450}