Skip to main content

pyo3_dlpack/
ffi.rs

1//! DLPack FFI types following the DLPack specification.
2//!
3//! These types match the C definitions from the DLPack header:
4//! <https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h>
5
6use std::ffi::c_void;
7
8/// Device type codes as defined by DLPack.
9///
10/// These correspond to `DLDeviceType` in the DLPack specification.
11#[repr(u32)]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum DLDeviceType {
14    /// CPU device
15    Cpu = 1,
16    /// CUDA GPU device
17    Cuda = 2,
18    /// Pinned CUDA CPU memory (allocated with cudaMallocHost)
19    CudaHost = 3,
20    /// OpenCL device
21    OpenCL = 4,
22    /// Vulkan device
23    Vulkan = 7,
24    /// Metal device (Apple)
25    Metal = 8,
26    /// VPI device
27    Vpi = 9,
28    /// ROCm device (AMD)
29    Rocm = 10,
30    /// ROCm host pinned memory
31    RocmHost = 11,
32    /// External DMA buffer
33    ExtDev = 12,
34    /// CUDA managed/unified memory
35    CudaManaged = 13,
36    /// Intel OneAPI device
37    OneApi = 14,
38    /// WebGPU device
39    WebGpu = 15,
40    /// Hexagon DSP device
41    Hexagon = 16,
42    /// MAIA accelerator
43    Maia = 17,
44}
45
46impl DLDeviceType {
47    /// Convert from raw u32 value.
48    ///
49    /// Returns `None` for unknown device types.
50    pub fn from_raw(value: u32) -> Option<Self> {
51        match value {
52            1 => Some(Self::Cpu),
53            2 => Some(Self::Cuda),
54            3 => Some(Self::CudaHost),
55            4 => Some(Self::OpenCL),
56            7 => Some(Self::Vulkan),
57            8 => Some(Self::Metal),
58            9 => Some(Self::Vpi),
59            10 => Some(Self::Rocm),
60            11 => Some(Self::RocmHost),
61            12 => Some(Self::ExtDev),
62            13 => Some(Self::CudaManaged),
63            14 => Some(Self::OneApi),
64            15 => Some(Self::WebGpu),
65            16 => Some(Self::Hexagon),
66            17 => Some(Self::Maia),
67            _ => None,
68        }
69    }
70}
71
72/// A device descriptor specifying where tensor data resides.
73///
74/// This corresponds to `DLDevice` in the DLPack specification.
75#[repr(C)]
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77pub struct DLDevice {
78    /// The device type (CPU, CUDA, etc.)
79    pub device_type: u32,
80    /// The device ID (e.g., which GPU for multi-GPU systems)
81    pub device_id: i32,
82}
83
84impl DLDevice {
85    /// Create a new device descriptor.
86    pub fn new(device_type: DLDeviceType, device_id: i32) -> Self {
87        Self {
88            device_type: device_type as u32,
89            device_id,
90        }
91    }
92
93    /// Get the device type as an enum.
94    ///
95    /// Returns `None` for unknown device types.
96    pub fn device_type_enum(&self) -> Option<DLDeviceType> {
97        DLDeviceType::from_raw(self.device_type)
98    }
99
100    /// Check if this is a CUDA device.
101    pub fn is_cuda(&self) -> bool {
102        self.device_type == DLDeviceType::Cuda as u32
103    }
104
105    /// Check if this is a CPU device.
106    pub fn is_cpu(&self) -> bool {
107        self.device_type == DLDeviceType::Cpu as u32
108    }
109
110    /// Check if this is CUDA host (pinned) memory.
111    pub fn is_cuda_host(&self) -> bool {
112        self.device_type == DLDeviceType::CudaHost as u32
113    }
114
115    /// Check if this is a ROCm device.
116    pub fn is_rocm(&self) -> bool {
117        self.device_type == DLDeviceType::Rocm as u32
118    }
119
120    /// Check if this is a Metal device (Apple GPU).
121    pub fn is_metal(&self) -> bool {
122        self.device_type == DLDeviceType::Metal as u32
123    }
124}
125
126/// Data type codes as defined by DLPack.
127///
128/// These correspond to `DLDataTypeCode` in the DLPack specification.
129#[repr(u8)]
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
131pub enum DLDataTypeCode {
132    /// Signed integer
133    Int = 0,
134    /// Unsigned integer
135    UInt = 1,
136    /// IEEE floating point
137    Float = 2,
138    /// Opaque handle type (not for computation)
139    OpaqueHandle = 3,
140    /// Bfloat16 (Brain Floating Point)
141    Bfloat = 4,
142    /// Complex numbers
143    Complex = 5,
144    /// Boolean
145    Bool = 6,
146}
147
148impl DLDataTypeCode {
149    /// Convert from raw u8 value.
150    ///
151    /// Returns `None` for unknown type codes.
152    pub fn from_raw(value: u8) -> Option<Self> {
153        match value {
154            0 => Some(Self::Int),
155            1 => Some(Self::UInt),
156            2 => Some(Self::Float),
157            3 => Some(Self::OpaqueHandle),
158            4 => Some(Self::Bfloat),
159            5 => Some(Self::Complex),
160            6 => Some(Self::Bool),
161            _ => None,
162        }
163    }
164}
165
166/// Data type descriptor specifying the element type of a tensor.
167///
168/// This corresponds to `DLDataType` in the DLPack specification.
169#[repr(C)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171pub struct DLDataType {
172    /// Type code (signed int, unsigned int, float, etc.)
173    pub code: u8,
174    /// Number of bits per element (e.g., 32 for float32)
175    pub bits: u8,
176    /// Number of lanes for vectorized types (usually 1)
177    pub lanes: u16,
178}
179
180impl DLDataType {
181    /// Create a new data type descriptor.
182    pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self {
183        Self {
184            code: code as u8,
185            bits,
186            lanes,
187        }
188    }
189
190    /// Get the type code as an enum.
191    ///
192    /// Returns `None` for unknown type codes.
193    pub fn code_enum(&self) -> Option<DLDataTypeCode> {
194        DLDataTypeCode::from_raw(self.code)
195    }
196
197    /// Check if this is f16 (half precision float).
198    pub fn is_f16(&self) -> bool {
199        self.code == DLDataTypeCode::Float as u8 && self.bits == 16 && self.lanes == 1
200    }
201
202    /// Check if this is f32 (single precision float).
203    pub fn is_f32(&self) -> bool {
204        self.code == DLDataTypeCode::Float as u8 && self.bits == 32 && self.lanes == 1
205    }
206
207    /// Check if this is f64 (double precision float).
208    pub fn is_f64(&self) -> bool {
209        self.code == DLDataTypeCode::Float as u8 && self.bits == 64 && self.lanes == 1
210    }
211
212    /// Check if this is bf16 (bfloat16).
213    pub fn is_bf16(&self) -> bool {
214        self.code == DLDataTypeCode::Bfloat as u8 && self.bits == 16 && self.lanes == 1
215    }
216
217    /// Check if this is i8 (signed 8-bit integer).
218    pub fn is_i8(&self) -> bool {
219        self.code == DLDataTypeCode::Int as u8 && self.bits == 8 && self.lanes == 1
220    }
221
222    /// Check if this is i16 (signed 16-bit integer).
223    pub fn is_i16(&self) -> bool {
224        self.code == DLDataTypeCode::Int as u8 && self.bits == 16 && self.lanes == 1
225    }
226
227    /// Check if this is i32 (signed 32-bit integer).
228    pub fn is_i32(&self) -> bool {
229        self.code == DLDataTypeCode::Int as u8 && self.bits == 32 && self.lanes == 1
230    }
231
232    /// Check if this is i64 (signed 64-bit integer).
233    pub fn is_i64(&self) -> bool {
234        self.code == DLDataTypeCode::Int as u8 && self.bits == 64 && self.lanes == 1
235    }
236
237    /// Check if this is u8 (unsigned 8-bit integer).
238    pub fn is_u8(&self) -> bool {
239        self.code == DLDataTypeCode::UInt as u8 && self.bits == 8 && self.lanes == 1
240    }
241
242    /// Check if this is u16 (unsigned 16-bit integer).
243    pub fn is_u16(&self) -> bool {
244        self.code == DLDataTypeCode::UInt as u8 && self.bits == 16 && self.lanes == 1
245    }
246
247    /// Check if this is u32 (unsigned 32-bit integer).
248    pub fn is_u32(&self) -> bool {
249        self.code == DLDataTypeCode::UInt as u8 && self.bits == 32 && self.lanes == 1
250    }
251
252    /// Check if this is u64 (unsigned 64-bit integer).
253    pub fn is_u64(&self) -> bool {
254        self.code == DLDataTypeCode::UInt as u8 && self.bits == 64 && self.lanes == 1
255    }
256
257    /// Check if this is bool.
258    pub fn is_bool(&self) -> bool {
259        self.code == DLDataTypeCode::Bool as u8 && self.bits == 8 && self.lanes == 1
260    }
261
262    /// Get the size of one element in bytes.
263    pub fn itemsize(&self) -> usize {
264        ((self.bits as usize) * (self.lanes as usize)).div_ceil(8)
265    }
266}
267
268/// The core DLTensor structure describing a tensor's data and layout.
269///
270/// This corresponds to `DLTensor` in the DLPack specification.
271#[repr(C)]
272pub struct DLTensor {
273    /// Pointer to the data buffer.
274    /// For GPU tensors, this is a device pointer.
275    pub data: *mut c_void,
276    /// Device descriptor specifying where the data resides.
277    pub device: DLDevice,
278    /// Number of dimensions.
279    pub ndim: i32,
280    /// Data type descriptor.
281    pub dtype: DLDataType,
282    /// Shape array (length = ndim).
283    /// Points to an array of dimension sizes.
284    pub shape: *mut i64,
285    /// Stride array in number of elements (length = ndim).
286    /// Can be null for compact row-major tensors.
287    pub strides: *mut i64,
288    /// Byte offset from the data pointer to the first element.
289    pub byte_offset: u64,
290}
291
292/// Deleter function signature for DLManagedTensor.
293///
294/// Called when the consumer is done with the tensor to free resources.
295pub type DLManagedTensorDeleter = unsafe extern "C" fn(*mut DLManagedTensor);
296
297/// A managed tensor with ownership semantics.
298///
299/// This corresponds to `DLManagedTensor` in the DLPack specification.
300/// It wraps a `DLTensor` and provides a deleter for cleanup.
301#[repr(C)]
302pub struct DLManagedTensor {
303    /// The underlying tensor descriptor.
304    pub dl_tensor: DLTensor,
305    /// Opaque manager context for the producer's use.
306    /// Typically used to store data needed by the deleter.
307    pub manager_ctx: *mut c_void,
308    /// Deleter function called when the consumer is done.
309    /// Can be null if no cleanup is needed.
310    pub deleter: Option<DLManagedTensorDeleter>,
311}
312
313/// DLPack protocol version, as carried by `DLManagedTensorVersioned`.
314///
315/// Corresponds to `DLPackVersion` in the DLPack specification.
316#[repr(C)]
317#[derive(Debug, Clone, Copy, PartialEq, Eq)]
318pub struct DLPackVersion {
319    /// Major version. Incremented on ABI-breaking changes.
320    pub major: u32,
321    /// Minor version. Incremented on backward-compatible additions.
322    pub minor: u32,
323}
324
325/// The DLPack major version this crate produces and accepts.
326pub const DLPACK_MAJOR_VERSION: u32 = 1;
327/// The DLPack minor version this crate produces and advertises.
328///
329/// Set to 0 because this crate implements the DLPack **1.0** feature set: the
330/// versioned struct layout (`DLManagedTensorVersioned`), the read-only flag,
331/// and the `max_version` negotiation kwarg. The Python array-API negotiation
332/// example itself uses `max_version=(1, 0)`. Minor versions are
333/// backward-compatible, so advertising `0` is the honest floor and is accepted
334/// by all conforming 1.x consumers.
335pub const DLPACK_MINOR_VERSION: u32 = 0;
336
337/// Flag bitmask: the tensor data is read-only.
338pub const DLPACK_FLAG_BITMASK_READ_ONLY: u64 = 1 << 0;
339/// Flag bitmask: the tensor data was copied by the producer.
340pub const DLPACK_FLAG_BITMASK_IS_COPIED: u64 = 1 << 1;
341/// Flag bitmask: a sub-byte-typed tensor is padded to a byte boundary.
342pub const DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED: u64 = 1 << 2;
343
344/// Deleter function signature for `DLManagedTensorVersioned`.
345///
346/// Note this takes a pointer to the *versioned* struct, which has a different
347/// layout from `DLManagedTensor`, so it is a distinct type from
348/// [`DLManagedTensorDeleter`].
349pub type DLManagedTensorVersionedDeleter = unsafe extern "C" fn(*mut DLManagedTensorVersioned);
350
351/// A versioned managed tensor (DLPack 1.0).
352///
353/// Corresponds to `DLManagedTensorVersioned` in the DLPack specification.
354/// The field order differs from [`DLManagedTensor`]: `version` is first and
355/// `dl_tensor` is last. This is an ABI contract — do not reorder.
356#[repr(C)]
357pub struct DLManagedTensorVersioned {
358    /// Protocol version of this struct.
359    pub version: DLPackVersion,
360    /// Opaque manager context for the producer's use.
361    pub manager_ctx: *mut c_void,
362    /// Deleter function called when the consumer is done. Can be null if no
363    /// cleanup is needed.
364    pub deleter: Option<DLManagedTensorVersionedDeleter>,
365    /// Bitmask of `DLPACK_FLAG_BITMASK_*` flags.
366    pub flags: u64,
367    /// The underlying tensor descriptor.
368    pub dl_tensor: DLTensor,
369}
370
371// ============================================================================
372// Convenience constructors
373// ============================================================================
374
375/// Create a DLDevice for CUDA with the specified device ID.
376pub fn cuda_device(device_id: i32) -> DLDevice {
377    DLDevice::new(DLDeviceType::Cuda, device_id)
378}
379
380/// Create a DLDevice for CPU.
381pub fn cpu_device() -> DLDevice {
382    DLDevice::new(DLDeviceType::Cpu, 0)
383}
384
385/// Create a DLDevice for Metal (Apple GPU) with the specified device ID.
386pub fn metal_device(device_id: i32) -> DLDevice {
387    DLDevice::new(DLDeviceType::Metal, device_id)
388}
389
390/// Create a DLDataType for f32 (single precision float).
391pub fn dtype_f32() -> DLDataType {
392    DLDataType::new(DLDataTypeCode::Float, 32, 1)
393}
394
395/// Create a DLDataType for f64 (double precision float).
396pub fn dtype_f64() -> DLDataType {
397    DLDataType::new(DLDataTypeCode::Float, 64, 1)
398}
399
400/// Create a DLDataType for f16 (half precision float).
401pub fn dtype_f16() -> DLDataType {
402    DLDataType::new(DLDataTypeCode::Float, 16, 1)
403}
404
405/// Create a DLDataType for bf16 (bfloat16).
406pub fn dtype_bf16() -> DLDataType {
407    DLDataType::new(DLDataTypeCode::Bfloat, 16, 1)
408}
409
410/// Create a DLDataType for i8 (signed 8-bit integer).
411pub fn dtype_i8() -> DLDataType {
412    DLDataType::new(DLDataTypeCode::Int, 8, 1)
413}
414
415/// Create a DLDataType for i16 (signed 16-bit integer).
416pub fn dtype_i16() -> DLDataType {
417    DLDataType::new(DLDataTypeCode::Int, 16, 1)
418}
419
420/// Create a DLDataType for i32 (signed 32-bit integer).
421pub fn dtype_i32() -> DLDataType {
422    DLDataType::new(DLDataTypeCode::Int, 32, 1)
423}
424
425/// Create a DLDataType for i64 (signed 64-bit integer).
426pub fn dtype_i64() -> DLDataType {
427    DLDataType::new(DLDataTypeCode::Int, 64, 1)
428}
429
430/// Create a DLDataType for u8 (unsigned 8-bit integer).
431pub fn dtype_u8() -> DLDataType {
432    DLDataType::new(DLDataTypeCode::UInt, 8, 1)
433}
434
435/// Create a DLDataType for u16 (unsigned 16-bit integer).
436pub fn dtype_u16() -> DLDataType {
437    DLDataType::new(DLDataTypeCode::UInt, 16, 1)
438}
439
440/// Create a DLDataType for u32 (unsigned 32-bit integer).
441pub fn dtype_u32() -> DLDataType {
442    DLDataType::new(DLDataTypeCode::UInt, 32, 1)
443}
444
445/// Create a DLDataType for u64 (unsigned 64-bit integer).
446pub fn dtype_u64() -> DLDataType {
447    DLDataType::new(DLDataTypeCode::UInt, 64, 1)
448}
449
450/// Create a DLDataType for bool.
451pub fn dtype_bool() -> DLDataType {
452    DLDataType::new(DLDataTypeCode::Bool, 8, 1)
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    // ========================================================================
460    // DLDeviceType tests
461    // ========================================================================
462
463    #[test]
464    fn test_device_type_from_raw_all_variants() {
465        assert_eq!(DLDeviceType::from_raw(1), Some(DLDeviceType::Cpu));
466        assert_eq!(DLDeviceType::from_raw(2), Some(DLDeviceType::Cuda));
467        assert_eq!(DLDeviceType::from_raw(3), Some(DLDeviceType::CudaHost));
468        assert_eq!(DLDeviceType::from_raw(4), Some(DLDeviceType::OpenCL));
469        assert_eq!(DLDeviceType::from_raw(7), Some(DLDeviceType::Vulkan));
470        assert_eq!(DLDeviceType::from_raw(8), Some(DLDeviceType::Metal));
471        assert_eq!(DLDeviceType::from_raw(9), Some(DLDeviceType::Vpi));
472        assert_eq!(DLDeviceType::from_raw(10), Some(DLDeviceType::Rocm));
473        assert_eq!(DLDeviceType::from_raw(11), Some(DLDeviceType::RocmHost));
474        assert_eq!(DLDeviceType::from_raw(12), Some(DLDeviceType::ExtDev));
475        assert_eq!(DLDeviceType::from_raw(13), Some(DLDeviceType::CudaManaged));
476        assert_eq!(DLDeviceType::from_raw(14), Some(DLDeviceType::OneApi));
477        assert_eq!(DLDeviceType::from_raw(15), Some(DLDeviceType::WebGpu));
478        assert_eq!(DLDeviceType::from_raw(16), Some(DLDeviceType::Hexagon));
479        assert_eq!(DLDeviceType::from_raw(17), Some(DLDeviceType::Maia));
480    }
481
482    #[test]
483    fn test_device_type_from_raw_unknown() {
484        assert_eq!(DLDeviceType::from_raw(0), None);
485        assert_eq!(DLDeviceType::from_raw(5), None);
486        assert_eq!(DLDeviceType::from_raw(6), None);
487        assert_eq!(DLDeviceType::from_raw(18), None);
488        assert_eq!(DLDeviceType::from_raw(100), None);
489        assert_eq!(DLDeviceType::from_raw(u32::MAX), None);
490    }
491
492    #[test]
493    fn test_device_type_debug() {
494        assert_eq!(format!("{:?}", DLDeviceType::Cpu), "Cpu");
495        assert_eq!(format!("{:?}", DLDeviceType::Cuda), "Cuda");
496    }
497
498    #[test]
499    fn test_device_type_clone_copy() {
500        let dt = DLDeviceType::Cuda;
501        let dt2 = dt;
502        let dt3 = dt;
503        assert_eq!(dt, dt2);
504        assert_eq!(dt, dt3);
505    }
506
507    #[test]
508    fn test_device_type_hash() {
509        use std::collections::HashSet;
510        let mut set = HashSet::new();
511        set.insert(DLDeviceType::Cpu);
512        set.insert(DLDeviceType::Cuda);
513        set.insert(DLDeviceType::Cpu);
514        assert_eq!(set.len(), 2);
515    }
516
517    // ========================================================================
518    // DLDevice tests
519    // ========================================================================
520
521    #[test]
522    fn test_device_new() {
523        let dev = DLDevice::new(DLDeviceType::Cuda, 3);
524        assert_eq!(dev.device_type, 2);
525        assert_eq!(dev.device_id, 3);
526    }
527
528    #[test]
529    fn test_device_type_enum() {
530        let dev = DLDevice::new(DLDeviceType::Rocm, 1);
531        assert_eq!(dev.device_type_enum(), Some(DLDeviceType::Rocm));
532
533        let unknown = DLDevice {
534            device_type: 99,
535            device_id: 0,
536        };
537        assert_eq!(unknown.device_type_enum(), None);
538    }
539
540    #[test]
541    fn test_device_is_cuda() {
542        assert!(cuda_device(0).is_cuda());
543        assert!(!cpu_device().is_cuda());
544        assert!(!DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda());
545    }
546
547    #[test]
548    fn test_device_is_cpu() {
549        assert!(cpu_device().is_cpu());
550        assert!(!cuda_device(0).is_cpu());
551    }
552
553    #[test]
554    fn test_device_is_cuda_host() {
555        assert!(DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda_host());
556        assert!(!cpu_device().is_cuda_host());
557        assert!(!cuda_device(0).is_cuda_host());
558    }
559
560    #[test]
561    fn test_device_is_rocm() {
562        assert!(DLDevice::new(DLDeviceType::Rocm, 0).is_rocm());
563        assert!(!cpu_device().is_rocm());
564        assert!(!cuda_device(0).is_rocm());
565    }
566
567    #[test]
568    fn test_device_is_metal() {
569        assert!(DLDevice::new(DLDeviceType::Metal, 0).is_metal());
570        assert!(metal_device(0).is_metal());
571        assert!(!cpu_device().is_metal());
572        assert!(!cuda_device(0).is_metal());
573    }
574
575    #[test]
576    fn test_device_debug() {
577        let dev = cuda_device(2);
578        let debug = format!("{:?}", dev);
579        assert!(debug.contains("device_type"));
580        assert!(debug.contains("device_id"));
581    }
582
583    #[test]
584    fn test_device_clone_copy() {
585        let dev = cuda_device(1);
586        let dev2 = dev;
587        let dev3 = dev;
588        assert_eq!(dev, dev2);
589        assert_eq!(dev, dev3);
590    }
591
592    #[test]
593    fn test_device_hash() {
594        use std::collections::HashSet;
595        let mut set = HashSet::new();
596        set.insert(cpu_device());
597        set.insert(cuda_device(0));
598        set.insert(cuda_device(1));
599        set.insert(cpu_device());
600        assert_eq!(set.len(), 3);
601    }
602
603    // ========================================================================
604    // DLDataTypeCode tests
605    // ========================================================================
606
607    #[test]
608    fn test_dtype_code_from_raw_all_variants() {
609        assert_eq!(DLDataTypeCode::from_raw(0), Some(DLDataTypeCode::Int));
610        assert_eq!(DLDataTypeCode::from_raw(1), Some(DLDataTypeCode::UInt));
611        assert_eq!(DLDataTypeCode::from_raw(2), Some(DLDataTypeCode::Float));
612        assert_eq!(
613            DLDataTypeCode::from_raw(3),
614            Some(DLDataTypeCode::OpaqueHandle)
615        );
616        assert_eq!(DLDataTypeCode::from_raw(4), Some(DLDataTypeCode::Bfloat));
617        assert_eq!(DLDataTypeCode::from_raw(5), Some(DLDataTypeCode::Complex));
618        assert_eq!(DLDataTypeCode::from_raw(6), Some(DLDataTypeCode::Bool));
619    }
620
621    #[test]
622    fn test_dtype_code_from_raw_unknown() {
623        assert_eq!(DLDataTypeCode::from_raw(7), None);
624        assert_eq!(DLDataTypeCode::from_raw(100), None);
625        assert_eq!(DLDataTypeCode::from_raw(u8::MAX), None);
626    }
627
628    #[test]
629    fn test_dtype_code_debug() {
630        assert_eq!(format!("{:?}", DLDataTypeCode::Float), "Float");
631        assert_eq!(format!("{:?}", DLDataTypeCode::Int), "Int");
632    }
633
634    #[test]
635    fn test_dtype_code_clone_copy() {
636        let code = DLDataTypeCode::Float;
637        let code2 = code;
638        let code3 = code;
639        assert_eq!(code, code2);
640        assert_eq!(code, code3);
641    }
642
643    #[test]
644    fn test_dtype_code_hash() {
645        use std::collections::HashSet;
646        let mut set = HashSet::new();
647        set.insert(DLDataTypeCode::Float);
648        set.insert(DLDataTypeCode::Int);
649        set.insert(DLDataTypeCode::Float);
650        assert_eq!(set.len(), 2);
651    }
652
653    // ========================================================================
654    // DLDataType tests
655    // ========================================================================
656
657    #[test]
658    fn test_dtype_new() {
659        let dt = DLDataType::new(DLDataTypeCode::Float, 32, 1);
660        assert_eq!(dt.code, 2);
661        assert_eq!(dt.bits, 32);
662        assert_eq!(dt.lanes, 1);
663    }
664
665    #[test]
666    fn test_dtype_code_enum() {
667        let dt = dtype_f32();
668        assert_eq!(dt.code_enum(), Some(DLDataTypeCode::Float));
669
670        let unknown = DLDataType {
671            code: 99,
672            bits: 32,
673            lanes: 1,
674        };
675        assert_eq!(unknown.code_enum(), None);
676    }
677
678    #[test]
679    fn test_dtype_is_f16() {
680        assert!(dtype_f16().is_f16());
681        assert!(!dtype_f32().is_f16());
682        assert!(!dtype_bf16().is_f16());
683        // Wrong lanes
684        let wrong = DLDataType::new(DLDataTypeCode::Float, 16, 2);
685        assert!(!wrong.is_f16());
686    }
687
688    #[test]
689    fn test_dtype_is_f32() {
690        assert!(dtype_f32().is_f32());
691        assert!(!dtype_f64().is_f32());
692        assert!(!dtype_f16().is_f32());
693    }
694
695    #[test]
696    fn test_dtype_is_f64() {
697        assert!(dtype_f64().is_f64());
698        assert!(!dtype_f32().is_f64());
699    }
700
701    #[test]
702    fn test_dtype_is_bf16() {
703        assert!(dtype_bf16().is_bf16());
704        assert!(!dtype_f16().is_bf16());
705        assert!(!dtype_f32().is_bf16());
706    }
707
708    #[test]
709    fn test_dtype_is_i8() {
710        assert!(dtype_i8().is_i8());
711        assert!(!dtype_i16().is_i8());
712        assert!(!dtype_u8().is_i8());
713    }
714
715    #[test]
716    fn test_dtype_is_i16() {
717        assert!(dtype_i16().is_i16());
718        assert!(!dtype_i8().is_i16());
719        assert!(!dtype_i32().is_i16());
720    }
721
722    #[test]
723    fn test_dtype_is_i32() {
724        assert!(dtype_i32().is_i32());
725        assert!(!dtype_i64().is_i32());
726        assert!(!dtype_u32().is_i32());
727    }
728
729    #[test]
730    fn test_dtype_is_i64() {
731        assert!(dtype_i64().is_i64());
732        assert!(!dtype_i32().is_i64());
733    }
734
735    #[test]
736    fn test_dtype_is_u8() {
737        assert!(dtype_u8().is_u8());
738        assert!(!dtype_i8().is_u8());
739        assert!(!dtype_u16().is_u8());
740    }
741
742    #[test]
743    fn test_dtype_is_u16() {
744        assert!(dtype_u16().is_u16());
745        assert!(!dtype_u8().is_u16());
746    }
747
748    #[test]
749    fn test_dtype_is_u32() {
750        assert!(dtype_u32().is_u32());
751        assert!(!dtype_i32().is_u32());
752    }
753
754    #[test]
755    fn test_dtype_is_u64() {
756        assert!(dtype_u64().is_u64());
757        assert!(!dtype_u32().is_u64());
758    }
759
760    #[test]
761    fn test_dtype_is_bool() {
762        assert!(dtype_bool().is_bool());
763        assert!(!dtype_u8().is_bool());
764        assert!(!dtype_i8().is_bool());
765    }
766
767    #[test]
768    fn test_dtype_itemsize() {
769        assert_eq!(dtype_f16().itemsize(), 2);
770        assert_eq!(dtype_f32().itemsize(), 4);
771        assert_eq!(dtype_f64().itemsize(), 8);
772        assert_eq!(dtype_bf16().itemsize(), 2);
773        assert_eq!(dtype_i8().itemsize(), 1);
774        assert_eq!(dtype_i16().itemsize(), 2);
775        assert_eq!(dtype_i32().itemsize(), 4);
776        assert_eq!(dtype_i64().itemsize(), 8);
777        assert_eq!(dtype_u8().itemsize(), 1);
778        assert_eq!(dtype_u16().itemsize(), 2);
779        assert_eq!(dtype_u32().itemsize(), 4);
780        assert_eq!(dtype_u64().itemsize(), 8);
781        assert_eq!(dtype_bool().itemsize(), 1);
782    }
783
784    #[test]
785    fn test_dtype_itemsize_vectorized() {
786        // Vectorized type with 4 lanes of f32
787        let vec_f32 = DLDataType::new(DLDataTypeCode::Float, 32, 4);
788        assert_eq!(vec_f32.itemsize(), 16); // 4 * 4 bytes
789
790        // 8 lanes of i16
791        let vec_i16 = DLDataType::new(DLDataTypeCode::Int, 16, 8);
792        assert_eq!(vec_i16.itemsize(), 16); // 8 * 2 bytes
793    }
794
795    #[test]
796    fn test_dtype_itemsize_rounding() {
797        // Test rounding up for non-byte-aligned types
798        let one_bit = DLDataType {
799            code: 0,
800            bits: 1,
801            lanes: 1,
802        };
803        assert_eq!(one_bit.itemsize(), 1);
804
805        let seven_bits = DLDataType {
806            code: 0,
807            bits: 7,
808            lanes: 1,
809        };
810        assert_eq!(seven_bits.itemsize(), 1);
811
812        let nine_bits = DLDataType {
813            code: 0,
814            bits: 9,
815            lanes: 1,
816        };
817        assert_eq!(nine_bits.itemsize(), 2);
818    }
819
820    #[test]
821    fn test_dtype_debug() {
822        let dt = dtype_f32();
823        let debug = format!("{:?}", dt);
824        assert!(debug.contains("code"));
825        assert!(debug.contains("bits"));
826        assert!(debug.contains("lanes"));
827    }
828
829    #[test]
830    fn test_dtype_clone_copy() {
831        let dt = dtype_f32();
832        let dt2 = dt;
833        let dt3 = dt;
834        assert_eq!(dt, dt2);
835        assert_eq!(dt, dt3);
836    }
837
838    #[test]
839    fn test_dtype_hash() {
840        use std::collections::HashSet;
841        let mut set = HashSet::new();
842        set.insert(dtype_f32());
843        set.insert(dtype_f64());
844        set.insert(dtype_f32());
845        assert_eq!(set.len(), 2);
846    }
847
848    // ========================================================================
849    // Convenience constructor tests
850    // ========================================================================
851
852    #[test]
853    fn test_cuda_device() {
854        let dev = cuda_device(0);
855        assert!(dev.is_cuda());
856        assert_eq!(dev.device_id, 0);
857
858        let dev1 = cuda_device(1);
859        assert!(dev1.is_cuda());
860        assert_eq!(dev1.device_id, 1);
861    }
862
863    #[test]
864    fn test_cpu_device() {
865        let dev = cpu_device();
866        assert!(dev.is_cpu());
867        assert_eq!(dev.device_id, 0);
868    }
869
870    #[test]
871    fn test_metal_device() {
872        let dev = metal_device(0);
873        assert!(dev.is_metal());
874        assert_eq!(dev.device_id, 0);
875
876        let dev1 = metal_device(1);
877        assert!(dev1.is_metal());
878        assert_eq!(dev1.device_id, 1);
879    }
880
881    #[test]
882    fn test_all_dtype_constructors() {
883        // Float types
884        assert!(dtype_f16().is_f16());
885        assert!(dtype_f32().is_f32());
886        assert!(dtype_f64().is_f64());
887        assert!(dtype_bf16().is_bf16());
888
889        // Signed integer types
890        assert!(dtype_i8().is_i8());
891        assert!(dtype_i16().is_i16());
892        assert!(dtype_i32().is_i32());
893        assert!(dtype_i64().is_i64());
894
895        // Unsigned integer types
896        assert!(dtype_u8().is_u8());
897        assert!(dtype_u16().is_u16());
898        assert!(dtype_u32().is_u32());
899        assert!(dtype_u64().is_u64());
900
901        // Boolean
902        assert!(dtype_bool().is_bool());
903    }
904
905    // ========================================================================
906    // DLTensor and DLManagedTensor struct layout tests
907    // ========================================================================
908
909    #[test]
910    fn test_dl_tensor_size() {
911        // DLTensor should be well-defined for FFI
912        // This test ensures the struct has a reasonable size
913        let size = std::mem::size_of::<DLTensor>();
914        assert!(size > 0);
915        // On 64-bit systems: data(8) + device(8) + ndim(4) + dtype(4) + shape(8) + strides(8) + byte_offset(8)
916        // = 48 bytes (with padding)
917    }
918
919    #[test]
920    fn test_dl_managed_tensor_size() {
921        let size = std::mem::size_of::<DLManagedTensor>();
922        assert!(size > 0);
923        // DLManagedTensor = DLTensor + manager_ctx(8) + deleter(8 or 16 for Option<fn>)
924    }
925
926    #[test]
927    fn test_dl_pack_version_layout() {
928        // DLPackVersion is two u32 fields, no padding.
929        assert_eq!(std::mem::size_of::<DLPackVersion>(), 8);
930        assert_eq!(std::mem::offset_of!(DLPackVersion, major), 0);
931        assert_eq!(std::mem::offset_of!(DLPackVersion, minor), 4);
932    }
933
934    #[test]
935    fn test_dl_managed_tensor_versioned_layout() {
936        // DLPack 1.0 field order: version, manager_ctx, deleter, flags, dl_tensor.
937        // Assert the full strict ordering (portable across pointer widths) so any
938        // accidental field reordering is caught, not just version-first / tensor-last.
939        use std::mem::offset_of;
940        assert_eq!(offset_of!(DLManagedTensorVersioned, version), 0);
941        assert!(
942            offset_of!(DLManagedTensorVersioned, version)
943                < offset_of!(DLManagedTensorVersioned, manager_ctx)
944        );
945        assert!(
946            offset_of!(DLManagedTensorVersioned, manager_ctx)
947                < offset_of!(DLManagedTensorVersioned, deleter)
948        );
949        assert!(
950            offset_of!(DLManagedTensorVersioned, deleter)
951                < offset_of!(DLManagedTensorVersioned, flags)
952        );
953        assert!(
954            offset_of!(DLManagedTensorVersioned, flags)
955                < offset_of!(DLManagedTensorVersioned, dl_tensor)
956        );
957        // The versioned struct embeds a full DLTensor plus header fields.
958        assert!(std::mem::size_of::<DLManagedTensorVersioned>() > std::mem::size_of::<DLTensor>());
959    }
960
961    #[test]
962    fn test_read_only_flag_value() {
963        assert_eq!(DLPACK_FLAG_BITMASK_READ_ONLY, 1);
964        assert_eq!(DLPACK_FLAG_BITMASK_IS_COPIED, 2);
965        assert_eq!(DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED, 4);
966        assert_eq!(DLPACK_MAJOR_VERSION, 1);
967    }
968
969    #[test]
970    fn test_dl_device_repr_c() {
971        // Verify the struct has expected alignment for FFI
972        assert_eq!(std::mem::align_of::<DLDevice>(), 4);
973        assert_eq!(std::mem::size_of::<DLDevice>(), 8);
974    }
975
976    #[test]
977    fn test_dl_data_type_repr_c() {
978        // Verify the struct has expected size for FFI
979        assert_eq!(std::mem::size_of::<DLDataType>(), 4);
980    }
981}