Skip to main content

cog_pose_estimation/
inference.rs

1//! Inference engine — loads `pose_v1.safetensors` (produced by the
2//! Candle training run on `ruvultra`'s RTX 5080, see
3//! `cog/artifacts/pose_v1.safetensors` + `docs/benchmarks/pose-estimation-cog.md`)
4//! and runs the encoder + pose head on each CSI window.
5//!
6//! Architecture mirrors the training script exactly:
7//!     Conv1d(56 -> 64,  k=3, dilation=1, padding=1)
8//!     Conv1d(64 -> 128, k=3, dilation=2, padding=2)
9//!     Conv1d(128 -> 128, k=3, dilation=4, padding=4)
10//!     mean over time -> [128]
11//!     Linear(128 -> 256) -> ReLU
12//!     Linear(256 -> 34)  -> sigmoid -> reshape [17, 2]
13//!
14//! When the safetensors file is missing the engine falls back to a
15//! centred-skeleton baseline with `confidence=0` so the cog still
16//! satisfies the ADR-100 runtime contract and the dashboard surfaces
17//! "no model yet" instead of dropping frames silently.
18
19use candle_core::{DType, Device, Tensor};
20use candle_nn::{Conv1d, Conv1dConfig, Linear, Module, VarBuilder};
21use std::path::Path;
22use std::sync::Arc;
23
24/// 56 subcarriers × 20 frames per CSI window — matches the format
25/// produced by `scripts/align-ground-truth.js` after #641.
26pub const INPUT_SUBCARRIERS: usize = 56;
27pub const INPUT_TIMESTEPS: usize = 20;
28pub const OUTPUT_KEYPOINTS: usize = 17;
29
30#[derive(Debug, Clone)]
31pub struct CsiWindow {
32    pub data: Vec<f32>, // length INPUT_SUBCARRIERS * INPUT_TIMESTEPS
33}
34
35#[derive(Debug, Clone)]
36pub struct PoseOutput {
37    /// Flat `[OUTPUT_KEYPOINTS * 2]` keypoints in `[0, 1]` normalised
38    /// image coords, ordered (x0, y0, x1, y1, …).
39    pub keypoints: Vec<f32>,
40    pub confidence: f32,
41}
42
43impl PoseOutput {
44    pub fn is_finite(&self) -> bool {
45        self.keypoints.iter().all(|v| v.is_finite()) && self.confidence.is_finite()
46    }
47}
48
49/// Internal model — mirrors the training script's `PoseModel` exactly.
50struct PoseNet {
51    c1: Conv1d,
52    c2: Conv1d,
53    c3: Conv1d,
54    fc1: Linear,
55    fc2: Linear,
56}
57
58impl PoseNet {
59    fn new(vb: VarBuilder<'_>) -> candle_core::Result<Self> {
60        let enc = vb.pp("enc");
61        let head = vb.pp("head");
62
63        let c1 = candle_nn::conv1d(
64            56,
65            64,
66            3,
67            Conv1dConfig {
68                padding: 1,
69                stride: 1,
70                dilation: 1,
71                groups: 1,
72                ..Default::default()
73            },
74            enc.pp("c1"),
75        )?;
76        let c2 = candle_nn::conv1d(
77            64,
78            128,
79            3,
80            Conv1dConfig {
81                padding: 2,
82                stride: 1,
83                dilation: 2,
84                groups: 1,
85                ..Default::default()
86            },
87            enc.pp("c2"),
88        )?;
89        let c3 = candle_nn::conv1d(
90            128,
91            128,
92            3,
93            Conv1dConfig {
94                padding: 4,
95                stride: 1,
96                dilation: 4,
97                groups: 1,
98                ..Default::default()
99            },
100            enc.pp("c3"),
101        )?;
102        let fc1 = candle_nn::linear(128, 256, head.pp("fc1"))?;
103        let fc2 = candle_nn::linear(256, 34, head.pp("fc2"))?;
104
105        Ok(Self {
106            c1,
107            c2,
108            c3,
109            fc1,
110            fc2,
111        })
112    }
113
114    /// Forward pass: `[B, 56, 20]` -> `[B, 34]` in `[0, 1]`.
115    fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
116        let h = self.c1.forward(x)?.relu()?;
117        let h = self.c2.forward(&h)?.relu()?;
118        let h = self.c3.forward(&h)?.relu()?;
119        // Global average pool over time dim (last dim) -> [B, 128]
120        let h = h.mean(2)?;
121        let h = self.fc1.forward(&h)?.relu()?;
122        let h = self.fc2.forward(&h)?;
123        // sigmoid -> keep in [0, 1]
124        candle_nn::ops::sigmoid(&h)
125    }
126}
127
128pub struct InferenceEngine {
129    inner: Option<Arc<LoadedModel>>,
130    device: Device,
131}
132
133struct LoadedModel {
134    net: PoseNet,
135}
136
137impl InferenceEngine {
138    /// Create an engine. Tries to load weights from `cog/artifacts/pose_v1.safetensors`
139    /// (relative to current dir or the cog install dir under
140    /// `/var/lib/cognitum/apps/pose-estimation/`). Returns a usable
141    /// engine either way — without weights, `infer` produces the
142    /// stub output.
143    pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
144        Self::with_weights(default_weights_path().as_deref())
145    }
146
147    /// Create an engine with a specific weights path (used by `--config`
148    /// in `cog-pose-estimation run`). If `weights_path` is `None`, the
149    /// stub fallback is used.
150    pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
151        let device = pick_device();
152        let inner = match weights_path {
153            Some(p) if p.exists() => {
154                // SAFETY: `from_mmaped_safetensors` mmaps the file for the
155                // VarBuilder's lifetime. We don't modify the file while the
156                // VarBuilder is alive, and the file is read-only on disk on
157                // appliance installs.
158                let vb = unsafe {
159                    VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
160                };
161                let net = PoseNet::new(vb)?;
162                Some(Arc::new(LoadedModel { net }))
163            }
164            _ => None,
165        };
166        Ok(Self { inner, device })
167    }
168
169    /// Where the weights actually came from. Useful for the run.started event.
170    pub fn backend(&self) -> &'static str {
171        match (&self.inner, &self.device) {
172            (Some(_), Device::Cuda(_)) => "candle-cuda",
173            (Some(_), _) => "candle-cpu",
174            (None, _) => "stub",
175        }
176    }
177
178    pub fn infer(&self, window: &CsiWindow) -> Result<PoseOutput, Box<dyn std::error::Error>> {
179        if window.data.len() != INPUT_SUBCARRIERS * INPUT_TIMESTEPS {
180            return Err(format!(
181                "expected {} input values, got {}",
182                INPUT_SUBCARRIERS * INPUT_TIMESTEPS,
183                window.data.len()
184            )
185            .into());
186        }
187
188        let Some(model) = &self.inner else {
189            // Stub fallback — model not loaded.
190            return Ok(PoseOutput {
191                keypoints: vec![0.5f32; OUTPUT_KEYPOINTS * 2],
192                confidence: 0.0,
193            });
194        };
195
196        // Build [1, 56, 20] tensor from the flat row-major buffer.
197        let t = Tensor::from_slice(
198            &window.data,
199            (1, INPUT_SUBCARRIERS, INPUT_TIMESTEPS),
200            &self.device,
201        )?;
202        let out = model.net.forward(&t)?; // [1, 34]
203        let flat: Vec<f32> = out.flatten_all()?.to_vec1()?;
204        // Confidence from pose_v1 is a published constant rather than per-frame —
205        // the trained model didn't emit a confidence head. Use the validation-set
206        // PCK@50 (18.5%) as the published self-reported confidence so downstream
207        // consumers can gate display decisions on it.
208        Ok(PoseOutput {
209            keypoints: flat,
210            confidence: 0.185,
211        })
212    }
213}
214
215/// Synthetic CSI window for the `health` subcommand. Zeros — exercises
216/// the I/O surface; the model never touches values that produce NaN.
217pub struct SyntheticInput;
218
219impl Default for SyntheticInput {
220    fn default() -> Self {
221        Self
222    }
223}
224
225impl SyntheticInput {
226    pub fn as_window(&self) -> CsiWindow {
227        CsiWindow {
228            data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS],
229        }
230    }
231}
232
233// ---------------------------------------------------------------------------
234// Helpers
235// ---------------------------------------------------------------------------
236
237fn pick_device() -> Device {
238    #[cfg(feature = "cuda")]
239    if let Ok(d) = Device::cuda_if_available(0) {
240        return d;
241    }
242    Device::Cpu
243}
244
245fn default_weights_path() -> Option<std::path::PathBuf> {
246    // Search in the order an installed Cog would see it.
247    let candidates = [
248        std::path::PathBuf::from("/var/lib/cognitum/apps/pose-estimation/pose_v1.safetensors"),
249        std::path::PathBuf::from("./pose_v1.safetensors"),
250        std::path::PathBuf::from("./cog/artifacts/pose_v1.safetensors"),
251        // From the repo root.
252        std::path::PathBuf::from("v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
253        // From inside v2/.
254        std::path::PathBuf::from("crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
255    ];
256    candidates.into_iter().find(|p| p.exists())
257}