edgefirst_tflite/tensor.rs
1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 Au-Zone Technologies. All Rights Reserved.
3
4//! Type-safe tensor wrappers for the TensorFlow Lite C API.
5//!
6//! This module provides [`Tensor`] (immutable) and [`TensorMut`] (mutable)
7//! views over the raw `TfLiteTensor` pointers returned by the C API. Both
8//! types expose shape introspection, quantization parameters, and typed
9//! data access via slices.
10//!
11//! # Tensor types
12//!
13//! The [`TensorType`] enum mirrors the `TfLiteType` constants from the C
14//! header, providing a safe Rust-side representation that can be pattern
15//! matched.
16//!
17//! # Data access
18//!
19//! Use [`Tensor::as_slice`] for read-only access and
20//! [`TensorMut::as_mut_slice`] or [`TensorMut::copy_from_slice`] for
21//! write access to the underlying tensor buffer.
22
23use std::ffi::CStr;
24use std::fmt;
25use std::ptr::NonNull;
26
27use edgefirst_tflite_sys::{
28 self as sys, TfLiteTensor, TfLiteType_kTfLiteBFloat16, TfLiteType_kTfLiteBool,
29 TfLiteType_kTfLiteComplex128, TfLiteType_kTfLiteComplex64, TfLiteType_kTfLiteFloat16,
30 TfLiteType_kTfLiteFloat32, TfLiteType_kTfLiteFloat64, TfLiteType_kTfLiteInt16,
31 TfLiteType_kTfLiteInt32, TfLiteType_kTfLiteInt4, TfLiteType_kTfLiteInt64,
32 TfLiteType_kTfLiteInt8, TfLiteType_kTfLiteNoType, TfLiteType_kTfLiteResource,
33 TfLiteType_kTfLiteString, TfLiteType_kTfLiteUInt16, TfLiteType_kTfLiteUInt32,
34 TfLiteType_kTfLiteUInt64, TfLiteType_kTfLiteUInt8, TfLiteType_kTfLiteVariant,
35};
36use num_traits::FromPrimitive;
37
38use crate::error::{Error, Result};
39
40// ---------------------------------------------------------------------------
41// TensorType
42// ---------------------------------------------------------------------------
43
44/// Element data type of a TensorFlow Lite tensor.
45///
46/// Each variant corresponds to a `kTfLite*` constant from the C API header
47/// `common.h`. The discriminant values match the C constants so that
48/// conversion via [`FromPrimitive`] is a zero-cost identity check.
49///
50/// # Example
51///
52/// ```ignore
53/// let ty = tensor.tensor_type();
54/// match ty {
55/// TensorType::Float32 => println!("32-bit float tensor"),
56/// TensorType::UInt8 => println!("quantized uint8 tensor"),
57/// _ => println!("other type: {ty:?}"),
58/// }
59/// ```
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, num_derive::FromPrimitive)]
61#[repr(isize)]
62#[allow(clippy::cast_possible_wrap)] // C constants are small u32 values; no wrap on any target.
63pub enum TensorType {
64 /// No type information (`kTfLiteNoType`).
65 NoType = TfLiteType_kTfLiteNoType as isize,
66 /// 32-bit IEEE 754 float (`kTfLiteFloat32`).
67 Float32 = TfLiteType_kTfLiteFloat32 as isize,
68 /// 32-bit signed integer (`kTfLiteInt32`).
69 Int32 = TfLiteType_kTfLiteInt32 as isize,
70 /// 8-bit unsigned integer (`kTfLiteUInt8`).
71 UInt8 = TfLiteType_kTfLiteUInt8 as isize,
72 /// 64-bit signed integer (`kTfLiteInt64`).
73 Int64 = TfLiteType_kTfLiteInt64 as isize,
74 /// Variable-length string (`kTfLiteString`).
75 String = TfLiteType_kTfLiteString as isize,
76 /// Boolean (`kTfLiteBool`).
77 Bool = TfLiteType_kTfLiteBool as isize,
78 /// 16-bit signed integer (`kTfLiteInt16`).
79 Int16 = TfLiteType_kTfLiteInt16 as isize,
80 /// 64-bit complex float (`kTfLiteComplex64`).
81 Complex64 = TfLiteType_kTfLiteComplex64 as isize,
82 /// 8-bit signed integer (`kTfLiteInt8`).
83 Int8 = TfLiteType_kTfLiteInt8 as isize,
84 /// 16-bit IEEE 754 half-precision float (`kTfLiteFloat16`).
85 Float16 = TfLiteType_kTfLiteFloat16 as isize,
86 /// 64-bit IEEE 754 double-precision float (`kTfLiteFloat64`).
87 Float64 = TfLiteType_kTfLiteFloat64 as isize,
88 /// 128-bit complex float (`kTfLiteComplex128`).
89 Complex128 = TfLiteType_kTfLiteComplex128 as isize,
90 /// 64-bit unsigned integer (`kTfLiteUInt64`).
91 UInt64 = TfLiteType_kTfLiteUInt64 as isize,
92 /// Resource handle (`kTfLiteResource`).
93 Resource = TfLiteType_kTfLiteResource as isize,
94 /// Variant type (`kTfLiteVariant`).
95 Variant = TfLiteType_kTfLiteVariant as isize,
96 /// 32-bit unsigned integer (`kTfLiteUInt32`).
97 UInt32 = TfLiteType_kTfLiteUInt32 as isize,
98 /// 16-bit unsigned integer (`kTfLiteUInt16`).
99 UInt16 = TfLiteType_kTfLiteUInt16 as isize,
100 /// 4-bit signed integer (`kTfLiteInt4`).
101 Int4 = TfLiteType_kTfLiteInt4 as isize,
102 /// Brain floating-point 16-bit (`kTfLiteBFloat16`).
103 BFloat16 = TfLiteType_kTfLiteBFloat16 as isize,
104}
105
106// ---------------------------------------------------------------------------
107// QuantizationParams
108// ---------------------------------------------------------------------------
109
110/// Affine quantization parameters for a tensor.
111///
112/// Quantized values can be converted back to floating point using:
113///
114/// ```text
115/// real_value = scale * (quantized_value - zero_point)
116/// ```
117#[derive(Debug, Clone, Copy, PartialEq)]
118pub struct QuantizationParams {
119 /// Scale factor for dequantization.
120 pub scale: f32,
121 /// Zero-point offset for dequantization.
122 pub zero_point: i32,
123}
124
125// ---------------------------------------------------------------------------
126// Tensor (immutable view)
127// ---------------------------------------------------------------------------
128
129/// An immutable view of a TensorFlow Lite tensor.
130///
131/// `Tensor` borrows the underlying C tensor pointer and the dynamically
132/// loaded library handle for the duration of its lifetime `'a`. It provides
133/// read-only access to tensor metadata (name, shape, type) and data.
134///
135/// Use [`Tensor::as_slice`] to obtain a typed slice over the tensor data.
136pub struct Tensor<'a> {
137 /// Raw pointer to the C `TfLiteTensor`.
138 ///
139 /// This is a raw `*const` pointer (not `NonNull`) because the C API
140 /// returns `*const TfLiteTensor` for output tensors.
141 pub(crate) ptr: *const TfLiteTensor,
142
143 /// Reference to the dynamically loaded `TFLite` C library.
144 pub(crate) lib: &'a sys::tensorflowlite_c,
145}
146
147impl Tensor<'_> {
148 /// Returns the element data type of this tensor.
149 ///
150 /// If the C API returns a type value not represented by [`TensorType`],
151 /// this method defaults to [`TensorType::NoType`].
152 #[must_use]
153 pub fn tensor_type(&self) -> TensorType {
154 // SAFETY: `self.ptr` is a valid tensor pointer obtained from the
155 // interpreter and `self.lib` is a valid reference to the loaded library.
156 let raw = unsafe { self.lib.TfLiteTensorType(self.ptr) };
157 FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
158 }
159
160 /// Returns the name of this tensor as a string slice.
161 ///
162 /// Returns `"<invalid-utf8>"` if the C API returns a name that is not
163 /// valid UTF-8.
164 #[must_use]
165 pub fn name(&self) -> &str {
166 // SAFETY: `self.ptr` is a valid tensor pointer; the C API returns a
167 // NUL-terminated string that lives as long as the tensor.
168 unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr)) }
169 .to_str()
170 .unwrap_or("<invalid-utf8>")
171 }
172
173 /// Returns the number of dimensions (rank) of this tensor.
174 ///
175 /// # Errors
176 ///
177 /// Returns an error if the tensor does not have its dimensions set
178 /// (the C API returns -1).
179 pub fn num_dims(&self) -> Result<usize> {
180 // SAFETY: `self.ptr` is a valid tensor pointer.
181 let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr) };
182 usize::try_from(n).map_err(|_| {
183 Error::invalid_argument(format!(
184 "tensor `{}` does not have dimensions set",
185 self.name()
186 ))
187 })
188 }
189
190 /// Returns the size of the `index`-th dimension.
191 ///
192 /// # Errors
193 ///
194 /// Returns an error if `index` is out of bounds (>= `num_dims`).
195 pub fn dim(&self, index: usize) -> Result<usize> {
196 let num_dims = self.num_dims()?;
197 if index >= num_dims {
198 return Err(Error::invalid_argument(format!(
199 "dimension index {index} out of bounds for tensor with {num_dims} dimensions"
200 )));
201 }
202 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
203 let i = index as i32;
204 // SAFETY: `self.ptr` is valid and `i` is bounds-checked above.
205 let d = unsafe { self.lib.TfLiteTensorDim(self.ptr, i) };
206 // `d` is non-negative because the C API guarantees valid dimension
207 // sizes for in-bounds indices.
208 #[allow(clippy::cast_sign_loss)]
209 Ok(d as usize)
210 }
211
212 /// Returns the full shape of this tensor as a `Vec<usize>`.
213 ///
214 /// # Errors
215 ///
216 /// Returns an error if the tensor dimensions are not set.
217 pub fn shape(&self) -> Result<Vec<usize>> {
218 let num_dims = self.num_dims()?;
219 let mut dims = Vec::with_capacity(num_dims);
220 for i in 0..num_dims {
221 dims.push(self.dim(i)?);
222 }
223 Ok(dims)
224 }
225
226 /// Returns the total number of bytes required to store this tensor's data.
227 #[must_use]
228 pub fn byte_size(&self) -> usize {
229 // SAFETY: `self.ptr` is a valid tensor pointer.
230 unsafe { self.lib.TfLiteTensorByteSize(self.ptr) }
231 }
232
233 /// Returns the total number of elements in this tensor (product of all
234 /// dimensions).
235 ///
236 /// # Errors
237 ///
238 /// Returns an error if the tensor dimensions are not set.
239 pub fn volume(&self) -> Result<usize> {
240 Ok(self.shape()?.iter().product::<usize>())
241 }
242
243 /// Returns the affine quantization parameters for this tensor.
244 #[must_use]
245 pub fn quantization_params(&self) -> QuantizationParams {
246 // SAFETY: `self.ptr` is a valid tensor pointer.
247 let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr) };
248 QuantizationParams {
249 scale: params.scale,
250 zero_point: params.zero_point,
251 }
252 }
253
254 /// Returns an immutable slice over the tensor data, interpreted as
255 /// elements of type `T`.
256 ///
257 /// The slice length equals [`Tensor::volume`]. The caller must ensure
258 /// that `T` matches the tensor's actual element type (e.g., `f32` for
259 /// a `Float32` tensor, `u8` for a `UInt8` tensor).
260 ///
261 /// # Errors
262 ///
263 /// Returns an error if:
264 /// - `size_of::<T>() * volume` exceeds [`Tensor::byte_size`]
265 /// - The underlying data pointer is null (tensor not yet allocated)
266 pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
267 let volume = self.volume()?;
268 if std::mem::size_of::<T>() * volume > self.byte_size() {
269 return Err(Error::invalid_argument(format!(
270 "tensor byte size {} is too small for {} elements of {}",
271 self.byte_size(),
272 volume,
273 std::any::type_name::<T>(),
274 )));
275 }
276 // SAFETY: `self.ptr` is a valid tensor pointer.
277 let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr) };
278 if ptr.is_null() {
279 return Err(Error::null_pointer("TfLiteTensorData returned null"));
280 }
281 // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
282 // bytes (checked above). The data is valid for reads for the tensor's lifetime
283 // which is tied to the interpreter borrow. `T: Copy` ensures no drop glue.
284 Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
285 }
286}
287
288/// Formats the tensor as `"name: 1x224x224x3 Float32"`.
289impl fmt::Debug for Tensor<'_> {
290 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
291 write_tensor_debug(
292 f,
293 self.name(),
294 self.num_dims(),
295 |i| self.dim(i),
296 self.tensor_type(),
297 )
298 }
299}
300
301/// Displays the tensor as `"name: 1x224x224x3 Float32"`.
302impl fmt::Display for Tensor<'_> {
303 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
304 write_tensor_debug(
305 f,
306 self.name(),
307 self.num_dims(),
308 |i| self.dim(i),
309 self.tensor_type(),
310 )
311 }
312}
313
314// ---------------------------------------------------------------------------
315// TensorMut (mutable view)
316// ---------------------------------------------------------------------------
317
318/// A mutable view of a TensorFlow Lite tensor.
319///
320/// `TensorMut` provides all the read-only operations of [`Tensor`] plus
321/// mutable data access via [`TensorMut::as_mut_slice`] and
322/// [`TensorMut::copy_from_slice`].
323///
324/// The pointer is stored as [`NonNull`] because the C API returns
325/// `*mut TfLiteTensor` for input tensors, which must be non-null after
326/// successful interpreter creation.
327pub struct TensorMut<'a> {
328 /// Non-null pointer to the C `TfLiteTensor`.
329 pub(crate) ptr: NonNull<TfLiteTensor>,
330
331 /// Reference to the dynamically loaded `TFLite` C library.
332 pub(crate) lib: &'a sys::tensorflowlite_c,
333}
334
335impl TensorMut<'_> {
336 /// Returns the element data type of this tensor.
337 ///
338 /// If the C API returns a type value not represented by [`TensorType`],
339 /// this method defaults to [`TensorType::NoType`].
340 #[must_use]
341 pub fn tensor_type(&self) -> TensorType {
342 // SAFETY: `self.ptr` is a valid non-null tensor pointer obtained from
343 // the interpreter and `self.lib` is a valid reference to the loaded library.
344 let raw = unsafe { self.lib.TfLiteTensorType(self.ptr.as_ptr()) };
345 FromPrimitive::from_u32(raw).unwrap_or(TensorType::NoType)
346 }
347
348 /// Returns the name of this tensor as a string slice.
349 ///
350 /// Returns `"<invalid-utf8>"` if the C API returns a name that is not
351 /// valid UTF-8.
352 #[must_use]
353 pub fn name(&self) -> &str {
354 // SAFETY: `self.ptr` is a valid tensor pointer; the C API returns a
355 // NUL-terminated string that lives as long as the tensor.
356 unsafe { CStr::from_ptr(self.lib.TfLiteTensorName(self.ptr.as_ptr())) }
357 .to_str()
358 .unwrap_or("<invalid-utf8>")
359 }
360
361 /// Returns the number of dimensions (rank) of this tensor.
362 ///
363 /// # Errors
364 ///
365 /// Returns an error if the tensor does not have its dimensions set
366 /// (the C API returns -1).
367 pub fn num_dims(&self) -> Result<usize> {
368 // SAFETY: `self.ptr` is a valid tensor pointer.
369 let n = unsafe { self.lib.TfLiteTensorNumDims(self.ptr.as_ptr()) };
370 usize::try_from(n).map_err(|_| {
371 Error::invalid_argument(format!(
372 "tensor `{}` does not have dimensions set",
373 self.name()
374 ))
375 })
376 }
377
378 /// Returns the size of the `index`-th dimension.
379 ///
380 /// # Errors
381 ///
382 /// Returns an error if `index` is out of bounds (>= `num_dims`).
383 pub fn dim(&self, index: usize) -> Result<usize> {
384 let num_dims = self.num_dims()?;
385 if index >= num_dims {
386 return Err(Error::invalid_argument(format!(
387 "dimension index {index} out of bounds for tensor with {num_dims} dimensions"
388 )));
389 }
390 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
391 let i = index as i32;
392 // SAFETY: `self.ptr` is valid and `i` is bounds-checked above.
393 let d = unsafe { self.lib.TfLiteTensorDim(self.ptr.as_ptr(), i) };
394 // `d` is non-negative because the C API guarantees valid dimension
395 // sizes for in-bounds indices.
396 #[allow(clippy::cast_sign_loss)]
397 Ok(d as usize)
398 }
399
400 /// Returns the full shape of this tensor as a `Vec<usize>`.
401 ///
402 /// # Errors
403 ///
404 /// Returns an error if the tensor dimensions are not set.
405 pub fn shape(&self) -> Result<Vec<usize>> {
406 let num_dims = self.num_dims()?;
407 let mut dims = Vec::with_capacity(num_dims);
408 for i in 0..num_dims {
409 dims.push(self.dim(i)?);
410 }
411 Ok(dims)
412 }
413
414 /// Returns the total number of bytes required to store this tensor's data.
415 #[must_use]
416 pub fn byte_size(&self) -> usize {
417 // SAFETY: `self.ptr` is a valid tensor pointer.
418 unsafe { self.lib.TfLiteTensorByteSize(self.ptr.as_ptr()) }
419 }
420
421 /// Returns the total number of elements in this tensor (product of all
422 /// dimensions).
423 ///
424 /// # Errors
425 ///
426 /// Returns an error if the tensor dimensions are not set.
427 pub fn volume(&self) -> Result<usize> {
428 Ok(self.shape()?.iter().product::<usize>())
429 }
430
431 /// Returns the affine quantization parameters for this tensor.
432 #[must_use]
433 pub fn quantization_params(&self) -> QuantizationParams {
434 // SAFETY: `self.ptr` is a valid tensor pointer.
435 let params = unsafe { self.lib.TfLiteTensorQuantizationParams(self.ptr.as_ptr()) };
436 QuantizationParams {
437 scale: params.scale,
438 zero_point: params.zero_point,
439 }
440 }
441
442 /// Returns an immutable slice over the tensor data, interpreted as
443 /// elements of type `T`.
444 ///
445 /// The slice length equals [`TensorMut::volume`]. The caller must
446 /// ensure that `T` matches the tensor's actual element type (e.g.,
447 /// `f32` for a `Float32` tensor, `u8` for a `UInt8` tensor).
448 ///
449 /// # Errors
450 ///
451 /// Returns an error if:
452 /// - `size_of::<T>() * volume` exceeds [`TensorMut::byte_size`]
453 /// - The underlying data pointer is null (tensor not yet allocated)
454 pub fn as_slice<T: Copy>(&self) -> Result<&[T]> {
455 let volume = self.volume()?;
456 if std::mem::size_of::<T>() * volume > self.byte_size() {
457 return Err(Error::invalid_argument(format!(
458 "tensor byte size {} is too small for {} elements of {}",
459 self.byte_size(),
460 volume,
461 std::any::type_name::<T>(),
462 )));
463 }
464 // SAFETY: `self.ptr` is a valid tensor pointer.
465 let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
466 if ptr.is_null() {
467 return Err(Error::null_pointer("TfLiteTensorData returned null"));
468 }
469 // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
470 // bytes (checked above). The data is valid for reads for the tensor's lifetime
471 // which is tied to the interpreter borrow. `T: Copy` ensures no drop glue.
472 Ok(unsafe { std::slice::from_raw_parts(ptr.cast::<T>(), volume) })
473 }
474
475 /// Returns a mutable slice over the tensor data, interpreted as elements
476 /// of type `T`.
477 ///
478 /// The slice length equals [`TensorMut::volume`]. The caller must
479 /// ensure that `T` matches the tensor's actual element type (e.g.,
480 /// `f32` for a `Float32` tensor, `u8` for a `UInt8` tensor).
481 ///
482 /// # Errors
483 ///
484 /// Returns an error if:
485 /// - `size_of::<T>() * volume` exceeds [`TensorMut::byte_size`]
486 /// - The underlying data pointer is null (tensor not yet allocated)
487 pub fn as_mut_slice<T: Copy>(&mut self) -> Result<&mut [T]> {
488 let volume = self.volume()?;
489 if std::mem::size_of::<T>() * volume > self.byte_size() {
490 return Err(Error::invalid_argument(format!(
491 "tensor byte size {} is too small for {} elements of {}",
492 self.byte_size(),
493 volume,
494 std::any::type_name::<T>(),
495 )));
496 }
497 // SAFETY: `self.ptr` is a valid tensor pointer.
498 let ptr = unsafe { self.lib.TfLiteTensorData(self.ptr.as_ptr()) };
499 if ptr.is_null() {
500 return Err(Error::null_pointer("TfLiteTensorData returned null"));
501 }
502 // SAFETY: `ptr` is non-null and points to at least `volume * size_of::<T>()`
503 // bytes (checked above). We hold `&mut self` ensuring exclusive access.
504 // `T: Copy` ensures no drop glue.
505 Ok(unsafe { std::slice::from_raw_parts_mut(ptr.cast::<T>(), volume) })
506 }
507
508 /// Copies the contents of `data` into this tensor's buffer.
509 ///
510 /// This is a convenience wrapper around [`TensorMut::as_mut_slice`] that
511 /// copies elements from the provided slice into the tensor.
512 ///
513 /// # Errors
514 ///
515 /// Returns an error if:
516 /// - The tensor cannot be mapped as a mutable slice of `T`
517 /// - `data.len()` does not match [`TensorMut::volume`]
518 pub fn copy_from_slice<T: Copy>(&mut self, data: &[T]) -> Result<()> {
519 let slice = self.as_mut_slice::<T>()?;
520 if data.len() != slice.len() {
521 return Err(Error::invalid_argument(format!(
522 "data length {} does not match tensor volume {}",
523 data.len(),
524 slice.len(),
525 )));
526 }
527 slice.copy_from_slice(data);
528 Ok(())
529 }
530}
531
532/// Formats the tensor as `"name: 1x224x224x3 Float32"`.
533impl fmt::Debug for TensorMut<'_> {
534 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
535 write_tensor_debug(
536 f,
537 self.name(),
538 self.num_dims(),
539 |i| self.dim(i),
540 self.tensor_type(),
541 )
542 }
543}
544
545/// Displays the tensor as `"name: 1x224x224x3 Float32"`.
546impl fmt::Display for TensorMut<'_> {
547 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
548 write_tensor_debug(
549 f,
550 self.name(),
551 self.num_dims(),
552 |i| self.dim(i),
553 self.tensor_type(),
554 )
555 }
556}
557
558// ---------------------------------------------------------------------------
559// Shared formatting helper
560// ---------------------------------------------------------------------------
561
562/// Writes the common tensor representation: `"name: 1x224x224x3 Float32"`.
563///
564/// Used by both `Tensor` and `TensorMut` `Debug` and `Display` implementations
565/// to avoid code duplication.
566fn write_tensor_debug(
567 f: &mut fmt::Formatter<'_>,
568 name: &str,
569 num_dims: Result<usize>,
570 dim_fn: impl Fn(usize) -> Result<usize>,
571 tensor_type: TensorType,
572) -> fmt::Result {
573 let num_dims = num_dims.unwrap_or(0);
574 write!(f, "{name}: ")?;
575 for i in 0..num_dims {
576 if i > 0 {
577 f.write_str("x")?;
578 }
579 write!(f, "{}", dim_fn(i).unwrap_or(0))?;
580 }
581 write!(f, " {tensor_type:?}")
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587
588 use std::collections::HashSet;
589
590 // -----------------------------------------------------------------------
591 // TensorType -- FromPrimitive conversion
592 // -----------------------------------------------------------------------
593
594 #[test]
595 fn tensor_type_from_primitive_all_variants() {
596 let cases: &[(isize, TensorType)] = &[
597 (0, TensorType::NoType),
598 (1, TensorType::Float32),
599 (2, TensorType::Int32),
600 (3, TensorType::UInt8),
601 (4, TensorType::Int64),
602 (5, TensorType::String),
603 (6, TensorType::Bool),
604 (7, TensorType::Int16),
605 (8, TensorType::Complex64),
606 (9, TensorType::Int8),
607 (10, TensorType::Float16),
608 (11, TensorType::Float64),
609 (12, TensorType::Complex128),
610 (13, TensorType::UInt64),
611 (14, TensorType::Resource),
612 (15, TensorType::Variant),
613 (16, TensorType::UInt32),
614 (17, TensorType::UInt16),
615 (18, TensorType::Int4),
616 (19, TensorType::BFloat16),
617 ];
618
619 for &(raw, expected) in cases {
620 let result = TensorType::from_isize(raw);
621 assert_eq!(
622 result,
623 Some(expected),
624 "TensorType::from_isize({raw}) should be Some({expected:?})"
625 );
626 }
627 }
628
629 #[test]
630 fn tensor_type_from_u32_all_variants() {
631 for raw in 0u32..=19 {
632 let result = TensorType::from_u32(raw);
633 assert!(
634 result.is_some(),
635 "TensorType::from_u32({raw}) should be Some"
636 );
637 }
638 }
639
640 #[test]
641 fn tensor_type_unknown_value_returns_none() {
642 assert_eq!(TensorType::from_isize(999), None);
643 assert_eq!(TensorType::from_u32(999), None);
644 assert_eq!(TensorType::from_isize(-1), None);
645 assert_eq!(TensorType::from_isize(20), None);
646 }
647
648 // -----------------------------------------------------------------------
649 // TensorType -- Clone, PartialEq, Hash
650 // -----------------------------------------------------------------------
651
652 #[test]
653 fn tensor_type_clone() {
654 let original = TensorType::Float32;
655 let cloned = original;
656 assert_eq!(original, cloned);
657 }
658
659 #[test]
660 fn tensor_type_partial_eq() {
661 assert_eq!(TensorType::Int8, TensorType::Int8);
662 assert_ne!(TensorType::Int8, TensorType::UInt8);
663 }
664
665 #[test]
666 fn tensor_type_hash() {
667 let mut set = HashSet::new();
668 set.insert(TensorType::Float32);
669 set.insert(TensorType::Float32);
670 set.insert(TensorType::Int32);
671 assert_eq!(set.len(), 2);
672 }
673
674 #[test]
675 fn tensor_type_all_variants_unique_in_hashset() {
676 let all = [
677 TensorType::NoType,
678 TensorType::Float32,
679 TensorType::Int32,
680 TensorType::UInt8,
681 TensorType::Int64,
682 TensorType::String,
683 TensorType::Bool,
684 TensorType::Int16,
685 TensorType::Complex64,
686 TensorType::Int8,
687 TensorType::Float16,
688 TensorType::Float64,
689 TensorType::Complex128,
690 TensorType::UInt64,
691 TensorType::Resource,
692 TensorType::Variant,
693 TensorType::UInt32,
694 TensorType::UInt16,
695 TensorType::Int4,
696 TensorType::BFloat16,
697 ];
698 let set: HashSet<_> = all.iter().copied().collect();
699 assert_eq!(set.len(), 20);
700 }
701
702 // -----------------------------------------------------------------------
703 // TensorType -- Debug formatting
704 // -----------------------------------------------------------------------
705
706 #[test]
707 fn tensor_type_debug_format() {
708 assert_eq!(format!("{:?}", TensorType::Float32), "Float32");
709 assert_eq!(format!("{:?}", TensorType::NoType), "NoType");
710 assert_eq!(format!("{:?}", TensorType::BFloat16), "BFloat16");
711 assert_eq!(format!("{:?}", TensorType::Complex128), "Complex128");
712 }
713
714 // -----------------------------------------------------------------------
715 // QuantizationParams -- construction and field access
716 // -----------------------------------------------------------------------
717
718 #[test]
719 fn quantization_params_construction() {
720 let params = QuantizationParams {
721 scale: 0.5,
722 zero_point: 128,
723 };
724 assert!((params.scale - 0.5).abs() < f32::EPSILON);
725 assert_eq!(params.zero_point, 128);
726 }
727
728 #[test]
729 fn quantization_params_zero_values() {
730 let params = QuantizationParams {
731 scale: 0.0,
732 zero_point: 0,
733 };
734 assert!((params.scale - 0.0).abs() < f32::EPSILON);
735 assert_eq!(params.zero_point, 0);
736 }
737
738 #[test]
739 fn quantization_params_negative_zero_point() {
740 let params = QuantizationParams {
741 scale: 0.007_812_5,
742 zero_point: -128,
743 };
744 assert!((params.scale - 0.007_812_5).abs() < f32::EPSILON);
745 assert_eq!(params.zero_point, -128);
746 }
747
748 // -----------------------------------------------------------------------
749 // QuantizationParams -- Debug, Clone, PartialEq
750 // -----------------------------------------------------------------------
751
752 #[test]
753 fn quantization_params_debug() {
754 let params = QuantizationParams {
755 scale: 1.0,
756 zero_point: 0,
757 };
758 let debug = format!("{params:?}");
759 assert!(debug.contains("QuantizationParams"));
760 assert!(debug.contains("scale"));
761 assert!(debug.contains("zero_point"));
762 }
763
764 #[test]
765 fn quantization_params_clone() {
766 let original = QuantizationParams {
767 scale: 0.25,
768 zero_point: 64,
769 };
770 let cloned = original;
771 assert_eq!(original, cloned);
772 }
773
774 #[test]
775 fn quantization_params_partial_eq() {
776 let a = QuantizationParams {
777 scale: 0.5,
778 zero_point: 128,
779 };
780 let b = QuantizationParams {
781 scale: 0.5,
782 zero_point: 128,
783 };
784 let c = QuantizationParams {
785 scale: 0.25,
786 zero_point: 128,
787 };
788 assert_eq!(a, b);
789 assert_ne!(a, c);
790 }
791}