1use std::ffi::c_void;
2use std::sync::Mutex;
3
4use anyhow::{anyhow, bail};
5use downcast_rs::{Downcast, impl_downcast};
6use tract_core::dyn_clone;
7use tract_core::prelude::{DatumType, TractResult};
8use tract_core::value::TValue;
9
10use crate::tensor::OwnedDeviceTensor;
11
12pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
13 fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
14 fn uninitialized_device_tensor(
15 &self,
16 shape: &[usize],
17 dt: DatumType,
18 ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
19 fn synchronize(&self) -> TractResult<()>;
20}
21
22impl_downcast!(DeviceContext);
23dyn_clone::clone_trait_object!(DeviceContext);
24
25pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
26 fn ptr(&self) -> *const c_void;
27}
28
29impl_downcast!(DeviceBuffer);
30dyn_clone::clone_trait_object!(DeviceBuffer);
31
32pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
33
34pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
35 let mut context = DEVICE_CONTEXT.lock().unwrap();
36 if context.is_none() {
37 *context = Some(curr_context);
38 Ok(())
39 } else {
40 bail!("Context is already set")
41 }
42}
43
44pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
45 let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
46 guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
47}