Skip to main content

edgefirst_tflite/
tensor.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Type-safe tensor wrappers for the TensorFlow Lite C API.
5//!
6//! This module provides [`Tensor`] (immutable) and [`TensorMut`] (mutable)
7//! views over the raw `TfLiteTensor` pointers returned by the C API. Both
8//! types expose shape introspection, quantization parameters, and typed
9//! data access via slices.
10//!
11//! # Tensor types
12//!
13//! The [`TensorType`] enum mirrors the `TfLiteType` constants from the C
14//! header, providing a safe Rust-side representation that can be pattern
15//! matched.
16//!
17//! # Data access
18//!
19//! Use [`Tensor::as_slice`] for read-only access and
20//! [`TensorMut::as_mut_slice`] or [`TensorMut::copy_from_slice`] for
21//! write access to the underlying tensor buffer.
22
23use std::ffi::CStr;
24use std::fmt;
25use std::ptr::NonNull;
26
27use edgefirst_tflite_sys::{
28    self as sys, TfLiteTensor, TfLiteType_kTfLiteBFloat16, TfLiteType_kTfLiteBool,
29    TfLiteType_kTfLiteComplex128, TfLiteType_kTfLiteComplex64, TfLiteType_kTfLiteFloat16,
30    TfLiteType_kTfLiteFloat32, TfLiteType_kTfLiteFloat64, TfLiteType_kTfLiteInt16,
31    TfLiteType_kTfLiteInt32, TfLiteType_kTfLiteInt4, TfLiteType_kTfLiteInt64,
32    TfLiteType_kTfLiteInt8, TfLiteType_kTfLiteNoType, TfLiteType_kTfLiteResource,
33    TfLiteType_kTfLiteString, TfLiteType_kTfLiteUInt16, TfLiteType_kTfLiteUInt32,
34    TfLiteType_kTfLiteUInt64, TfLiteType_kTfLiteUInt8, TfLiteType_kTfLiteVariant,
35};
36use num_traits::FromPrimitive;
37
38use crate::error::{Error, Result};
39
40// ---------------------------------------------------------------------------
41// TensorType
42// ---------------------------------------------------------------------------
43
44/// Element data type of a TensorFlow Lite tensor.
45///
46/// Each variant corresponds to a `kTfLite*` constant from the C API header
47/// `common.h`. The discriminant values match the C constants so that
48/// conversion via [`FromPrimitive`] is a zero-cost identity check.
49///
50/// # Example
51///
52/// ```ignore
53/// let ty = tensor.tensor_type();
54/// match ty {
55///     TensorType::Float32 => println!("32-bit float tensor"),
56///     TensorType::UInt8   => println!("quantized uint8 tensor"),
57///     _ => println!("other type: {ty:?}"),
58/// }
59/// ```
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, num_derive::FromPrimitive)]
61#[repr(isize)]
62#[allow(clippy::cast_possible_wrap)] // C constants are small u32 values; no wrap on any target.
63pub enum TensorType {
64    /// No type information (`kTfLiteNoType`).
65    NoType = TfLiteType_kTfLiteNoType as isize,
66    /// 32-bit IEEE 754 float (`kTfLiteFloat32`).
67    Float32 = TfLiteType_kTfLiteFloat32 as isize,
68    /// 32-bit signed integer (`kTfLiteInt32`).
69    Int32 = TfLiteType_kTfLiteInt32 as isize,
70    /// 8-bit unsigned integer (`kTfLiteUInt8`).
71    UInt8 = TfLiteType_kTfLiteUInt8 as isize,
72    /// 64-bit signed integer (`kTfLiteInt64`).
73    Int64 = TfLiteType_kTfLiteInt64 as isize,
74    /// Variable-length string (`kTfLiteString`).
75    String = TfLiteType_kTfLiteString as isize,
76    /// Boolean (`kTfLiteBool`).
77    Bool = TfLiteType_kTfLiteBool as isize,
78    /// 16-bit signed integer (`kTfLiteInt16`).
79    Int16 = TfLiteType_kTfLiteInt16 as isize,
80    /// 64-bit complex float (`kTfLiteComplex64`).
81    Complex64 = TfLiteType_kTfLiteComplex64 as isize,
82    /// 8-bit signed integer (`kTfLiteInt8`).
83    Int8 = TfLiteType_kTfLiteInt8 as isize,
84    /// 16-bit IEEE 754 half-precision float (`kTfLiteFloat16`).
85    Float16 = TfLiteType_kTfLiteFloat16 as isize,
86    /// 64-bit IEEE 754 double-precision float (`kTfLiteFloat64`).
87    Float64 = TfLiteType_kTfLiteFloat64 as isize,
88    /// 128-bit complex float (`kTfLiteComplex128`).
89    Complex128 = TfLiteType_kTfLiteComplex128 as isize,
90    /// 64-bit unsigned integer (`kTfLiteUInt64`).
91    UInt64 = TfLiteType_kTfLiteUInt64 as isize,
92    /// Resource handle (`kTfLiteResource`).
93    Resource = TfLiteType_kTfLiteResource as isize,
94    /// Variant type (`kTfLiteVariant`).
95    Variant = TfLiteType_kTfLiteVariant as isize,
96    /// 32-bit unsigned integer (`kTfLiteUInt32`).
97    UInt32 = TfLiteType_kTfLiteUInt32 as isize,
98    /// 16-bit unsigned integer (`kTfLiteUInt16`).
99    UInt16 = TfLiteType_kTfLiteUInt16 as isize,
100    /// 4-bit signed integer (`kTfLiteInt4`).
101    Int4 = TfLiteType_kTfLiteInt4 as isize,
102    /// Brain floating-point 16-bit (`kTfLiteBFloat16`).
103    BFloat16 = TfLiteType_kTfLiteBFloat16 as isize,
104}
105
106// ---------------------------------------------------------------------------
107// QuantizationParams
108// ---------------------------------------------------------------------------
109
110/// Affine quantization parameters for a tensor.
111///
112/// Quantized values can be converted back to floating point using:
113///
114/// ```text
115/// real_value = scale * (quantized_value - zero_point)
116/// ```
117#[derive(Debug, Clone, Copy, PartialEq)]
118pub struct QuantizationParams {
119    /// Scale factor for dequantization.
120    pub scale: f32,
121    /// Zero-point offset for dequantization.
122    pub zero_point: i32,
123}
124
125// ---------------------------------------------------------------------------
126// Tensor (immutable view)
127// ---------------------------------------------------------------------------
128
129/// An immutable view of a TensorFlow Lite tensor.
130///
131/// `Tensor` borrows the underlying C tensor pointer and the dynamically
132/// loaded library handle for the duration of its lifetime `'a`. It provides
133/// read-only access to tensor metadata (name, shape, type) and data.
134///
135/// Use [`Tensor::as_slice`] to obtain a typed slice over the tensor data.
136pub struct Tensor<'a> {
137    /// Raw pointer to the C `TfLiteTensor`.
138    ///
139    /// This is a raw `*const` pointer (not `NonNull`) because the C API
140    /// returns `*const TfLiteTensor` for output tensors.
141    pub(crate) ptr: *const TfLiteTensor,
142
143    /// Reference to the dynamically loaded `TFLite` C library.
144    pub(crate) lib: &'a sys::tensorflowlite_c,
145}
146
147impl Tensor<'_> {
148    /// Returns the element data type of this tensor.
149    ///
150    /// If the C API returns a type value not represented by [`TensorType`],
151    /// this method defaults to [`TensorType::NoType`].
152    #[must_use]
153    pub fn tensor_type(&self) -> TensorType {
154        // SAFETY: `self.ptr` is a valid tensor pointer obtained from the
155        // interpreter and `self.lib` is a valid reference to the loaded library.
156        let raw = unsafe { self.lib.TfLiteTensorType(self.ptr) };
157        FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
158    }
159
160    /// Returns the name of this tensor as a string slice.
161    ///
162    /// Returns `"<invalid-utf8>"` if the C API returns a name that is not
163    /// valid UTF-8.
164    #[must_use]
165    pub fn name(&self) -> &str {
166        // SAFETY: `self.ptr` is a valid tensor pointer; the C API returns a
167        // NUL-terminated string that lives as long as the tensor.
168        unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr)) }
169            .to_str()
170            .unwrap_or("<invalid-utf8>")
171    }
172
173    /// Returns the number of dimensions (rank) of this tensor.
174    ///
175    /// # Errors
176    ///
177    /// Returns an error if the tensor does not have its dimensions set
178    /// (the C API returns -1).
179    pub fn num_dims(&self) -> Result<usize> {
180        // SAFETY: `self.ptr` is a valid tensor pointer.
181        let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr) };
182        usize::try_from(n).map_err(|_| {
183            Error::invalid_argument(format!(
184                "tensor `{}` does not have dimensions set",
185                self.name()
186            ))
187        })
188    }
189
190    /// Returns the size of the `index`-th dimension.
191    ///
192    /// # Errors
193    ///
194    /// Returns an error if `index` is out of bounds (>= `num_dims`).
195    pub fn dim(&self, index: usize) -> Result<usize> {
196        let num_dims = self.num_dims()?;
197        if index >= num_dims {
198            return Err(Error::invalid_argument(format!(
199                "dimension index {index} out of bounds for tensor with {num_dims} dimensions"
200            )));
201        }
202        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
203        let i = index as i32;
204        // SAFETY: `self.ptr` is valid and `i` is bounds-checked above.
205        let d = unsafe { self.lib.TfLiteTensorDim(self.ptr, i) };
206        // `d` is non-negative because the C API guarantees valid dimension
207        // sizes for in-bounds indices.
208        #[allow(clippy::cast_sign_loss)]
209        Ok(d as usize)
210    }
211
212    /// Returns the full shape of this tensor as a `Vec<usize>`.
213    ///
214    /// # Errors
215    ///
216    /// Returns an error if the tensor dimensions are not set.
217    pub fn shape(&self) -> Result<Vec<usize>> {
218        let num_dims = self.num_dims()?;
219        let mut dims = Vec::with_capacity(num_dims);
220        for i in 0..num_dims {
221            dims.push(self.dim(i)?);
222        }
223        Ok(dims)
224    }
225
226    /// Returns the total number of bytes required to store this tensor's data.
227    #[must_use]
228    pub fn byte_size(&self) -> usize {
229        // SAFETY: `self.ptr` is a valid tensor pointer.
230        unsafe { self.lib.TfLiteTensorByteSize(self.ptr) }
231    }
232
233    /// Returns the total number of elements in this tensor (product of all
234    /// dimensions).
235    ///
236    /// # Errors
237    ///
238    /// Returns an error if the tensor dimensions are not set.
239    pub fn volume(&self) -> Result<usize> {
240        Ok(self.shape()?.iter().product::<usize>())
241    }
242
243    /// Returns the affine quantization parameters for this tensor.
244    #[must_use]
245    pub fn quantization_params(&self) -> QuantizationParams {
246        // SAFETY: `self.ptr` is a valid tensor pointer.
247        let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr) };
248        QuantizationParams {
249            scale: params.scale,
250            zero_point: params.zero_point,
251        }
252    }
253
254    /// Returns an immutable slice over the tensor data, interpreted as
255    /// elements of type `T`.
256    ///
257    /// The slice length equals [`Tensor::volume`]. The caller must ensure
258    /// that `T` matches the tensor's actual element type (e.g., `f32` for
259    /// a `Float32` tensor, `u8` for a `UInt8` tensor).
260    ///
261    /// # Errors
262    ///
263    /// Returns an error if:
264    /// - `size_of::<T>() * volume` exceeds [`Tensor::byte_size`]
265    /// - The underlying data pointer is null (tensor not yet allocated)
266    pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
267        let volume = self.volume()?;
268        if std::mem::size_of::<T>() * volume > self.byte_size() {
269            return Err(Error::invalid_argument(format!(
270                "tensor byte size {} is too small for {} elements of {}",
271                self.byte_size(),
272                volume,
273                std::any::type_name::<T>(),
274            )));
275        }
276        // SAFETY: `self.ptr` is a valid tensor pointer.
277        let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr) };
278        if ptr.is_null() {
279            return Err(Error::null_pointer("TfLiteTensorData returned null"));
280        }
281        // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
282        // bytes (checked above). The data is valid for reads for the tensor's lifetime
283        // which is tied to the interpreter borrow. `T: Copy` ensures no drop glue.
284        Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
285    }
286}
287
288/// Formats the tensor as `"name: 1x224x224x3 Float32"`.
289impl fmt::Debug for Tensor<'_> {
290    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291        write_tensor_debug(
292            f,
293            self.name(),
294            self.num_dims(),
295            |i| self.dim(i),
296            self.tensor_type(),
297        )
298    }
299}
300
301/// Displays the tensor as `"name: 1x224x224x3 Float32"`.
302impl fmt::Display for Tensor<'_> {
303    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304        write_tensor_debug(
305            f,
306            self.name(),
307            self.num_dims(),
308            |i| self.dim(i),
309            self.tensor_type(),
310        )
311    }
312}
313
314// ---------------------------------------------------------------------------
315// TensorMut (mutable view)
316// ---------------------------------------------------------------------------
317
318/// A mutable view of a TensorFlow Lite tensor.
319///
320/// `TensorMut` provides all the read-only operations of [`Tensor`] plus
321/// mutable data access via [`TensorMut::as_mut_slice`] and
322/// [`TensorMut::copy_from_slice`].
323///
324/// The pointer is stored as [`NonNull`] because the C API returns
325/// `*mut TfLiteTensor` for input tensors, which must be non-null after
326/// successful interpreter creation.
327pub struct TensorMut<'a> {
328    /// Non-null pointer to the C `TfLiteTensor`.
329    pub(crate) ptr: NonNull<TfLiteTensor>,
330
331    /// Reference to the dynamically loaded `TFLite` C library.
332    pub(crate) lib: &'a sys::tensorflowlite_c,
333}
334
335impl TensorMut<'_> {
336    /// Returns the element data type of this tensor.
337    ///
338    /// If the C API returns a type value not represented by [`TensorType`],
339    /// this method defaults to [`TensorType::NoType`].
340    #[must_use]
341    pub fn tensor_type(&self) -> TensorType {
342        // SAFETY: `self.ptr` is a valid non-null tensor pointer obtained from
343        // the interpreter and `self.lib` is a valid reference to the loaded library.
344        let raw = unsafe { self.lib.TfLiteTensorType(self.ptr.as_ptr()) };
345        FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
346    }
347
348    /// Returns the name of this tensor as a string slice.
349    ///
350    /// Returns `"<invalid-utf8>"` if the C API returns a name that is not
351    /// valid UTF-8.
352    #[must_use]
353    pub fn name(&self) -> &str {
354        // SAFETY: `self.ptr` is a valid tensor pointer; the C API returns a
355        // NUL-terminated string that lives as long as the tensor.
356        unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr.as_ptr())) }
357            .to_str()
358            .unwrap_or("<invalid-utf8>")
359    }
360
361    /// Returns the number of dimensions (rank) of this tensor.
362    ///
363    /// # Errors
364    ///
365    /// Returns an error if the tensor does not have its dimensions set
366    /// (the C API returns -1).
367    pub fn num_dims(&self) -> Result<usize> {
368        // SAFETY: `self.ptr` is a valid tensor pointer.
369        let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr.as_ptr()) };
370        usize::try_from(n).map_err(|_| {
371            Error::invalid_argument(format!(
372                "tensor `{}` does not have dimensions set",
373                self.name()
374            ))
375        })
376    }
377
378    /// Returns the size of the `index`-th dimension.
379    ///
380    /// # Errors
381    ///
382    /// Returns an error if `index` is out of bounds (>= `num_dims`).
383    pub fn dim(&self, index: usize) -> Result<usize> {
384        let num_dims = self.num_dims()?;
385        if index >= num_dims {
386            return Err(Error::invalid_argument(format!(
387                "dimension index {index} out of bounds for tensor with {num_dims} dimensions"
388            )));
389        }
390        #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
391        let i = index as i32;
392        // SAFETY: `self.ptr` is valid and `i` is bounds-checked above.
393        let d = unsafe { self.lib.TfLiteTensorDim(self.ptr.as_ptr(), i) };
394        // `d` is non-negative because the C API guarantees valid dimension
395        // sizes for in-bounds indices.
396        #[allow(clippy::cast_sign_loss)]
397        Ok(d as usize)
398    }
399
400    /// Returns the full shape of this tensor as a `Vec<usize>`.
401    ///
402    /// # Errors
403    ///
404    /// Returns an error if the tensor dimensions are not set.
405    pub fn shape(&self) -> Result<Vec<usize>> {
406        let num_dims = self.num_dims()?;
407        let mut dims = Vec::with_capacity(num_dims);
408        for i in 0..num_dims {
409            dims.push(self.dim(i)?);
410        }
411        Ok(dims)
412    }
413
414    /// Returns the total number of bytes required to store this tensor's data.
415    #[must_use]
416    pub fn byte_size(&self) -> usize {
417        // SAFETY: `self.ptr` is a valid tensor pointer.
418        unsafe { self.lib.TfLiteTensorByteSize(self.ptr.as_ptr()) }
419    }
420
421    /// Returns the total number of elements in this tensor (product of all
422    /// dimensions).
423    ///
424    /// # Errors
425    ///
426    /// Returns an error if the tensor dimensions are not set.
427    pub fn volume(&self) -> Result<usize> {
428        Ok(self.shape()?.iter().product::<usize>())
429    }
430
431    /// Returns the affine quantization parameters for this tensor.
432    #[must_use]
433    pub fn quantization_params(&self) -> QuantizationParams {
434        // SAFETY: `self.ptr` is a valid tensor pointer.
435        let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr.as_ptr()) };
436        QuantizationParams {
437            scale: params.scale,
438            zero_point: params.zero_point,
439        }
440    }
441
442    /// Returns an immutable slice over the tensor data, interpreted as
443    /// elements of type `T`.
444    ///
445    /// The slice length equals [`TensorMut::volume`]. The caller must
446    /// ensure that `T` matches the tensor's actual element type (e.g.,
447    /// `f32` for a `Float32` tensor, `u8` for a `UInt8` tensor).
448    ///
449    /// # Errors
450    ///
451    /// Returns an error if:
452    /// - `size_of::<T>() * volume` exceeds [`TensorMut::byte_size`]
453    /// - The underlying data pointer is null (tensor not yet allocated)
454    pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
455        let volume = self.volume()?;
456        if std::mem::size_of::<T>() * volume > self.byte_size() {
457            return Err(Error::invalid_argument(format!(
458                "tensor byte size {} is too small for {} elements of {}",
459                self.byte_size(),
460                volume,
461                std::any::type_name::<T>(),
462            )));
463        }
464        // SAFETY: `self.ptr` is a valid tensor pointer.
465        let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
466        if ptr.is_null() {
467            return Err(Error::null_pointer("TfLiteTensorData returned null"));
468        }
469        // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
470        // bytes (checked above). The data is valid for reads for the tensor's lifetime
471        // which is tied to the interpreter borrow. `T: Copy` ensures no drop glue.
472        Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
473    }
474
475    /// Returns a mutable slice over the tensor data, interpreted as elements
476    /// of type `T`.
477    ///
478    /// The slice length equals [`TensorMut::volume`]. The caller must
479    /// ensure that `T` matches the tensor's actual element type (e.g.,
480    /// `f32` for a `Float32` tensor, `u8` for a `UInt8` tensor).
481    ///
482    /// # Errors
483    ///
484    /// Returns an error if:
485    /// - `size_of::<T>() * volume` exceeds [`TensorMut::byte_size`]
486    /// - The underlying data pointer is null (tensor not yet allocated)
487    pub fn as_mut_slice<T: Copy>(&mut self) -> Result<&mut [T]> {
488        let volume = self.volume()?;
489        if std::mem::size_of::<T>() * volume > self.byte_size() {
490            return Err(Error::invalid_argument(format!(
491                "tensor byte size {} is too small for {} elements of {}",
492                self.byte_size(),
493                volume,
494                std::any::type_name::<T>(),
495            )));
496        }
497        // SAFETY: `self.ptr` is a valid tensor pointer.
498        let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
499        if ptr.is_null() {
500            return Err(Error::null_pointer("TfLiteTensorData returned null"));
501        }
502        // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
503        // bytes (checked above). We hold `&mut self` ensuring exclusive access.
504        // `T: Copy` ensures no drop glue.
505        Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), volume) })
506    }
507
508    /// Copies the contents of `data` into this tensor's buffer.
509    ///
510    /// This is a convenience wrapper around [`TensorMut::as_mut_slice`] that
511    /// copies elements from the provided slice into the tensor.
512    ///
513    /// # Errors
514    ///
515    /// Returns an error if:
516    /// - The tensor cannot be mapped as a mutable slice of `T`
517    /// - `data.len()` does not match [`TensorMut::volume`]
518    pub fn copy_from_slice<T: Copy>(&mut self, data: &[T]) -> Result<()> {
519        let slice = self.as_mut_slice::<T>()?;
520        if data.len() != slice.len() {
521            return Err(Error::invalid_argument(format!(
522                "data length {} does not match tensor volume {}",
523                data.len(),
524                slice.len(),
525            )));
526        }
527        slice.copy_from_slice(data);
528        Ok(())
529    }
530}
531
532/// Formats the tensor as `"name: 1x224x224x3 Float32"`.
533impl fmt::Debug for TensorMut<'_> {
534    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535        write_tensor_debug(
536            f,
537            self.name(),
538            self.num_dims(),
539            |i| self.dim(i),
540            self.tensor_type(),
541        )
542    }
543}
544
545/// Displays the tensor as `"name: 1x224x224x3 Float32"`.
546impl fmt::Display for TensorMut<'_> {
547    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548        write_tensor_debug(
549            f,
550            self.name(),
551            self.num_dims(),
552            |i| self.dim(i),
553            self.tensor_type(),
554        )
555    }
556}
557
558// ---------------------------------------------------------------------------
559// Shared formatting helper
560// ---------------------------------------------------------------------------
561
562/// Writes the common tensor representation: `"name: 1x224x224x3 Float32"`.
563///
564/// Used by both `Tensor` and `TensorMut` `Debug` and `Display` implementations
565/// to avoid code duplication.
566fn write_tensor_debug(
567    f: &mut fmt::Formatter<'_>,
568    name: &str,
569    num_dims: Result<usize>,
570    dim_fn: impl Fn(usize) -> Result<usize>,
571    tensor_type: TensorType,
572) -> fmt::Result {
573    let num_dims = num_dims.unwrap_or(0);
574    write!(f, "{name}: ")?;
575    for i in 0..num_dims {
576        if i > 0 {
577            f.write_str("x")?;
578        }
579        write!(f, "{}", dim_fn(i).unwrap_or(0))?;
580    }
581    write!(f, " {tensor_type:?}")
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587
588    use std::collections::HashSet;
589
590    // -----------------------------------------------------------------------
591    // TensorType -- FromPrimitive conversion
592    // -----------------------------------------------------------------------
593
594    #[test]
595    fn tensor_type_from_primitive_all_variants() {
596        let cases: &[(isize, TensorType)] = &[
597            (0, TensorType::NoType),
598            (1, TensorType::Float32),
599            (2, TensorType::Int32),
600            (3, TensorType::UInt8),
601            (4, TensorType::Int64),
602            (5, TensorType::String),
603            (6, TensorType::Bool),
604            (7, TensorType::Int16),
605            (8, TensorType::Complex64),
606            (9, TensorType::Int8),
607            (10, TensorType::Float16),
608            (11, TensorType::Float64),
609            (12, TensorType::Complex128),
610            (13, TensorType::UInt64),
611            (14, TensorType::Resource),
612            (15, TensorType::Variant),
613            (16, TensorType::UInt32),
614            (17, TensorType::UInt16),
615            (18, TensorType::Int4),
616            (19, TensorType::BFloat16),
617        ];
618
619        for &(raw, expected) in cases {
620            let result = TensorType::from_isize(raw);
621            assert_eq!(
622                result,
623                Some(expected),
624                "TensorType::from_isize({raw}) should be Some({expected:?})"
625            );
626        }
627    }
628
629    #[test]
630    fn tensor_type_from_u32_all_variants() {
631        for raw in 0u32..=19 {
632            let result = TensorType::from_u32(raw);
633            assert!(
634                result.is_some(),
635                "TensorType::from_u32({raw}) should be Some"
636            );
637        }
638    }
639
640    #[test]
641    fn tensor_type_unknown_value_returns_none() {
642        assert_eq!(TensorType::from_isize(999), None);
643        assert_eq!(TensorType::from_u32(999), None);
644        assert_eq!(TensorType::from_isize(-1), None);
645        assert_eq!(TensorType::from_isize(20), None);
646    }
647
648    // -----------------------------------------------------------------------
649    // TensorType -- Clone, PartialEq, Hash
650    // -----------------------------------------------------------------------
651
652    #[test]
653    fn tensor_type_clone() {
654        let original = TensorType::Float32;
655        let cloned = original;
656        assert_eq!(original, cloned);
657    }
658
659    #[test]
660    fn tensor_type_partial_eq() {
661        assert_eq!(TensorType::Int8, TensorType::Int8);
662        assert_ne!(TensorType::Int8, TensorType::UInt8);
663    }
664
665    #[test]
666    fn tensor_type_hash() {
667        let mut set = HashSet::new();
668        set.insert(TensorType::Float32);
669        set.insert(TensorType::Float32);
670        set.insert(TensorType::Int32);
671        assert_eq!(set.len(), 2);
672    }
673
674    #[test]
675    fn tensor_type_all_variants_unique_in_hashset() {
676        let all = [
677            TensorType::NoType,
678            TensorType::Float32,
679            TensorType::Int32,
680            TensorType::UInt8,
681            TensorType::Int64,
682            TensorType::String,
683            TensorType::Bool,
684            TensorType::Int16,
685            TensorType::Complex64,
686            TensorType::Int8,
687            TensorType::Float16,
688            TensorType::Float64,
689            TensorType::Complex128,
690            TensorType::UInt64,
691            TensorType::Resource,
692            TensorType::Variant,
693            TensorType::UInt32,
694            TensorType::UInt16,
695            TensorType::Int4,
696            TensorType::BFloat16,
697        ];
698        let set: HashSet<_> = all.iter().copied().collect();
699        assert_eq!(set.len(), 20);
700    }
701
702    // -----------------------------------------------------------------------
703    // TensorType -- Debug formatting
704    // -----------------------------------------------------------------------
705
706    #[test]
707    fn tensor_type_debug_format() {
708        assert_eq!(format!("{:?}", TensorType::Float32), "Float32");
709        assert_eq!(format!("{:?}", TensorType::NoType), "NoType");
710        assert_eq!(format!("{:?}", TensorType::BFloat16), "BFloat16");
711        assert_eq!(format!("{:?}", TensorType::Complex128), "Complex128");
712    }
713
714    // -----------------------------------------------------------------------
715    // QuantizationParams -- construction and field access
716    // -----------------------------------------------------------------------
717
718    #[test]
719    fn quantization_params_construction() {
720        let params = QuantizationParams {
721            scale: 0.5,
722            zero_point: 128,
723        };
724        assert!((params.scale - 0.5).abs() < f32::EPSILON);
725        assert_eq!(params.zero_point, 128);
726    }
727
728    #[test]
729    fn quantization_params_zero_values() {
730        let params = QuantizationParams {
731            scale: 0.0,
732            zero_point: 0,
733        };
734        assert!((params.scale - 0.0).abs() < f32::EPSILON);
735        assert_eq!(params.zero_point, 0);
736    }
737
738    #[test]
739    fn quantization_params_negative_zero_point() {
740        let params = QuantizationParams {
741            scale: 0.007_812_5,
742            zero_point: -128,
743        };
744        assert!((params.scale - 0.007_812_5).abs() < f32::EPSILON);
745        assert_eq!(params.zero_point, -128);
746    }
747
748    // -----------------------------------------------------------------------
749    // QuantizationParams -- Debug, Clone, PartialEq
750    // -----------------------------------------------------------------------
751
752    #[test]
753    fn quantization_params_debug() {
754        let params = QuantizationParams {
755            scale: 1.0,
756            zero_point: 0,
757        };
758        let debug = format!("{params:?}");
759        assert!(debug.contains("QuantizationParams"));
760        assert!(debug.contains("scale"));
761        assert!(debug.contains("zero_point"));
762    }
763
764    #[test]
765    fn quantization_params_clone() {
766        let original = QuantizationParams {
767            scale: 0.25,
768            zero_point: 64,
769        };
770        let cloned = original;
771        assert_eq!(original, cloned);
772    }
773
774    #[test]
775    fn quantization_params_partial_eq() {
776        let a = QuantizationParams {
777            scale: 0.5,
778            zero_point: 128,
779        };
780        let b = QuantizationParams {
781            scale: 0.5,
782            zero_point: 128,
783        };
784        let c = QuantizationParams {
785            scale: 0.25,
786            zero_point: 128,
787        };
788        assert_eq!(a, b);
789        assert_ne!(a, c);
790    }
791}