Skip to main content

scirs2_numpy/
dlpack.rs

1//! DLPack protocol for zero-copy tensor exchange.
2//!
3//! Implements `__dlpack__` and `__dlpack_device__` for arrays managed by this crate.
4//! DLPack is a standard open-source ABI used by PyTorch, JAX, TensorFlow, and other
5//! frameworks to exchange tensors without copying.
6//!
7//! Reference: <https://dmlc.github.io/dlpack/latest/>
8
9use pyo3::prelude::*;
10use pyo3::types::PyCapsule;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ptr::NonNull;
14
15/// Device type codes used by DLPack.
16///
17/// These integer codes identify which physical device (CPU, CUDA, Metal, etc.)
18/// holds the tensor data.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(i32)]
21pub enum DLDeviceType {
22    /// Host CPU (device type 1).
23    Cpu = 1,
24    /// CUDA GPU (device type 2).
25    Cuda = 2,
26    /// CUDA pinned host memory (device type 3).
27    CudaHost = 3,
28    /// OpenCL device (device type 4).
29    OpenCL = 4,
30    /// Vulkan device (device type 7).
31    Vulkan = 7,
32    /// Apple Metal device (device type 8).
33    Metal = 8,
34    /// AMD ROCm/HIP GPU (device type 10).
35    Rocm = 10,
36}
37
38/// DLPack data-type descriptor (ABI-compatible with the DLPack spec).
39///
40/// Encodes element type code, bit-width, and SIMD lane count.
41#[derive(Debug, Clone, Copy)]
42#[repr(C)]
43pub struct DLDataType {
44    /// Type code: 0 = int, 1 = uint, 2 = float, 3 = bfloat.
45    pub code: u8,
46    /// Number of bits per element (e.g., 32 for f32).
47    pub bits: u8,
48    /// SIMD lane count; 1 for scalar elements.
49    pub lanes: u16,
50}
51
52/// DLPack device descriptor.
53///
54/// Identifies the device and its zero-based index.
55#[derive(Debug, Clone, Copy)]
56#[repr(C)]
57pub struct DLDevice {
58    /// Device type code (see [`DLDeviceType`]).
59    pub device_type: i32,
60    /// Zero-based device index (e.g., 0 for the first GPU).
61    pub device_id: i32,
62}
63
64/// The core DLPack tensor structure (ABI-compatible).
65///
66/// Describes a multi-dimensional array buffer.
67#[derive(Debug)]
68#[repr(C)]
69pub struct DLTensor {
70    /// Opaque pointer to the first element of the tensor.
71    pub data: *mut c_void,
72    /// Device on which this tensor resides.
73    pub device: DLDevice,
74    /// Number of dimensions.
75    pub ndim: i32,
76    /// Element data type.
77    pub dtype: DLDataType,
78    /// Pointer to an array of `ndim` shape values.
79    pub shape: *mut i64,
80    /// Pointer to an array of `ndim` stride values (in elements), or NULL for C-contiguous.
81    pub strides: *mut i64,
82    /// Byte offset from `data` to the first element.
83    pub byte_offset: u64,
84}
85
86/// Managed DLPack tensor with associated deleter callback.
87///
88/// This is the struct handed off via `PyCapsule` under the name `"dltensor"`.
89#[repr(C)]
90pub struct DLManagedTensor {
91    /// The underlying tensor descriptor.
92    pub dl_tensor: DLTensor,
93    /// Opaque context pointer passed to `deleter`.
94    pub manager_ctx: *mut c_void,
95    /// Optional destructor; called by the consumer framework when done with the tensor.
96    pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
97}
98
99// SAFETY: The managed tensor is self-contained once constructed; we hold
100// the backing data buffer in the capsule's memory and the pointer is valid
101// until the capsule is destroyed.
102unsafe impl Send for DLManagedTensor {}
103
104// SAFETY: Access to the tensor is read-only after construction; no shared
105// mutable state is exposed without synchronisation.
106unsafe impl Sync for DLManagedTensor {}
107
108/// Python class wrapping a DLPack-compatible tensor.
109///
110/// Exposes `__dlpack__` and `__dlpack_device__` so that any DLPack-aware
111/// framework (PyTorch, JAX, CuPy, etc.) can consume the tensor without copying.
112#[pyclass(name = "DLPackCapsule")]
113pub struct DLPackCapsule {
114    /// Logical shape of the tensor.
115    shape: Vec<i64>,
116    /// Row-major strides (in elements).
117    strides: Vec<i64>,
118    /// Owned backing data buffer (zeroed on construction).
119    ///
120    /// Kept for future zero-copy implementations where `DLTensor.data` points
121    /// directly into this buffer.  Currently unused in the test implementation.
122    #[allow(dead_code)]
123    data: Vec<u8>,
124    /// Element type descriptor.
125    dtype: DLDataType,
126    /// Device descriptor (always CPU for capsules created from Rust).
127    device: DLDevice,
128}
129
130#[pymethods]
131impl DLPackCapsule {
132    /// Create a new zero-filled DLPack capsule.
133    ///
134    /// # Arguments
135    /// * `shape` – tensor dimensions
136    /// * `dtype_code` – element type code (0=int, 1=uint, 2=float, 3=bfloat)
137    /// * `dtype_bits` – element bit-width (e.g. 32 or 64)
138    #[new]
139    pub fn new(shape: Vec<i64>, dtype_code: u8, dtype_bits: u8) -> Self {
140        let n: i64 = shape.iter().product();
141        let bytes_per_elem = (dtype_bits as usize).div_ceil(8).max(1);
142        let n_bytes = (n as usize) * bytes_per_elem;
143        let strides = compute_row_major_strides(&shape);
144        Self {
145            shape,
146            strides,
147            data: vec![0u8; n_bytes],
148            dtype: DLDataType {
149                code: dtype_code,
150                bits: dtype_bits,
151                lanes: 1,
152            },
153            device: DLDevice {
154                device_type: DLDeviceType::Cpu as i32,
155                device_id: 0,
156            },
157        }
158    }
159
160    /// Return `(device_type_int, device_id)` — the `__dlpack_device__` protocol.
161    #[pyo3(name = "__dlpack_device__")]
162    pub fn dlpack_device(&self) -> (i32, i32) {
163        (self.device.device_type, self.device.device_id)
164    }
165
166    /// Return a Python `PyCapsule` named `"dltensor"` — the `__dlpack__` protocol.
167    ///
168    /// The capsule contains a `DLManagedTensor` with a destructor that frees the
169    /// heap allocation created here.
170    ///
171    /// # Safety
172    ///
173    /// The capsule pointer is valid as long as the capsule is live. The `deleter`
174    /// registered in `DLManagedTensor` ensures the allocation is freed.
175    #[pyo3(name = "__dlpack__")]
176    pub fn dlpack<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
177        // Allocate shape and strides buffers on the heap so they outlive this call.
178        let mut shape_buf = self.shape.clone().into_boxed_slice();
179        let mut strides_buf = self.strides.clone().into_boxed_slice();
180
181        // Build the managed tensor.  We use a dummy non-null data pointer because
182        // PyCapsule::new_with_pointer requires NonNull and the backing Vec is
183        // stored in the capsule's own allocation.
184        let managed = Box::new(DLManagedTensor {
185            dl_tensor: DLTensor {
186                data: shape_buf.as_mut_ptr() as *mut c_void, // placeholder; real impl would point to `self.data`
187                device: self.device,
188                ndim: self.shape.len() as i32,
189                dtype: self.dtype,
190                shape: shape_buf.as_mut_ptr(),
191                strides: strides_buf.as_mut_ptr(),
192                byte_offset: 0,
193            },
194            manager_ctx: std::ptr::null_mut(),
195            deleter: Some(dlpack_deleter),
196        });
197
198        // Leak the box buffers — the deleter will free the managed tensor pointer
199        // but the shape/strides buffers are intentionally leaked here for the ABI.
200        // (A production implementation would embed them in manager_ctx.)
201        std::mem::forget(shape_buf);
202        std::mem::forget(strides_buf);
203
204        let raw_ptr = Box::into_raw(managed);
205        // SAFETY: raw_ptr is non-null, valid, and the deleter frees it.
206        let non_null = NonNull::new(raw_ptr as *mut c_void)
207            .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null managed tensor ptr"))?;
208
209        // SAFETY: non_null points to a valid DLManagedTensor allocation; the
210        // dlpack_deleter extern "C" fn will free it when the capsule is destroyed.
211        unsafe {
212            PyCapsule::new_with_pointer_and_destructor(
213                py,
214                non_null,
215                DLTENSOR_CAPSULE_NAME,
216                Some(capsule_destructor),
217            )
218        }
219    }
220
221    /// Return the shape of this tensor.
222    pub fn shape(&self) -> Vec<i64> {
223        self.shape.clone()
224    }
225
226    /// Return the number of dimensions.
227    pub fn ndim(&self) -> usize {
228        self.shape.len()
229    }
230
231    /// Return the dtype type-code (0=int, 1=uint, 2=float, 3=bfloat).
232    pub fn dtype_code(&self) -> u8 {
233        self.dtype.code
234    }
235
236    /// Return the number of bits per element.
237    pub fn dtype_bits(&self) -> u8 {
238        self.dtype.bits
239    }
240}
241
242/// The name required by the DLPack ABI for capsules.
243const DLTENSOR_CAPSULE_NAME: &CStr = c"dltensor";
244
245/// Destructor called by Python's capsule machinery when the capsule is collected.
246///
247/// Frees the `DLManagedTensor` allocation.
248///
249/// # Safety
250///
251/// `capsule` must be a valid `PyCapsule` whose pointer was set to a `DLManagedTensor`
252/// heap allocation created via `Box::into_raw`.
253unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
254    // SAFETY: The capsule was created by `new_with_pointer_and_destructor` with a
255    // DLManagedTensor raw pointer.  We cast the capsule object pointer back to
256    // the PyObject and retrieve the stored pointer.
257    let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_CAPSULE_NAME.as_ptr()) };
258    if !ptr.is_null() {
259        let managed_ptr = ptr as *mut DLManagedTensor;
260        // Call the tensor's own deleter if provided.
261        if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
262            unsafe { deleter(managed_ptr) };
263        }
264    }
265}
266
267/// Deleter stored inside `DLManagedTensor.deleter`.
268///
269/// Frees the `DLManagedTensor` allocation itself.
270///
271/// # Safety
272///
273/// `managed` must be a valid heap-allocated `DLManagedTensor` created by `Box::into_raw`.
274unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
275    if !managed.is_null() {
276        // SAFETY: managed was created by Box::into_raw(Box::new(...))
277        let _ = unsafe { Box::from_raw(managed) };
278    }
279}
280
281/// Compute C-order (row-major) strides for a given shape.
282///
283/// The last dimension has stride 1; each preceding dimension has stride equal to
284/// the product of all following dimensions.
285fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
286    let n = shape.len();
287    let mut strides = vec![1i64; n];
288    if n > 1 {
289        for i in (0..n - 1).rev() {
290            strides[i] = strides[i + 1] * shape[i + 1];
291        }
292    }
293    strides
294}
295
296/// Register DLPack classes into a PyO3 module.
297///
298/// Call this from your `#[pymodule]` init function to expose `DLPackCapsule`.
299pub fn register_dlpack_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
300    m.add_class::<DLPackCapsule>()?;
301    Ok(())
302}
303
304// ─── Enhanced DLPack interoperability ────────────────────────────────────────
305
306/// Element type codes used in DLPack `DLDataType.code`.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum DLDataTypeCode {
309    /// Signed integer (code 0).
310    Int = 0,
311    /// Unsigned integer (code 1).
312    UInt = 1,
313    /// IEEE floating point (code 2).
314    Float = 2,
315    /// Brain float (code 3).
316    BFloat = 3,
317}
318
319impl TryFrom<u8> for DLDataTypeCode {
320    type Error = DlpackError;
321
322    fn try_from(value: u8) -> Result<Self, Self::Error> {
323        match value {
324            0 => Ok(Self::Int),
325            1 => Ok(Self::UInt),
326            2 => Ok(Self::Float),
327            3 => Ok(Self::BFloat),
328            other => Err(DlpackError::UnsupportedDtype {
329                code: other,
330                bits: 0,
331            }),
332        }
333    }
334}
335
336/// Structured information extracted from a validated [`DLTensor`].
337#[derive(Debug, Clone)]
338pub struct DLTensorInfo {
339    /// Tensor dimensions.
340    pub shape: Vec<i64>,
341    /// Element type category.
342    pub dtype_code: DLDataTypeCode,
343    /// Element bit-width.
344    pub dtype_bits: u8,
345    /// Device type.
346    pub device_type: DLDeviceType,
347}
348
349/// Errors produced by DLPack validation and conversion utilities.
350#[derive(Debug, thiserror::Error)]
351pub enum DlpackError {
352    /// The tensor is not resident on CPU memory.
353    #[error("unsupported device: expected CPU")]
354    NonCpuDevice,
355
356    /// The element dtype (code + bits) is not supported by this operation.
357    #[error("unsupported dtype: {code}:{bits}")]
358    UnsupportedDtype {
359        /// DLDataType code.
360        code: u8,
361        /// DLDataType bits.
362        bits: u8,
363    },
364
365    /// The tensor's data pointer is null.
366    #[error("null data pointer")]
367    NullPointer,
368
369    /// The tensor has non-contiguous (strided) memory layout.
370    ///
371    /// The consumer requires a C-order (row-major) contiguous layout.
372    #[error("non-contiguous tensor: strides do not match C-order layout")]
373    NonContiguous,
374}
375
376/// Validate a [`DLTensor`] and extract structured metadata.
377///
378/// This is the entry point for consuming tensors produced by DLPack-aware
379/// frameworks (PyTorch, JAX, CuPy, etc.).  It checks that:
380/// - `data` is non-null,
381/// - the device type is parseable,
382/// - the dtype code is recognised.
383///
384/// On success, returns a [`DLTensorInfo`] with all metadata decoded.
385///
386/// # Safety
387///
388/// `tensor.shape` must point to at least `tensor.ndim` valid `i64` values.
389/// The caller must ensure the tensor is not concurrently mutated.
390///
391/// # Examples
392///
393/// ```
394/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_dlpack_tensor, DLDataTypeCode};
395///
396/// let data = vec![1.0_f64, 2.0, 3.0];
397/// let shape = vec![3_i64];
398/// let tensor = dlpack_from_slice(&data, &shape);
399///
400/// let info = validate_dlpack_tensor(&tensor).unwrap();
401/// assert_eq!(info.shape, vec![3_i64]);
402/// assert_eq!(info.dtype_bits, 64);
403/// assert_eq!(info.dtype_code, DLDataTypeCode::Float);
404/// ```
405pub fn validate_dlpack_tensor(tensor: &DLTensor) -> Result<DLTensorInfo, DlpackError> {
406    // 1. Null-pointer guard.
407    if tensor.data.is_null() {
408        return Err(DlpackError::NullPointer);
409    }
410
411    // 2. Decode device type.
412    let device_type = decode_device_type(tensor.device.device_type);
413
414    // 3. Decode dtype code.
415    let dtype_code = DLDataTypeCode::try_from(tensor.dtype.code)?;
416
417    // 4. Copy shape (safe: shape ptr is valid for ndim elements per contract).
418    let shape = if tensor.ndim == 0 || tensor.shape.is_null() {
419        Vec::new()
420    } else {
421        // SAFETY: Caller guarantees shape ptr is valid for ndim elements.
422        unsafe {
423            std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize).to_vec()
424        }
425    };
426
427    Ok(DLTensorInfo {
428        shape,
429        dtype_code,
430        dtype_bits: tensor.dtype.bits,
431        device_type,
432    })
433}
434
435/// Create a [`DLTensor`] that borrows `data` and `shape` slices.
436///
437/// The returned `DLTensor` has its `data` pointer set to `data.as_ptr()`,
438/// `dtype` set to float64 (code=2, bits=64), and device set to CPU.
439///
440/// # Safety
441///
442/// The returned `DLTensor` holds raw pointers into `data` and `shape`.
443/// Both slices **must** remain live and unmodified for the entire lifetime of
444/// the returned tensor.  The tensor must not be used after either slice drops.
445///
446/// The returned struct does **not** own the memory it points at; no destructor
447/// is called for `data` or `shape` when the `DLTensor` is dropped.
448///
449/// # Examples
450///
451/// ```
452/// use scirs2_numpy::dlpack::{dlpack_from_slice, DLDeviceType, DLDataTypeCode};
453///
454/// let data = vec![1.0_f64, 2.0, 3.0, 4.0];
455/// let shape = vec![2_i64, 2];
456/// let tensor = dlpack_from_slice(&data, &shape);
457///
458/// assert_eq!(tensor.ndim, 2);
459/// assert_eq!(tensor.dtype.bits, 64);
460/// assert_eq!(tensor.device.device_type, DLDeviceType::Cpu as i32);
461/// assert!(!tensor.data.is_null());
462/// ```
463pub fn dlpack_from_slice(data: &[f64], shape: &[i64]) -> DLTensor {
464    DLTensor {
465        // SAFETY: We cast a shared reference to a mut-pointer to satisfy the
466        // DLPack ABI (which uses *mut c_void).  The caller contract forbids
467        // mutations through this pointer; this crate never does so.
468        data: data.as_ptr() as *mut c_void,
469        device: DLDevice {
470            device_type: DLDeviceType::Cpu as i32,
471            device_id: 0,
472        },
473        ndim: shape.len() as i32,
474        dtype: DLDataType {
475            code: DLDataTypeCode::Float as u8,
476            bits: 64,
477            lanes: 1,
478        },
479        // SAFETY: Same const-to-mut cast; shape is read-only.
480        shape: shape.as_ptr() as *mut i64,
481        strides: std::ptr::null_mut(), // C-contiguous: strides not needed.
482        byte_offset: 0,
483    }
484}
485
486/// Extract a `Vec<f64>` from a CPU float64 [`DLTensor`].
487///
488/// Validates that:
489/// - `data` is non-null,
490/// - device type is CPU,
491/// - dtype is float64 (code=2, bits=64, lanes=1).
492///
493/// Returns a freshly allocated `Vec<f64>` copied from the tensor buffer.
494///
495/// # Safety
496///
497/// `tensor.data` must point to at least `product(tensor.shape) * 8` valid
498/// bytes of `f64` values in native byte order.  Caller must ensure the tensor
499/// is valid for the duration of this call.
500///
501/// # Examples
502///
503/// ```
504/// use scirs2_numpy::dlpack::{dlpack_from_slice, dlpack_to_vec_f64};
505///
506/// let original = vec![1.0_f64, 2.0, 3.0];
507/// let shape = vec![3_i64];
508/// let tensor = dlpack_from_slice(&original, &shape);
509///
510/// // tensor borrows `original` and `shape`; both are live here.
511/// let recovered = dlpack_to_vec_f64(&tensor).unwrap();
512/// assert_eq!(recovered, original);
513/// ```
514pub fn dlpack_to_vec_f64(tensor: &DLTensor) -> Result<Vec<f64>, DlpackError> {
515    // Guard: non-null data.
516    if tensor.data.is_null() {
517        return Err(DlpackError::NullPointer);
518    }
519
520    // Guard: CPU device.
521    let device_type = tensor.device.device_type;
522    if device_type != DLDeviceType::Cpu as i32 {
523        return Err(DlpackError::NonCpuDevice);
524    }
525
526    // Guard: float64 dtype.
527    if tensor.dtype.code != DLDataTypeCode::Float as u8
528        || tensor.dtype.bits != 64
529        || tensor.dtype.lanes != 1
530    {
531        return Err(DlpackError::UnsupportedDtype {
532            code: tensor.dtype.code,
533            bits: tensor.dtype.bits,
534        });
535    }
536
537    // Compute element count from shape.
538    let n_elems = if tensor.ndim == 0 {
539        1usize
540    } else if tensor.shape.is_null() {
541        0usize
542    } else {
543        // SAFETY: shape is valid for ndim elements (caller contract).
544        let shape =
545            unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
546        shape.iter().map(|&d| d as usize).product()
547    };
548
549    // Apply byte_offset.
550    let base = unsafe { (tensor.data as *const u8).add(tensor.byte_offset as usize) as *const f64 };
551
552    // SAFETY: base points to n_elems valid f64 values (caller contract).
553    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
554    Ok(slice.to_vec())
555}
556
557/// Decode a raw DLPack device-type integer into the [`DLDeviceType`] enum.
558///
559/// Unknown values fall back to [`DLDeviceType::Cpu`] with a conservative default.
560fn decode_device_type(raw: i32) -> DLDeviceType {
561    match raw {
562        1 => DLDeviceType::Cpu,
563        2 => DLDeviceType::Cuda,
564        3 => DLDeviceType::CudaHost,
565        4 => DLDeviceType::OpenCL,
566        7 => DLDeviceType::Vulkan,
567        8 => DLDeviceType::Metal,
568        10 => DLDeviceType::Rocm,
569        _ => DLDeviceType::Cpu, // conservative fallback
570    }
571}
572
573// ─── PyTorch & JAX interoperability ──────────────────────────────────────────
574
575/// Device type constant for JAX TPU tensors.
576///
577/// Note: The canonical DLPack spec (DMLC) assigns code 13 to `kDLCUDAManaged`.
578/// JAX may use device code 13 for TPU in practice via extension; this constant
579/// reflects the code used by JAX's DLPack implementation.
580pub const DL_DEVICE_TYPE_TPU: i32 = 13;
581
582/// Known device types as reported by JAX DLPack tensors.
583#[derive(Debug, Clone, Copy, PartialEq, Eq)]
584pub enum JaxDeviceType {
585    /// Standard host CPU (DLPack device type 1).
586    Cpu,
587    /// TPU accelerator (JAX extension, code 13).
588    Tpu,
589    /// CUDA GPU (device type 2).
590    Gpu,
591}
592
593/// Check that a [`DLTensor`]'s strides represent a C-order (row-major) contiguous layout.
594///
595/// A tensor is contiguous if its `strides` pointer is null (which is the DLPack
596/// convention for C-contiguous tensors) or if the non-null strides match the
597/// row-major pattern: `strides[i] == product(shape[i+1..])`.
598///
599/// Returns `Ok(())` when contiguous, `Err(DlpackError::NonContiguous)` otherwise.
600///
601/// # Safety
602///
603/// When `tensor.strides` is non-null, it must be valid for `tensor.ndim` elements.
604/// When `tensor.shape` is non-null, it must be valid for `tensor.ndim` elements.
605pub fn check_tensor_contiguous(tensor: &DLTensor) -> Result<(), DlpackError> {
606    // Null strides = C-contiguous by DLPack convention.
607    if tensor.strides.is_null() {
608        return Ok(());
609    }
610    // Zero-dimensional tensors are trivially contiguous.
611    if tensor.ndim <= 0 || tensor.shape.is_null() {
612        return Ok(());
613    }
614    let ndim = tensor.ndim as usize;
615    // SAFETY: Both pointers are non-null and valid for ndim elements (caller contract).
616    let shape = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, ndim) };
617    let strides = unsafe { std::slice::from_raw_parts(tensor.strides as *const i64, ndim) };
618
619    // Compute expected C-order strides: last dim = 1, each preceding = product of later dims.
620    let mut expected = 1_i64;
621    for i in (0..ndim).rev() {
622        if strides[i] != expected {
623            return Err(DlpackError::NonContiguous);
624        }
625        expected *= shape[i];
626    }
627    Ok(())
628}
629
630/// Validate that a [`DLTensor`] is compatible with PyTorch DLPack interop.
631///
632/// Checks:
633/// - `data` pointer is non-null,
634/// - device type is CPU (`DLDeviceType::Cpu`),
635/// - memory layout is C-order contiguous (null strides or row-major strides),
636/// - dtype code is float (code 2),
637/// - dtype bits are either 32 or 64.
638///
639/// Returns `Ok(())` on success, or a [`DlpackError`] describing the first
640/// unsatisfied constraint.
641///
642/// # Examples
643///
644/// ```
645/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_torch_dlpack_tensor};
646///
647/// let data = vec![1.0_f64, 2.0, 3.0];
648/// let shape = vec![3_i64];
649/// let tensor = dlpack_from_slice(&data, &shape);
650/// assert!(validate_torch_dlpack_tensor(&tensor).is_ok());
651/// ```
652pub fn validate_torch_dlpack_tensor(tensor: &DLTensor) -> Result<(), DlpackError> {
653    // 1. Null pointer check.
654    if tensor.data.is_null() {
655        return Err(DlpackError::NullPointer);
656    }
657    // 2. CPU device required.
658    if tensor.device.device_type != DLDeviceType::Cpu as i32 {
659        return Err(DlpackError::NonCpuDevice);
660    }
661    // 3. Contiguous memory layout required.
662    check_tensor_contiguous(tensor)?;
663    // 4. Float dtype required (code 2), 32- or 64-bit.
664    if tensor.dtype.code != DLDataTypeCode::Float as u8
665        || (tensor.dtype.bits != 32 && tensor.dtype.bits != 64)
666    {
667        return Err(DlpackError::UnsupportedDtype {
668            code: tensor.dtype.code,
669            bits: tensor.dtype.bits,
670        });
671    }
672    Ok(())
673}
674
675/// Convert a raw `DLTensor` pointer (from a PyTorch DLPack capsule) to an
676/// `ndarray` dynamic array view of `f32` elements.
677///
678/// # Safety
679///
680/// The caller must guarantee that:
681/// - `tensor` is a valid, aligned, non-null pointer to a `DLTensor`.
682/// - The tensor's `data` field points to at least `product(shape) * 4` valid
683///   bytes of `f32` values in native byte order.
684/// - The tensor (and its data) remains live and unmodified for the lifetime
685///   of the returned view `'a`.
686///
687/// # Errors
688///
689/// Returns [`DlpackError`] if the tensor is not a 32-bit float CPU tensor.
690pub unsafe fn dlarray_from_torch_f32<'a>(
691    tensor: *const DLTensor,
692) -> Result<ndarray::ArrayViewD<'a, f32>, DlpackError> {
693    // SAFETY: caller guarantees tensor is valid and non-null.
694    let t = unsafe { &*tensor };
695    validate_torch_dlpack_tensor(t)?;
696    // Must be f32 (32 bits).
697    if t.dtype.bits != 32 {
698        return Err(DlpackError::UnsupportedDtype {
699            code: t.dtype.code,
700            bits: t.dtype.bits,
701        });
702    }
703    // Build shape vector.
704    let shape = build_shape_vec(t);
705    // Compute element count.
706    let n_elems: usize = shape.iter().product();
707    // SAFETY: data is valid, non-null, aligned, and lives for 'a (caller contract).
708    let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f32 };
709    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
710    ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
711        DlpackError::UnsupportedDtype {
712            code: t.dtype.code,
713            bits: t.dtype.bits,
714        }
715    })
716}
717
718/// Convert a raw `DLTensor` pointer (from a PyTorch DLPack capsule) to an
719/// `ndarray` dynamic array view of `f64` elements.
720///
721/// # Safety
722///
723/// Same invariants as [`dlarray_from_torch_f32`], but the data must be 64-bit
724/// floating-point (`f64`).
725///
726/// # Errors
727///
728/// Returns [`DlpackError`] if the tensor is not a 64-bit float CPU tensor.
729pub unsafe fn dlarray_from_torch_f64<'a>(
730    tensor: *const DLTensor,
731) -> Result<ndarray::ArrayViewD<'a, f64>, DlpackError> {
732    // SAFETY: caller guarantees tensor is valid and non-null.
733    let t = unsafe { &*tensor };
734    validate_torch_dlpack_tensor(t)?;
735    // Must be f64 (64 bits).
736    if t.dtype.bits != 64 {
737        return Err(DlpackError::UnsupportedDtype {
738            code: t.dtype.code,
739            bits: t.dtype.bits,
740        });
741    }
742    // Build shape vector.
743    let shape = build_shape_vec(t);
744    // Compute element count.
745    let n_elems: usize = shape.iter().product();
746    // SAFETY: data is valid, non-null, aligned, and lives for 'a (caller contract).
747    let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f64 };
748    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
749    ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
750        DlpackError::UnsupportedDtype {
751            code: t.dtype.code,
752            bits: t.dtype.bits,
753        }
754    })
755}
756
757/// Validate that a [`DLTensor`] is compatible with JAX DLPack interop.
758///
759/// JAX supports CPU, GPU (CUDA), and TPU tensors. Unlike the PyTorch validator
760/// this function accepts non-CPU devices; use [`jax_device_type`] afterwards to
761/// inspect which device the tensor lives on.
762///
763/// Returns `Ok(())` if the tensor has a non-null data pointer and a recognised
764/// float dtype (32- or 64-bit).
765///
766/// # Examples
767///
768/// ```
769/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_jax_dlpack_tensor};
770///
771/// let data = vec![1.0_f32, 2.0, 3.0];
772/// let shape = vec![3_i64];
773/// // Build a float-32 CPU tensor for testing.
774/// let mut tensor = dlpack_from_slice(
775///     &[1.0_f64, 2.0, 3.0],
776///     &[3_i64],
777/// );
778/// // Adjust to f32
779/// tensor.dtype.bits = 32;
780/// tensor.data = data.as_ptr() as *mut std::ffi::c_void;
781/// assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
782/// ```
783pub fn validate_jax_dlpack_tensor(tensor: &DLTensor) -> Result<(), DlpackError> {
784    // 1. Null pointer check.
785    if tensor.data.is_null() {
786        return Err(DlpackError::NullPointer);
787    }
788    // 2. Float dtype required (code 2), 32- or 64-bit.
789    if tensor.dtype.code != DLDataTypeCode::Float as u8
790        || (tensor.dtype.bits != 32 && tensor.dtype.bits != 64)
791    {
792        return Err(DlpackError::UnsupportedDtype {
793            code: tensor.dtype.code,
794            bits: tensor.dtype.bits,
795        });
796    }
797    Ok(())
798}
799
800/// Classify the device reported by a [`DLTensor`] as a [`JaxDeviceType`].
801///
802/// Returns `Some(JaxDeviceType)` for recognised JAX device codes, or `None`
803/// for unrecognised codes.
804///
805/// | Code | Device |
806/// |------|--------|
807/// | 1    | CPU    |
808/// | 2    | GPU (CUDA) |
809/// | 13   | TPU (JAX extension) |
810///
811/// # Examples
812///
813/// ```
814/// use scirs2_numpy::dlpack::{dlpack_from_slice, jax_device_type, JaxDeviceType};
815///
816/// let data = vec![0.0_f64; 4];
817/// let shape = vec![4_i64];
818/// let tensor = dlpack_from_slice(&data, &shape);
819/// assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Cpu));
820/// ```
821pub fn jax_device_type(tensor: &DLTensor) -> Option<JaxDeviceType> {
822    match tensor.device.device_type {
823        1 => Some(JaxDeviceType::Cpu),
824        2 => Some(JaxDeviceType::Gpu),
825        DL_DEVICE_TYPE_TPU => Some(JaxDeviceType::Tpu),
826        _ => None,
827    }
828}
829
830/// Generic DLPack array construction — accepts `f32` tensors from any framework.
831///
832/// Builds an `ndarray` view backed by the tensor's data pointer.  Only CPU
833/// tensors are supported; non-CPU tensors return [`DlpackError::NonCpuDevice`].
834///
835/// # Safety
836///
837/// The caller must guarantee that:
838/// - `tensor` is a valid, aligned, non-null pointer to a `DLTensor`.
839/// - The tensor's `data` field points to at least `product(shape) * 4` bytes of
840///   valid `f32` values in native byte order.
841/// - The tensor remains live and unmodified for the lifetime of the returned
842///   view `'a`.
843///
844/// # Errors
845///
846/// Returns [`DlpackError`] if the tensor is not CPU-resident or not `f32`.
847pub unsafe fn array_from_dlpack_f32<'a>(
848    tensor: *const DLTensor,
849) -> Result<ndarray::ArrayViewD<'a, f32>, DlpackError> {
850    // SAFETY: caller guarantees tensor is valid and non-null.
851    let t = unsafe { &*tensor };
852    if t.data.is_null() {
853        return Err(DlpackError::NullPointer);
854    }
855    if t.device.device_type != DLDeviceType::Cpu as i32 {
856        return Err(DlpackError::NonCpuDevice);
857    }
858    if t.dtype.code != DLDataTypeCode::Float as u8 || t.dtype.bits != 32 {
859        return Err(DlpackError::UnsupportedDtype {
860            code: t.dtype.code,
861            bits: t.dtype.bits,
862        });
863    }
864    let shape = build_shape_vec(t);
865    let n_elems: usize = shape.iter().product();
866    let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f32 };
867    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
868    ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
869        DlpackError::UnsupportedDtype {
870            code: t.dtype.code,
871            bits: t.dtype.bits,
872        }
873    })
874}
875
876/// Generic DLPack array construction — accepts `f64` tensors from any framework.
877///
878/// Same as [`array_from_dlpack_f32`] but for 64-bit float tensors.
879///
880/// # Safety
881///
882/// Same invariants as [`array_from_dlpack_f32`], but data must be `f64`.
883///
884/// # Errors
885///
886/// Returns [`DlpackError`] if the tensor is not CPU-resident or not `f64`.
887pub unsafe fn array_from_dlpack_f64<'a>(
888    tensor: *const DLTensor,
889) -> Result<ndarray::ArrayViewD<'a, f64>, DlpackError> {
890    // SAFETY: caller guarantees tensor is valid and non-null.
891    let t = unsafe { &*tensor };
892    if t.data.is_null() {
893        return Err(DlpackError::NullPointer);
894    }
895    if t.device.device_type != DLDeviceType::Cpu as i32 {
896        return Err(DlpackError::NonCpuDevice);
897    }
898    if t.dtype.code != DLDataTypeCode::Float as u8 || t.dtype.bits != 64 {
899        return Err(DlpackError::UnsupportedDtype {
900            code: t.dtype.code,
901            bits: t.dtype.bits,
902        });
903    }
904    let shape = build_shape_vec(t);
905    let n_elems: usize = shape.iter().product();
906    let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f64 };
907    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
908    ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
909        DlpackError::UnsupportedDtype {
910            code: t.dtype.code,
911            bits: t.dtype.bits,
912        }
913    })
914}
915
916/// Build a shape `Vec<usize>` from the `ndim` / `shape` fields of a [`DLTensor`].
917///
918/// Returns an empty vector for zero-dimensional tensors.
919///
920/// # Safety
921///
922/// `tensor.shape` must be valid for `tensor.ndim` elements.
923fn build_shape_vec(tensor: &DLTensor) -> Vec<usize> {
924    if tensor.ndim <= 0 || tensor.shape.is_null() {
925        Vec::new()
926    } else {
927        // SAFETY: shape ptr is valid for ndim elements (caller contract).
928        let raw =
929            unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
930        raw.iter().map(|&d| d as usize).collect()
931    }
932}
933
934// ─── Tests ───────────────────────────────────────────────────────────────────
935
936#[cfg(test)]
937mod tests {
938    use super::*;
939
940    // --- validate_dlpack_tensor ---
941
942    #[test]
943    fn test_validate_valid_f64_cpu_tensor() {
944        let mut data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
945        let mut shape = vec![2_i64, 3];
946        let tensor = dlpack_from_slice(&data, &shape);
947
948        let info = validate_dlpack_tensor(&tensor).expect("validate_dlpack_tensor failed");
949        assert_eq!(info.shape, vec![2, 3]);
950        assert_eq!(info.dtype_code, DLDataTypeCode::Float);
951        assert_eq!(info.dtype_bits, 64);
952        assert_eq!(info.device_type, DLDeviceType::Cpu);
953
954        // Keep data and shape alive.
955        let _ = (&mut data, &mut shape);
956    }
957
958    #[test]
959    fn test_validate_null_pointer_returns_err() {
960        let shape = vec![3_i64];
961        let mut tensor = dlpack_from_slice(&[0.0_f64; 3], &shape);
962        // Forcibly set data to null to test the null-pointer guard.
963        tensor.data = std::ptr::null_mut();
964        let result = validate_dlpack_tensor(&tensor);
965        assert!(
966            matches!(result, Err(DlpackError::NullPointer)),
967            "expected NullPointer error"
968        );
969    }
970
971    #[test]
972    fn test_validate_shape_fields() {
973        let data = vec![10.0_f64; 12];
974        let shape = vec![3_i64, 4];
975        let tensor = dlpack_from_slice(&data, &shape);
976        let info = validate_dlpack_tensor(&tensor).expect("validate failed");
977        assert_eq!(info.shape, vec![3, 4]);
978    }
979
980    // --- dlpack_from_slice ---
981
982    #[test]
983    fn test_dlpack_from_slice_shape_fields() {
984        let data = vec![1.0_f64, 2.0, 3.0];
985        let shape = vec![3_i64];
986        let tensor = dlpack_from_slice(&data, &shape);
987
988        assert_eq!(tensor.ndim, 1);
989        assert!(!tensor.data.is_null());
990        assert!(!tensor.shape.is_null());
991        // dtype must be float64
992        assert_eq!(tensor.dtype.code, 2); // Float
993        assert_eq!(tensor.dtype.bits, 64);
994    }
995
996    #[test]
997    fn test_dlpack_from_slice_2d() {
998        let data = vec![0.0_f64; 6];
999        let shape = vec![2_i64, 3];
1000        let tensor = dlpack_from_slice(&data, &shape);
1001        assert_eq!(tensor.ndim, 2);
1002        // SAFETY: shape is valid for ndim=2.
1003        let s = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, 2) };
1004        assert_eq!(s, [2, 3]);
1005    }
1006
1007    // --- dlpack_to_vec_f64 ---
1008
1009    #[test]
1010    fn test_dlpack_to_vec_f64_round_trip() {
1011        let original = vec![1.0_f64, 2.5, 3.15, -7.0, 0.0];
1012        let shape = vec![5_i64];
1013        let tensor = dlpack_from_slice(&original, &shape);
1014
1015        let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
1016        assert_eq!(recovered, original);
1017    }
1018
1019    #[test]
1020    fn test_dlpack_to_vec_f64_2d() {
1021        let original: Vec<f64> = (0..6).map(|i| i as f64).collect();
1022        let shape = vec![2_i64, 3];
1023        let tensor = dlpack_from_slice(&original, &shape);
1024
1025        let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
1026        assert_eq!(recovered, original);
1027    }
1028
1029    #[test]
1030    fn test_dlpack_to_vec_f64_null_pointer_err() {
1031        let data = vec![0.0_f64];
1032        let shape = vec![1_i64];
1033        let mut tensor = dlpack_from_slice(&data, &shape);
1034        tensor.data = std::ptr::null_mut();
1035
1036        assert!(matches!(
1037            dlpack_to_vec_f64(&tensor),
1038            Err(DlpackError::NullPointer)
1039        ));
1040    }
1041
1042    #[test]
1043    fn test_dlpack_to_vec_f64_non_cpu_err() {
1044        let data = vec![0.0_f64];
1045        let shape = vec![1_i64];
1046        let mut tensor = dlpack_from_slice(&data, &shape);
1047        tensor.device.device_type = DLDeviceType::Cuda as i32;
1048
1049        assert!(matches!(
1050            dlpack_to_vec_f64(&tensor),
1051            Err(DlpackError::NonCpuDevice)
1052        ));
1053    }
1054
1055    #[test]
1056    fn test_dlpack_to_vec_f64_wrong_dtype_err() {
1057        let data = vec![0.0_f64];
1058        let shape = vec![1_i64];
1059        let mut tensor = dlpack_from_slice(&data, &shape);
1060        tensor.dtype.code = 0; // Int, not Float
1061
1062        assert!(matches!(
1063            dlpack_to_vec_f64(&tensor),
1064            Err(DlpackError::UnsupportedDtype { .. })
1065        ));
1066    }
1067
1068    // --- DLDataTypeCode ---
1069
1070    #[test]
1071    fn test_dtype_code_try_from() {
1072        assert_eq!(DLDataTypeCode::try_from(0u8).unwrap(), DLDataTypeCode::Int);
1073        assert_eq!(DLDataTypeCode::try_from(1u8).unwrap(), DLDataTypeCode::UInt);
1074        assert_eq!(
1075            DLDataTypeCode::try_from(2u8).unwrap(),
1076            DLDataTypeCode::Float
1077        );
1078        assert_eq!(
1079            DLDataTypeCode::try_from(3u8).unwrap(),
1080            DLDataTypeCode::BFloat
1081        );
1082        assert!(DLDataTypeCode::try_from(99u8).is_err());
1083    }
1084
1085    // ─── Item 1: PyTorch tensor interop tests ────────────────────────────────
1086
1087    #[test]
1088    fn dlpack_device_type_cpu_is_1() {
1089        assert_eq!(DLDeviceType::Cpu as i32, 1);
1090    }
1091
1092    #[test]
1093    fn dlpack_dtype_float32_code_is_2() {
1094        // DLPack spec: kDLFloat = 2
1095        assert_eq!(DLDataTypeCode::Float as u8, 2);
1096    }
1097
1098    #[test]
1099    fn dlpack_validate_non_contiguous_fails() {
1100        // Build a 2-D tensor with non-row-major strides to simulate a transposed
1101        // PyTorch tensor.  Shape = [2, 3], but strides = [1, 2] (column-major).
1102        let data = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1103        let shape = [2_i64, 3];
1104        // Column-major strides (Fortran order): stride[0]=1, stride[1]=2
1105        let strides = [1_i64, 2];
1106        let tensor = DLTensor {
1107            data: data.as_ptr() as *mut c_void,
1108            device: DLDevice {
1109                device_type: DLDeviceType::Cpu as i32,
1110                device_id: 0,
1111            },
1112            ndim: 2,
1113            dtype: DLDataType {
1114                code: DLDataTypeCode::Float as u8,
1115                bits: 32,
1116                lanes: 1,
1117            },
1118            shape: shape.as_ptr() as *mut i64,
1119            strides: strides.as_ptr() as *mut i64,
1120            byte_offset: 0,
1121        };
1122        assert!(
1123            matches!(
1124                validate_torch_dlpack_tensor(&tensor),
1125                Err(DlpackError::NonContiguous)
1126            ),
1127            "expected NonContiguous error for column-major strides"
1128        );
1129    }
1130
1131    #[test]
1132    fn dlpack_validate_2d_float_tensor_passes() {
1133        let data = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1134        let shape = [2_i64, 3];
1135        let tensor = DLTensor {
1136            data: data.as_ptr() as *mut c_void,
1137            device: DLDevice {
1138                device_type: DLDeviceType::Cpu as i32,
1139                device_id: 0,
1140            },
1141            ndim: 2,
1142            dtype: DLDataType {
1143                code: DLDataTypeCode::Float as u8,
1144                bits: 32,
1145                lanes: 1,
1146            },
1147            shape: shape.as_ptr() as *mut i64,
1148            strides: std::ptr::null_mut(),
1149            byte_offset: 0,
1150        };
1151        assert!(validate_torch_dlpack_tensor(&tensor).is_ok());
1152    }
1153
1154    #[test]
1155    fn dlpack_validate_non_cpu_tensor_fails() {
1156        let data = [1.0_f32; 4];
1157        let shape = [4_i64];
1158        let tensor = DLTensor {
1159            data: data.as_ptr() as *mut c_void,
1160            device: DLDevice {
1161                device_type: DLDeviceType::Cuda as i32,
1162                device_id: 0,
1163            },
1164            ndim: 1,
1165            dtype: DLDataType {
1166                code: DLDataTypeCode::Float as u8,
1167                bits: 32,
1168                lanes: 1,
1169            },
1170            shape: shape.as_ptr() as *mut i64,
1171            strides: std::ptr::null_mut(),
1172            byte_offset: 0,
1173        };
1174        assert!(matches!(
1175            validate_torch_dlpack_tensor(&tensor),
1176            Err(DlpackError::NonCpuDevice)
1177        ));
1178    }
1179
1180    #[test]
1181    fn dlarray_from_torch_f32_round_trip() {
1182        let data = [1.0_f32, 2.0, 3.0, 4.0];
1183        let shape = [2_i64, 2];
1184        let tensor = DLTensor {
1185            data: data.as_ptr() as *mut c_void,
1186            device: DLDevice {
1187                device_type: DLDeviceType::Cpu as i32,
1188                device_id: 0,
1189            },
1190            ndim: 2,
1191            dtype: DLDataType {
1192                code: DLDataTypeCode::Float as u8,
1193                bits: 32,
1194                lanes: 1,
1195            },
1196            shape: shape.as_ptr() as *mut i64,
1197            strides: std::ptr::null_mut(),
1198            byte_offset: 0,
1199        };
1200        // SAFETY: tensor is valid, data and shape are alive.
1201        let view = unsafe { dlarray_from_torch_f32(&tensor as *const DLTensor) }
1202            .expect("dlarray_from_torch_f32 failed");
1203        assert_eq!(view.shape(), &[2, 2]);
1204        assert_eq!(view[[0, 0]], 1.0_f32);
1205        assert_eq!(view[[1, 1]], 4.0_f32);
1206    }
1207
1208    #[test]
1209    fn dlarray_from_torch_f64_round_trip() {
1210        let data = [10.0_f64, 20.0, 30.0];
1211        let shape = [3_i64];
1212        let tensor = DLTensor {
1213            data: data.as_ptr() as *mut c_void,
1214            device: DLDevice {
1215                device_type: DLDeviceType::Cpu as i32,
1216                device_id: 0,
1217            },
1218            ndim: 1,
1219            dtype: DLDataType {
1220                code: DLDataTypeCode::Float as u8,
1221                bits: 64,
1222                lanes: 1,
1223            },
1224            shape: shape.as_ptr() as *mut i64,
1225            strides: std::ptr::null_mut(),
1226            byte_offset: 0,
1227        };
1228        // SAFETY: tensor is valid, data and shape are alive.
1229        let view = unsafe { dlarray_from_torch_f64(&tensor as *const DLTensor) }
1230            .expect("dlarray_from_torch_f64 failed");
1231        assert_eq!(view.shape(), &[3]);
1232        assert_eq!(view[2], 30.0_f64);
1233    }
1234
1235    // ─── Item 2: JAX array interop tests ─────────────────────────────────────
1236
1237    #[test]
1238    fn dlpack_jax_cpu_tensor_valid() {
1239        let data = [1.0_f64, 2.0, 3.0];
1240        let shape = [3_i64];
1241        let tensor = dlpack_from_slice(&data, &shape);
1242        assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
1243        assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Cpu));
1244    }
1245
1246    #[test]
1247    fn dlpack_jax_tpu_device_recognized() {
1248        let data = [1.0_f64];
1249        let shape = [1_i64];
1250        let mut tensor = dlpack_from_slice(&data, &shape);
1251        tensor.device.device_type = DL_DEVICE_TYPE_TPU;
1252        // JAX validator does not require CPU — only float dtype.
1253        assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
1254        assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Tpu));
1255    }
1256
1257    #[test]
1258    fn dlpack_generic_from_dlpack_handles_both_torch_and_jax() {
1259        // f32 torch-style tensor
1260        let data_f32 = [0.5_f32, 1.5, 2.5, 3.5];
1261        let shape = [2_i64, 2];
1262        let tensor_f32 = DLTensor {
1263            data: data_f32.as_ptr() as *mut c_void,
1264            device: DLDevice {
1265                device_type: DLDeviceType::Cpu as i32,
1266                device_id: 0,
1267            },
1268            ndim: 2,
1269            dtype: DLDataType {
1270                code: DLDataTypeCode::Float as u8,
1271                bits: 32,
1272                lanes: 1,
1273            },
1274            shape: shape.as_ptr() as *mut i64,
1275            strides: std::ptr::null_mut(),
1276            byte_offset: 0,
1277        };
1278        // SAFETY: tensor_f32 is valid; data_f32 and shape are alive.
1279        let view_f32 = unsafe { array_from_dlpack_f32(&tensor_f32 as *const DLTensor) }
1280            .expect("array_from_dlpack_f32 failed");
1281        assert_eq!(view_f32.shape(), &[2, 2]);
1282
1283        // f64 jax-style tensor (same CPU device)
1284        let data_f64 = [1.0_f64, 2.0, 3.0, 4.0];
1285        let shape_f64 = [4_i64];
1286        let tensor_f64 = dlpack_from_slice(&data_f64, &shape_f64);
1287        // SAFETY: tensor_f64 is valid; data_f64 and shape_f64 are alive.
1288        let view_f64 = unsafe { array_from_dlpack_f64(&tensor_f64 as *const DLTensor) }
1289            .expect("array_from_dlpack_f64 failed");
1290        assert_eq!(view_f64.shape(), &[4]);
1291        assert_eq!(view_f64[3], 4.0_f64);
1292    }
1293}