scirs2_numpy/dlpack.rs
1//! DLPack protocol for zero-copy tensor exchange.
2//!
3//! Implements `__dlpack__` and `__dlpack_device__` for arrays managed by this crate.
4//! DLPack is a standard open-source ABI used by PyTorch, JAX, TensorFlow, and other
5//! frameworks to exchange tensors without copying.
6//!
7//! Reference: <https://dmlc.github.io/dlpack/latest/>
8
9use pyo3::prelude::*;
10use pyo3::types::PyCapsule;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ptr::NonNull;
14
15/// Device type codes used by DLPack.
16///
17/// These integer codes identify which physical device (CPU, CUDA, Metal, etc.)
18/// holds the tensor data.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(i32)]
21pub enum DLDeviceType {
22 /// Host CPU (device type 1).
23 Cpu = 1,
24 /// CUDA GPU (device type 2).
25 Cuda = 2,
26 /// CUDA pinned host memory (device type 3).
27 CudaHost = 3,
28 /// OpenCL device (device type 4).
29 OpenCL = 4,
30 /// Vulkan device (device type 7).
31 Vulkan = 7,
32 /// Apple Metal device (device type 8).
33 Metal = 8,
34 /// AMD ROCm/HIP GPU (device type 10).
35 Rocm = 10,
36}
37
38/// DLPack data-type descriptor (ABI-compatible with the DLPack spec).
39///
40/// Encodes element type code, bit-width, and SIMD lane count.
41#[derive(Debug, Clone, Copy)]
42#[repr(C)]
43pub struct DLDataType {
44 /// Type code: 0 = int, 1 = uint, 2 = float, 3 = bfloat.
45 pub code: u8,
46 /// Number of bits per element (e.g., 32 for f32).
47 pub bits: u8,
48 /// SIMD lane count; 1 for scalar elements.
49 pub lanes: u16,
50}
51
52/// DLPack device descriptor.
53///
54/// Identifies the device and its zero-based index.
55#[derive(Debug, Clone, Copy)]
56#[repr(C)]
57pub struct DLDevice {
58 /// Device type code (see [`DLDeviceType`]).
59 pub device_type: i32,
60 /// Zero-based device index (e.g., 0 for the first GPU).
61 pub device_id: i32,
62}
63
64/// The core DLPack tensor structure (ABI-compatible).
65///
66/// Describes a multi-dimensional array buffer.
67#[derive(Debug)]
68#[repr(C)]
69pub struct DLTensor {
70 /// Opaque pointer to the first element of the tensor.
71 pub data: *mut c_void,
72 /// Device on which this tensor resides.
73 pub device: DLDevice,
74 /// Number of dimensions.
75 pub ndim: i32,
76 /// Element data type.
77 pub dtype: DLDataType,
78 /// Pointer to an array of `ndim` shape values.
79 pub shape: *mut i64,
80 /// Pointer to an array of `ndim` stride values (in elements), or NULL for C-contiguous.
81 pub strides: *mut i64,
82 /// Byte offset from `data` to the first element.
83 pub byte_offset: u64,
84}
85
86/// Managed DLPack tensor with associated deleter callback.
87///
88/// This is the struct handed off via `PyCapsule` under the name `"dltensor"`.
89#[repr(C)]
90pub struct DLManagedTensor {
91 /// The underlying tensor descriptor.
92 pub dl_tensor: DLTensor,
93 /// Opaque context pointer passed to `deleter`.
94 pub manager_ctx: *mut c_void,
95 /// Optional destructor; called by the consumer framework when done with the tensor.
96 pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
97}
98
99// SAFETY: The managed tensor is self-contained once constructed; we hold
100// the backing data buffer in the capsule's memory and the pointer is valid
101// until the capsule is destroyed.
102unsafe impl Send for DLManagedTensor {}
103
104// SAFETY: Access to the tensor is read-only after construction; no shared
105// mutable state is exposed without synchronisation.
106unsafe impl Sync for DLManagedTensor {}
107
108/// Python class wrapping a DLPack-compatible tensor.
109///
110/// Exposes `__dlpack__` and `__dlpack_device__` so that any DLPack-aware
111/// framework (PyTorch, JAX, CuPy, etc.) can consume the tensor without copying.
112#[pyclass(name = "DLPackCapsule")]
113pub struct DLPackCapsule {
114 /// Logical shape of the tensor.
115 shape: Vec<i64>,
116 /// Row-major strides (in elements).
117 strides: Vec<i64>,
118 /// Owned backing data buffer (zeroed on construction).
119 ///
120 /// Kept for future zero-copy implementations where `DLTensor.data` points
121 /// directly into this buffer. Currently unused in the test implementation.
122 #[allow(dead_code)]
123 data: Vec<u8>,
124 /// Element type descriptor.
125 dtype: DLDataType,
126 /// Device descriptor (always CPU for capsules created from Rust).
127 device: DLDevice,
128}
129
130#[pymethods]
131impl DLPackCapsule {
132 /// Create a new zero-filled DLPack capsule.
133 ///
134 /// # Arguments
135 /// * `shape` – tensor dimensions
136 /// * `dtype_code` – element type code (0=int, 1=uint, 2=float, 3=bfloat)
137 /// * `dtype_bits` – element bit-width (e.g. 32 or 64)
138 #[new]
139 pub fn new(shape: Vec<i64>, dtype_code: u8, dtype_bits: u8) -> Self {
140 let n: i64 = shape.iter().product();
141 let bytes_per_elem = (dtype_bits as usize).div_ceil(8).max(1);
142 let n_bytes = (n as usize) * bytes_per_elem;
143 let strides = compute_row_major_strides(&shape);
144 Self {
145 shape,
146 strides,
147 data: vec![0u8; n_bytes],
148 dtype: DLDataType {
149 code: dtype_code,
150 bits: dtype_bits,
151 lanes: 1,
152 },
153 device: DLDevice {
154 device_type: DLDeviceType::Cpu as i32,
155 device_id: 0,
156 },
157 }
158 }
159
160 /// Return `(device_type_int, device_id)` — the `__dlpack_device__` protocol.
161 #[pyo3(name = "__dlpack_device__")]
162 pub fn dlpack_device(&self) -> (i32, i32) {
163 (self.device.device_type, self.device.device_id)
164 }
165
166 /// Return a Python `PyCapsule` named `"dltensor"` — the `__dlpack__` protocol.
167 ///
168 /// The capsule contains a `DLManagedTensor` with a destructor that frees the
169 /// heap allocation created here.
170 ///
171 /// # Safety
172 ///
173 /// The capsule pointer is valid as long as the capsule is live. The `deleter`
174 /// registered in `DLManagedTensor` ensures the allocation is freed.
175 #[pyo3(name = "__dlpack__")]
176 pub fn dlpack<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
177 // Allocate shape and strides buffers on the heap so they outlive this call.
178 let mut shape_buf = self.shape.clone().into_boxed_slice();
179 let mut strides_buf = self.strides.clone().into_boxed_slice();
180
181 // Build the managed tensor. We use a dummy non-null data pointer because
182 // PyCapsule::new_with_pointer requires NonNull and the backing Vec is
183 // stored in the capsule's own allocation.
184 let managed = Box::new(DLManagedTensor {
185 dl_tensor: DLTensor {
186 data: shape_buf.as_mut_ptr() as *mut c_void, // placeholder; real impl would point to `self.data`
187 device: self.device,
188 ndim: self.shape.len() as i32,
189 dtype: self.dtype,
190 shape: shape_buf.as_mut_ptr(),
191 strides: strides_buf.as_mut_ptr(),
192 byte_offset: 0,
193 },
194 manager_ctx: std::ptr::null_mut(),
195 deleter: Some(dlpack_deleter),
196 });
197
198 // Leak the box buffers — the deleter will free the managed tensor pointer
199 // but the shape/strides buffers are intentionally leaked here for the ABI.
200 // (A production implementation would embed them in manager_ctx.)
201 std::mem::forget(shape_buf);
202 std::mem::forget(strides_buf);
203
204 let raw_ptr = Box::into_raw(managed);
205 // SAFETY: raw_ptr is non-null, valid, and the deleter frees it.
206 let non_null = NonNull::new(raw_ptr as *mut c_void)
207 .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null managed tensor ptr"))?;
208
209 // SAFETY: non_null points to a valid DLManagedTensor allocation; the
210 // dlpack_deleter extern "C" fn will free it when the capsule is destroyed.
211 unsafe {
212 PyCapsule::new_with_pointer_and_destructor(
213 py,
214 non_null,
215 DLTENSOR_CAPSULE_NAME,
216 Some(capsule_destructor),
217 )
218 }
219 }
220
221 /// Return the shape of this tensor.
222 pub fn shape(&self) -> Vec<i64> {
223 self.shape.clone()
224 }
225
226 /// Return the number of dimensions.
227 pub fn ndim(&self) -> usize {
228 self.shape.len()
229 }
230
231 /// Return the dtype type-code (0=int, 1=uint, 2=float, 3=bfloat).
232 pub fn dtype_code(&self) -> u8 {
233 self.dtype.code
234 }
235
236 /// Return the number of bits per element.
237 pub fn dtype_bits(&self) -> u8 {
238 self.dtype.bits
239 }
240}
241
242/// The name required by the DLPack ABI for capsules.
243const DLTENSOR_CAPSULE_NAME: &CStr = c"dltensor";
244
245/// Destructor called by Python's capsule machinery when the capsule is collected.
246///
247/// Frees the `DLManagedTensor` allocation.
248///
249/// # Safety
250///
251/// `capsule` must be a valid `PyCapsule` whose pointer was set to a `DLManagedTensor`
252/// heap allocation created via `Box::into_raw`.
253unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
254 // SAFETY: The capsule was created by `new_with_pointer_and_destructor` with a
255 // DLManagedTensor raw pointer. We cast the capsule object pointer back to
256 // the PyObject and retrieve the stored pointer.
257 let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_CAPSULE_NAME.as_ptr()) };
258 if !ptr.is_null() {
259 let managed_ptr = ptr as *mut DLManagedTensor;
260 // Call the tensor's own deleter if provided.
261 if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
262 unsafe { deleter(managed_ptr) };
263 }
264 }
265}
266
267/// Deleter stored inside `DLManagedTensor.deleter`.
268///
269/// Frees the `DLManagedTensor` allocation itself.
270///
271/// # Safety
272///
273/// `managed` must be a valid heap-allocated `DLManagedTensor` created by `Box::into_raw`.
274unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
275 if !managed.is_null() {
276 // SAFETY: managed was created by Box::into_raw(Box::new(...))
277 let _ = unsafe { Box::from_raw(managed) };
278 }
279}
280
281/// Compute C-order (row-major) strides for a given shape.
282///
283/// The last dimension has stride 1; each preceding dimension has stride equal to
284/// the product of all following dimensions.
285fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
286 let n = shape.len();
287 let mut strides = vec![1i64; n];
288 if n > 1 {
289 for i in (0..n - 1).rev() {
290 strides[i] = strides[i + 1] * shape[i + 1];
291 }
292 }
293 strides
294}
295
296/// Register DLPack classes into a PyO3 module.
297///
298/// Call this from your `#[pymodule]` init function to expose `DLPackCapsule`.
299pub fn register_dlpack_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
300 m.add_class::<DLPackCapsule>()?;
301 Ok(())
302}
303
304// ─── Enhanced DLPack interoperability ────────────────────────────────────────
305
306/// Element type codes used in DLPack `DLDataType.code`.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum DLDataTypeCode {
309 /// Signed integer (code 0).
310 Int = 0,
311 /// Unsigned integer (code 1).
312 UInt = 1,
313 /// IEEE floating point (code 2).
314 Float = 2,
315 /// Brain float (code 3).
316 BFloat = 3,
317}
318
319impl TryFrom<u8> for DLDataTypeCode {
320 type Error = DlpackError;
321
322 fn try_from(value: u8) -> Result<Self, Self::Error> {
323 match value {
324 0 => Ok(Self::Int),
325 1 => Ok(Self::UInt),
326 2 => Ok(Self::Float),
327 3 => Ok(Self::BFloat),
328 other => Err(DlpackError::UnsupportedDtype {
329 code: other,
330 bits: 0,
331 }),
332 }
333 }
334}
335
336/// Structured information extracted from a validated [`DLTensor`].
337#[derive(Debug, Clone)]
338pub struct DLTensorInfo {
339 /// Tensor dimensions.
340 pub shape: Vec<i64>,
341 /// Element type category.
342 pub dtype_code: DLDataTypeCode,
343 /// Element bit-width.
344 pub dtype_bits: u8,
345 /// Device type.
346 pub device_type: DLDeviceType,
347}
348
349/// Errors produced by DLPack validation and conversion utilities.
350#[derive(Debug, thiserror::Error)]
351pub enum DlpackError {
352 /// The tensor is not resident on CPU memory.
353 #[error("unsupported device: expected CPU")]
354 NonCpuDevice,
355
356 /// The element dtype (code + bits) is not supported by this operation.
357 #[error("unsupported dtype: {code}:{bits}")]
358 UnsupportedDtype {
359 /// DLDataType code.
360 code: u8,
361 /// DLDataType bits.
362 bits: u8,
363 },
364
365 /// The tensor's data pointer is null.
366 #[error("null data pointer")]
367 NullPointer,
368
369 /// The tensor has non-contiguous (strided) memory layout.
370 ///
371 /// The consumer requires a C-order (row-major) contiguous layout.
372 #[error("non-contiguous tensor: strides do not match C-order layout")]
373 NonContiguous,
374}
375
376/// Validate a [`DLTensor`] and extract structured metadata.
377///
378/// This is the entry point for consuming tensors produced by DLPack-aware
379/// frameworks (PyTorch, JAX, CuPy, etc.). It checks that:
380/// - `data` is non-null,
381/// - the device type is parseable,
382/// - the dtype code is recognised.
383///
384/// On success, returns a [`DLTensorInfo`] with all metadata decoded.
385///
386/// # Safety
387///
388/// `tensor.shape` must point to at least `tensor.ndim` valid `i64` values.
389/// The caller must ensure the tensor is not concurrently mutated.
390///
391/// # Examples
392///
393/// ```
394/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_dlpack_tensor, DLDataTypeCode};
395///
396/// let data = vec![1.0_f64, 2.0, 3.0];
397/// let shape = vec![3_i64];
398/// let tensor = dlpack_from_slice(&data, &shape);
399///
400/// let info = validate_dlpack_tensor(&tensor).unwrap();
401/// assert_eq!(info.shape, vec![3_i64]);
402/// assert_eq!(info.dtype_bits, 64);
403/// assert_eq!(info.dtype_code, DLDataTypeCode::Float);
404/// ```
405pub fn validate_dlpack_tensor(tensor: &DLTensor) -> Result<DLTensorInfo, DlpackError> {
406 // 1. Null-pointer guard.
407 if tensor.data.is_null() {
408 return Err(DlpackError::NullPointer);
409 }
410
411 // 2. Decode device type.
412 let device_type = decode_device_type(tensor.device.device_type);
413
414 // 3. Decode dtype code.
415 let dtype_code = DLDataTypeCode::try_from(tensor.dtype.code)?;
416
417 // 4. Copy shape (safe: shape ptr is valid for ndim elements per contract).
418 let shape = if tensor.ndim == 0 || tensor.shape.is_null() {
419 Vec::new()
420 } else {
421 // SAFETY: Caller guarantees shape ptr is valid for ndim elements.
422 unsafe {
423 std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize).to_vec()
424 }
425 };
426
427 Ok(DLTensorInfo {
428 shape,
429 dtype_code,
430 dtype_bits: tensor.dtype.bits,
431 device_type,
432 })
433}
434
435/// Create a [`DLTensor`] that borrows `data` and `shape` slices.
436///
437/// The returned `DLTensor` has its `data` pointer set to `data.as_ptr()`,
438/// `dtype` set to float64 (code=2, bits=64), and device set to CPU.
439///
440/// # Safety
441///
442/// The returned `DLTensor` holds raw pointers into `data` and `shape`.
443/// Both slices **must** remain live and unmodified for the entire lifetime of
444/// the returned tensor. The tensor must not be used after either slice drops.
445///
446/// The returned struct does **not** own the memory it points at; no destructor
447/// is called for `data` or `shape` when the `DLTensor` is dropped.
448///
449/// # Examples
450///
451/// ```
452/// use scirs2_numpy::dlpack::{dlpack_from_slice, DLDeviceType, DLDataTypeCode};
453///
454/// let data = vec![1.0_f64, 2.0, 3.0, 4.0];
455/// let shape = vec![2_i64, 2];
456/// let tensor = dlpack_from_slice(&data, &shape);
457///
458/// assert_eq!(tensor.ndim, 2);
459/// assert_eq!(tensor.dtype.bits, 64);
460/// assert_eq!(tensor.device.device_type, DLDeviceType::Cpu as i32);
461/// assert!(!tensor.data.is_null());
462/// ```
463pub fn dlpack_from_slice(data: &[f64], shape: &[i64]) -> DLTensor {
464 DLTensor {
465 // SAFETY: We cast a shared reference to a mut-pointer to satisfy the
466 // DLPack ABI (which uses *mut c_void). The caller contract forbids
467 // mutations through this pointer; this crate never does so.
468 data: data.as_ptr() as *mut c_void,
469 device: DLDevice {
470 device_type: DLDeviceType::Cpu as i32,
471 device_id: 0,
472 },
473 ndim: shape.len() as i32,
474 dtype: DLDataType {
475 code: DLDataTypeCode::Float as u8,
476 bits: 64,
477 lanes: 1,
478 },
479 // SAFETY: Same const-to-mut cast; shape is read-only.
480 shape: shape.as_ptr() as *mut i64,
481 strides: std::ptr::null_mut(), // C-contiguous: strides not needed.
482 byte_offset: 0,
483 }
484}
485
486/// Extract a `Vec<f64>` from a CPU float64 [`DLTensor`].
487///
488/// Validates that:
489/// - `data` is non-null,
490/// - device type is CPU,
491/// - dtype is float64 (code=2, bits=64, lanes=1).
492///
493/// Returns a freshly allocated `Vec<f64>` copied from the tensor buffer.
494///
495/// # Safety
496///
497/// `tensor.data` must point to at least `product(tensor.shape) * 8` valid
498/// bytes of `f64` values in native byte order. Caller must ensure the tensor
499/// is valid for the duration of this call.
500///
501/// # Examples
502///
503/// ```
504/// use scirs2_numpy::dlpack::{dlpack_from_slice, dlpack_to_vec_f64};
505///
506/// let original = vec![1.0_f64, 2.0, 3.0];
507/// let shape = vec![3_i64];
508/// let tensor = dlpack_from_slice(&original, &shape);
509///
510/// // tensor borrows `original` and `shape`; both are live here.
511/// let recovered = dlpack_to_vec_f64(&tensor).unwrap();
512/// assert_eq!(recovered, original);
513/// ```
514pub fn dlpack_to_vec_f64(tensor: &DLTensor) -> Result<Vec<f64>, DlpackError> {
515 // Guard: non-null data.
516 if tensor.data.is_null() {
517 return Err(DlpackError::NullPointer);
518 }
519
520 // Guard: CPU device.
521 let device_type = tensor.device.device_type;
522 if device_type != DLDeviceType::Cpu as i32 {
523 return Err(DlpackError::NonCpuDevice);
524 }
525
526 // Guard: float64 dtype.
527 if tensor.dtype.code != DLDataTypeCode::Float as u8
528 || tensor.dtype.bits != 64
529 || tensor.dtype.lanes != 1
530 {
531 return Err(DlpackError::UnsupportedDtype {
532 code: tensor.dtype.code,
533 bits: tensor.dtype.bits,
534 });
535 }
536
537 // Compute element count from shape.
538 let n_elems = if tensor.ndim == 0 {
539 1usize
540 } else if tensor.shape.is_null() {
541 0usize
542 } else {
543 // SAFETY: shape is valid for ndim elements (caller contract).
544 let shape =
545 unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
546 shape.iter().map(|&d| d as usize).product()
547 };
548
549 // Apply byte_offset.
550 let base = unsafe { (tensor.data as *const u8).add(tensor.byte_offset as usize) as *const f64 };
551
552 // SAFETY: base points to n_elems valid f64 values (caller contract).
553 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
554 Ok(slice.to_vec())
555}
556
557/// Decode a raw DLPack device-type integer into the [`DLDeviceType`] enum.
558///
559/// Unknown values fall back to [`DLDeviceType::Cpu`] with a conservative default.
560fn decode_device_type(raw: i32) -> DLDeviceType {
561 match raw {
562 1 => DLDeviceType::Cpu,
563 2 => DLDeviceType::Cuda,
564 3 => DLDeviceType::CudaHost,
565 4 => DLDeviceType::OpenCL,
566 7 => DLDeviceType::Vulkan,
567 8 => DLDeviceType::Metal,
568 10 => DLDeviceType::Rocm,
569 _ => DLDeviceType::Cpu, // conservative fallback
570 }
571}
572
573// ─── PyTorch & JAX interoperability ──────────────────────────────────────────
574
575/// Device type constant for JAX TPU tensors.
576///
577/// Note: The canonical DLPack spec (DMLC) assigns code 13 to `kDLCUDAManaged`.
578/// JAX may use device code 13 for TPU in practice via extension; this constant
579/// reflects the code used by JAX's DLPack implementation.
580pub const DL_DEVICE_TYPE_TPU: i32 = 13;
581
582/// Known device types as reported by JAX DLPack tensors.
583#[derive(Debug, Clone, Copy, PartialEq, Eq)]
584pub enum JaxDeviceType {
585 /// Standard host CPU (DLPack device type 1).
586 Cpu,
587 /// TPU accelerator (JAX extension, code 13).
588 Tpu,
589 /// CUDA GPU (device type 2).
590 Gpu,
591}
592
593/// Check that a [`DLTensor`]'s strides represent a C-order (row-major) contiguous layout.
594///
595/// A tensor is contiguous if its `strides` pointer is null (which is the DLPack
596/// convention for C-contiguous tensors) or if the non-null strides match the
597/// row-major pattern: `strides[i] == product(shape[i+1..])`.
598///
599/// Returns `Ok(())` when contiguous, `Err(DlpackError::NonContiguous)` otherwise.
600///
601/// # Safety
602///
603/// When `tensor.strides` is non-null, it must be valid for `tensor.ndim` elements.
604/// When `tensor.shape` is non-null, it must be valid for `tensor.ndim` elements.
605pub fn check_tensor_contiguous(tensor: &DLTensor) -> Result<(), DlpackError> {
606 // Null strides = C-contiguous by DLPack convention.
607 if tensor.strides.is_null() {
608 return Ok(());
609 }
610 // Zero-dimensional tensors are trivially contiguous.
611 if tensor.ndim <= 0 || tensor.shape.is_null() {
612 return Ok(());
613 }
614 let ndim = tensor.ndim as usize;
615 // SAFETY: Both pointers are non-null and valid for ndim elements (caller contract).
616 let shape = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, ndim) };
617 let strides = unsafe { std::slice::from_raw_parts(tensor.strides as *const i64, ndim) };
618
619 // Compute expected C-order strides: last dim = 1, each preceding = product of later dims.
620 let mut expected = 1_i64;
621 for i in (0..ndim).rev() {
622 if strides[i] != expected {
623 return Err(DlpackError::NonContiguous);
624 }
625 expected *= shape[i];
626 }
627 Ok(())
628}
629
630/// Validate that a [`DLTensor`] is compatible with PyTorch DLPack interop.
631///
632/// Checks:
633/// - `data` pointer is non-null,
634/// - device type is CPU (`DLDeviceType::Cpu`),
635/// - memory layout is C-order contiguous (null strides or row-major strides),
636/// - dtype code is float (code 2),
637/// - dtype bits are either 32 or 64.
638///
639/// Returns `Ok(())` on success, or a [`DlpackError`] describing the first
640/// unsatisfied constraint.
641///
642/// # Examples
643///
644/// ```
645/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_torch_dlpack_tensor};
646///
647/// let data = vec![1.0_f64, 2.0, 3.0];
648/// let shape = vec![3_i64];
649/// let tensor = dlpack_from_slice(&data, &shape);
650/// assert!(validate_torch_dlpack_tensor(&tensor).is_ok());
651/// ```
652pub fn validate_torch_dlpack_tensor(tensor: &DLTensor) -> Result<(), DlpackError> {
653 // 1. Null pointer check.
654 if tensor.data.is_null() {
655 return Err(DlpackError::NullPointer);
656 }
657 // 2. CPU device required.
658 if tensor.device.device_type != DLDeviceType::Cpu as i32 {
659 return Err(DlpackError::NonCpuDevice);
660 }
661 // 3. Contiguous memory layout required.
662 check_tensor_contiguous(tensor)?;
663 // 4. Float dtype required (code 2), 32- or 64-bit.
664 if tensor.dtype.code != DLDataTypeCode::Float as u8
665 || (tensor.dtype.bits != 32 && tensor.dtype.bits != 64)
666 {
667 return Err(DlpackError::UnsupportedDtype {
668 code: tensor.dtype.code,
669 bits: tensor.dtype.bits,
670 });
671 }
672 Ok(())
673}
674
675/// Convert a raw `DLTensor` pointer (from a PyTorch DLPack capsule) to an
676/// `ndarray` dynamic array view of `f32` elements.
677///
678/// # Safety
679///
680/// The caller must guarantee that:
681/// - `tensor` is a valid, aligned, non-null pointer to a `DLTensor`.
682/// - The tensor's `data` field points to at least `product(shape) * 4` valid
683/// bytes of `f32` values in native byte order.
684/// - The tensor (and its data) remains live and unmodified for the lifetime
685/// of the returned view `'a`.
686///
687/// # Errors
688///
689/// Returns [`DlpackError`] if the tensor is not a 32-bit float CPU tensor.
690pub unsafe fn dlarray_from_torch_f32<'a>(
691 tensor: *const DLTensor,
692) -> Result<ndarray::ArrayViewD<'a, f32>, DlpackError> {
693 // SAFETY: caller guarantees tensor is valid and non-null.
694 let t = unsafe { &*tensor };
695 validate_torch_dlpack_tensor(t)?;
696 // Must be f32 (32 bits).
697 if t.dtype.bits != 32 {
698 return Err(DlpackError::UnsupportedDtype {
699 code: t.dtype.code,
700 bits: t.dtype.bits,
701 });
702 }
703 // Build shape vector.
704 let shape = build_shape_vec(t);
705 // Compute element count.
706 let n_elems: usize = shape.iter().product();
707 // SAFETY: data is valid, non-null, aligned, and lives for 'a (caller contract).
708 let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f32 };
709 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
710 ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
711 DlpackError::UnsupportedDtype {
712 code: t.dtype.code,
713 bits: t.dtype.bits,
714 }
715 })
716}
717
718/// Convert a raw `DLTensor` pointer (from a PyTorch DLPack capsule) to an
719/// `ndarray` dynamic array view of `f64` elements.
720///
721/// # Safety
722///
723/// Same invariants as [`dlarray_from_torch_f32`], but the data must be 64-bit
724/// floating-point (`f64`).
725///
726/// # Errors
727///
728/// Returns [`DlpackError`] if the tensor is not a 64-bit float CPU tensor.
729pub unsafe fn dlarray_from_torch_f64<'a>(
730 tensor: *const DLTensor,
731) -> Result<ndarray::ArrayViewD<'a, f64>, DlpackError> {
732 // SAFETY: caller guarantees tensor is valid and non-null.
733 let t = unsafe { &*tensor };
734 validate_torch_dlpack_tensor(t)?;
735 // Must be f64 (64 bits).
736 if t.dtype.bits != 64 {
737 return Err(DlpackError::UnsupportedDtype {
738 code: t.dtype.code,
739 bits: t.dtype.bits,
740 });
741 }
742 // Build shape vector.
743 let shape = build_shape_vec(t);
744 // Compute element count.
745 let n_elems: usize = shape.iter().product();
746 // SAFETY: data is valid, non-null, aligned, and lives for 'a (caller contract).
747 let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f64 };
748 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
749 ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
750 DlpackError::UnsupportedDtype {
751 code: t.dtype.code,
752 bits: t.dtype.bits,
753 }
754 })
755}
756
757/// Validate that a [`DLTensor`] is compatible with JAX DLPack interop.
758///
759/// JAX supports CPU, GPU (CUDA), and TPU tensors. Unlike the PyTorch validator
760/// this function accepts non-CPU devices; use [`jax_device_type`] afterwards to
761/// inspect which device the tensor lives on.
762///
763/// Returns `Ok(())` if the tensor has a non-null data pointer and a recognised
764/// float dtype (32- or 64-bit).
765///
766/// # Examples
767///
768/// ```
769/// use scirs2_numpy::dlpack::{dlpack_from_slice, validate_jax_dlpack_tensor};
770///
771/// let data = vec![1.0_f32, 2.0, 3.0];
772/// let shape = vec![3_i64];
773/// // Build a float-32 CPU tensor for testing.
774/// let mut tensor = dlpack_from_slice(
775/// &[1.0_f64, 2.0, 3.0],
776/// &[3_i64],
777/// );
778/// // Adjust to f32
779/// tensor.dtype.bits = 32;
780/// tensor.data = data.as_ptr() as *mut std::ffi::c_void;
781/// assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
782/// ```
783pub fn validate_jax_dlpack_tensor(tensor: &DLTensor) -> Result<(), DlpackError> {
784 // 1. Null pointer check.
785 if tensor.data.is_null() {
786 return Err(DlpackError::NullPointer);
787 }
788 // 2. Float dtype required (code 2), 32- or 64-bit.
789 if tensor.dtype.code != DLDataTypeCode::Float as u8
790 || (tensor.dtype.bits != 32 && tensor.dtype.bits != 64)
791 {
792 return Err(DlpackError::UnsupportedDtype {
793 code: tensor.dtype.code,
794 bits: tensor.dtype.bits,
795 });
796 }
797 Ok(())
798}
799
800/// Classify the device reported by a [`DLTensor`] as a [`JaxDeviceType`].
801///
802/// Returns `Some(JaxDeviceType)` for recognised JAX device codes, or `None`
803/// for unrecognised codes.
804///
805/// | Code | Device |
806/// |------|--------|
807/// | 1 | CPU |
808/// | 2 | GPU (CUDA) |
809/// | 13 | TPU (JAX extension) |
810///
811/// # Examples
812///
813/// ```
814/// use scirs2_numpy::dlpack::{dlpack_from_slice, jax_device_type, JaxDeviceType};
815///
816/// let data = vec![0.0_f64; 4];
817/// let shape = vec![4_i64];
818/// let tensor = dlpack_from_slice(&data, &shape);
819/// assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Cpu));
820/// ```
821pub fn jax_device_type(tensor: &DLTensor) -> Option<JaxDeviceType> {
822 match tensor.device.device_type {
823 1 => Some(JaxDeviceType::Cpu),
824 2 => Some(JaxDeviceType::Gpu),
825 DL_DEVICE_TYPE_TPU => Some(JaxDeviceType::Tpu),
826 _ => None,
827 }
828}
829
830/// Generic DLPack array construction — accepts `f32` tensors from any framework.
831///
832/// Builds an `ndarray` view backed by the tensor's data pointer. Only CPU
833/// tensors are supported; non-CPU tensors return [`DlpackError::NonCpuDevice`].
834///
835/// # Safety
836///
837/// The caller must guarantee that:
838/// - `tensor` is a valid, aligned, non-null pointer to a `DLTensor`.
839/// - The tensor's `data` field points to at least `product(shape) * 4` bytes of
840/// valid `f32` values in native byte order.
841/// - The tensor remains live and unmodified for the lifetime of the returned
842/// view `'a`.
843///
844/// # Errors
845///
846/// Returns [`DlpackError`] if the tensor is not CPU-resident or not `f32`.
847pub unsafe fn array_from_dlpack_f32<'a>(
848 tensor: *const DLTensor,
849) -> Result<ndarray::ArrayViewD<'a, f32>, DlpackError> {
850 // SAFETY: caller guarantees tensor is valid and non-null.
851 let t = unsafe { &*tensor };
852 if t.data.is_null() {
853 return Err(DlpackError::NullPointer);
854 }
855 if t.device.device_type != DLDeviceType::Cpu as i32 {
856 return Err(DlpackError::NonCpuDevice);
857 }
858 if t.dtype.code != DLDataTypeCode::Float as u8 || t.dtype.bits != 32 {
859 return Err(DlpackError::UnsupportedDtype {
860 code: t.dtype.code,
861 bits: t.dtype.bits,
862 });
863 }
864 let shape = build_shape_vec(t);
865 let n_elems: usize = shape.iter().product();
866 let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f32 };
867 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
868 ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
869 DlpackError::UnsupportedDtype {
870 code: t.dtype.code,
871 bits: t.dtype.bits,
872 }
873 })
874}
875
876/// Generic DLPack array construction — accepts `f64` tensors from any framework.
877///
878/// Same as [`array_from_dlpack_f32`] but for 64-bit float tensors.
879///
880/// # Safety
881///
882/// Same invariants as [`array_from_dlpack_f32`], but data must be `f64`.
883///
884/// # Errors
885///
886/// Returns [`DlpackError`] if the tensor is not CPU-resident or not `f64`.
887pub unsafe fn array_from_dlpack_f64<'a>(
888 tensor: *const DLTensor,
889) -> Result<ndarray::ArrayViewD<'a, f64>, DlpackError> {
890 // SAFETY: caller guarantees tensor is valid and non-null.
891 let t = unsafe { &*tensor };
892 if t.data.is_null() {
893 return Err(DlpackError::NullPointer);
894 }
895 if t.device.device_type != DLDeviceType::Cpu as i32 {
896 return Err(DlpackError::NonCpuDevice);
897 }
898 if t.dtype.code != DLDataTypeCode::Float as u8 || t.dtype.bits != 64 {
899 return Err(DlpackError::UnsupportedDtype {
900 code: t.dtype.code,
901 bits: t.dtype.bits,
902 });
903 }
904 let shape = build_shape_vec(t);
905 let n_elems: usize = shape.iter().product();
906 let base = unsafe { (t.data as *const u8).add(t.byte_offset as usize) as *const f64 };
907 let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
908 ndarray::ArrayViewD::from_shape(shape.as_slice(), slice).map_err(|_| {
909 DlpackError::UnsupportedDtype {
910 code: t.dtype.code,
911 bits: t.dtype.bits,
912 }
913 })
914}
915
916/// Build a shape `Vec<usize>` from the `ndim` / `shape` fields of a [`DLTensor`].
917///
918/// Returns an empty vector for zero-dimensional tensors.
919///
920/// # Safety
921///
922/// `tensor.shape` must be valid for `tensor.ndim` elements.
923fn build_shape_vec(tensor: &DLTensor) -> Vec<usize> {
924 if tensor.ndim <= 0 || tensor.shape.is_null() {
925 Vec::new()
926 } else {
927 // SAFETY: shape ptr is valid for ndim elements (caller contract).
928 let raw =
929 unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
930 raw.iter().map(|&d| d as usize).collect()
931 }
932}
933
934// ─── Tests ───────────────────────────────────────────────────────────────────
935
936#[cfg(test)]
937mod tests {
938 use super::*;
939
940 // --- validate_dlpack_tensor ---
941
942 #[test]
943 fn test_validate_valid_f64_cpu_tensor() {
944 let mut data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
945 let mut shape = vec![2_i64, 3];
946 let tensor = dlpack_from_slice(&data, &shape);
947
948 let info = validate_dlpack_tensor(&tensor).expect("validate_dlpack_tensor failed");
949 assert_eq!(info.shape, vec![2, 3]);
950 assert_eq!(info.dtype_code, DLDataTypeCode::Float);
951 assert_eq!(info.dtype_bits, 64);
952 assert_eq!(info.device_type, DLDeviceType::Cpu);
953
954 // Keep data and shape alive.
955 let _ = (&mut data, &mut shape);
956 }
957
958 #[test]
959 fn test_validate_null_pointer_returns_err() {
960 let shape = vec![3_i64];
961 let mut tensor = dlpack_from_slice(&[0.0_f64; 3], &shape);
962 // Forcibly set data to null to test the null-pointer guard.
963 tensor.data = std::ptr::null_mut();
964 let result = validate_dlpack_tensor(&tensor);
965 assert!(
966 matches!(result, Err(DlpackError::NullPointer)),
967 "expected NullPointer error"
968 );
969 }
970
971 #[test]
972 fn test_validate_shape_fields() {
973 let data = vec![10.0_f64; 12];
974 let shape = vec![3_i64, 4];
975 let tensor = dlpack_from_slice(&data, &shape);
976 let info = validate_dlpack_tensor(&tensor).expect("validate failed");
977 assert_eq!(info.shape, vec![3, 4]);
978 }
979
980 // --- dlpack_from_slice ---
981
982 #[test]
983 fn test_dlpack_from_slice_shape_fields() {
984 let data = vec![1.0_f64, 2.0, 3.0];
985 let shape = vec![3_i64];
986 let tensor = dlpack_from_slice(&data, &shape);
987
988 assert_eq!(tensor.ndim, 1);
989 assert!(!tensor.data.is_null());
990 assert!(!tensor.shape.is_null());
991 // dtype must be float64
992 assert_eq!(tensor.dtype.code, 2); // Float
993 assert_eq!(tensor.dtype.bits, 64);
994 }
995
996 #[test]
997 fn test_dlpack_from_slice_2d() {
998 let data = vec![0.0_f64; 6];
999 let shape = vec![2_i64, 3];
1000 let tensor = dlpack_from_slice(&data, &shape);
1001 assert_eq!(tensor.ndim, 2);
1002 // SAFETY: shape is valid for ndim=2.
1003 let s = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, 2) };
1004 assert_eq!(s, [2, 3]);
1005 }
1006
1007 // --- dlpack_to_vec_f64 ---
1008
1009 #[test]
1010 fn test_dlpack_to_vec_f64_round_trip() {
1011 let original = vec![1.0_f64, 2.5, 3.15, -7.0, 0.0];
1012 let shape = vec![5_i64];
1013 let tensor = dlpack_from_slice(&original, &shape);
1014
1015 let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
1016 assert_eq!(recovered, original);
1017 }
1018
1019 #[test]
1020 fn test_dlpack_to_vec_f64_2d() {
1021 let original: Vec<f64> = (0..6).map(|i| i as f64).collect();
1022 let shape = vec![2_i64, 3];
1023 let tensor = dlpack_from_slice(&original, &shape);
1024
1025 let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
1026 assert_eq!(recovered, original);
1027 }
1028
1029 #[test]
1030 fn test_dlpack_to_vec_f64_null_pointer_err() {
1031 let data = vec![0.0_f64];
1032 let shape = vec![1_i64];
1033 let mut tensor = dlpack_from_slice(&data, &shape);
1034 tensor.data = std::ptr::null_mut();
1035
1036 assert!(matches!(
1037 dlpack_to_vec_f64(&tensor),
1038 Err(DlpackError::NullPointer)
1039 ));
1040 }
1041
1042 #[test]
1043 fn test_dlpack_to_vec_f64_non_cpu_err() {
1044 let data = vec![0.0_f64];
1045 let shape = vec![1_i64];
1046 let mut tensor = dlpack_from_slice(&data, &shape);
1047 tensor.device.device_type = DLDeviceType::Cuda as i32;
1048
1049 assert!(matches!(
1050 dlpack_to_vec_f64(&tensor),
1051 Err(DlpackError::NonCpuDevice)
1052 ));
1053 }
1054
1055 #[test]
1056 fn test_dlpack_to_vec_f64_wrong_dtype_err() {
1057 let data = vec![0.0_f64];
1058 let shape = vec![1_i64];
1059 let mut tensor = dlpack_from_slice(&data, &shape);
1060 tensor.dtype.code = 0; // Int, not Float
1061
1062 assert!(matches!(
1063 dlpack_to_vec_f64(&tensor),
1064 Err(DlpackError::UnsupportedDtype { .. })
1065 ));
1066 }
1067
1068 // --- DLDataTypeCode ---
1069
1070 #[test]
1071 fn test_dtype_code_try_from() {
1072 assert_eq!(DLDataTypeCode::try_from(0u8).unwrap(), DLDataTypeCode::Int);
1073 assert_eq!(DLDataTypeCode::try_from(1u8).unwrap(), DLDataTypeCode::UInt);
1074 assert_eq!(
1075 DLDataTypeCode::try_from(2u8).unwrap(),
1076 DLDataTypeCode::Float
1077 );
1078 assert_eq!(
1079 DLDataTypeCode::try_from(3u8).unwrap(),
1080 DLDataTypeCode::BFloat
1081 );
1082 assert!(DLDataTypeCode::try_from(99u8).is_err());
1083 }
1084
1085 // ─── Item 1: PyTorch tensor interop tests ────────────────────────────────
1086
1087 #[test]
1088 fn dlpack_device_type_cpu_is_1() {
1089 assert_eq!(DLDeviceType::Cpu as i32, 1);
1090 }
1091
1092 #[test]
1093 fn dlpack_dtype_float32_code_is_2() {
1094 // DLPack spec: kDLFloat = 2
1095 assert_eq!(DLDataTypeCode::Float as u8, 2);
1096 }
1097
1098 #[test]
1099 fn dlpack_validate_non_contiguous_fails() {
1100 // Build a 2-D tensor with non-row-major strides to simulate a transposed
1101 // PyTorch tensor. Shape = [2, 3], but strides = [1, 2] (column-major).
1102 let data = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1103 let shape = [2_i64, 3];
1104 // Column-major strides (Fortran order): stride[0]=1, stride[1]=2
1105 let strides = [1_i64, 2];
1106 let tensor = DLTensor {
1107 data: data.as_ptr() as *mut c_void,
1108 device: DLDevice {
1109 device_type: DLDeviceType::Cpu as i32,
1110 device_id: 0,
1111 },
1112 ndim: 2,
1113 dtype: DLDataType {
1114 code: DLDataTypeCode::Float as u8,
1115 bits: 32,
1116 lanes: 1,
1117 },
1118 shape: shape.as_ptr() as *mut i64,
1119 strides: strides.as_ptr() as *mut i64,
1120 byte_offset: 0,
1121 };
1122 assert!(
1123 matches!(
1124 validate_torch_dlpack_tensor(&tensor),
1125 Err(DlpackError::NonContiguous)
1126 ),
1127 "expected NonContiguous error for column-major strides"
1128 );
1129 }
1130
1131 #[test]
1132 fn dlpack_validate_2d_float_tensor_passes() {
1133 let data = [1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1134 let shape = [2_i64, 3];
1135 let tensor = DLTensor {
1136 data: data.as_ptr() as *mut c_void,
1137 device: DLDevice {
1138 device_type: DLDeviceType::Cpu as i32,
1139 device_id: 0,
1140 },
1141 ndim: 2,
1142 dtype: DLDataType {
1143 code: DLDataTypeCode::Float as u8,
1144 bits: 32,
1145 lanes: 1,
1146 },
1147 shape: shape.as_ptr() as *mut i64,
1148 strides: std::ptr::null_mut(),
1149 byte_offset: 0,
1150 };
1151 assert!(validate_torch_dlpack_tensor(&tensor).is_ok());
1152 }
1153
1154 #[test]
1155 fn dlpack_validate_non_cpu_tensor_fails() {
1156 let data = [1.0_f32; 4];
1157 let shape = [4_i64];
1158 let tensor = DLTensor {
1159 data: data.as_ptr() as *mut c_void,
1160 device: DLDevice {
1161 device_type: DLDeviceType::Cuda as i32,
1162 device_id: 0,
1163 },
1164 ndim: 1,
1165 dtype: DLDataType {
1166 code: DLDataTypeCode::Float as u8,
1167 bits: 32,
1168 lanes: 1,
1169 },
1170 shape: shape.as_ptr() as *mut i64,
1171 strides: std::ptr::null_mut(),
1172 byte_offset: 0,
1173 };
1174 assert!(matches!(
1175 validate_torch_dlpack_tensor(&tensor),
1176 Err(DlpackError::NonCpuDevice)
1177 ));
1178 }
1179
1180 #[test]
1181 fn dlarray_from_torch_f32_round_trip() {
1182 let data = [1.0_f32, 2.0, 3.0, 4.0];
1183 let shape = [2_i64, 2];
1184 let tensor = DLTensor {
1185 data: data.as_ptr() as *mut c_void,
1186 device: DLDevice {
1187 device_type: DLDeviceType::Cpu as i32,
1188 device_id: 0,
1189 },
1190 ndim: 2,
1191 dtype: DLDataType {
1192 code: DLDataTypeCode::Float as u8,
1193 bits: 32,
1194 lanes: 1,
1195 },
1196 shape: shape.as_ptr() as *mut i64,
1197 strides: std::ptr::null_mut(),
1198 byte_offset: 0,
1199 };
1200 // SAFETY: tensor is valid, data and shape are alive.
1201 let view = unsafe { dlarray_from_torch_f32(&tensor as *const DLTensor) }
1202 .expect("dlarray_from_torch_f32 failed");
1203 assert_eq!(view.shape(), &[2, 2]);
1204 assert_eq!(view[[0, 0]], 1.0_f32);
1205 assert_eq!(view[[1, 1]], 4.0_f32);
1206 }
1207
1208 #[test]
1209 fn dlarray_from_torch_f64_round_trip() {
1210 let data = [10.0_f64, 20.0, 30.0];
1211 let shape = [3_i64];
1212 let tensor = DLTensor {
1213 data: data.as_ptr() as *mut c_void,
1214 device: DLDevice {
1215 device_type: DLDeviceType::Cpu as i32,
1216 device_id: 0,
1217 },
1218 ndim: 1,
1219 dtype: DLDataType {
1220 code: DLDataTypeCode::Float as u8,
1221 bits: 64,
1222 lanes: 1,
1223 },
1224 shape: shape.as_ptr() as *mut i64,
1225 strides: std::ptr::null_mut(),
1226 byte_offset: 0,
1227 };
1228 // SAFETY: tensor is valid, data and shape are alive.
1229 let view = unsafe { dlarray_from_torch_f64(&tensor as *const DLTensor) }
1230 .expect("dlarray_from_torch_f64 failed");
1231 assert_eq!(view.shape(), &[3]);
1232 assert_eq!(view[2], 30.0_f64);
1233 }
1234
1235 // ─── Item 2: JAX array interop tests ─────────────────────────────────────
1236
1237 #[test]
1238 fn dlpack_jax_cpu_tensor_valid() {
1239 let data = [1.0_f64, 2.0, 3.0];
1240 let shape = [3_i64];
1241 let tensor = dlpack_from_slice(&data, &shape);
1242 assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
1243 assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Cpu));
1244 }
1245
1246 #[test]
1247 fn dlpack_jax_tpu_device_recognized() {
1248 let data = [1.0_f64];
1249 let shape = [1_i64];
1250 let mut tensor = dlpack_from_slice(&data, &shape);
1251 tensor.device.device_type = DL_DEVICE_TYPE_TPU;
1252 // JAX validator does not require CPU — only float dtype.
1253 assert!(validate_jax_dlpack_tensor(&tensor).is_ok());
1254 assert_eq!(jax_device_type(&tensor), Some(JaxDeviceType::Tpu));
1255 }
1256
1257 #[test]
1258 fn dlpack_generic_from_dlpack_handles_both_torch_and_jax() {
1259 // f32 torch-style tensor
1260 let data_f32 = [0.5_f32, 1.5, 2.5, 3.5];
1261 let shape = [2_i64, 2];
1262 let tensor_f32 = DLTensor {
1263 data: data_f32.as_ptr() as *mut c_void,
1264 device: DLDevice {
1265 device_type: DLDeviceType::Cpu as i32,
1266 device_id: 0,
1267 },
1268 ndim: 2,
1269 dtype: DLDataType {
1270 code: DLDataTypeCode::Float as u8,
1271 bits: 32,
1272 lanes: 1,
1273 },
1274 shape: shape.as_ptr() as *mut i64,
1275 strides: std::ptr::null_mut(),
1276 byte_offset: 0,
1277 };
1278 // SAFETY: tensor_f32 is valid; data_f32 and shape are alive.
1279 let view_f32 = unsafe { array_from_dlpack_f32(&tensor_f32 as *const DLTensor) }
1280 .expect("array_from_dlpack_f32 failed");
1281 assert_eq!(view_f32.shape(), &[2, 2]);
1282
1283 // f64 jax-style tensor (same CPU device)
1284 let data_f64 = [1.0_f64, 2.0, 3.0, 4.0];
1285 let shape_f64 = [4_i64];
1286 let tensor_f64 = dlpack_from_slice(&data_f64, &shape_f64);
1287 // SAFETY: tensor_f64 is valid; data_f64 and shape_f64 are alive.
1288 let view_f64 = unsafe { array_from_dlpack_f64(&tensor_f64 as *const DLTensor) }
1289 .expect("array_from_dlpack_f64 failed");
1290 assert_eq!(view_f64.shape(), &[4]);
1291 assert_eq!(view_f64[3], 4.0_f64);
1292 }
1293}