Skip to main content

pyo3_dlpack/
ffi.rs

1//! DLPack FFI types following the DLPack specification.
2//!
3//! These types match the C definitions from the DLPack header:
4//! <https://github.com/dmlc/dlpack/blob/main/include/dlpack/dlpack.h>
5
6use std::ffi::c_void;
7
8/// Device type codes as defined by DLPack.
9///
10/// These correspond to `DLDeviceType` in the DLPack specification.
11#[repr(u32)]
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
13pub enum DLDeviceType {
14    /// CPU device
15    Cpu = 1,
16    /// CUDA GPU device
17    Cuda = 2,
18    /// Pinned CUDA CPU memory (allocated with cudaMallocHost)
19    CudaHost = 3,
20    /// OpenCL device
21    OpenCL = 4,
22    /// Vulkan device
23    Vulkan = 7,
24    /// Metal device (Apple)
25    Metal = 8,
26    /// VPI device
27    Vpi = 9,
28    /// ROCm device (AMD)
29    Rocm = 10,
30    /// ROCm host pinned memory
31    RocmHost = 11,
32    /// External DMA buffer
33    ExtDev = 12,
34    /// CUDA managed/unified memory
35    CudaManaged = 13,
36    /// Intel OneAPI device
37    OneApi = 14,
38    /// WebGPU device
39    WebGpu = 15,
40    /// Hexagon DSP device
41    Hexagon = 16,
42    /// MAIA accelerator
43    Maia = 17,
44}
45
46impl DLDeviceType {
47    /// Convert from raw u32 value.
48    ///
49    /// Returns `None` for unknown device types.
50    pub fn from_raw(value: u32) -> Option<Self> {
51        match value {
52            1 => Some(Self::Cpu),
53            2 => Some(Self::Cuda),
54            3 => Some(Self::CudaHost),
55            4 => Some(Self::OpenCL),
56            7 => Some(Self::Vulkan),
57            8 => Some(Self::Metal),
58            9 => Some(Self::Vpi),
59            10 => Some(Self::Rocm),
60            11 => Some(Self::RocmHost),
61            12 => Some(Self::ExtDev),
62            13 => Some(Self::CudaManaged),
63            14 => Some(Self::OneApi),
64            15 => Some(Self::WebGpu),
65            16 => Some(Self::Hexagon),
66            17 => Some(Self::Maia),
67            _ => None,
68        }
69    }
70}
71
72/// A device descriptor specifying where tensor data resides.
73///
74/// This corresponds to `DLDevice` in the DLPack specification.
75#[repr(C)]
76#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
77pub struct DLDevice {
78    /// The device type (CPU, CUDA, etc.)
79    pub device_type: u32,
80    /// The device ID (e.g., which GPU for multi-GPU systems)
81    pub device_id: i32,
82}
83
84impl DLDevice {
85    /// Create a new device descriptor.
86    pub fn new(device_type: DLDeviceType, device_id: i32) -> Self {
87        Self {
88            device_type: device_type as u32,
89            device_id,
90        }
91    }
92
93    /// Get the device type as an enum.
94    ///
95    /// Returns `None` for unknown device types.
96    pub fn device_type_enum(&self) -> Option<DLDeviceType> {
97        DLDeviceType::from_raw(self.device_type)
98    }
99
100    /// Check if this is a CUDA device.
101    pub fn is_cuda(&self) -> bool {
102        self.device_type == DLDeviceType::Cuda as u32
103    }
104
105    /// Check if this is a CPU device.
106    pub fn is_cpu(&self) -> bool {
107        self.device_type == DLDeviceType::Cpu as u32
108    }
109
110    /// Check if this is CUDA host (pinned) memory.
111    pub fn is_cuda_host(&self) -> bool {
112        self.device_type == DLDeviceType::CudaHost as u32
113    }
114
115    /// Check if this is a ROCm device.
116    pub fn is_rocm(&self) -> bool {
117        self.device_type == DLDeviceType::Rocm as u32
118    }
119
120    /// Check if this is a Metal device (Apple GPU).
121    pub fn is_metal(&self) -> bool {
122        self.device_type == DLDeviceType::Metal as u32
123    }
124}
125
126/// Data type codes as defined by DLPack.
127///
128/// These correspond to `DLDataTypeCode` in the DLPack specification.
129#[repr(u8)]
130#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
131pub enum DLDataTypeCode {
132    /// Signed integer
133    Int = 0,
134    /// Unsigned integer
135    UInt = 1,
136    /// IEEE floating point
137    Float = 2,
138    /// Opaque handle type (not for computation)
139    OpaqueHandle = 3,
140    /// Bfloat16 (Brain Floating Point)
141    Bfloat = 4,
142    /// Complex numbers
143    Complex = 5,
144    /// Boolean
145    Bool = 6,
146}
147
148impl DLDataTypeCode {
149    /// Convert from raw u8 value.
150    ///
151    /// Returns `None` for unknown type codes.
152    pub fn from_raw(value: u8) -> Option<Self> {
153        match value {
154            0 => Some(Self::Int),
155            1 => Some(Self::UInt),
156            2 => Some(Self::Float),
157            3 => Some(Self::OpaqueHandle),
158            4 => Some(Self::Bfloat),
159            5 => Some(Self::Complex),
160            6 => Some(Self::Bool),
161            _ => None,
162        }
163    }
164}
165
166/// Data type descriptor specifying the element type of a tensor.
167///
168/// This corresponds to `DLDataType` in the DLPack specification.
169#[repr(C)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
171pub struct DLDataType {
172    /// Type code (signed int, unsigned int, float, etc.)
173    pub code: u8,
174    /// Number of bits per element (e.g., 32 for float32)
175    pub bits: u8,
176    /// Number of lanes for vectorized types (usually 1)
177    pub lanes: u16,
178}
179
180impl DLDataType {
181    /// Create a new data type descriptor.
182    pub fn new(code: DLDataTypeCode, bits: u8, lanes: u16) -> Self {
183        Self {
184            code: code as u8,
185            bits,
186            lanes,
187        }
188    }
189
190    /// Get the type code as an enum.
191    ///
192    /// Returns `None` for unknown type codes.
193    pub fn code_enum(&self) -> Option<DLDataTypeCode> {
194        DLDataTypeCode::from_raw(self.code)
195    }
196
197    /// Check if this is f16 (half precision float).
198    pub fn is_f16(&self) -> bool {
199        self.code == DLDataTypeCode::Float as u8 && self.bits == 16 && self.lanes == 1
200    }
201
202    /// Check if this is f32 (single precision float).
203    pub fn is_f32(&self) -> bool {
204        self.code == DLDataTypeCode::Float as u8 && self.bits == 32 && self.lanes == 1
205    }
206
207    /// Check if this is f64 (double precision float).
208    pub fn is_f64(&self) -> bool {
209        self.code == DLDataTypeCode::Float as u8 && self.bits == 64 && self.lanes == 1
210    }
211
212    /// Check if this is bf16 (bfloat16).
213    pub fn is_bf16(&self) -> bool {
214        self.code == DLDataTypeCode::Bfloat as u8 && self.bits == 16 && self.lanes == 1
215    }
216
217    /// Check if this is i8 (signed 8-bit integer).
218    pub fn is_i8(&self) -> bool {
219        self.code == DLDataTypeCode::Int as u8 && self.bits == 8 && self.lanes == 1
220    }
221
222    /// Check if this is i16 (signed 16-bit integer).
223    pub fn is_i16(&self) -> bool {
224        self.code == DLDataTypeCode::Int as u8 && self.bits == 16 && self.lanes == 1
225    }
226
227    /// Check if this is i32 (signed 32-bit integer).
228    pub fn is_i32(&self) -> bool {
229        self.code == DLDataTypeCode::Int as u8 && self.bits == 32 && self.lanes == 1
230    }
231
232    /// Check if this is i64 (signed 64-bit integer).
233    pub fn is_i64(&self) -> bool {
234        self.code == DLDataTypeCode::Int as u8 && self.bits == 64 && self.lanes == 1
235    }
236
237    /// Check if this is u8 (unsigned 8-bit integer).
238    pub fn is_u8(&self) -> bool {
239        self.code == DLDataTypeCode::UInt as u8 && self.bits == 8 && self.lanes == 1
240    }
241
242    /// Check if this is u16 (unsigned 16-bit integer).
243    pub fn is_u16(&self) -> bool {
244        self.code == DLDataTypeCode::UInt as u8 && self.bits == 16 && self.lanes == 1
245    }
246
247    /// Check if this is u32 (unsigned 32-bit integer).
248    pub fn is_u32(&self) -> bool {
249        self.code == DLDataTypeCode::UInt as u8 && self.bits == 32 && self.lanes == 1
250    }
251
252    /// Check if this is u64 (unsigned 64-bit integer).
253    pub fn is_u64(&self) -> bool {
254        self.code == DLDataTypeCode::UInt as u8 && self.bits == 64 && self.lanes == 1
255    }
256
257    /// Check if this is bool.
258    pub fn is_bool(&self) -> bool {
259        self.code == DLDataTypeCode::Bool as u8 && self.bits == 8 && self.lanes == 1
260    }
261
262    /// Get the size of one element in bytes.
263    pub fn itemsize(&self) -> usize {
264        ((self.bits as usize) * (self.lanes as usize) + 7) / 8
265    }
266}
267
268/// The core DLTensor structure describing a tensor's data and layout.
269///
270/// This corresponds to `DLTensor` in the DLPack specification.
271#[repr(C)]
272pub struct DLTensor {
273    /// Pointer to the data buffer.
274    /// For GPU tensors, this is a device pointer.
275    pub data: *mut c_void,
276    /// Device descriptor specifying where the data resides.
277    pub device: DLDevice,
278    /// Number of dimensions.
279    pub ndim: i32,
280    /// Data type descriptor.
281    pub dtype: DLDataType,
282    /// Shape array (length = ndim).
283    /// Points to an array of dimension sizes.
284    pub shape: *mut i64,
285    /// Stride array in number of elements (length = ndim).
286    /// Can be null for compact row-major tensors.
287    pub strides: *mut i64,
288    /// Byte offset from the data pointer to the first element.
289    pub byte_offset: u64,
290}
291
292/// Deleter function signature for DLManagedTensor.
293///
294/// Called when the consumer is done with the tensor to free resources.
295pub type DLManagedTensorDeleter = unsafe extern "C" fn(*mut DLManagedTensor);
296
297/// A managed tensor with ownership semantics.
298///
299/// This corresponds to `DLManagedTensor` in the DLPack specification.
300/// It wraps a `DLTensor` and provides a deleter for cleanup.
301#[repr(C)]
302pub struct DLManagedTensor {
303    /// The underlying tensor descriptor.
304    pub dl_tensor: DLTensor,
305    /// Opaque manager context for the producer's use.
306    /// Typically used to store data needed by the deleter.
307    pub manager_ctx: *mut c_void,
308    /// Deleter function called when the consumer is done.
309    /// Can be null if no cleanup is needed.
310    pub deleter: Option<DLManagedTensorDeleter>,
311}
312
313// ============================================================================
314// Convenience constructors
315// ============================================================================
316
317/// Create a DLDevice for CUDA with the specified device ID.
318pub fn cuda_device(device_id: i32) -> DLDevice {
319    DLDevice::new(DLDeviceType::Cuda, device_id)
320}
321
322/// Create a DLDevice for CPU.
323pub fn cpu_device() -> DLDevice {
324    DLDevice::new(DLDeviceType::Cpu, 0)
325}
326
327/// Create a DLDevice for Metal (Apple GPU) with the specified device ID.
328pub fn metal_device(device_id: i32) -> DLDevice {
329    DLDevice::new(DLDeviceType::Metal, device_id)
330}
331
332/// Create a DLDataType for f32 (single precision float).
333pub fn dtype_f32() -> DLDataType {
334    DLDataType::new(DLDataTypeCode::Float, 32, 1)
335}
336
337/// Create a DLDataType for f64 (double precision float).
338pub fn dtype_f64() -> DLDataType {
339    DLDataType::new(DLDataTypeCode::Float, 64, 1)
340}
341
342/// Create a DLDataType for f16 (half precision float).
343pub fn dtype_f16() -> DLDataType {
344    DLDataType::new(DLDataTypeCode::Float, 16, 1)
345}
346
347/// Create a DLDataType for bf16 (bfloat16).
348pub fn dtype_bf16() -> DLDataType {
349    DLDataType::new(DLDataTypeCode::Bfloat, 16, 1)
350}
351
352/// Create a DLDataType for i8 (signed 8-bit integer).
353pub fn dtype_i8() -> DLDataType {
354    DLDataType::new(DLDataTypeCode::Int, 8, 1)
355}
356
357/// Create a DLDataType for i16 (signed 16-bit integer).
358pub fn dtype_i16() -> DLDataType {
359    DLDataType::new(DLDataTypeCode::Int, 16, 1)
360}
361
362/// Create a DLDataType for i32 (signed 32-bit integer).
363pub fn dtype_i32() -> DLDataType {
364    DLDataType::new(DLDataTypeCode::Int, 32, 1)
365}
366
367/// Create a DLDataType for i64 (signed 64-bit integer).
368pub fn dtype_i64() -> DLDataType {
369    DLDataType::new(DLDataTypeCode::Int, 64, 1)
370}
371
372/// Create a DLDataType for u8 (unsigned 8-bit integer).
373pub fn dtype_u8() -> DLDataType {
374    DLDataType::new(DLDataTypeCode::UInt, 8, 1)
375}
376
377/// Create a DLDataType for u16 (unsigned 16-bit integer).
378pub fn dtype_u16() -> DLDataType {
379    DLDataType::new(DLDataTypeCode::UInt, 16, 1)
380}
381
382/// Create a DLDataType for u32 (unsigned 32-bit integer).
383pub fn dtype_u32() -> DLDataType {
384    DLDataType::new(DLDataTypeCode::UInt, 32, 1)
385}
386
387/// Create a DLDataType for u64 (unsigned 64-bit integer).
388pub fn dtype_u64() -> DLDataType {
389    DLDataType::new(DLDataTypeCode::UInt, 64, 1)
390}
391
392/// Create a DLDataType for bool.
393pub fn dtype_bool() -> DLDataType {
394    DLDataType::new(DLDataTypeCode::Bool, 8, 1)
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    // ========================================================================
402    // DLDeviceType tests
403    // ========================================================================
404
405    #[test]
406    fn test_device_type_from_raw_all_variants() {
407        assert_eq!(DLDeviceType::from_raw(1), Some(DLDeviceType::Cpu));
408        assert_eq!(DLDeviceType::from_raw(2), Some(DLDeviceType::Cuda));
409        assert_eq!(DLDeviceType::from_raw(3), Some(DLDeviceType::CudaHost));
410        assert_eq!(DLDeviceType::from_raw(4), Some(DLDeviceType::OpenCL));
411        assert_eq!(DLDeviceType::from_raw(7), Some(DLDeviceType::Vulkan));
412        assert_eq!(DLDeviceType::from_raw(8), Some(DLDeviceType::Metal));
413        assert_eq!(DLDeviceType::from_raw(9), Some(DLDeviceType::Vpi));
414        assert_eq!(DLDeviceType::from_raw(10), Some(DLDeviceType::Rocm));
415        assert_eq!(DLDeviceType::from_raw(11), Some(DLDeviceType::RocmHost));
416        assert_eq!(DLDeviceType::from_raw(12), Some(DLDeviceType::ExtDev));
417        assert_eq!(DLDeviceType::from_raw(13), Some(DLDeviceType::CudaManaged));
418        assert_eq!(DLDeviceType::from_raw(14), Some(DLDeviceType::OneApi));
419        assert_eq!(DLDeviceType::from_raw(15), Some(DLDeviceType::WebGpu));
420        assert_eq!(DLDeviceType::from_raw(16), Some(DLDeviceType::Hexagon));
421        assert_eq!(DLDeviceType::from_raw(17), Some(DLDeviceType::Maia));
422    }
423
424    #[test]
425    fn test_device_type_from_raw_unknown() {
426        assert_eq!(DLDeviceType::from_raw(0), None);
427        assert_eq!(DLDeviceType::from_raw(5), None);
428        assert_eq!(DLDeviceType::from_raw(6), None);
429        assert_eq!(DLDeviceType::from_raw(18), None);
430        assert_eq!(DLDeviceType::from_raw(100), None);
431        assert_eq!(DLDeviceType::from_raw(u32::MAX), None);
432    }
433
434    #[test]
435    fn test_device_type_debug() {
436        assert_eq!(format!("{:?}", DLDeviceType::Cpu), "Cpu");
437        assert_eq!(format!("{:?}", DLDeviceType::Cuda), "Cuda");
438    }
439
440    #[test]
441    fn test_device_type_clone_copy() {
442        let dt = DLDeviceType::Cuda;
443        let dt2 = dt;
444        let dt3 = dt;
445        assert_eq!(dt, dt2);
446        assert_eq!(dt, dt3);
447    }
448
449    #[test]
450    fn test_device_type_hash() {
451        use std::collections::HashSet;
452        let mut set = HashSet::new();
453        set.insert(DLDeviceType::Cpu);
454        set.insert(DLDeviceType::Cuda);
455        set.insert(DLDeviceType::Cpu);
456        assert_eq!(set.len(), 2);
457    }
458
459    // ========================================================================
460    // DLDevice tests
461    // ========================================================================
462
463    #[test]
464    fn test_device_new() {
465        let dev = DLDevice::new(DLDeviceType::Cuda, 3);
466        assert_eq!(dev.device_type, 2);
467        assert_eq!(dev.device_id, 3);
468    }
469
470    #[test]
471    fn test_device_type_enum() {
472        let dev = DLDevice::new(DLDeviceType::Rocm, 1);
473        assert_eq!(dev.device_type_enum(), Some(DLDeviceType::Rocm));
474
475        let unknown = DLDevice {
476            device_type: 99,
477            device_id: 0,
478        };
479        assert_eq!(unknown.device_type_enum(), None);
480    }
481
482    #[test]
483    fn test_device_is_cuda() {
484        assert!(cuda_device(0).is_cuda());
485        assert!(!cpu_device().is_cuda());
486        assert!(!DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda());
487    }
488
489    #[test]
490    fn test_device_is_cpu() {
491        assert!(cpu_device().is_cpu());
492        assert!(!cuda_device(0).is_cpu());
493    }
494
495    #[test]
496    fn test_device_is_cuda_host() {
497        assert!(DLDevice::new(DLDeviceType::CudaHost, 0).is_cuda_host());
498        assert!(!cpu_device().is_cuda_host());
499        assert!(!cuda_device(0).is_cuda_host());
500    }
501
502    #[test]
503    fn test_device_is_rocm() {
504        assert!(DLDevice::new(DLDeviceType::Rocm, 0).is_rocm());
505        assert!(!cpu_device().is_rocm());
506        assert!(!cuda_device(0).is_rocm());
507    }
508
509    #[test]
510    fn test_device_is_metal() {
511        assert!(DLDevice::new(DLDeviceType::Metal, 0).is_metal());
512        assert!(metal_device(0).is_metal());
513        assert!(!cpu_device().is_metal());
514        assert!(!cuda_device(0).is_metal());
515    }
516
517    #[test]
518    fn test_device_debug() {
519        let dev = cuda_device(2);
520        let debug = format!("{:?}", dev);
521        assert!(debug.contains("device_type"));
522        assert!(debug.contains("device_id"));
523    }
524
525    #[test]
526    fn test_device_clone_copy() {
527        let dev = cuda_device(1);
528        let dev2 = dev;
529        let dev3 = dev;
530        assert_eq!(dev, dev2);
531        assert_eq!(dev, dev3);
532    }
533
534    #[test]
535    fn test_device_hash() {
536        use std::collections::HashSet;
537        let mut set = HashSet::new();
538        set.insert(cpu_device());
539        set.insert(cuda_device(0));
540        set.insert(cuda_device(1));
541        set.insert(cpu_device());
542        assert_eq!(set.len(), 3);
543    }
544
545    // ========================================================================
546    // DLDataTypeCode tests
547    // ========================================================================
548
549    #[test]
550    fn test_dtype_code_from_raw_all_variants() {
551        assert_eq!(DLDataTypeCode::from_raw(0), Some(DLDataTypeCode::Int));
552        assert_eq!(DLDataTypeCode::from_raw(1), Some(DLDataTypeCode::UInt));
553        assert_eq!(DLDataTypeCode::from_raw(2), Some(DLDataTypeCode::Float));
554        assert_eq!(
555            DLDataTypeCode::from_raw(3),
556            Some(DLDataTypeCode::OpaqueHandle)
557        );
558        assert_eq!(DLDataTypeCode::from_raw(4), Some(DLDataTypeCode::Bfloat));
559        assert_eq!(DLDataTypeCode::from_raw(5), Some(DLDataTypeCode::Complex));
560        assert_eq!(DLDataTypeCode::from_raw(6), Some(DLDataTypeCode::Bool));
561    }
562
563    #[test]
564    fn test_dtype_code_from_raw_unknown() {
565        assert_eq!(DLDataTypeCode::from_raw(7), None);
566        assert_eq!(DLDataTypeCode::from_raw(100), None);
567        assert_eq!(DLDataTypeCode::from_raw(u8::MAX), None);
568    }
569
570    #[test]
571    fn test_dtype_code_debug() {
572        assert_eq!(format!("{:?}", DLDataTypeCode::Float), "Float");
573        assert_eq!(format!("{:?}", DLDataTypeCode::Int), "Int");
574    }
575
576    #[test]
577    fn test_dtype_code_clone_copy() {
578        let code = DLDataTypeCode::Float;
579        let code2 = code;
580        let code3 = code;
581        assert_eq!(code, code2);
582        assert_eq!(code, code3);
583    }
584
585    #[test]
586    fn test_dtype_code_hash() {
587        use std::collections::HashSet;
588        let mut set = HashSet::new();
589        set.insert(DLDataTypeCode::Float);
590        set.insert(DLDataTypeCode::Int);
591        set.insert(DLDataTypeCode::Float);
592        assert_eq!(set.len(), 2);
593    }
594
595    // ========================================================================
596    // DLDataType tests
597    // ========================================================================
598
599    #[test]
600    fn test_dtype_new() {
601        let dt = DLDataType::new(DLDataTypeCode::Float, 32, 1);
602        assert_eq!(dt.code, 2);
603        assert_eq!(dt.bits, 32);
604        assert_eq!(dt.lanes, 1);
605    }
606
607    #[test]
608    fn test_dtype_code_enum() {
609        let dt = dtype_f32();
610        assert_eq!(dt.code_enum(), Some(DLDataTypeCode::Float));
611
612        let unknown = DLDataType {
613            code: 99,
614            bits: 32,
615            lanes: 1,
616        };
617        assert_eq!(unknown.code_enum(), None);
618    }
619
620    #[test]
621    fn test_dtype_is_f16() {
622        assert!(dtype_f16().is_f16());
623        assert!(!dtype_f32().is_f16());
624        assert!(!dtype_bf16().is_f16());
625        // Wrong lanes
626        let wrong = DLDataType::new(DLDataTypeCode::Float, 16, 2);
627        assert!(!wrong.is_f16());
628    }
629
630    #[test]
631    fn test_dtype_is_f32() {
632        assert!(dtype_f32().is_f32());
633        assert!(!dtype_f64().is_f32());
634        assert!(!dtype_f16().is_f32());
635    }
636
637    #[test]
638    fn test_dtype_is_f64() {
639        assert!(dtype_f64().is_f64());
640        assert!(!dtype_f32().is_f64());
641    }
642
643    #[test]
644    fn test_dtype_is_bf16() {
645        assert!(dtype_bf16().is_bf16());
646        assert!(!dtype_f16().is_bf16());
647        assert!(!dtype_f32().is_bf16());
648    }
649
650    #[test]
651    fn test_dtype_is_i8() {
652        assert!(dtype_i8().is_i8());
653        assert!(!dtype_i16().is_i8());
654        assert!(!dtype_u8().is_i8());
655    }
656
657    #[test]
658    fn test_dtype_is_i16() {
659        assert!(dtype_i16().is_i16());
660        assert!(!dtype_i8().is_i16());
661        assert!(!dtype_i32().is_i16());
662    }
663
664    #[test]
665    fn test_dtype_is_i32() {
666        assert!(dtype_i32().is_i32());
667        assert!(!dtype_i64().is_i32());
668        assert!(!dtype_u32().is_i32());
669    }
670
671    #[test]
672    fn test_dtype_is_i64() {
673        assert!(dtype_i64().is_i64());
674        assert!(!dtype_i32().is_i64());
675    }
676
677    #[test]
678    fn test_dtype_is_u8() {
679        assert!(dtype_u8().is_u8());
680        assert!(!dtype_i8().is_u8());
681        assert!(!dtype_u16().is_u8());
682    }
683
684    #[test]
685    fn test_dtype_is_u16() {
686        assert!(dtype_u16().is_u16());
687        assert!(!dtype_u8().is_u16());
688    }
689
690    #[test]
691    fn test_dtype_is_u32() {
692        assert!(dtype_u32().is_u32());
693        assert!(!dtype_i32().is_u32());
694    }
695
696    #[test]
697    fn test_dtype_is_u64() {
698        assert!(dtype_u64().is_u64());
699        assert!(!dtype_u32().is_u64());
700    }
701
702    #[test]
703    fn test_dtype_is_bool() {
704        assert!(dtype_bool().is_bool());
705        assert!(!dtype_u8().is_bool());
706        assert!(!dtype_i8().is_bool());
707    }
708
709    #[test]
710    fn test_dtype_itemsize() {
711        assert_eq!(dtype_f16().itemsize(), 2);
712        assert_eq!(dtype_f32().itemsize(), 4);
713        assert_eq!(dtype_f64().itemsize(), 8);
714        assert_eq!(dtype_bf16().itemsize(), 2);
715        assert_eq!(dtype_i8().itemsize(), 1);
716        assert_eq!(dtype_i16().itemsize(), 2);
717        assert_eq!(dtype_i32().itemsize(), 4);
718        assert_eq!(dtype_i64().itemsize(), 8);
719        assert_eq!(dtype_u8().itemsize(), 1);
720        assert_eq!(dtype_u16().itemsize(), 2);
721        assert_eq!(dtype_u32().itemsize(), 4);
722        assert_eq!(dtype_u64().itemsize(), 8);
723        assert_eq!(dtype_bool().itemsize(), 1);
724    }
725
726    #[test]
727    fn test_dtype_itemsize_vectorized() {
728        // Vectorized type with 4 lanes of f32
729        let vec_f32 = DLDataType::new(DLDataTypeCode::Float, 32, 4);
730        assert_eq!(vec_f32.itemsize(), 16); // 4 * 4 bytes
731
732        // 8 lanes of i16
733        let vec_i16 = DLDataType::new(DLDataTypeCode::Int, 16, 8);
734        assert_eq!(vec_i16.itemsize(), 16); // 8 * 2 bytes
735    }
736
737    #[test]
738    fn test_dtype_itemsize_rounding() {
739        // Test rounding up for non-byte-aligned types
740        let one_bit = DLDataType {
741            code: 0,
742            bits: 1,
743            lanes: 1,
744        };
745        assert_eq!(one_bit.itemsize(), 1);
746
747        let seven_bits = DLDataType {
748            code: 0,
749            bits: 7,
750            lanes: 1,
751        };
752        assert_eq!(seven_bits.itemsize(), 1);
753
754        let nine_bits = DLDataType {
755            code: 0,
756            bits: 9,
757            lanes: 1,
758        };
759        assert_eq!(nine_bits.itemsize(), 2);
760    }
761
762    #[test]
763    fn test_dtype_debug() {
764        let dt = dtype_f32();
765        let debug = format!("{:?}", dt);
766        assert!(debug.contains("code"));
767        assert!(debug.contains("bits"));
768        assert!(debug.contains("lanes"));
769    }
770
771    #[test]
772    fn test_dtype_clone_copy() {
773        let dt = dtype_f32();
774        let dt2 = dt;
775        let dt3 = dt;
776        assert_eq!(dt, dt2);
777        assert_eq!(dt, dt3);
778    }
779
780    #[test]
781    fn test_dtype_hash() {
782        use std::collections::HashSet;
783        let mut set = HashSet::new();
784        set.insert(dtype_f32());
785        set.insert(dtype_f64());
786        set.insert(dtype_f32());
787        assert_eq!(set.len(), 2);
788    }
789
790    // ========================================================================
791    // Convenience constructor tests
792    // ========================================================================
793
794    #[test]
795    fn test_cuda_device() {
796        let dev = cuda_device(0);
797        assert!(dev.is_cuda());
798        assert_eq!(dev.device_id, 0);
799
800        let dev1 = cuda_device(1);
801        assert!(dev1.is_cuda());
802        assert_eq!(dev1.device_id, 1);
803    }
804
805    #[test]
806    fn test_cpu_device() {
807        let dev = cpu_device();
808        assert!(dev.is_cpu());
809        assert_eq!(dev.device_id, 0);
810    }
811
812    #[test]
813    fn test_metal_device() {
814        let dev = metal_device(0);
815        assert!(dev.is_metal());
816        assert_eq!(dev.device_id, 0);
817
818        let dev1 = metal_device(1);
819        assert!(dev1.is_metal());
820        assert_eq!(dev1.device_id, 1);
821    }
822
823    #[test]
824    fn test_all_dtype_constructors() {
825        // Float types
826        assert!(dtype_f16().is_f16());
827        assert!(dtype_f32().is_f32());
828        assert!(dtype_f64().is_f64());
829        assert!(dtype_bf16().is_bf16());
830
831        // Signed integer types
832        assert!(dtype_i8().is_i8());
833        assert!(dtype_i16().is_i16());
834        assert!(dtype_i32().is_i32());
835        assert!(dtype_i64().is_i64());
836
837        // Unsigned integer types
838        assert!(dtype_u8().is_u8());
839        assert!(dtype_u16().is_u16());
840        assert!(dtype_u32().is_u32());
841        assert!(dtype_u64().is_u64());
842
843        // Boolean
844        assert!(dtype_bool().is_bool());
845    }
846
847    // ========================================================================
848    // DLTensor and DLManagedTensor struct layout tests
849    // ========================================================================
850
851    #[test]
852    fn test_dl_tensor_size() {
853        // DLTensor should be well-defined for FFI
854        // This test ensures the struct has a reasonable size
855        let size = std::mem::size_of::<DLTensor>();
856        assert!(size > 0);
857        // On 64-bit systems: data(8) + device(8) + ndim(4) + dtype(4) + shape(8) + strides(8) + byte_offset(8)
858        // = 48 bytes (with padding)
859    }
860
861    #[test]
862    fn test_dl_managed_tensor_size() {
863        let size = std::mem::size_of::<DLManagedTensor>();
864        assert!(size > 0);
865        // DLManagedTensor = DLTensor + manager_ctx(8) + deleter(8 or 16 for Option<fn>)
866    }
867
868    #[test]
869    fn test_dl_device_repr_c() {
870        // Verify the struct has expected alignment for FFI
871        assert_eq!(std::mem::align_of::<DLDevice>(), 4);
872        assert_eq!(std::mem::size_of::<DLDevice>(), 8);
873    }
874
875    #[test]
876    fn test_dl_data_type_repr_c() {
877        // Verify the struct has expected size for FFI
878        assert_eq!(std::mem::size_of::<DLDataType>(), 4);
879    }
880}