1#![allow(clippy::useless_conversion)] use ndarray::{Array1, Array2};
36use pyo3::exceptions::PyValueError;
37use pyo3::prelude::*;
38
39const WEIGHT_SUM_EPS: f64 = 1e-12;
40
41pub type HretUpdate = (Vec<f64>, Vec<f64>, Vec<f64>, Vec<f64>);
49
50#[derive(Debug, Clone, PartialEq, Eq)]
52pub struct HretError {
53 message: String,
54}
55
56impl HretError {
57 fn new(message: impl Into<String>) -> Self {
58 Self {
59 message: message.into(),
60 }
61 }
62}
63
64impl std::fmt::Display for HretError {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.write_str(&self.message)
67 }
68}
69
70impl std::error::Error for HretError {}
71
72#[derive(Clone, Debug)]
73#[pyclass]
74pub struct HretObserver {
79 m: usize,
80 g: usize,
81 group_mapping: Array1<usize>,
82 group_indices: Vec<Vec<usize>>,
83 rho: f64,
84 rho_g: Array1<f64>,
85 beta_k: Array1<f64>,
86 beta_g: Array1<f64>,
87 s_k: Array1<f64>,
88 s_g: Array1<f64>,
89 k_k: Array2<f64>,
90}
91
92impl HretObserver {
93 #[allow(clippy::too_many_arguments)]
98 pub fn new(
99 m: usize,
100 g: usize,
101 group_mapping: Vec<usize>,
102 rho: f64,
103 rho_g: Vec<f64>,
104 beta_k: Vec<f64>,
105 beta_g: Vec<f64>,
106 k_k: Vec<Vec<f64>>,
107 ) -> Result<Self, HretError> {
108 validate_positive("m", m)?;
109 validate_positive("g", g)?;
110 validate_len("group_mapping", m, group_mapping.len())?;
111 validate_len("rho_g", g, rho_g.len())?;
112 validate_len("beta_k", m, beta_k.len())?;
113 validate_len("beta_g", g, beta_g.len())?;
114 validate_forgetting_factor("rho", rho)?;
115 validate_forgetting_factors("rho_g", &rho_g)?;
116 validate_non_negative_finite("beta_k", &beta_k)?;
117 validate_non_negative_finite("beta_g", &beta_g)?;
118
119 let mut group_indices = vec![Vec::new(); g];
120 for (channel_idx, &group_idx) in group_mapping.iter().enumerate() {
121 if group_idx >= g {
122 return Err(HretError::new(format!(
123 "group_mapping[{channel_idx}] = {group_idx} is out of range 0..{g}",
124 )));
125 }
126 group_indices[group_idx].push(channel_idx);
127 }
128
129 if k_k.is_empty() {
130 return Err(HretError::new("k_k must contain at least one gain row"));
131 }
132
133 let p = k_k.len();
134 let mut k_k_flat = Vec::with_capacity(p * m);
135 for (row_idx, row) in k_k.into_iter().enumerate() {
136 validate_len(&format!("k_k[{row_idx}]"), m, row.len())?;
137 for (col_idx, value) in row.into_iter().enumerate() {
138 if !value.is_finite() {
139 return Err(HretError::new(format!(
140 "k_k[{row_idx}][{col_idx}] must be finite (got {value})",
141 )));
142 }
143 k_k_flat.push(value);
144 }
145 }
146
147 let k_k = Array2::from_shape_vec((p, m), k_k_flat).map_err(|e| {
148 HretError::new(format!(
149 "failed to build gain matrix with shape ({p}, {m}): {e}",
150 ))
151 })?;
152
153 Ok(Self {
154 m,
155 g,
156 group_mapping: Array1::from(group_mapping),
157 group_indices,
158 rho,
159 rho_g: Array1::from(rho_g),
160 beta_k: Array1::from(beta_k),
161 beta_g: Array1::from(beta_g),
162 s_k: Array1::zeros(m),
163 s_g: Array1::zeros(g),
164 k_k,
165 })
166 }
167
168 pub fn update(&mut self, residuals: Vec<f64>) -> Result<HretUpdate, HretError> {
173 validate_len("residuals", self.m, residuals.len())?;
174 validate_finite("residuals", &residuals)?;
175
176 let r_arr = Array1::from(residuals);
177
178 self.s_k = self.rho * &self.s_k + (1.0 - self.rho) * r_arr.mapv(f64::abs);
180
181 for (group_idx, channels) in self.group_indices.iter().enumerate() {
183 if channels.is_empty() {
184 continue;
185 }
186
187 let avg_abs_r =
188 channels.iter().map(|&i| r_arr[i].abs()).sum::<f64>() / channels.len() as f64;
189 self.s_g[group_idx] = self.rho_g[group_idx] * self.s_g[group_idx]
190 + (1.0 - self.rho_g[group_idx]) * avg_abs_r;
191 }
192
193 let w_k =
195 Array1::from_iter((0..self.m).map(|i| 1.0 / (1.0 + self.beta_k[i] * self.s_k[i])));
196 let w_g =
197 Array1::from_iter((0..self.g).map(|i| 1.0 / (1.0 + self.beta_g[i] * self.s_g[i])));
198
199 let w_g_mapped =
201 Array1::from_iter(self.group_mapping.iter().map(|&group_idx| w_g[group_idx]));
202 let hat_w_k = &w_k * &w_g_mapped;
203 let sum_hat = hat_w_k.sum();
204 let tilde_w_k = if sum_hat > WEIGHT_SUM_EPS {
205 hat_w_k / sum_hat
206 } else {
207 Array1::from_elem(self.m, 1.0 / self.m as f64)
208 };
209
210 let weighted_r = &tilde_w_k * &r_arr;
212 let delta_x = self.k_k.dot(&weighted_r);
213
214 debug_assert!(tilde_w_k.iter().all(|&w| w >= -1e-12));
215 debug_assert!((tilde_w_k.sum() - 1.0).abs() < 1e-8);
216
217 Ok((
218 delta_x.to_vec(),
219 tilde_w_k.to_vec(),
220 self.s_k.to_vec(),
221 self.s_g.to_vec(),
222 ))
223 }
224
225 pub fn reset_envelopes(&mut self) {
227 self.s_k.fill(0.0);
228 self.s_g.fill(0.0);
229 }
230
231 pub fn channel_count(&self) -> usize {
233 self.m
234 }
235
236 pub fn group_count(&self) -> usize {
238 self.g
239 }
240
241 pub fn group_mapping_vec(&self) -> Vec<usize> {
243 self.group_mapping.to_vec()
244 }
245}
246
247#[pymethods]
248impl HretObserver {
249 #[new]
250 #[pyo3(signature = (m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k))]
251 #[allow(clippy::too_many_arguments)]
252 fn py_new(
253 m: usize,
254 g: usize,
255 group_mapping: Vec<usize>,
256 rho: f64,
257 rho_g: Vec<f64>,
258 beta_k: Vec<f64>,
259 beta_g: Vec<f64>,
260 k_k: Vec<Vec<f64>>,
261 ) -> PyResult<Self> {
262 Self::new(m, g, group_mapping, rho, rho_g, beta_k, beta_g, k_k)
263 .map_err(|error| PyValueError::new_err(error.to_string()))
264 }
265
266 #[pyo3(name = "update")]
267 #[allow(clippy::useless_conversion)]
268 fn py_update(&mut self, residuals: Vec<f64>) -> PyResult<HretUpdate> {
269 self.update(residuals)
270 .map_err(|error| PyValueError::new_err(error.to_string()))
271 }
272
273 #[pyo3(name = "reset_envelopes")]
274 fn py_reset_envelopes(&mut self) {
275 self.reset_envelopes();
276 }
277
278 #[getter]
279 fn m(&self) -> usize {
280 self.channel_count()
281 }
282
283 #[getter]
284 fn g(&self) -> usize {
285 self.group_count()
286 }
287
288 #[getter]
289 fn group_mapping(&self) -> Vec<usize> {
290 self.group_mapping_vec()
291 }
292
293 fn __repr__(&self) -> String {
294 format!(
295 "HretObserver(m={}, g={}, p={})",
296 self.m,
297 self.g,
298 self.k_k.nrows()
299 )
300 }
301}
302
303fn validate_positive(field: &str, value: usize) -> Result<(), HretError> {
304 if value == 0 {
305 return Err(HretError::new(format!("{field} must be > 0 (got 0)")));
306 }
307 Ok(())
308}
309
310fn validate_len(field: &str, expected: usize, got: usize) -> Result<(), HretError> {
311 if expected != got {
312 return Err(HretError::new(format!(
313 "{field} length mismatch: expected {expected}, got {got}",
314 )));
315 }
316 Ok(())
317}
318
319fn validate_forgetting_factor(field: &str, value: f64) -> Result<(), HretError> {
320 if !value.is_finite() || value <= 0.0 || value >= 1.0 {
321 return Err(HretError::new(format!(
322 "{field} must be finite and in (0, 1); got {value}",
323 )));
324 }
325 Ok(())
326}
327
328fn validate_forgetting_factors(field: &str, values: &[f64]) -> Result<(), HretError> {
329 for (idx, value) in values.iter().copied().enumerate() {
330 if !value.is_finite() || value <= 0.0 || value >= 1.0 {
331 return Err(HretError::new(format!(
332 "{field}[{idx}] must be finite and in (0, 1); got {value}",
333 )));
334 }
335 }
336 Ok(())
337}
338
339fn validate_non_negative_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
340 for (idx, value) in values.iter().copied().enumerate() {
341 if !value.is_finite() || value < 0.0 {
342 return Err(HretError::new(format!(
343 "{field}[{idx}] must be finite and >= 0; got {value}",
344 )));
345 }
346 }
347 Ok(())
348}
349
350fn validate_finite(field: &str, values: &[f64]) -> Result<(), HretError> {
351 for (idx, value) in values.iter().copied().enumerate() {
352 if !value.is_finite() {
353 return Err(HretError::new(format!(
354 "{field}[{idx}] must be finite; got {value}",
355 )));
356 }
357 }
358 Ok(())
359}
360
361#[pymodule]
362fn dsfb_hret(m: &Bound<'_, PyModule>) -> PyResult<()> {
363 m.add_class::<HretObserver>()?;
364 Ok(())
365}
366
367#[cfg(test)]
368mod tests;