1use core::slice;
11use std::{
12 ffi::CStr,
13 fmt::Debug,
14 intrinsics::copy_nonoverlapping,
15 mem::{size_of_val, transmute},
16 ops::{Bound, RangeBounds},
17};
18
19#[cfg(feature = "gpu")]
20use cuda_driver_sys::{
21 cuMemAllocHost_v2, cuMemAlloc_v2, cuMemFreeHost, cuMemFree_v2, cuMemcpyDtoD_v2,
22 cuMemcpyDtoH_v2, cuMemcpyHtoD_v2, CUdeviceptr,
23};
24use libc::{c_void, calloc, free};
25
26use crate::{
27 error::{Error, ErrorCode, CSTR_CONVERT_ERROR_PLUG},
28 sys, to_cstring,
29};
30
31macro_rules! impl_sample {
32 ($type:ty, $data:expr) => {
33 impl private::Sealed for $type {}
34
35 impl Sample for $type {
36 const DATA_TYPE: DataType = $data;
37 }
38 };
39}
40
41mod private {
42 pub trait Sealed: Clone + Copy {}
43}
44
45pub trait Sample: private::Sealed {
47 const DATA_TYPE: DataType;
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
52#[repr(u32)]
53pub enum DataType {
54 Invalid = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INVALID,
55 Bool = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BOOL,
56 Uint8 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT8,
57 Uint16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT16,
58 Uint32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT32,
59 Uint64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_UINT64,
60 Int8 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT8,
61 Int16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT16,
62 Int32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT32,
63 Int64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INT64,
64 Fp16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP16,
65 Fp32 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP32,
66 Fp64 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_FP64,
67 Bytes = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BYTES,
68 Bf16 = sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_BF16,
69}
70
71#[derive(Clone, Copy)]
72pub struct Byte(pub u8);
73
74impl_sample!(bool, DataType::Bool);
75impl_sample!(u8, DataType::Uint8);
76impl_sample!(Byte, DataType::Bytes);
77impl_sample!(u16, DataType::Uint16);
78impl_sample!(u32, DataType::Uint32);
79impl_sample!(u64, DataType::Uint64);
80
81impl_sample!(i8, DataType::Int8);
82impl_sample!(i16, DataType::Int16);
83impl_sample!(i32, DataType::Int32);
84impl_sample!(i64, DataType::Int64);
85
86impl_sample!(half::f16, DataType::Fp16);
87impl_sample!(half::bf16, DataType::Bf16);
88impl_sample!(f32, DataType::Fp32);
89impl_sample!(f64, DataType::Fp64);
90
91impl DataType {
92 pub fn as_str(self) -> &'static str {
94 let ptr = unsafe { sys::TRITONSERVER_DataTypeString(self as u32) };
95 unsafe { CStr::from_ptr(ptr) }
96 .to_str()
97 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
98 }
99
100 pub fn size(self) -> u32 {
102 unsafe { sys::TRITONSERVER_DataTypeByteSize(self as u32) }
103 }
104}
105
106impl TryFrom<&str> for DataType {
107 type Error = Error;
108 fn try_from(name: &str) -> Result<Self, Self::Error> {
110 let name = to_cstring(name)?;
111 let data_type = unsafe { sys::TRITONSERVER_StringToDataType(name.as_ptr()) };
112 if data_type != sys::TRITONSERVER_datatype_enum_TRITONSERVER_TYPE_INVALID {
113 Ok(unsafe { transmute::<u32, crate::memory::DataType>(data_type) })
114 } else {
115 Err(Error::new(ErrorCode::InvalidArg, ""))
116 }
117 }
118}
119
120#[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
122#[repr(u32)]
123pub enum MemoryType {
124 Cpu = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_CPU,
125 Pinned = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_CPU_PINNED,
126 Gpu = sys::TRITONSERVER_memorytype_enum_TRITONSERVER_MEMORY_GPU,
127}
128
129impl MemoryType {
130 pub fn as_str(self) -> &'static str {
132 let ptr = unsafe { sys::TRITONSERVER_MemoryTypeString(self as u32) };
133 unsafe { CStr::from_ptr(ptr) }
134 .to_str()
135 .unwrap_or(CSTR_CONVERT_ERROR_PLUG)
136 }
137}
138
139#[cfg(feature = "gpu")]
143pub struct CudaArray {
144 pub ptr: CUdeviceptr,
145 pub len: usize,
146}
147
148#[derive(Debug)]
157pub struct Buffer {
158 pub(crate) ptr: *mut c_void,
159 pub(crate) len: usize,
161 pub(crate) data_type: DataType,
162 pub(crate) memory_type: MemoryType,
163 pub(crate) owned: bool,
165}
166
167unsafe impl Send for Buffer {}
168unsafe impl Sync for Buffer {}
169
170impl Buffer {
172 pub fn try_clone(&self) -> Result<Self, Error> {
176 self.check_mem_type_feature()?;
177
178 let sample_count = self.len / self.data_type.size() as usize;
179 let mut res = Buffer::alloc_with_data_type(sample_count, self.memory_type, self.data_type)?;
180
181 if self.memory_type == MemoryType::Gpu {
182 #[cfg(feature = "gpu")]
183 res.copy_from_cuda_array(0, unsafe { self.get_cuda_array() })?;
184 } else {
185 res.copy_from_slice(0, self.bytes())?;
186 }
187
188 Ok(res)
189 }
190
191 pub fn alloc<T: Sample>(count: usize, memory_type: MemoryType) -> Result<Self, Error> {
197 Self::alloc_with_data_type(count, memory_type, T::DATA_TYPE)
198 }
199
200 pub fn alloc_with_data_type(
207 count: usize,
208 memory_type: MemoryType,
209 data_type: DataType,
210 ) -> Result<Self, Error> {
211 let data_type_size = data_type.size() as usize;
212 let size = count * data_type_size;
213
214 let ptr = match memory_type {
215 MemoryType::Cpu => Ok::<_, Error>(unsafe { calloc(count as _, data_type_size) }),
216 MemoryType::Pinned => {
217 #[cfg(not(feature = "gpu"))]
218 return Err(Error::wrong_type(memory_type));
219 #[cfg(feature = "gpu")]
220 {
221 let mut data = std::ptr::null_mut::<c_void>();
222 cuda_call!(cuMemAllocHost_v2(&mut data, size))?;
223 Ok(data)
224 }
225 }
226 MemoryType::Gpu => {
227 #[cfg(not(feature = "gpu"))]
228 return Err(Error::wrong_type(memory_type));
229 #[cfg(feature = "gpu")]
230 {
231 let mut data = 0;
232 cuda_call!(cuMemAlloc_v2(&mut data, size))?;
233 Ok(data as *mut c_void)
234 }
235 }
236 }?;
237
238 if ptr.is_null() {
239 Err(Error::new(
240 ErrorCode::Internal,
241 format!("OutOfMemory. {memory_type:?}"),
242 ))
243 } else {
244 Ok(Buffer {
245 ptr,
246 len: size,
247 data_type,
248 memory_type,
249 owned: true,
250 })
251 }
252 }
253
254 pub fn from<T: Sample, S: AsRef<[T]>>(slice: S) -> Self {
256 let slice = slice.as_ref();
257 let ptr = unsafe {
258 let ptr = calloc(slice.len(), std::mem::size_of::<T>()) as *mut T;
259 copy_nonoverlapping(slice.as_ptr(), ptr, slice.len());
260 ptr
261 };
262
263 Buffer {
264 ptr: ptr as *mut _,
265 len: size_of_val(slice),
266 data_type: T::DATA_TYPE,
267 memory_type: MemoryType::Cpu,
268 owned: true,
269 }
270 }
271}
272
273#[cfg(feature = "gpu")]
279impl From<CudaArray> for Buffer {
280 fn from(value: CudaArray) -> Self {
281 Buffer {
282 ptr: value.ptr as *mut c_void,
283 len: value.len,
284 data_type: DataType::Uint8,
285 memory_type: MemoryType::Gpu,
286 owned: true,
287 }
288 }
289}
290
291#[cfg(feature = "gpu")]
295impl From<Buffer> for CudaArray {
296 fn from(value: Buffer) -> CudaArray {
297 let res = CudaArray {
298 ptr: value.ptr as _,
299 len: value.len,
300 };
301 std::mem::forget(value);
302 res
303 }
304}
305
306impl Buffer {
308 pub fn memory_type(&self) -> MemoryType {
310 self.memory_type
311 }
312
313 pub fn data_type(&self) -> DataType {
315 self.data_type
316 }
317
318 pub fn size(&self) -> usize {
320 self.len
321 }
322
323 pub fn is_empty(&self) -> bool {
325 self.len == 0
326 }
327}
328
329impl Buffer {
331 pub fn copy_from_slice<S: AsRef<[T]>, T: Sample>(
339 &mut self,
340 offset: usize,
341 source: S,
342 ) -> Result<(), Error> {
343 self.check_mem_type_feature()?;
344
345 let slice = source.as_ref();
346
347 let byte_size = size_of_val(slice);
348
349 if self.len < byte_size + offset {
350 return Err(Error::new(
351 ErrorCode::Internal,
352 format!(
353 "copy_from_slice error: size mismatch! (required {}, buffer len {})",
354 byte_size + offset,
355 self.len
356 ),
357 ));
358 }
359
360 match self.memory_type {
361 MemoryType::Cpu | MemoryType::Pinned => unsafe {
362 copy_nonoverlapping(slice.as_ptr(), self.ptr.byte_add(offset) as _, slice.len());
363 },
364 MemoryType::Gpu => {
365 #[cfg(feature = "gpu")]
366 cuda_call!(cuMemcpyHtoD_v2(
367 self.ptr as CUdeviceptr + offset as CUdeviceptr,
368 slice.as_ptr() as _,
369 byte_size
370 ))?;
371 }
372 }
373 Ok(())
374 }
375
376 #[cfg(feature = "gpu")]
384 pub fn copy_from_cuda_array(&mut self, offset: usize, source: CudaArray) -> Result<(), Error> {
385 let CudaArray { ptr, len } = source;
386
387 if len + offset > self.len {
388 return Err(Error::new(
389 ErrorCode::Internal,
390 format!(
391 "copy_from_cuda_array error: size mismatch (buffer len {}, required {})",
392 self.len,
393 len + offset
394 ),
395 ));
396 }
397
398 match self.memory_type {
399 MemoryType::Pinned | MemoryType::Cpu => {
400 cuda_call!(cuMemcpyDtoH_v2(
401 self.ptr.byte_add(offset),
402 ptr as CUdeviceptr,
403 len
404 ))?;
405 }
406 MemoryType::Gpu => {
407 cuda_call!(cuMemcpyDtoD_v2(
408 self.ptr as CUdeviceptr + offset as CUdeviceptr,
409 ptr as CUdeviceptr,
410 len
411 ))?;
412 }
413 }
414 Ok(())
415 }
416
417 pub fn into_cpu(self) -> Result<Self, Error> {
421 self.into_mem_type(MemoryType::Cpu)
422 }
423
424 #[cfg(feature = "gpu")]
428 pub fn into_pinned(self) -> Result<Self, Error> {
429 self.into_mem_type(MemoryType::Pinned)
430 }
431
432 #[cfg(feature = "gpu")]
436 pub fn into_gpu(self) -> Result<Self, Error> {
437 self.into_mem_type(MemoryType::Gpu)
438 }
439
440 fn into_mem_type(self, mem_type: MemoryType) -> Result<Self, Error> {
441 self.check_mem_type_feature()?;
442
443 if self.memory_type == mem_type {
444 return Ok(self);
445 }
446
447 let sample_count = self.len / self.data_type.size() as usize;
448 let mut res = Buffer::alloc_with_data_type(sample_count, mem_type, self.data_type)?;
449
450 if self.memory_type == MemoryType::Gpu {
451 #[cfg(feature = "gpu")]
452 res.copy_from_cuda_array(0, unsafe { self.get_cuda_array() })?;
453 } else {
454 res.copy_from_slice(0, self.bytes())?;
455 }
456 Ok(res)
457 }
458}
459
460impl Buffer {
462 pub fn bytes(&self) -> &[u8] {
466 if self.memory_type == MemoryType::Gpu {
467 log::warn!("Use bytes() on Gpu Buffer. empty slice will be returned");
468 return &[];
469 }
470
471 unsafe { slice::from_raw_parts(self.ptr as *const u8, self.len) }
472 }
473
474 pub fn bytes_mut(&mut self) -> &mut [u8] {
478 if self.memory_type == MemoryType::Gpu {
479 log::warn!("Use bytes_mut() on Gpu Buffer. empty slice will be returned");
480 return &mut [];
481 }
482
483 unsafe { slice::from_raw_parts_mut(self.ptr as *mut u8, self.len) }
484 }
485
486 pub fn get_owned_slice<Range: RangeBounds<usize> + Debug>(
489 &self,
490 range: Range,
491 ) -> Result<Vec<u8>, Error> {
492 self.check_mem_type_feature()?;
493
494 let left = match range.start_bound() {
495 Bound::Unbounded => 0,
496 Bound::Included(pos) => *pos,
497 Bound::Excluded(pos) => *pos + 1,
498 };
499 let right = match range.end_bound() {
500 Bound::Unbounded => self.len,
501 Bound::Included(pos) => *pos + 1,
502 Bound::Excluded(pos) => *pos,
503 };
504
505 if right > self.len {
506 return Err(Error::new(
507 ErrorCode::InvalidArg,
508 format!(
509 "get_slice invalid range: {range:?}, buffer len is: {}",
510 self.len
511 ),
512 ));
513 }
514
515 if self.memory_type != MemoryType::Gpu {
516 Ok(self.bytes()[left..right].to_vec())
517 } else {
518 let mut res = Vec::with_capacity(right - left);
519 #[cfg(feature = "gpu")]
520 cuda_call!(cuMemcpyDtoH_v2(
521 res.as_mut_ptr() as _,
522 self.ptr as CUdeviceptr + left as CUdeviceptr,
523 right - left
524 ))?;
525
526 unsafe { res.set_len(self.len) };
527 Ok(res)
528 }
529 }
530
531 #[cfg(feature = "gpu")]
539 pub unsafe fn get_cuda_array(&self) -> CudaArray {
540 if self.memory_type != MemoryType::Gpu {
541 panic!("Invoking get_cuda_array for non GPU-based buffer");
542 }
543
544 CudaArray {
545 ptr: self.ptr as _,
546 len: self.len,
547 }
548 }
549
550 fn check_mem_type_feature(&self) -> Result<(), Error> {
551 #[cfg(not(feature = "gpu"))]
552 if self.memory_type != MemoryType::Cpu {
553 return Err(Error::wrong_type(self.memory_type));
554 }
555 Ok(())
556 }
557}
558
559impl<T: Sample> AsRef<[T]> for Buffer {
560 fn as_ref(&self) -> &[T] {
566 if T::DATA_TYPE != self.data_type {
567 panic!(
568 "Buffer data_type {:?} != target slice data_type: {:?}",
569 self.data_type,
570 T::DATA_TYPE
571 )
572 }
573
574 if self.memory_type == MemoryType::Gpu {
575 log::warn!("Use as_ref() on Gpu Buffer. empty slice will be returned");
576 return &[];
577 }
578
579 unsafe { slice::from_raw_parts(self.ptr as *const T, self.len) }
580 }
581}
582
583impl<T: Sample> AsMut<[T]> for Buffer {
584 fn as_mut(&mut self) -> &mut [T] {
590 if T::DATA_TYPE != self.data_type {
591 panic!(
592 "Buffer data_type {:?} != target slice data_type: {:?}",
593 self.data_type,
594 T::DATA_TYPE
595 )
596 }
597
598 if self.memory_type == MemoryType::Gpu {
599 log::warn!("Use as_mut() on Gpu Buffer. empty slice will be returned");
600 return &mut [];
601 }
602
603 unsafe { slice::from_raw_parts_mut(self.ptr as *mut T, self.len) }
604 }
605}
606
607impl Drop for Buffer {
608 fn drop(&mut self) {
609 if self.owned && !self.ptr.is_null() {
610 unsafe {
611 match self.memory_type {
612 MemoryType::Cpu => {
613 free(self.ptr);
614 }
615 MemoryType::Pinned => {
616 #[cfg(feature = "gpu")]
617 cuMemFreeHost(self.ptr);
618 }
619 MemoryType::Gpu => {
620 #[cfg(feature = "gpu")]
621 cuMemFree_v2(self.ptr as CUdeviceptr);
622 }
623 }
624 }
625 }
626 }
627}