scirs2_numpy/dlpack.rs
1//! DLPack protocol for zero-copy tensor exchange.
2//!
3//! Implements `__dlpack__` and `__dlpack_device__` for arrays managed by this crate.
4//! DLPack is a standard open-source ABI used by PyTorch, JAX, TensorFlow, and other
5//! frameworks to exchange tensors without copying.
6//!
7//! Reference: <https://dmlc.github.io/dlpack/latest/>
8
9use pyo3::prelude::*;
10use pyo3::types::PyCapsule;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ptr::NonNull;
14
15/// Device type codes used by DLPack.
16///
17/// These integer codes identify which physical device (CPU, CUDA, Metal, etc.)
18/// holds the tensor data.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(i32)]
21pub enum DLDeviceType {
22 /// Host CPU (device type 1).
23 Cpu = 1,
24 /// CUDA GPU (device type 2).
25 Cuda = 2,
26 /// CUDA pinned host memory (device type 3).
27 CudaHost = 3,
28 /// OpenCL device (device type 4).
29 OpenCL = 4,
30 /// Vulkan device (device type 7).
31 Vulkan = 7,
32 /// Apple Metal device (device type 8).
33 Metal = 8,
34 /// AMD ROCm/HIP GPU (device type 10).
35 Rocm = 10,
36}
37
38/// DLPack data-type descriptor (ABI-compatible with the DLPack spec).
39///
40/// Encodes element type code, bit-width, and SIMD lane count.
41#[derive(Debug, Clone, Copy)]
42#[repr(C)]
43pub struct DLDataType {
44 /// Type code: 0 = int, 1 = uint, 2 = float, 3 = bfloat.
45 pub code: u8,
46 /// Number of bits per element (e.g., 32 for f32).
47 pub bits: u8,
48 /// SIMD lane count; 1 for scalar elements.
49 pub lanes: u16,
50}
51
52/// DLPack device descriptor.
53///
54/// Identifies the device and its zero-based index.
55#[derive(Debug, Clone, Copy)]
56#[repr(C)]
57pub struct DLDevice {
58 /// Device type code (see [`DLDeviceType`]).
59 pub device_type: i32,
60 /// Zero-based device index (e.g., 0 for the first GPU).
61 pub device_id: i32,
62}
63
64/// The core DLPack tensor structure (ABI-compatible).
65///
66/// Describes a multi-dimensional array buffer.
67#[derive(Debug)]
68#[repr(C)]
69pub struct DLTensor {
70 /// Opaque pointer to the first element of the tensor.
71 pub data: *mut c_void,
72 /// Device on which this tensor resides.
73 pub device: DLDevice,
74 /// Number of dimensions.
75 pub ndim: i32,
76 /// Element data type.
77 pub dtype: DLDataType,
78 /// Pointer to an array of `ndim` shape values.
79 pub shape: *mut i64,
80 /// Pointer to an array of `ndim` stride values (in elements), or NULL for C-contiguous.
81 pub strides: *mut i64,
82 /// Byte offset from `data` to the first element.
83 pub byte_offset: u64,
84}
85
86/// Managed DLPack tensor with associated deleter callback.
87///
88/// This is the struct handed off via `PyCapsule` under the name `"dltensor"`.
89#[repr(C)]
90pub struct DLManagedTensor {
91 /// The underlying tensor descriptor.
92 pub dl_tensor: DLTensor,
93 /// Opaque context pointer passed to `deleter`.
94 pub manager_ctx: *mut c_void,
95 /// Optional destructor; called by the consumer framework when done with the tensor.
96 pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
97}
98
99// SAFETY: The managed tensor is self-contained once constructed; we hold
100// the backing data buffer in the capsule's memory and the pointer is valid
101// until the capsule is destroyed.
102unsafe impl Send for DLManagedTensor {}
103
104// SAFETY: Access to the tensor is read-only after construction; no shared
105// mutable state is exposed without synchronisation.
106unsafe impl Sync for DLManagedTensor {}
107
108/// Python class wrapping a DLPack-compatible tensor.
109///
110/// Exposes `__dlpack__` and `__dlpack_device__` so that any DLPack-aware
111/// framework (PyTorch, JAX, CuPy, etc.) can consume the tensor without copying.
112#[pyclass(name = "DLPackCapsule")]
113pub struct DLPackCapsule {
114 /// Logical shape of the tensor.
115 shape: Vec<i64>,
116 /// Row-major strides (in elements).
117 strides: Vec<i64>,
118 /// Owned backing data buffer (zeroed on construction).
119 ///
120 /// Kept for future zero-copy implementations where `DLTensor.data` points
121 /// directly into this buffer. Currently unused in the test implementation.
122 #[allow(dead_code)]
123 data: Vec<u8>,
124 /// Element type descriptor.
125 dtype: DLDataType,
126 /// Device descriptor (always CPU for capsules created from Rust).
127 device: DLDevice,
128}
129
130#[pymethods]
131impl DLPackCapsule {
132 /// Create a new zero-filled DLPack capsule.
133 ///
134 /// # Arguments
135 /// * `shape` – tensor dimensions
136 /// * `dtype_code` – element type code (0=int, 1=uint, 2=float, 3=bfloat)
137 /// * `dtype_bits` – element bit-width (e.g. 32 or 64)
138 #[new]
139 pub fn new(shape: Vec<i64>, dtype_code: u8, dtype_bits: u8) -> Self {
140 let n: i64 = shape.iter().product();
141 let bytes_per_elem = (dtype_bits as usize).div_ceil(8).max(1);
142 let n_bytes = (n as usize) * bytes_per_elem;
143 let strides = compute_row_major_strides(&shape);
144 Self {
145 shape,
146 strides,
147 data: vec![0u8; n_bytes],
148 dtype: DLDataType {
149 code: dtype_code,
150 bits: dtype_bits,
151 lanes: 1,
152 },
153 device: DLDevice {
154 device_type: DLDeviceType::Cpu as i32,
155 device_id: 0,
156 },
157 }
158 }
159
160 /// Return `(device_type_int, device_id)` — the `__dlpack_device__` protocol.
161 #[pyo3(name = "__dlpack_device__")]
162 pub fn dlpack_device(&self) -> (i32, i32) {
163 (self.device.device_type, self.device.device_id)
164 }
165
166 /// Return a Python `PyCapsule` named `"dltensor"` — the `__dlpack__` protocol.
167 ///
168 /// The capsule contains a `DLManagedTensor` with a destructor that frees the
169 /// heap allocation created here.
170 ///
171 /// # Safety
172 ///
173 /// The capsule pointer is valid as long as the capsule is live. The `deleter`
174 /// registered in `DLManagedTensor` ensures the allocation is freed.
175 #[pyo3(name = "__dlpack__")]
176 pub fn dlpack<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
177 // Allocate shape and strides buffers on the heap so they outlive this call.
178 let mut shape_buf = self.shape.clone().into_boxed_slice();
179 let mut strides_buf = self.strides.clone().into_boxed_slice();
180
181 // Build the managed tensor. We use a dummy non-null data pointer because
182 // PyCapsule::new_with_pointer requires NonNull and the backing Vec is
183 // stored in the capsule's own allocation.
184 let managed = Box::new(DLManagedTensor {
185 dl_tensor: DLTensor {
186 data: shape_buf.as_mut_ptr() as *mut c_void, // placeholder; real impl would point to `self.data`
187 device: self.device,
188 ndim: self.shape.len() as i32,
189 dtype: self.dtype,
190 shape: shape_buf.as_mut_ptr(),
191 strides: strides_buf.as_mut_ptr(),
192 byte_offset: 0,
193 },
194 manager_ctx: std::ptr::null_mut(),
195 deleter: Some(dlpack_deleter),
196 });
197
198 // Leak the box buffers — the deleter will free the managed tensor pointer
199 // but the shape/strides buffers are intentionally leaked here for the ABI.
200 // (A production implementation would embed them in manager_ctx.)
201 std::mem::forget(shape_buf);
202 std::mem::forget(strides_buf);
203
204 let raw_ptr = Box::into_raw(managed);
205 // SAFETY: raw_ptr is non-null, valid, and the deleter frees it.
206 let non_null = NonNull::new(raw_ptr as *mut c_void)
207 .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null managed tensor ptr"))?;
208
209 // SAFETY: non_null points to a valid DLManagedTensor allocation; the
210 // dlpack_deleter extern "C" fn will free it when the capsule is destroyed.
211 unsafe {
212 PyCapsule::new_with_pointer_and_destructor(
213 py,
214 non_null,
215 DLTENSOR_CAPSULE_NAME,
216 Some(capsule_destructor),
217 )
218 }
219 }
220
221 /// Return the shape of this tensor.
222 pub fn shape(&self) -> Vec<i64> {
223 self.shape.clone()
224 }
225
226 /// Return the number of dimensions.
227 pub fn ndim(&self) -> usize {
228 self.shape.len()
229 }
230
231 /// Return the dtype type-code (0=int, 1=uint, 2=float, 3=bfloat).
232 pub fn dtype_code(&self) -> u8 {
233 self.dtype.code
234 }
235
236 /// Return the number of bits per element.
237 pub fn dtype_bits(&self) -> u8 {
238 self.dtype.bits
239 }
240}
241
242/// The name required by the DLPack ABI for capsules.
243const DLTENSOR_CAPSULE_NAME: &CStr = c"dltensor";
244
245/// Destructor called by Python's capsule machinery when the capsule is collected.
246///
247/// Frees the `DLManagedTensor` allocation.
248///
249/// # Safety
250///
251/// `capsule` must be a valid `PyCapsule` whose pointer was set to a `DLManagedTensor`
252/// heap allocation created via `Box::into_raw`.
253unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
254 // SAFETY: The capsule was created by `new_with_pointer_and_destructor` with a
255 // DLManagedTensor raw pointer. We cast the capsule object pointer back to
256 // the PyObject and retrieve the stored pointer.
257 let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_CAPSULE_NAME.as_ptr()) };
258 if !ptr.is_null() {
259 let managed_ptr = ptr as *mut DLManagedTensor;
260 // Call the tensor's own deleter if provided.
261 if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
262 unsafe { deleter(managed_ptr) };
263 }
264 }
265}
266
267/// Deleter stored inside `DLManagedTensor.deleter`.
268///
269/// Frees the `DLManagedTensor` allocation itself.
270///
271/// # Safety
272///
273/// `managed` must be a valid heap-allocated `DLManagedTensor` created by `Box::into_raw`.
274unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
275 if !managed.is_null() {
276 // SAFETY: managed was created by Box::into_raw(Box::new(...))
277 let _ = unsafe { Box::from_raw(managed) };
278 }
279}
280
281/// Compute C-order (row-major) strides for a given shape.
282///
283/// The last dimension has stride 1; each preceding dimension has stride equal to
284/// the product of all following dimensions.
285fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
286 let n = shape.len();
287 let mut strides = vec![1i64; n];
288 if n > 1 {
289 for i in (0..n - 1).rev() {
290 strides[i] = strides[i + 1] * shape[i + 1];
291 }
292 }
293 strides
294}
295
296/// Register DLPack classes into a PyO3 module.
297///
298/// Call this from your `#[pymodule]` init function to expose `DLPackCapsule`.
299pub fn register_dlpack_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
300 m.add_class::<DLPackCapsule>()?;
301 Ok(())
302}
303
304// ─── Enhanced DLPack interoperability ────────────────────────────────────────
305
306/// Element type codes used in DLPack `DLDataType.code`.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum DLDataTypeCode {
309 /// Signed integer (code 0).
310 Int = 0,
311 /// Unsigned integer (code 1).
312 UInt = 1,
313 /// IEEE floating point (code 2).
314 Float = 2,
315 /// Brain float (code 3).
316 BFloat = 3,
317}
318
319impl TryFrom<u8> for DLDataTypeCode {
320 type Error = DlpackError;
321
322 fn try_from(value: u8) -> Result<Self, Self::Error> {
323 match value {
324 0 => Ok(Self::Int),
325 1 => Ok(Self::UInt),
326 2 => Ok(Self::Float),
327 3 => Ok(Self::BFloat),
328 other => Err(DlpackError::UnsupportedDtype {
329 code: other,
330 bits: 0,
331 }),
332 }
333 }
334}
335
336/// Structured information extracted from a validated [`DLTensor`].
337#[derive(Debug, Clone)]
338pub struct DLTensorInfo {
339 /// Tensor dimensions.
340 pub shape: Vec<i64>,
341 /// Element type category.
342 pub dtype_code: DLDataTypeCode,
343 /// Element bit-width.
344 pub dtype_bits: u8,
345 /// Device type.
346 pub device_type: DLDeviceType,
347}
348
349/// Errors produced by DLPack validation and conversion utilities.
350#[derive(Debug, thiserror::Error)]
351pub enum DlpackError {
352 /// The tensor is not resident on CPU memory.
353 #[error("unsupported device: expected CPU")]
354 NonCpuDevice,
355
356 /// The element dtype (code + bits) is not supported by this operation.
357 #[error("unsupported dtype: {code}:{bits}")]
358 UnsupportedDtype {
359 /// DLDataType code.
360 code: u8,
361 /// DLDataType bits.
362 bits: u8,
363 },
364
365 /// The tensor's data pointer is null.
366 #[error("null data pointer")]
367 NullPointer,
368}
369
370/// Validate a [`DLTensor`] and extract structured metadata.
371///
372/// This is the entry point for consuming tensors produced by DLPack-aware
373/// frameworks (PyTorch, JAX, CuPy, etc.). It checks that:
374/// - `data` is non-null,
375/// - the device type is parseable,
376/// - the dtype code is recognised.
377///
378/// On success, returns a [`DLTensorInfo`] with all metadata decoded.
379///
380/// # Safety
381///
382/// `tensor.shape` must point to at least `tensor.ndim` valid `i64` values.
383/// The caller must ensure the tensor is not concurrently mutated.
384pub fn validate_dlpack_tensor(tensor: &DLTensor) -> Result<DLTensorInfo, DlpackError> {
385 // 1. Null-pointer guard.
386 if tensor.data.is_null() {
387 return Err(DlpackError::NullPointer);
388 }
389
390 // 2. Decode device type.
391 let device_type = decode_device_type(tensor.device.device_type);
392
393 // 3. Decode dtype code.
394 let dtype_code = DLDataTypeCode::try_from(tensor.dtype.code)?;
395
396 // 4. Copy shape (safe: shape ptr is valid for ndim elements per contract).
397 let shape = if tensor.ndim == 0 || tensor.shape.is_null() {
398 Vec::new()
399 } else {
400 // SAFETY: Caller guarantees shape ptr is valid for ndim elements.
401 unsafe {
402 std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize).to_vec()
403 }
404 };
405
406 Ok(DLTensorInfo {
407 shape,
408 dtype_code,
409 dtype_bits: tensor.dtype.bits,
410 device_type,
411 })
412}
413
414/// Create a [`DLTensor`] that borrows `data` and `shape` slices.
415///
416/// The returned `DLTensor` has its `data` pointer set to `data.as_ptr()`,
417/// `dtype` set to float64 (code=2, bits=64), and device set to CPU.
418///
419/// # Safety
420///
421/// The returned `DLTensor` holds raw pointers into `data` and `shape`.
422/// Both slices **must** remain live and unmodified for the entire lifetime of
423/// the returned tensor. The tensor must not be used after either slice drops.
424///
425/// The returned struct does **not** own the memory it points at; no destructor
426/// is called for `data` or `shape` when the `DLTensor` is dropped.
427pub fn dlpack_from_slice(data: &[f64], shape: &[i64]) -> DLTensor {
428 DLTensor {
429 // SAFETY: We cast a shared reference to a mut-pointer to satisfy the
430 // DLPack ABI (which uses *mut c_void). The caller contract forbids
431 // mutations through this pointer; this crate never does so.
432 data: data.as_ptr() as *mut c_void,
433 device: DLDevice {
434 device_type: DLDeviceType::Cpu as i32,
435 device_id: 0,
436 },
437 ndim: shape.len() as i32,
438 dtype: DLDataType {
439 code: DLDataTypeCode::Float as u8,
440 bits: 64,
441 lanes: 1,
442 },
443 // SAFETY: Same const-to-mut cast; shape is read-only.
444 shape: shape.as_ptr() as *mut i64,
445 strides: std::ptr::null_mut(), // C-contiguous: strides not needed.
446 byte_offset: 0,
447 }
448}
449
450/// Extract a `Vec<f64>` from a CPU float64 [`DLTensor`].
451///
452/// Validates that:
453/// - `data` is non-null,
454/// - device type is CPU,
455/// - dtype is float64 (code=2, bits=64, lanes=1).
456///
457/// Returns a freshly allocated `Vec<f64>` copied from the tensor buffer.
458///
459/// # Safety
460///
461/// `tensor.data` must point to at least `product(tensor.shape) * 8` valid
462/// bytes of `f64` values in native byte order. Caller must ensure the tensor
463/// is valid for the duration of this call.
464pub fn dlpack_to_vec_f64(tensor: &DLTensor) -> Result<Vec<f64>, DlpackError> {
465 // Guard: non-null data.
466 if tensor.data.is_null() {
467 return Err(DlpackError::NullPointer);
468 }
469
470 // Guard: CPU device.
471 let device_type = tensor.device.device_type;
472 if device_type != DLDeviceType::Cpu as i32 {
473 return Err(DlpackError::NonCpuDevice);
474 }
475
476 // Guard: float64 dtype.
477 if tensor.dtype.code != DLDataTypeCode::Float as u8
478 || tensor.dtype.bits != 64
479 || tensor.dtype.lanes != 1
480 {
481 return Err(DlpackError::UnsupportedDtype {
482 code: tensor.dtype.code,
483 bits: tensor.dtype.bits,
484 });
485 }
486
487 // Compute element count from shape.
488 let n_elems = if tensor.ndim == 0 {
489 1usize
490 } else if tensor.shape.is_null() {
491 0usize
492 } else {
493 // SAFETY: shape is valid for ndim elements (caller contract).
494 let shape =
495 unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
496 shape.iter().map(|&d| d as usize).product()
497 };
498
499 // Apply byte_offset.
500 let base = unsafe { (tensor.data as *const u8).add(tensor.byte_offset as usize) as *const f64 };
501
502 // SAFETY: base points to n_elems valid f64 values (caller contract).
503 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
504 Ok(slice.to_vec())
505}
506
507/// Decode a raw DLPack device-type integer into the [`DLDeviceType`] enum.
508///
509/// Unknown values fall back to [`DLDeviceType::Cpu`] with a conservative default.
510fn decode_device_type(raw: i32) -> DLDeviceType {
511 match raw {
512 1 => DLDeviceType::Cpu,
513 2 => DLDeviceType::Cuda,
514 3 => DLDeviceType::CudaHost,
515 4 => DLDeviceType::OpenCL,
516 7 => DLDeviceType::Vulkan,
517 8 => DLDeviceType::Metal,
518 10 => DLDeviceType::Rocm,
519 _ => DLDeviceType::Cpu, // conservative fallback
520 }
521}
522
523// ─── Tests ───────────────────────────────────────────────────────────────────
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 // --- validate_dlpack_tensor ---
530
531 #[test]
532 fn test_validate_valid_f64_cpu_tensor() {
533 let mut data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
534 let mut shape = vec![2_i64, 3];
535 let tensor = dlpack_from_slice(&data, &shape);
536
537 let info = validate_dlpack_tensor(&tensor).expect("validate_dlpack_tensor failed");
538 assert_eq!(info.shape, vec![2, 3]);
539 assert_eq!(info.dtype_code, DLDataTypeCode::Float);
540 assert_eq!(info.dtype_bits, 64);
541 assert_eq!(info.device_type, DLDeviceType::Cpu);
542
543 // Keep data and shape alive.
544 let _ = (&mut data, &mut shape);
545 }
546
547 #[test]
548 fn test_validate_null_pointer_returns_err() {
549 let shape = vec![3_i64];
550 let mut tensor = dlpack_from_slice(&[0.0_f64; 3], &shape);
551 // Forcibly set data to null to test the null-pointer guard.
552 tensor.data = std::ptr::null_mut();
553 let result = validate_dlpack_tensor(&tensor);
554 assert!(
555 matches!(result, Err(DlpackError::NullPointer)),
556 "expected NullPointer error"
557 );
558 }
559
560 #[test]
561 fn test_validate_shape_fields() {
562 let data = vec![10.0_f64; 12];
563 let shape = vec![3_i64, 4];
564 let tensor = dlpack_from_slice(&data, &shape);
565 let info = validate_dlpack_tensor(&tensor).expect("validate failed");
566 assert_eq!(info.shape, vec![3, 4]);
567 }
568
569 // --- dlpack_from_slice ---
570
571 #[test]
572 fn test_dlpack_from_slice_shape_fields() {
573 let data = vec![1.0_f64, 2.0, 3.0];
574 let shape = vec![3_i64];
575 let tensor = dlpack_from_slice(&data, &shape);
576
577 assert_eq!(tensor.ndim, 1);
578 assert!(!tensor.data.is_null());
579 assert!(!tensor.shape.is_null());
580 // dtype must be float64
581 assert_eq!(tensor.dtype.code, 2); // Float
582 assert_eq!(tensor.dtype.bits, 64);
583 }
584
585 #[test]
586 fn test_dlpack_from_slice_2d() {
587 let data = vec![0.0_f64; 6];
588 let shape = vec![2_i64, 3];
589 let tensor = dlpack_from_slice(&data, &shape);
590 assert_eq!(tensor.ndim, 2);
591 // SAFETY: shape is valid for ndim=2.
592 let s = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, 2) };
593 assert_eq!(s, [2, 3]);
594 }
595
596 // --- dlpack_to_vec_f64 ---
597
598 #[test]
599 fn test_dlpack_to_vec_f64_round_trip() {
600 let original = vec![1.0_f64, 2.5, 3.15, -7.0, 0.0];
601 let shape = vec![5_i64];
602 let tensor = dlpack_from_slice(&original, &shape);
603
604 let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
605 assert_eq!(recovered, original);
606 }
607
608 #[test]
609 fn test_dlpack_to_vec_f64_2d() {
610 let original: Vec<f64> = (0..6).map(|i| i as f64).collect();
611 let shape = vec![2_i64, 3];
612 let tensor = dlpack_from_slice(&original, &shape);
613
614 let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
615 assert_eq!(recovered, original);
616 }
617
618 #[test]
619 fn test_dlpack_to_vec_f64_null_pointer_err() {
620 let data = vec![0.0_f64];
621 let shape = vec![1_i64];
622 let mut tensor = dlpack_from_slice(&data, &shape);
623 tensor.data = std::ptr::null_mut();
624
625 assert!(matches!(
626 dlpack_to_vec_f64(&tensor),
627 Err(DlpackError::NullPointer)
628 ));
629 }
630
631 #[test]
632 fn test_dlpack_to_vec_f64_non_cpu_err() {
633 let data = vec![0.0_f64];
634 let shape = vec![1_i64];
635 let mut tensor = dlpack_from_slice(&data, &shape);
636 tensor.device.device_type = DLDeviceType::Cuda as i32;
637
638 assert!(matches!(
639 dlpack_to_vec_f64(&tensor),
640 Err(DlpackError::NonCpuDevice)
641 ));
642 }
643
644 #[test]
645 fn test_dlpack_to_vec_f64_wrong_dtype_err() {
646 let data = vec![0.0_f64];
647 let shape = vec![1_i64];
648 let mut tensor = dlpack_from_slice(&data, &shape);
649 tensor.dtype.code = 0; // Int, not Float
650
651 assert!(matches!(
652 dlpack_to_vec_f64(&tensor),
653 Err(DlpackError::UnsupportedDtype { .. })
654 ));
655 }
656
657 // --- DLDataTypeCode ---
658
659 #[test]
660 fn test_dtype_code_try_from() {
661 assert_eq!(DLDataTypeCode::try_from(0u8).unwrap(), DLDataTypeCode::Int);
662 assert_eq!(DLDataTypeCode::try_from(1u8).unwrap(), DLDataTypeCode::UInt);
663 assert_eq!(
664 DLDataTypeCode::try_from(2u8).unwrap(),
665 DLDataTypeCode::Float
666 );
667 assert_eq!(
668 DLDataTypeCode::try_from(3u8).unwrap(),
669 DLDataTypeCode::BFloat
670 );
671 assert!(DLDataTypeCode::try_from(99u8).is_err());
672 }
673}