Skip to main content

rknn_runtime/
tensor.rs

1/// Convert raw INT8 values to f32 using affine dequantization.
2///
3/// RKNN models quantize weights and activations to INT8 during conversion.
4/// Each tensor has a zero-point (`zp`) and scale stored in its
5/// [`TensorAttr`]. This function reverses the quantization:
6///
7/// ```text
8/// f32_value = (raw_i8 - zero_point) * scale
9/// ```
10///
11/// # Example
12///
13/// ```
14/// use rknn_runtime::dequantize_affine;
15///
16/// let raw = vec![10i8, 20, 30];
17/// let zp = 5;
18/// let scale = 0.1;
19/// let result = dequantize_affine(&raw, zp, scale);
20/// assert_eq!(result, vec![0.5, 1.5, 2.5]);
21/// ```
22/// 
23pub fn dequantize_affine(data: &[i8], zp: i32, scale: f32) -> Vec<f32> {
24    data.iter()
25        .map(|&v| (v as f32 - zp as f32) * scale)
26        .collect()
27}
28
29/// Convert NC1HWC2 tensor layout to flat NCHW order.
30///
31/// RKNN NPU stores output tensors in a packed format called NC1HWC2.
32/// Instead of laying out channels sequentially (like NCHW), it groups
33/// them into blocks of `c2` (typically 16):
34///
35/// ```text
36/// NC1HWC2 shape: [1, c1, H, W, c2]
37///
38/// c1 = ceil(total_channels / c2)
39/// Actual channels used: total_channels (the rest is padding)
40/// ```
41///
42/// This function unpacks that into a flat `[total_channels * H * W]` array
43/// in standard NCHW order, so you can index it as:
44///
45/// ```text
46/// value = output[channel * H * W + y * W + x]
47/// ```
48///
49/// Works with both `i8` (raw INT8 output) and `f32` (after dequantization).
50///
51/// # Arguments
52///
53/// - `data` - raw tensor data in NC1HWC2 layout
54/// - `c1` - number of channel blocks (shape\[1\])
55/// - `h` - height dimension (shape\[2\])
56/// - `w` - width dimension (shape\[3\])
57/// - `c2` - channels per block, typically 16 (shape\[4\])
58/// - `total_channels` - actual number of channels (e.g. 84 for YOLOv8 with 80 classes)
59///
60/// # Example
61///
62/// ```
63/// use rknn_runtime::nc1hwc2_to_flat;
64///
65/// // 4 channels packed into blocks of 2 (c1=2, c2=2), spatial 1x1
66/// let nc1hwc2_data: Vec<i8> = vec![
67///     10, 20, // block 0: channels 0, 1
68///     30, 40, // block 1: channels 2, 3
69/// ];
70/// let flat = nc1hwc2_to_flat(&nc1hwc2_data, 2, 1, 1, 2, 4);
71/// assert_eq!(flat, vec![10, 20, 30, 40]);
72/// ```
73/// 
74pub fn nc1hwc2_to_flat<T: Copy + Default>(
75    data: &[T],
76    c1: usize,
77    h: usize,
78    w: usize,
79    c2: usize,
80    total_channels: usize,
81) -> Vec<T> {
82    let mut out = vec![T::default(); total_channels * h * w];
83    for c1_idx in 0..c1 {
84        for y in 0..h {
85            for x in 0..w {
86                for c2_idx in 0..c2 {
87                    let ch = c1_idx * c2 + c2_idx;
88                    if ch >= total_channels {
89                        continue;
90                    }
91                    let src_offset =
92                        ((c1_idx * h + y) * w + x) * c2 + c2_idx;
93                    let dst_offset = ch * h * w + y * w + x;
94                    if src_offset < data.len() {
95                        out[dst_offset] = data[src_offset];
96                    }
97                }
98            }
99        }
100    }
101    out
102}
103
104/// Metadata for a single tensor (input or output).
105///
106/// Contains everything you need to interpret the tensor data:
107/// shape, memory layout, data type, and quantization parameters.
108///
109/// # Quantization fields
110///
111/// - `zp` (zero-point) and `scale` are used for INT8 affine dequantization:
112///   `f32_value = (raw_i8 - zp) * scale`
113/// - These are set during model conversion and are different for each tensor.
114///
115/// # Shape
116///
117/// For NC1HWC2 outputs (common on RV1106), the shape is `[1, c1, H, W, c2]`.
118/// For NHWC inputs, the shape is `[1, H, W, C]`.
119/// 
120#[derive(Debug, Clone)]
121pub struct TensorAttr {
122    /// Tensor index (0 for first input/output, 1 for second, etc.).
123    pub index: u32,
124    /// Tensor dimensions. See [`TensorFormat`] for how to interpret them.
125    pub shape: Vec<u32>,
126    /// Total number of elements in the tensor.
127    pub n_elems: u32,
128    /// Size in bytes.
129    pub size: u32,
130    /// Size in bytes including stride padding (used for memory allocation).
131    pub size_with_stride: u32,
132    /// Memory layout of the tensor data.
133    pub format: TensorFormat,
134    /// Element data type.
135    pub data_type: TensorType,
136    /// Quantization method.
137    pub qnt_type: QuantType,
138    /// Quantization zero-point (for affine dequantization).
139    pub zp: i32,
140    /// Quantization scale (for affine dequantization).
141    pub scale: f32,
142    /// Human-readable tensor name from the model.
143    pub name: String,
144}
145
146/// Memory layout of a tensor.
147///
148/// Describes how tensor data is arranged in memory.
149///
150/// - **NCHW** - channels first. Common in PyTorch. Shape: `[batch, channels, height, width]`.
151/// - **NHWC** - channels last. Used for RKNN inputs. Shape: `[batch, height, width, channels]`.
152/// - **NC1HWC2** - RKNN NPU packed format. Channels are split into blocks.
153///   Shape: `[batch, c1, height, width, c2]`. Use [`nc1hwc2_to_flat`] to convert.
154/// 
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
156pub enum TensorFormat {
157    /// Channels first: `[N, C, H, W]`.
158    NCHW,
159    /// Channels last: `[N, H, W, C]`. Standard for RKNN inputs.
160    NHWC,
161    /// NPU packed format: `[N, c1, H, W, c2]`. Common for RKNN outputs on RV1106.
162    NC1HWC2,
163    /// Unknown or unsupported format.
164    Undefined,
165}
166
167/// Just an alias for TensorFormat
168impl From<u32> for TensorFormat {
169    fn from(v: u32) -> Self {
170        match v {
171            0 => TensorFormat::NCHW,
172            1 => TensorFormat::NHWC,
173            2 => TensorFormat::NC1HWC2,
174            _ => TensorFormat::Undefined,
175        }
176    }
177}
178
179/// Element data type of a tensor.
180///
181/// INT8 quantized models (the most common on RKNN) use [`Int8`](Self::Int8) for outputs
182/// and [`Uint8`](Self::Uint8) for inputs.
183/// 
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum TensorType {
186    Float32,
187    Float16,
188    Int8,
189    Uint8,
190    Int16,
191    Int32,
192    /// Unrecognized type ID from the RKNN runtime.
193    Unknown(u32),
194}
195
196/// Just an alias for TensorType
197impl From<u32> for TensorType {
198    fn from(v: u32) -> Self {
199        match v {
200            0 => TensorType::Float32,
201            1 => TensorType::Float16,
202            2 => TensorType::Int8,
203            3 => TensorType::Uint8,
204            4 => TensorType::Int16,
205            5 => TensorType::Int32,
206            other => TensorType::Unknown(other),
207        }
208    }
209}
210
211/// Quantization method used for a tensor.
212///
213/// Most RKNN INT8 models use [`Affine`](Self::Affine) quantization,
214/// where each value is converted via `f32 = (i8 - zp) * scale`.
215/// 
216#[derive(Debug, Clone, Copy, PartialEq, Eq)]
217pub enum QuantType {
218    /// No quantization (float model).
219    None,
220    /// Dynamic fixed-point quantization.
221    Dfp,
222    /// Affine quantization: `value = (raw - zp) * scale`. The most common type.
223    Affine,
224    /// Unrecognized quantization type ID from the RKNN runtime.
225    Unknown(u32),
226}
227
228/// Just an alias for QuantType
229impl From<u32> for QuantType {
230    fn from(v: u32) -> Self {
231        match v {
232            0 => QuantType::None,
233            1 => QuantType::Dfp,
234            2 => QuantType::Affine,
235            other => QuantType::Unknown(other),
236        }
237    }
238}
239
240/// Implements `From` for TensorAttr
241impl From<&crate::ffi::RknnTensorAttr> for TensorAttr {
242    fn from(raw: &crate::ffi::RknnTensorAttr) -> Self {
243        Self {
244            index: raw.index,
245            shape: raw.shape().to_vec(),
246            n_elems: raw.n_elems,
247            size: raw.size,
248            size_with_stride: raw.size_with_stride,
249            format: TensorFormat::from(raw.fmt),
250            data_type: TensorType::from(raw.type_),
251            qnt_type: QuantType::from(raw.qnt_type),
252            zp: raw.zp,
253            scale: raw.scale,
254            name: raw.name_str().to_string(),
255        }
256    }
257}