Skip to main content

rknn_runtime/
inference.rs

1//! High-level inference API.
2//!
3//! This module contains [`RknnModel`] - the main entry point for loading
4//! and running RKNN models.
5
6use crate::context::RknnContext;
7use crate::error::Error;
8use crate::memory::ZeroCopyMem;
9use crate::tensor::{dequantize_affine, TensorAttr};
10
11const DEFAULT_LIB_PATH: &str = "/usr/lib/librknnmrt.so";
12
13/// A loaded RKNN model ready for inference.
14///
15/// This is the main type you interact with. It holds the model, pre-allocated
16/// zero-copy memory buffers for input and outputs, and handles all
17/// communication with the NPU.
18///
19/// # Lifecycle
20///
21/// 1. **Load** a model with [`load`](Self::load) or [`load_with_lib`](Self::load_with_lib).
22///    This initializes the NPU context and allocates memory buffers.
23/// 2. **Inspect** tensor metadata via [`input_attr`](Self::input_attr) and
24///    [`output_attrs`](Self::output_attrs) to learn expected shapes and formats.
25/// 3. **Run** inference with [`run`](Self::run) - pass raw RGB bytes (NHWC, u8,
26///    no normalization).
27/// 4. **Read** results with [`output_raw`](Self::output_raw) (zero-copy `&[i8]`)
28///    or [`output_f32`](Self::output_f32) (dequantized `Vec<f32>`).
29///
30/// # Example
31///
32/// ```rust,no_run
33/// use rknn_runtime::RknnModel;
34///
35/// let model = RknnModel::load("model.rknn")?;
36///
37/// // Check what the model expects
38/// let input = model.input_attr();
39/// // e.g. [1, 320, 320, 3]
40/// println!("Input shape: {:?}", input.shape);
41///
42/// // Run inference
43/// # let rgb_bytes = vec![0u8; 320 * 320 * 3];
44/// model.run(&rgb_bytes)?;
45///
46/// // Get raw INT8 output (zero-copy - no allocation, just a slice into NPU memory)
47/// let raw = model.output_raw(0)?;
48///
49/// // Or get dequantized f32 output (allocates a new Vec)
50/// let floats = model.output_f32(0)?;
51/// # Ok::<(), rknn_runtime::Error>(())
52/// ```
53///
54/// # Drop order
55///
56/// Internally, memory buffers are dropped before the RKNN context.
57/// This is handled automatically - you don't need to worry about it.
58pub struct RknnModel {
59    output_mems: Vec<ZeroCopyMem>,
60    input_mem: ZeroCopyMem,
61    rknn: RknnContext,
62    input_attr: TensorAttr,
63    output_attrs: Vec<TensorAttr>,
64}
65
66impl RknnModel {
67    /// Load a `.rknn` model from a file.
68    ///
69    /// Uses the default library path (`/usr/lib/librknnmrt.so`).
70    /// If your `librknnmrt.so` is elsewhere, use [`load_with_lib`](Self::load_with_lib).
71    ///
72    /// # Errors
73    ///
74    /// - [`Error::IoError`] if the file cannot be read.
75    /// - [`Error::LibraryNotFound`] if `librknnmrt.so` is not found.
76    /// - [`Error::InitFailed`] if the NPU rejects the model.
77    pub fn load(model_path: &str) -> Result<Self, Error> {
78        Self::load_with_lib(model_path, DEFAULT_LIB_PATH)
79    }
80
81    /// Load a `.rknn` model from a file, using a custom library path.
82    ///
83    /// ```rust,no_run
84    /// # use rknn_runtime::RknnModel;
85    /// let model = RknnModel::load_with_lib(
86    ///     "model.rknn",
87    ///     "/opt/rknn/lib/librknnmrt.so",
88    /// )?;
89    /// # Ok::<(), rknn_runtime::Error>(())
90    /// ```
91    pub fn load_with_lib(model_path: &str, lib_path: &str) -> Result<Self, Error> {
92        let model_data = std::fs::read(model_path)?;
93        Self::load_from_bytes(&model_data, lib_path)
94    }
95
96    /// Load a model from raw bytes already in memory.
97    ///
98    /// Useful when the `.rknn` file is embedded in your binary or received
99    /// over the network.
100    pub fn load_from_bytes(model_data: &[u8], lib_path: &str) -> Result<Self, Error> {
101        let rknn = RknnContext::load(model_data, lib_path)?;
102        let (n_input, n_output) = rknn.query_io_num()?;
103
104        // Query input attributes (NHWC native for zero-copy)
105        // We only support single-input models for now
106        let raw_input_attr = rknn.query_input_attr_nhwc(0)?;
107        let input_attr = TensorAttr::from(&raw_input_attr);
108
109        // Query output attributes (native format for zero-copy)
110        let mut raw_output_attrs = Vec::with_capacity(n_output as usize);
111        let mut output_attrs = Vec::with_capacity(n_output as usize);
112        for i in 0..n_output {
113            let attr = rknn.query_output_attr_native(i)?;
114            output_attrs.push(TensorAttr::from(&attr));
115            raw_output_attrs.push(attr);
116        }
117
118        // Allocate zero-copy memory
119        let input_mem = ZeroCopyMem::new(&rknn, raw_input_attr)?;
120
121        let mut output_mems = Vec::with_capacity(n_output as usize);
122        for attr in raw_output_attrs {
123            output_mems.push(ZeroCopyMem::new(&rknn, attr)?);
124        }
125
126        // Suppress unused variable warning for single-input models
127        let _ = n_input;
128
129        Ok(Self {
130            rknn,
131            input_mem,
132            output_mems,
133            input_attr,
134            output_attrs,
135        })
136    }
137
138    /// Input tensor metadata (shape, format, data type).
139    ///
140    /// The shape is typically `[1, H, W, 3]` (NHWC).
141    /// Use this to know what image size the model expects:
142    ///
143    /// ```rust,no_run
144    /// # use rknn_runtime::RknnModel;
145    /// # let model = RknnModel::load("m.rknn").unwrap();
146    /// let input = model.input_attr();
147    /// let (h, w) = (input.shape[1], input.shape[2]);
148    /// println!("Model expects {}x{} RGB image", h, w);
149    /// ```
150    pub fn input_attr(&self) -> &TensorAttr {
151        &self.input_attr
152    }
153
154    /// Output tensor metadata for all outputs.
155    ///
156    /// Most models have a single output, but some could have several.
157    /// Each [`TensorAttr`] contains the shape, format, quantization zero-point
158    /// and scale - everything you need to decode the output.
159    pub fn output_attrs(&self) -> &[TensorAttr] {
160        &self.output_attrs
161    }
162
163    /// Run inference on the NPU.
164    ///
165    /// `input` must be raw RGB bytes in **NHWC** format (`[1, H, W, 3]`).
166    /// No normalization, no channel reordering - just plain `u8` pixel values.
167    ///
168    /// After this returns, read results with [`output_raw`](Self::output_raw)
169    /// or [`output_f32`](Self::output_f32).
170    ///
171    /// # What happens inside
172    ///
173    /// 1. Copies `input` bytes into the pre-allocated NPU input buffer.
174    /// 2. Calls `rknn_run()` - the NPU executes the model.
175    /// 3. Calls `rknn_mem_sync()` on each output buffer (syncs NPU cache to CPU).
176    /// In my case: this step is critical on RV1106 - without it, I get stale data.
177    ///
178    pub fn run(&self, input: &[u8]) -> Result<(), Error> {
179        // Copy input data to zero-copy memory
180        self.input_mem.write(input);
181
182        // Run NPU inference
183        let ret = unsafe {
184            (self.rknn.funcs.run)(self.rknn.ctx, std::ptr::null())
185        };
186        if ret != 0 {
187            return Err(Error::InferenceFailed(ret));
188        }
189
190        // Sync all output memories from NPU to CPU
191        for mem in &self.output_mems {
192            mem.sync_from_device(&self.rknn)?;
193        }
194
195        Ok(())
196    }
197
198    /// Raw INT8 output data for the given output index.
199    ///
200    /// Returns a slice pointing directly into the NPU's zero-copy buffer.
201    /// No allocation, no copying - this is as fast as it gets.
202    ///
203    /// The data is in whatever layout the NPU uses (often NC1HWC2).
204    /// Use [`nc1hwc2_to_flat`](crate::nc1hwc2_to_flat) to convert it
205    /// to standard NCHW if needed.
206    ///
207    /// # Errors
208    ///
209    /// Returns [`Error::InvalidIndex`] if `index` is out of range.
210    /// 
211    pub fn output_raw(&self, index: usize) -> Result<&[i8], Error> {
212        if index >= self.output_mems.len() {
213            return Err(Error::InvalidIndex {
214                requested: index,
215                available: self.output_mems.len(),
216            });
217        }
218        Ok(self.output_mems[index].as_i8_slice())
219    }
220
221    /// Dequantized f32 output for the given output index.
222    ///
223    /// Converts each raw INT8 value to f32 using affine dequantization:
224    ///
225    /// ```text
226    /// value = (raw_i8 - zero_point) * scale
227    /// ```
228    ///
229    /// Zero-point and scale are read from the tensor's quantization parameters
230    /// (set during model conversion).
231    ///
232    /// **Note:** This allocates a new `Vec<f32>`. If you need to dequantize
233    /// only part of the output (e.g. after NC1HWC2 conversion), use
234    /// [`dequantize_affine`] directly.
235    /// 
236    pub fn output_f32(&self, index: usize) -> Result<Vec<f32>, Error> {
237        let raw = self.output_raw(index)?;
238        let attr = &self.output_attrs[index];
239        Ok(dequantize_affine(raw, attr.zp, attr.scale))
240    }
241}