1use std::ffi::c_void;
2use std::ops::Range;
3use std::sync::Mutex;
4
5use anyhow::{anyhow, bail};
6use downcast_rs::{Downcast, impl_downcast};
7use tract_core::dyn_clone;
8use tract_core::internal::*;
9use tract_core::value::TValue;
10
11use crate::tensor::{DeviceTensor, OwnedDeviceTensor};
12
13pub trait DeviceContext: Downcast + dyn_clone::DynClone + Send + Sync {
14 fn tensor_to_device(&self, tensor: TValue) -> TractResult<Box<dyn OwnedDeviceTensor>>;
15 fn uninitialized_device_tensor(
16 &self,
17 shape: &[usize],
18 dt: DatumType,
19 ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
20 fn uninitialized_device_exotic_tensor(
21 &self,
22 exotic_fact: Box<dyn ExoticFact>,
23 ) -> TractResult<Box<dyn OwnedDeviceTensor>>;
24 fn synchronize(&self) -> TractResult<()>;
25 fn copy_nd(
26 &self,
27 input: &DeviceTensor,
28 input_offset: usize,
29 input_strides: &[isize],
30 output: &DeviceTensor,
31 output_offset: usize,
32 output_shape: &[usize],
33 output_strides: &[isize],
34 ) -> TractResult<()>;
35
36 fn assign_slice(
38 &self,
39 dst: &DeviceTensor,
40 dst_range: Range<usize>,
41 src: &DeviceTensor,
42 src_range: Range<usize>,
43 axis: usize,
44 ) -> TractResult<()> {
45 let mut zone_shape: TVec<usize> = src.shape().into();
46 zone_shape[axis] = src_range.len();
47 if zone_shape.iter().product::<usize>() == 0 {
48 return Ok(());
49 }
50 let src_offset =
51 src_range.start * src.strides()[axis] as usize * src.datum_type().size_of();
52 let dst_offset =
53 dst_range.start * dst.strides()[axis] as usize * dst.datum_type().size_of();
54 self.copy_nd(src, src_offset, src.strides(), dst, dst_offset, &zone_shape, dst.strides())
55 }
56
57 fn copy_with_origins(
59 &self,
60 zone_shape: &[usize],
61 dst: &DeviceTensor,
62 dst_origin: &[usize],
63 dst_strides: &[isize],
64 src: &DeviceTensor,
65 src_origin: &[usize],
66 src_strides: &[isize],
67 ) -> TractResult<()> {
68 if zone_shape.iter().product::<usize>() == 0 {
69 return Ok(());
70 }
71 let dt_size = src.datum_type().size_of();
72 let src_offset: usize =
73 src_origin.iter().zip(src_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
74 * dt_size;
75 let dst_offset: usize =
76 dst_origin.iter().zip(dst_strides).map(|(o, s)| o * *s as usize).sum::<usize>()
77 * dt_size;
78 self.copy_nd(src, src_offset, src_strides, dst, dst_offset, zone_shape, dst_strides)
79 }
80
81 fn flat_copy(
83 &self,
84 src: &DeviceTensor,
85 src_byte_offset: usize,
86 dst: &DeviceTensor,
87 dst_byte_offset: usize,
88 byte_len: usize,
89 ) -> TractResult<()> {
90 if byte_len == 0 {
91 return Ok(());
92 }
93 let elem_size = src.datum_type().size_of();
96 ensure!(
97 byte_len % elem_size == 0
98 && src_byte_offset % elem_size == 0
99 && dst_byte_offset % elem_size == 0,
100 "flat_copy: byte_len {byte_len}, src_offset {src_byte_offset}, \
101 dst_offset {dst_byte_offset} not aligned to element size {elem_size}"
102 );
103 self.copy_nd(
104 src,
105 src_byte_offset,
106 &[1],
107 dst,
108 dst_byte_offset,
109 &[byte_len / elem_size],
110 &[1],
111 )
112 }
113}
114
115impl_downcast!(DeviceContext);
116dyn_clone::clone_trait_object!(DeviceContext);
117
118pub trait DeviceBuffer: Downcast + dyn_clone::DynClone + Send + Sync + std::fmt::Debug {
119 fn ptr(&self) -> *const c_void;
120}
121
122impl_downcast!(DeviceBuffer);
123dyn_clone::clone_trait_object!(DeviceBuffer);
124
125pub static DEVICE_CONTEXT: Mutex<Option<Box<dyn DeviceContext>>> = Mutex::new(None);
126
127pub fn set_context(curr_context: Box<dyn DeviceContext>) -> TractResult<()> {
128 let mut context = DEVICE_CONTEXT.lock().unwrap();
129 if context.is_none() {
130 *context = Some(curr_context);
131 Ok(())
132 } else {
133 bail!("Context is already set")
134 }
135}
136
137pub fn get_context() -> TractResult<Box<dyn DeviceContext>> {
138 let guard = DEVICE_CONTEXT.lock().map_err(|_| anyhow!("Cannot read GPU Context"))?;
139 guard.as_ref().cloned().ok_or_else(|| anyhow!("GPU Context not initialized"))
140}