chess_corners_ml/lib.rs
1//! ONNX-backed ML refiner for ChESS corner candidates.
2//!
3//! **Internal crate — not published to crates.io.**
4//! `chess-corners-ml` is an implementation detail of the chess-corners
5//! workspace. It backs the optional `ml-refiner` feature of the
6//! `chess-corners` facade crate and is not a public API contract; its
7//! surface may change without semver consideration.
8//!
9//! This crate provides [`MlModel`], a thin wrapper around a
10//! [tract-onnx](https://docs.rs/tract-onnx) runtime that predicts
11//! subpixel `(dx, dy)` offsets for each corner candidate from a
12//! normalized intensity patch.
13//!
14//! # Intended use
15//!
16//! This crate is not meant to be used directly. It is consumed by the
17//! `chess-corners` facade crate when the `ml-refiner` feature is
18//! enabled. With the feature on, set the active ChESS refiner to
19//! `ChessRefiner::Ml` and call `Detector::detect` to route through
20//! the ML refiner.
21//!
22//! # Embedded model
23//!
24//! When the optional `embed-model` feature is enabled, the ONNX model
25//! and its external data file are compiled into the binary via
26//! `include_bytes!` and extracted to a temporary directory on first
27//! use. The extraction is thread-safe and idempotent (write-then-rename
28//! with byte-match skip).
29//!
30//! # Performance note
31//!
32//! ML refinement is significantly slower than the geometric refiners
33//! (~24 ms vs <1 ms for 77 corners on a 640×480 image). Use it only
34//! when maximum subpixel accuracy is required and throughput allows.
35
36use anyhow::{anyhow, Context, Result};
37use std::path::{Path, PathBuf};
38#[cfg(feature = "embed-model")]
39use std::sync::OnceLock;
40use tract_onnx::prelude::tract_ndarray::{Array4, Ix2};
41use tract_onnx::prelude::*;
42
43/// Specifies where [`MlModel::load`] should read the ONNX model from.
44#[derive(Clone, Debug)]
45pub enum ModelSource {
46 /// Load from an explicit filesystem path to the `.onnx` file.
47 /// A `fixtures/meta.json` sidecar next to the model's parent directory
48 /// is read to determine the patch size; falls back to the compiled-in
49 /// default (21 px) when absent.
50 Path(PathBuf),
51 /// Use the model compiled into the binary via the `embed-model`
52 /// Cargo feature. Returns an error when that feature is not enabled.
53 EmbeddedDefault,
54}
55
56/// Loaded and optimised ONNX model for corner refinement.
57///
58/// The model accepts a batch of `f32` intensity patches with shape
59/// `[N, 1, patch_size, patch_size]` (values in `[0, 1]`) and returns
60/// `[N, 3]` with columns `[dx, dy, conf_logit]`. Only `dx` and `dy`
61/// are currently used; `conf_logit` is ignored.
62pub struct MlModel {
63 model: TypedRunnableModel<TypedModel>,
64 patch_size: usize,
65 // `SymbolScope` owns the `Symbol` object for the dynamic batch
66 // dimension "N". Dropping it before `model` would leave the compiled
67 // graph with a dangling reference to the scope's internal table, so
68 // this field must be kept alive for the lifetime of `MlModel` even
69 // though it is never explicitly read after construction.
70 #[allow(dead_code)]
71 symbols: SymbolScope,
72}
73
74impl MlModel {
75 /// Load and optimise an ONNX model from the given source.
76 ///
77 /// For [`ModelSource::EmbeddedDefault`] the `embed-model` Cargo
78 /// feature must be enabled; an error is returned otherwise.
79 ///
80 /// # Errors
81 ///
82 /// Returns an error if the model file cannot be read, the ONNX
83 /// graph is malformed, or tract optimisation / compilation fails.
84 pub fn load(source: ModelSource) -> Result<Self> {
85 let (model_path, patch_size) = match source {
86 ModelSource::Path(path) => {
87 let patch_size =
88 patch_size_from_meta_path(&path).unwrap_or_else(default_patch_size);
89 (path, patch_size)
90 }
91 ModelSource::EmbeddedDefault => {
92 #[cfg(feature = "embed-model")]
93 {
94 let patch_size = patch_size_from_meta_bytes(EMBED_META_JSON)
95 .unwrap_or_else(|_| default_patch_size());
96 let path = embedded_model_path()?;
97 (path, patch_size)
98 }
99 #[cfg(not(feature = "embed-model"))]
100 {
101 return Err(anyhow!(
102 "embedded model support disabled; enable feature \"embed-model\""
103 ));
104 }
105 }
106 };
107
108 let mut model = tract_onnx::onnx()
109 .model_for_path(&model_path)
110 .with_context(|| format!("load ONNX model from {}", model_path.display()))?;
111 let symbols = SymbolScope::default();
112 let batch = symbols.sym("N");
113 let shape = tvec!(
114 batch.to_dim(),
115 1.to_dim(),
116 (patch_size as i64).to_dim(),
117 (patch_size as i64).to_dim()
118 );
119 model
120 .set_input_fact(0, InferenceFact::dt_shape(f32::datum_type(), shape))
121 .context("set ML refiner input fact")?;
122 let model = model
123 .into_optimized()
124 .context("optimize ONNX model")?
125 .into_runnable()
126 .context("make ONNX model runnable")?;
127
128 Ok(Self {
129 model,
130 patch_size,
131 symbols,
132 })
133 }
134
135 /// Side length (in pixels) of the square intensity patch the model expects.
136 pub fn patch_size(&self) -> usize {
137 self.patch_size
138 }
139
140 /// Run inference on a flat batch of intensity patches.
141 ///
142 /// `patches` must contain exactly `batch * patch_size * patch_size`
143 /// `f32` values in `[N, 1, H, W]` order (values in `[0, 1]`).
144 /// Returns one `[dx, dy, conf_logit]` triple per input patch.
145 ///
146 /// # Errors
147 ///
148 /// Returns an error if the slice length does not match
149 /// `batch * patch_size²`, if the ONNX output shape is unexpected,
150 /// or if tract inference fails.
151 pub fn infer_batch(&self, patches: &[f32], batch: usize) -> Result<Vec<[f32; 3]>> {
152 if batch == 0 {
153 return Ok(Vec::new());
154 }
155 let patch_area = self.patch_size * self.patch_size;
156 let expected = batch * patch_area;
157 if patches.len() != expected {
158 return Err(anyhow!(
159 "expected {} floats (batch {} * patch {}x{}), got {}",
160 expected,
161 batch,
162 self.patch_size,
163 self.patch_size,
164 patches.len()
165 ));
166 }
167
168 let input = Array4::from_shape_vec(
169 (batch, 1, self.patch_size, self.patch_size),
170 patches.to_vec(),
171 )
172 .context("reshape input patches")?
173 .into_tensor();
174 let result = self
175 .model
176 .run(tvec!(input.into_tvalue()))
177 .context("run ONNX inference")?;
178 let output = result[0]
179 .to_array_view::<f32>()
180 .context("read ONNX output")?
181 .into_dimensionality::<Ix2>()
182 .context("reshape ONNX output")?;
183
184 if output.ncols() != 3 {
185 return Err(anyhow!(
186 "expected output shape [N,3], got [N,{}]",
187 output.ncols()
188 ));
189 }
190
191 let mut out = Vec::with_capacity(batch);
192 for row in output.outer_iter() {
193 out.push([row[0], row[1], row[2]]);
194 }
195 Ok(out)
196 }
197}
198
199fn patch_size_from_meta_bytes(bytes: &[u8]) -> Result<usize> {
200 let meta: serde_json::Value =
201 serde_json::from_slice(bytes).context("parse ML refiner meta.json")?;
202 let size = meta
203 .get("patch_size")
204 .and_then(|v| v.as_u64())
205 .ok_or_else(|| anyhow!("meta.json missing patch_size"))?;
206 Ok(size as usize)
207}
208
209fn patch_size_from_meta_path(path: &Path) -> Option<usize> {
210 let meta_path = path.parent()?.join("fixtures").join("meta.json");
211 let bytes = std::fs::read(meta_path).ok()?;
212 patch_size_from_meta_bytes(&bytes).ok()
213}
214
215fn default_patch_size() -> usize {
216 #[cfg(feature = "embed-model")]
217 {
218 patch_size_from_meta_bytes(EMBED_META_JSON).unwrap_or(21)
219 }
220 #[cfg(not(feature = "embed-model"))]
221 {
222 21
223 }
224}
225
226#[cfg(feature = "embed-model")]
227const EMBED_ONNX_NAME: &str = "chess_refiner_v4.onnx";
228#[cfg(feature = "embed-model")]
229const EMBED_ONNX_DATA_NAME: &str = "chess_refiner_v4.onnx.data";
230
231#[cfg(feature = "embed-model")]
232const EMBED_ONNX: &[u8] = include_bytes!(concat!(
233 env!("CARGO_MANIFEST_DIR"),
234 "/assets/ml/chess_refiner_v4.onnx"
235));
236#[cfg(feature = "embed-model")]
237const EMBED_ONNX_DATA: &[u8] = include_bytes!(concat!(
238 env!("CARGO_MANIFEST_DIR"),
239 "/assets/ml/chess_refiner_v4.onnx.data"
240));
241#[cfg(feature = "embed-model")]
242const EMBED_META_JSON: &[u8] = include_bytes!(concat!(
243 env!("CARGO_MANIFEST_DIR"),
244 "/assets/ml/fixtures/v4/meta.json"
245));
246
247#[cfg(feature = "embed-model")]
248fn embedded_model_path() -> Result<PathBuf> {
249 // `OnceLock::get_or_init` serializes the writes across threads in
250 // this process. Without it, parallel `#[test]` runs all entered
251 // `write_if_changed`, the second `std::fs::write` truncated the
252 // file to 0 bytes mid-rewrite, and a concurrent `tract_onnx`
253 // model load saw an empty `.data` slice and panicked
254 // (`range start index 768 out of range for slice of length 0`).
255 //
256 // For cross-process races (e.g. `cargo test -p A` and
257 // `cargo test -p B` sharing `/tmp/chess_corners_ml/`), the
258 // atomic write-then-rename in `write_if_changed` ensures the
259 // file is either at its old contents or at its new contents,
260 // never partially written.
261 static PATH: OnceLock<PathBuf> = OnceLock::new();
262 let path = PATH.get_or_init(|| {
263 let dir = std::env::temp_dir().join("chess_corners_ml");
264 std::fs::create_dir_all(&dir).expect("create ML model temp dir");
265 let onnx_path = dir.join(EMBED_ONNX_NAME);
266 let data_path = dir.join(EMBED_ONNX_DATA_NAME);
267 // Write `.data` before `.onnx` so tract never sees an `.onnx`
268 // that references a missing or partially-written `.data`.
269 write_if_changed(&data_path, EMBED_ONNX_DATA).expect("write embedded ONNX data");
270 write_if_changed(&onnx_path, EMBED_ONNX).expect("write embedded ONNX model");
271 onnx_path
272 });
273 Ok(path.clone())
274}
275
276/// Write `data` to `path` only if the file doesn't already contain
277/// the same bytes. Uses write-then-rename so concurrent readers see
278/// either the old contents or the new contents — never a truncated /
279/// partially-written file. Cheap-out via the byte-match check avoids
280/// rewriting unchanged files across re-runs in a shared temp dir.
281#[cfg(feature = "embed-model")]
282fn write_if_changed(path: &std::path::Path, data: &[u8]) -> std::io::Result<()> {
283 if let Ok(meta) = std::fs::metadata(path) {
284 if meta.len() == data.len() as u64 {
285 if let Ok(existing) = std::fs::read(path) {
286 if existing == data {
287 return Ok(());
288 }
289 }
290 }
291 }
292 let tmp = path.with_extension("tmp");
293 std::fs::write(&tmp, data)?;
294 std::fs::rename(&tmp, path)
295}
296
297#[cfg(all(test, feature = "embed-model"))]
298mod tests {
299 use super::write_if_changed;
300
301 #[test]
302 fn write_if_changed_rewrites_same_size_changed_bytes() {
303 let dir = tempfile::tempdir().expect("tempdir");
304 let path = dir.path().join("model.bin");
305
306 write_if_changed(&path, b"abc").expect("initial write");
307 write_if_changed(&path, b"xyz").expect("rewrite same-size bytes");
308
309 let bytes = std::fs::read(&path).expect("read rewritten bytes");
310 assert_eq!(bytes, b"xyz");
311 }
312}