1use core::slice;
11use std::{
12 ffi::CStr,
13 fmt::Debug,
14 intrinsics::copy_nonoverlapping,
15 mem::{size_of_val, transmute},
16 ops::{Bound, RangeBounds},
17};
18
19#[cfg(feature = "gpu")]
20use cuda_driver_sys::{
21 cuMemAllocHost_v2, cuMemAlloc_v2, cuMemFreeHost, cuMemFree_v2, cuMemcpyDtoD_v2,
22 cuMemcpyDtoH_v2, cuMemcpyHtoD_v2, CUdeviceptr,
23};
24use libc::{c_void, calloc, free};
25
26use crate::{
27 error::{Error, ErrorCode, CSTR_CONVERT_ERROR_PLUG},
28 sys, to_cstring,
29};
30
31macro_rules! impl_sample {
32 ($type:ty, $data:expr) => {
33 impl private::Sealed for $type {}
34
35 impl Sample for $type {
36 const DATA_TYPE: DataType = $data;
37 }
38 };
39}
40
41mod private {
42 pub trait Sealed: Clone + Copy {}
43}
44
45pub trait Sample: private::Sealed {
47 const DATA_TYPE: DataType;
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
52#[repr(u32)]
53pub enum DataType {
54 Invalid = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INVALID,
55 Bool = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BOOL,
56 Uint8 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT8,
57 Uint16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT16,
58 Uint32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT32,
59 Uint64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT64,
60 Int8 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT8,
61 Int16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT16,
62 Int32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT32,
63 Int64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT64,
64 Fp16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP16,
65 Fp32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP32,
66 Fp64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP64,
67 Bytes = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BYTES,
68 Bf16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BF16,
69}
70
71#[derive(Clone, Copy)]
72pub struct Byte(pub u8);
73
74impl_sample!(bool, DataType::Bool);
75impl_sample!(u8, DataType::Uint8);
76impl_sample!(Byte, DataType::Bytes);
77impl_sample!(u16, DataType::Uint16);
78impl_sample!(u32, DataType::Uint32);
79impl_sample!(u64, DataType::Uint64);
80
81impl_sample!(i8, DataType::Int8);
82impl_sample!(i16, DataType::Int16);
83impl_sample!(i32, DataType::Int32);
84impl_sample!(i64, DataType::Int64);
85
86impl_sample!(half::f16, DataType::Fp16);
87impl_sample!(half::bf16, DataType::Bf16);
88impl_sample!(f32, DataType::Fp32);
89impl_sample!(f64, DataType::Fp64);
90
91impl DataType {
92 pub fn as_str(self) -> &'static str {
94 let ptr = unsafe { sys::TRITONSERVER_DataTypeString(self as u32) };
95 unsafe { CStr::from_ptr(ptr) }
96 .to_str()
97 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
98 }
99
100 pub fn size(self) -> u32 {
102 if self == Self::Bytes {
103 size_of::<Byte>() as u32
104 } else {
105 unsafe { sys::TRITONSERVER_DataTypeByteSize(self as u32) }
106 }
107 }
108}
109
110impl TryFrom<&str> for DataType {
111 type Error = Error;
112 fn try_from(name: &str) -> Result<Self, Self::Error> {
114 let name = to_cstring(name)?;
115 let data_type = unsafe { sys::TRITONSERVER_StringToDataType(name.as_ptr()) };
116 if data_type != sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INVALID {
117 Ok(unsafe { transmute::<u32, crate::memory::DataType>(data_type) })
118 } else {
119 Err(Error::new(ErrorCode::InvalidArg, ""))
120 }
121 }
122}
123
124#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
126#[repr(u32)]
127pub enum MemoryType {
128 Cpu = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_CPU,
129 Pinned = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_CPU_PINNED,
130 Gpu = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_GPU,
131}
132
133impl MemoryType {
134 pub fn as_str(self) -> &'static str {
136 let ptr = unsafe { sys::TRITONSERVER_MemoryTypeString(self as u32) };
137 unsafe { CStr::from_ptr(ptr) }
138 .to_str()
139 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
140 }
141}
142
143#[cfg(feature = "gpu")]
147pub struct CudaArray {
148 pub ptr: CUdeviceptr,
149 pub len: usize,
150}
151
152#[derive(Debug)]
161pub struct Buffer {
162 pub(crate) ptr: *mut c_void,
163 pub(crate) len: usize,
165 pub(crate) data_type: DataType,
166 pub(crate) memory_type: MemoryType,
167 pub(crate) owned: bool,
169}
170
171unsafe impl Send for Buffer {}
172unsafe impl Sync for Buffer {}
173
174impl Buffer {
176 pub fn try_clone(&self) -> Result<Self, Error> {
180 self.check_mem_type_feature()?;
181
182 let sample_count = self.len / self.data_type.size() as usize;
183 let mut res = Buffer::alloc_with_data_type(sample_count, self.memory_type, self.data_type)?;
184
185 if self.memory_type == MemoryType::Gpu {
186 #[cfg(feature = "gpu")]
187 res.copy_from_cuda_array(0, unsafe { self.get_cuda_array() })?;
188 } else {
189 res.copy_from_slice(0, self.bytes())?;
190 }
191
192 Ok(res)
193 }
194
195 pub fn alloc<T: Sample>(count: usize, memory_type: MemoryType) -> Result<Self, Error> {
201 Self::alloc_with_data_type(count, memory_type, T::DATA_TYPE)
202 }
203
204 pub fn alloc_with_data_type(
211 count: usize,
212 memory_type: MemoryType,
213 data_type: DataType,
214 ) -> Result<Self, Error> {
215 let data_type_size = data_type.size() as usize;
216 let size = count * data_type_size;
217
218 let ptr = match memory_type {
219 MemoryType::Cpu => Ok::<_, Error>(unsafe { calloc(count as _, data_type_size) }),
220 MemoryType::Pinned => {
221 #[cfg(not(feature = "gpu"))]
222 return Err(Error::wrong_type(memory_type));
223 #[cfg(feature = "gpu")]
224 {
225 let mut data = std::ptr::null_mut::<c_void>();
226 cuda_call!(cuMemAllocHost_v2(&mut data, size))?;
227 Ok(data)
228 }
229 }
230 MemoryType::Gpu => {
231 #[cfg(not(feature = "gpu"))]
232 return Err(Error::wrong_type(memory_type));
233 #[cfg(feature = "gpu")]
234 {
235 let mut data = 0;
236 cuda_call!(cuMemAlloc_v2(&mut data, size))?;
237 Ok(data as *mut c_void)
238 }
239 }
240 }?;
241
242 if ptr.is_null() {
243 Err(Error::new(
244 ErrorCode::Internal,
245 format!("OutOfMemory. {memory_type:?}"),
246 ))
247 } else {
248 Ok(Buffer {
249 ptr,
250 len: size,
251 data_type,
252 memory_type,
253 owned: true,
254 })
255 }
256 }
257
258 pub fn from<T: Sample, S: AsRef<[T]>>(slice: S) -> Self {
260 let slice = slice.as_ref();
261 let ptr = unsafe {
262 let ptr = calloc(slice.len(), std::mem::size_of::<T>()) as *mut T;
263 copy_nonoverlapping(slice.as_ptr(), ptr, slice.len());
264 ptr
265 };
266
267 Buffer {
268 ptr: ptr as *mut _,
269 len: size_of_val(slice),
270 data_type: T::DATA_TYPE,
271 memory_type: MemoryType::Cpu,
272 owned: true,
273 }
274 }
275}
276
277#[cfg(feature = "gpu")]
283impl From<CudaArray> for Buffer {
284 fn from(value: CudaArray) -> Self {
285 Buffer {
286 ptr: value.ptr as *mut c_void,
287 len: value.len,
288 data_type: DataType::Uint8,
289 memory_type: MemoryType::Gpu,
290 owned: true,
291 }
292 }
293}
294
295#[cfg(feature = "gpu")]
299impl From<Buffer> for CudaArray {
300 fn from(value: Buffer) -> CudaArray {
301 let res = CudaArray {
302 ptr: value.ptr as _,
303 len: value.len,
304 };
305 std::mem::forget(value);
306 res
307 }
308}
309
310impl Buffer {
312 pub fn memory_type(&self) -> MemoryType {
314 self.memory_type
315 }
316
317 pub fn data_type(&self) -> DataType {
319 self.data_type
320 }
321
322 pub fn size(&self) -> usize {
324 self.len
325 }
326
327 pub fn is_empty(&self) -> bool {
329 self.len == 0
330 }
331}
332
333impl Buffer {
335 pub fn copy_from_slice<S: AsRef<[T]>, T: Sample>(
343 &mut self,
344 offset: usize,
345 source: S,
346 ) -> Result<(), Error> {
347 self.check_mem_type_feature()?;
348
349 let slice = source.as_ref();
350
351 let byte_size = size_of_val(slice);
352
353 if self.len < byte_size + offset {
354 return Err(Error::new(
355 ErrorCode::Internal,
356 format!(
357 "copy_from_slice error: size mismatch! (required {}, buffer len {})",
358 byte_size + offset,
359 self.len
360 ),
361 ));
362 }
363
364 match self.memory_type {
365 MemoryType::Cpu | MemoryType::Pinned => unsafe {
366 copy_nonoverlapping(slice.as_ptr(), self.ptr.byte_add(offset) as _, slice.len());
367 },
368 MemoryType::Gpu => {
369 #[cfg(feature = "gpu")]
370 cuda_call!(cuMemcpyHtoD_v2(
371 self.ptr as CUdeviceptr + offset as CUdeviceptr,
372 slice.as_ptr() as _,
373 byte_size
374 ))?;
375 }
376 }
377 Ok(())
378 }
379
380 #[cfg(feature = "gpu")]
388 pub fn copy_from_cuda_array(&mut self, offset: usize, source: CudaArray) -> Result<(), Error> {
389 let CudaArray { ptr, len } = source;
390
391 if len + offset > self.len {
392 return Err(Error::new(
393 ErrorCode::Internal,
394 format!(
395 "copy_from_cuda_array error: size mismatch (buffer len {}, required {})",
396 self.len,
397 len + offset
398 ),
399 ));
400 }
401
402 match self.memory_type {
403 MemoryType::Pinned | MemoryType::Cpu => {
404 cuda_call!(cuMemcpyDtoH_v2(
405 self.ptr.byte_add(offset),
406 ptr as CUdeviceptr,
407 len
408 ))?;
409 }
410 MemoryType::Gpu => {
411 cuda_call!(cuMemcpyDtoD_v2(
412 self.ptr as CUdeviceptr + offset as CUdeviceptr,
413 ptr as CUdeviceptr,
414 len
415 ))?;
416 }
417 }
418 Ok(())
419 }
420
421 pub fn into_cpu(self) -> Result<Self, Error> {
425 self.into_mem_type(MemoryType::Cpu)
426 }
427
428 #[cfg(feature = "gpu")]
432 pub fn into_pinned(self) -> Result<Self, Error> {
433 self.into_mem_type(MemoryType::Pinned)
434 }
435
436 #[cfg(feature = "gpu")]
440 pub fn into_gpu(self) -> Result<Self, Error> {
441 self.into_mem_type(MemoryType::Gpu)
442 }
443
444 fn into_mem_type(self, mem_type: MemoryType) -> Result<Self, Error> {
445 self.check_mem_type_feature()?;
446
447 if self.memory_type == mem_type {
448 return Ok(self);
449 }
450
451 let sample_count = self.len / self.data_type.size() as usize;
452 let mut res = Buffer::alloc_with_data_type(sample_count, mem_type, self.data_type)?;
453
454 if self.memory_type == MemoryType::Gpu {
455 #[cfg(feature = "gpu")]
456 res.copy_from_cuda_array(0, unsafe { self.get_cuda_array() })?;
457 } else {
458 res.copy_from_slice(0, self.bytes())?;
459 }
460 Ok(res)
461 }
462}
463
464impl Buffer {
466 pub fn bytes(&self) -> &[u8] {
470 if self.memory_type == MemoryType::Gpu {
471 log::warn!("Use bytes() on Gpu Buffer. empty slice will be returned");
472 return &[];
473 }
474
475 unsafe { slice::from_raw_parts(self.ptr as *const u8, self.len) }
476 }
477
478 pub fn bytes_mut(&mut self) -> &mut [u8] {
482 if self.memory_type == MemoryType::Gpu {
483 log::warn!("Use bytes_mut() on Gpu Buffer. empty slice will be returned");
484 return &mut [];
485 }
486
487 unsafe { slice::from_raw_parts_mut(self.ptr as *mut u8, self.len) }
488 }
489
490 #[allow(clippy::uninit_vec)]
491 pub fn get_owned_slice<Range: RangeBounds<usize> + Debug>(
494 &self,
495 range: Range,
496 ) -> Result<Vec<u8>, Error> {
497 self.check_mem_type_feature()?;
498
499 let left = match range.start_bound() {
500 Bound::Unbounded => 0,
501 Bound::Included(pos) => *pos,
502 Bound::Excluded(pos) => *pos + 1,
503 };
504 let right = match range.end_bound() {
505 Bound::Unbounded => self.len,
506 Bound::Included(pos) => *pos + 1,
507 Bound::Excluded(pos) => *pos,
508 };
509
510 if right > self.len {
511 return Err(Error::new(
512 ErrorCode::InvalidArg,
513 format!(
514 "get_slice invalid range: {range:?}, buffer len is: {}",
515 self.len
516 ),
517 ));
518 }
519
520 if self.memory_type != MemoryType::Gpu {
521 Ok(self.bytes()[left..right].to_vec())
522 } else {
523 let mut res = Vec::with_capacity(right - left);
524 #[cfg(feature = "gpu")]
525 cuda_call!(cuMemcpyDtoH_v2(
526 res.as_mut_ptr() as _,
527 self.ptr as CUdeviceptr + left as CUdeviceptr,
528 right - left
529 ))?;
530
531 unsafe { res.set_len(self.len) };
532 Ok(res)
533 }
534 }
535
536 #[cfg(feature = "gpu")]
544 pub unsafe fn get_cuda_array(&self) -> CudaArray {
545 if self.memory_type != MemoryType::Gpu {
546 panic!("Invoking get_cuda_array for non GPU-based buffer");
547 }
548
549 CudaArray {
550 ptr: self.ptr as _,
551 len: self.len,
552 }
553 }
554
555 fn check_mem_type_feature(&self) -> Result<(), Error> {
556 #[cfg(not(feature = "gpu"))]
557 if self.memory_type != MemoryType::Cpu {
558 return Err(Error::wrong_type(self.memory_type));
559 }
560 Ok(())
561 }
562}
563
564impl<T: Sample> AsRef<[T]> for Buffer {
565 fn as_ref(&self) -> &[T] {
571 if T::DATA_TYPE != self.data_type {
572 panic!(
573 "Buffer data_type {:?} != target slice data_type: {:?}",
574 self.data_type,
575 T::DATA_TYPE
576 )
577 }
578
579 if self.memory_type == MemoryType::Gpu {
580 log::warn!("Use as_ref() on Gpu Buffer. empty slice will be returned");
581 return &[];
582 }
583
584 unsafe { slice::from_raw_parts(self.ptr as *const T, self.len) }
585 }
586}
587
588impl<T: Sample> AsMut<[T]> for Buffer {
589 fn as_mut(&mut self) -> &mut [T] {
595 if T::DATA_TYPE != self.data_type {
596 panic!(
597 "Buffer data_type {:?} != target slice data_type: {:?}",
598 self.data_type,
599 T::DATA_TYPE
600 )
601 }
602
603 if self.memory_type == MemoryType::Gpu {
604 log::warn!("Use as_mut() on Gpu Buffer. empty slice will be returned");
605 return &mut [];
606 }
607
608 unsafe { slice::from_raw_parts_mut(self.ptr as *mut T, self.len) }
609 }
610}
611
612impl Drop for Buffer {
613 fn drop(&mut self) {
614 if self.owned && !self.ptr.is_null() {
615 unsafe {
616 match self.memory_type {
617 MemoryType::Cpu => {
618 free(self.ptr);
619 }
620 MemoryType::Pinned => {
621 #[cfg(feature = "gpu")]
622 cuMemFreeHost(self.ptr);
623 }
624 MemoryType::Gpu => {
625 #[cfg(feature = "gpu")]
626 cuMemFree_v2(self.ptr as CUdeviceptr);
627 }
628 }
629 }
630 }
631 }
632}