1use std::{
2 ffi::c_void,
3 ops,
4 ptr::NonNull,
5 sync::{
6 atomic::{self, AtomicBool},
7 Arc, Mutex, MutexGuard,
8 },
9};
10
11use anyhow::{anyhow, bail, ensure, Result};
12use thiserror::Error;
13
14use ggml_sys_bleedingedge as gg;
15
16use crate::{dims::*, gtensor::GTensor, util::GType, validation::*};
17
18#[derive(Debug, Error, Clone)]
19pub enum GContextError {
20 #[error("Attempt to set invalid scratch buffer id {0}")]
21 InvalidScratchBufferId(usize),
22
23 #[error("Failed to lock context mutex")]
24 MutexFailure,
25
26 #[error("Context is deceased: {0}")]
27 DeadContext(Arc<anyhow::Error>),
28
29 #[error("Unknown error (likely mutex acquisition failure)")]
30 Unknown,
31
32 #[error("Not enough memory for request {0:?}")]
34 InsufficientMemory(GMemoryRequest),
35
36 #[error("Could not create tensor")]
38 TensorCreationFailed,
39
40 #[error("Attempt to access data in or compute with a no_alloc context")]
41 NoAlloc,
42
43 #[error("General error: {0}")]
44 General(Arc<anyhow::Error>),
45}
46
47pub(crate) struct IContext {
48 pub(crate) gctx: NonNull<gg::ggml_context>,
50
51 pub(crate) context_memory: usize,
52 pub(crate) context_used: usize,
54
55 pub(crate) scratch_buffers: Vec<ScratchBuffer>,
58
59 pub(crate) current_scratch_buffer: Option<usize>,
61
62 pub(crate) failed: Option<Arc<anyhow::Error>>,
65}
66
67unsafe impl Send for IContext {}
69
70impl Drop for IContext {
71 fn drop(&mut self) {
74 unsafe { gg::ggml_free(self.gctx.as_ptr()) }
75 }
76}
77
78impl ops::Deref for IContext {
79 type Target = NonNull<gg::ggml_context>;
80
81 fn deref(&self) -> &Self::Target {
82 &self.gctx
83 }
84}
85
86impl ops::DerefMut for IContext {
87 fn deref_mut(&mut self) -> &mut Self::Target {
88 &mut self.gctx
89 }
90}
91
92impl IContext {
93 pub(crate) unsafe fn gptr(&self) -> *mut gg::ggml_context {
94 self.gctx.as_ptr()
95 }
96
97 pub(crate) fn update_used_memory(&mut self, mr: &GMemoryRequest) -> Result<()> {
98 let mut mr = *mr;
99 ensure!(
100 mr.required_scratch == 0 || self.current_scratch_buffer == mr.current_scratch_buffer,
101 "Scratch buffer mismatch in IContext::use_memory. Current: {:?}, expected: {:?}",
102 self.current_scratch_buffer,
103 mr.current_scratch_buffer,
104 );
105 match mr.reqtype {
106 GMemoryRequestType::Tensor { .. } => (),
107 wut => bail!("Request type {wut:?} currently not implemented in IContext::use_memory"),
108 }
109 let new_ctx_used = self.context_used + mr.required_ctx;
110 if new_ctx_used > self.context_memory {
111 mr.fits = false;
112 bail!(GContextError::InsufficientMemory(mr));
113 }
114 if let Some(bufid) = &mr.current_scratch_buffer {
115 let buf = &mut self.scratch_buffers[*bufid];
116 let new_scratch_used = buf.used + mr.required_scratch;
117 if new_scratch_used > buf.buf.len() {
118 println!(
119 "MEM(scratch): {new_scratch_used} > {} -- {mr:?}",
120 buf.buf.len()
121 );
122 mr.fits = false;
123 bail!(GContextError::InsufficientMemory(mr));
124 }
125 buf.used = new_scratch_used;
126 }
127 self.context_used = new_ctx_used;
128 Ok(())
129 }
130}
131
132#[derive(Clone)]
133pub struct GContext {
134 pub(crate) ptrval: usize,
138
139 #[allow(dead_code)]
141 pub(crate) context_size: usize,
142
143 #[allow(dead_code)]
144 pub(crate) no_alloc: bool,
145
146 pub(crate) dead: Arc<AtomicBool>,
151
152 pub(crate) ictx: Arc<Mutex<IContext>>,
155}
156
157pub struct ScratchBuffer {
159 pub(crate) buf: Box<[u8]>,
160 pub(crate) used: usize,
161}
162
163impl ScratchBuffer {
164 pub fn new(size: usize) -> Self {
166 let mut data: Vec<u8> = Vec::with_capacity(size);
167 #[allow(clippy::uninit_vec)]
168 unsafe {
169 data.set_len(size);
170 }
171 Self {
172 buf: data.into_boxed_slice(),
173 used: 0,
174 }
175 }
176}
177
178#[derive(Default)]
179pub struct GContextBuilder {
182 mem_size: usize,
183 no_alloc: bool,
184}
185
186impl GContextBuilder {
189 pub fn new() -> Self {
191 Self::default()
192 }
193
194 pub fn mem_size(mut self, mem_size: usize) -> Self {
196 self.mem_size = mem_size;
197 self
198 }
199
200 pub fn no_alloc(mut self, no_alloc: bool) -> Self {
212 self.no_alloc = no_alloc;
213 self
214 }
215
216 pub fn build(self) -> Result<GContext> {
219 let ptr = unsafe {
220 gg::ggml_init(gg::ggml_init_params {
221 mem_size: self.mem_size,
222 mem_buffer: std::ptr::null_mut(),
223 no_alloc: self.no_alloc,
224 })
225 };
226 ensure!(!ptr.is_null(), "GGML init failed");
227 Ok(GContext {
228 context_size: self.mem_size,
229 no_alloc: self.no_alloc,
230 ptrval: ptr as usize,
231 ictx: Arc::new(Mutex::new(IContext {
232 gctx: NonNull::new(ptr).unwrap(),
233 context_used: 0,
234 context_memory: self.mem_size,
235 scratch_buffers: vec![],
236 current_scratch_buffer: None,
237 failed: None,
238 })),
239 dead: Arc::new(AtomicBool::new(false)),
240 })
241 }
242}
243
244impl GContext {
245 pub(crate) fn with_icontext<OUT, F>(&self, fun: F) -> Result<OUT>
246 where
247 F: FnOnce(&GContext, MutexGuard<IContext>) -> Result<OUT>,
248 {
249 let failed = self.dead.load(atomic::Ordering::SeqCst);
250 let ictx = self
251 .ictx
252 .lock()
253 .map_err(|_e| anyhow!(GContextError::MutexFailure))?;
254 if let Some(e) = ictx.failed.clone() {
255 bail!(GContextError::DeadContext(e));
256 }
257 if failed {
258 bail!(GContextError::Unknown)
259 } else {
260 fun(self, ictx)
261 }
262 }
263
264 pub(crate) fn with_icontext_infallible<OUT, F>(&self, fun: F) -> Result<OUT>
266 where
267 F: FnOnce(MutexGuard<IContext>) -> OUT,
268 {
269 let failed = self.dead.load(atomic::Ordering::SeqCst);
270 let mut ctx = self.ictx.lock().map_err(|_e| {
271 self.dead.store(true, atomic::Ordering::SeqCst);
272 GContextError::MutexFailure
273 })?;
274 if let Some(e) = ctx.failed.clone() {
275 bail!(GContextError::DeadContext(e));
276 }
277 if failed {
283 let e = GContextError::Unknown;
284 ctx.failed = Some(Arc::new(anyhow::Error::new(e.clone())));
285 Err(e)?;
286 }
287 Ok(fun(ctx))
288 }
289
290 pub(crate) fn delay_failure_with_icontext<OUT, DF, F>(&self, dfun: DF, fun: F) -> OUT
291 where
292 DF: Fn() -> OUT,
293 F: FnOnce(&mut IContext) -> Result<OUT>,
294 {
295 self.with_icontext_infallible(|mut ictx| {
296 fun(&mut ictx).unwrap_or_else(|e| {
297 self.dead.store(true, atomic::Ordering::SeqCst);
300 ictx.failed = Some(Arc::new(e));
301 dfun()
302 })
303 })
304 .unwrap_or_else(|_e| {
305 self.dead.store(true, atomic::Ordering::SeqCst);
307 dfun()
308 })
309 }
310
311 pub fn estimate_tensor_size<const DIMS: usize>(
312 &self,
313 typ: GType,
314 shape: [usize; DIMS],
315 ) -> Result<GMemoryRequest> {
316 self.with_icontext(|ctx, ictx| {
317 Ok(GMemoryRequest::estimate_tensor_request_ictx(
318 ctx, &ictx, typ, shape,
319 ))
320 })
321 }
322
323 pub fn tensor<const DIMS: usize>(
328 &self,
329 typ: GType,
330 shape: [usize; DIMS],
331 ) -> Result<GTensor<DIMS>>
332 where
333 Dim<DIMS>: DimValid,
334 DimPair<DIMS, 4>: DimLt,
335 {
336 self.with_icontext(|ctx, mut ictx| {
337 let mr = GMemoryRequest::estimate_tensor_request_ictx(self, &ictx, typ, shape);
338 mr.fit_or_die()?;
339
340 unsafe {
341 let p = match DIMS {
342 1 => gg::ggml_new_tensor_1d(ictx.gptr(), typ as u32, shape[0] as i64),
343 2 => gg::ggml_new_tensor_2d(
344 ictx.gptr(),
345 typ as u32,
346 shape[1] as i64,
347 shape[0] as i64,
348 ),
349 3 => gg::ggml_new_tensor_3d(
350 ictx.gptr(),
351 typ as u32,
352 shape[1] as i64,
353 shape[0] as i64,
354 shape[2] as i64,
355 ),
356 _ => unreachable!(),
357 };
358
359 if p.is_null() {
360 Err(GContextError::TensorCreationFailed)?;
361 }
362 GTensor::new_from_ptr(ctx, &mut ictx, (mr, p))
363 }
364 })
365 }
366
367 pub fn register_scratch_buffer(&mut self, buf: ScratchBuffer) -> Result<usize> {
370 self.with_icontext_infallible(|mut ictx| {
371 let bufid = ictx.scratch_buffers.len();
372 ictx.scratch_buffers.push(buf);
373 bufid
374 })
375 }
376
377 pub fn set_scratch_buffer(&self, maybebufid: Option<usize>) -> Result<()> {
383 self.with_icontext(|_ctx, mut ictx| {
384 let (size, data) = if let Some(bufid) = maybebufid {
385 if bufid >= ictx.scratch_buffers.len() {
386 Err(GContextError::InvalidScratchBufferId(bufid))?;
387 }
388 ictx.current_scratch_buffer = maybebufid;
389 let buf = &mut ictx.scratch_buffers[bufid].buf;
390 (buf.len(), buf.as_mut_ptr() as *mut c_void)
391 } else {
392 (0, std::ptr::null_mut())
393 };
394 unsafe {
395 gg::ggml_set_scratch(
396 ictx.gptr(),
397 gg::ggml_scratch {
398 offs: 0,
399 size,
400 data,
401 },
402 );
403 }
404 Ok(())
405 })
406 }
407
408 pub fn compute(&self, graph: &mut GGraph) -> Result<()> {
410 ensure!(!self.no_alloc, GContextError::NoAlloc);
411 let n_threads = graph.n_threads;
412 self.with_icontext_infallible(|ictx| unsafe {
413 gg::ggml_graph_compute_with_ctx(ictx.gptr(), &mut graph.graph, n_threads as i32)
414 })
415 }
416
417 pub fn used_mem(&self) -> Result<usize> {
419 self.with_icontext_infallible(|ictx| unsafe { gg::ggml_used_mem(ictx.gptr()) })
420 }
421}
422
423pub struct GGraph {
424 n_threads: usize,
425 graph: gg::ggml_cgraph,
426}
427
428impl GGraph {
429 pub fn new(n_threads: usize) -> Self {
431 let graph = unsafe { std::mem::zeroed::<gg::ggml_cgraph>() };
432 Self { n_threads, graph }
433 }
434
435 pub fn build_forward_expand<const DIMS: usize, T: AsRef<GTensor<DIMS>>>(
437 &mut self,
438 tensor: T,
439 ) -> Result<()>
440 where
441 Dim<DIMS>: DimValid,
442 {
443 tensor
445 .as_ref()
446 .with_tensor_infallible(|_ctx, _ictx, tptr| unsafe {
447 gg::ggml_build_forward_expand(&mut self.graph, tptr)
448 })
449 }
450}