1use candle_core::{DType, Device, Tensor};
20use candle_nn::{Conv1d, Conv1dConfig, Linear, Module, VarBuilder};
21use std::path::Path;
22use std::sync::Arc;
23
24pub const INPUT_SUBCARRIERS: usize = 56;
27pub const INPUT_TIMESTEPS: usize = 20;
28pub const OUTPUT_KEYPOINTS: usize = 17;
29
30#[derive(Debug, Clone)]
31pub struct CsiWindow {
32 pub data: Vec<f32>, }
34
35#[derive(Debug, Clone)]
36pub struct PoseOutput {
37 pub keypoints: Vec<f32>,
40 pub confidence: f32,
41}
42
43impl PoseOutput {
44 pub fn is_finite(&self) -> bool {
45 self.keypoints.iter().all(|v| v.is_finite()) && self.confidence.is_finite()
46 }
47}
48
49struct PoseNet {
51 c1: Conv1d,
52 c2: Conv1d,
53 c3: Conv1d,
54 fc1: Linear,
55 fc2: Linear,
56}
57
58impl PoseNet {
59 fn new(vb: VarBuilder<'_>) -> candle_core::Result<Self> {
60 let enc = vb.pp("enc");
61 let head = vb.pp("head");
62
63 let c1 = candle_nn::conv1d(
64 56,
65 64,
66 3,
67 Conv1dConfig {
68 padding: 1,
69 stride: 1,
70 dilation: 1,
71 groups: 1,
72 ..Default::default()
73 },
74 enc.pp("c1"),
75 )?;
76 let c2 = candle_nn::conv1d(
77 64,
78 128,
79 3,
80 Conv1dConfig {
81 padding: 2,
82 stride: 1,
83 dilation: 2,
84 groups: 1,
85 ..Default::default()
86 },
87 enc.pp("c2"),
88 )?;
89 let c3 = candle_nn::conv1d(
90 128,
91 128,
92 3,
93 Conv1dConfig {
94 padding: 4,
95 stride: 1,
96 dilation: 4,
97 groups: 1,
98 ..Default::default()
99 },
100 enc.pp("c3"),
101 )?;
102 let fc1 = candle_nn::linear(128, 256, head.pp("fc1"))?;
103 let fc2 = candle_nn::linear(256, 34, head.pp("fc2"))?;
104
105 Ok(Self {
106 c1,
107 c2,
108 c3,
109 fc1,
110 fc2,
111 })
112 }
113
114 fn forward(&self, x: &Tensor) -> candle_core::Result<Tensor> {
116 let h = self.c1.forward(x)?.relu()?;
117 let h = self.c2.forward(&h)?.relu()?;
118 let h = self.c3.forward(&h)?.relu()?;
119 let h = h.mean(2)?;
121 let h = self.fc1.forward(&h)?.relu()?;
122 let h = self.fc2.forward(&h)?;
123 candle_nn::ops::sigmoid(&h)
125 }
126}
127
128pub struct InferenceEngine {
129 inner: Option<Arc<LoadedModel>>,
130 device: Device,
131}
132
133struct LoadedModel {
134 net: PoseNet,
135}
136
137impl InferenceEngine {
138 pub fn new() -> Result<Self, Box<dyn std::error::Error>> {
144 Self::with_weights(default_weights_path().as_deref())
145 }
146
147 pub fn with_weights(weights_path: Option<&Path>) -> Result<Self, Box<dyn std::error::Error>> {
151 let device = pick_device();
152 let inner = match weights_path {
153 Some(p) if p.exists() => {
154 let vb = unsafe {
159 VarBuilder::from_mmaped_safetensors(&[p.to_path_buf()], DType::F32, &device)?
160 };
161 let net = PoseNet::new(vb)?;
162 Some(Arc::new(LoadedModel { net }))
163 }
164 _ => None,
165 };
166 Ok(Self { inner, device })
167 }
168
169 pub fn backend(&self) -> &'static str {
171 match (&self.inner, &self.device) {
172 (Some(_), Device::Cuda(_)) => "candle-cuda",
173 (Some(_), _) => "candle-cpu",
174 (None, _) => "stub",
175 }
176 }
177
178 pub fn infer(&self, window: &CsiWindow) -> Result<PoseOutput, Box<dyn std::error::Error>> {
179 if window.data.len() != INPUT_SUBCARRIERS * INPUT_TIMESTEPS {
180 return Err(format!(
181 "expected {} input values, got {}",
182 INPUT_SUBCARRIERS * INPUT_TIMESTEPS,
183 window.data.len()
184 )
185 .into());
186 }
187
188 let Some(model) = &self.inner else {
189 return Ok(PoseOutput {
191 keypoints: vec![0.5f32; OUTPUT_KEYPOINTS * 2],
192 confidence: 0.0,
193 });
194 };
195
196 let t = Tensor::from_slice(
198 &window.data,
199 (1, INPUT_SUBCARRIERS, INPUT_TIMESTEPS),
200 &self.device,
201 )?;
202 let out = model.net.forward(&t)?; let flat: Vec<f32> = out.flatten_all()?.to_vec1()?;
204 Ok(PoseOutput {
209 keypoints: flat,
210 confidence: 0.185,
211 })
212 }
213}
214
215pub struct SyntheticInput;
218
219impl Default for SyntheticInput {
220 fn default() -> Self {
221 Self
222 }
223}
224
225impl SyntheticInput {
226 pub fn as_window(&self) -> CsiWindow {
227 CsiWindow {
228 data: vec![0.0; INPUT_SUBCARRIERS * INPUT_TIMESTEPS],
229 }
230 }
231}
232
233fn pick_device() -> Device {
238 #[cfg(feature = "cuda")]
239 if let Ok(d) = Device::cuda_if_available(0) {
240 return d;
241 }
242 Device::Cpu
243}
244
245fn default_weights_path() -> Option<std::path::PathBuf> {
246 let candidates = [
248 std::path::PathBuf::from("/var/lib/cognitum/apps/pose-estimation/pose_v1.safetensors"),
249 std::path::PathBuf::from("./pose_v1.safetensors"),
250 std::path::PathBuf::from("./cog/artifacts/pose_v1.safetensors"),
251 std::path::PathBuf::from("v2/crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
253 std::path::PathBuf::from("crates/cog-pose-estimation/cog/artifacts/pose_v1.safetensors"),
255 ];
256 candidates.into_iter().find(|p| p.exists())
257}