Skip to main content

gam_terms/analytic_penalties/
nested_prefix.rs

1use super::*;
2
3// ---------------------------------------------------------------------------
4// NestedPrefixPenalty — Matryoshka SAE
5// ---------------------------------------------------------------------------
6
7/// Nested-prefix sparsity penalty used by the Matryoshka SAE
8/// (Bussmann/Nabeshima/Karvonen/Nanda, ICML 2025, arXiv:2503.17547).
9///
10/// Given K nested prefix sizes `m_1 < m_2 < ... < m_K ≤ F` over the latent
11/// dimension `F`, and per-shell weights `λ_k = w_k · exp(ρ_k)`, the penalty is
12///
13/// ```text
14///   P(t; ρ) = Σ_k λ_k · Σ_{i=0}^{m_k - 1} sqrt(t_i² + ε²)
15/// ```
16///
17/// summed over all rows of the latent target. Equivalently, coordinate `i`
18/// contributes with effective weight `W_i = Σ_{k: m_k > i} λ_k`, so the
19/// earliest atoms (small `i`) are penalized by every shell (= strongest L¹)
20/// and the latest atoms only by the outermost shell. This is exactly the
21/// mask-weighted sum-of-L¹ over K prefixes used to enforce shell-wise
22/// reconstruction during Matryoshka training.
23///
24/// Closed forms (per row, summed across all rows):
25///
26/// ```text
27///   ∂P/∂t_i      = W_i · t_i / sqrt(t_i² + ε²)
28///   Hess_diag(i) = W_i · ε² / (t_i² + ε²)^{3/2}           (PSD)
29///   ∂P/∂ρ_k      = λ_k · Σ_{i < m_k} sqrt(t_i² + ε²)
30/// ```
31///
32/// `target` lays out `n_rows × latent_dim` in row-major order (`row * F + col`).
33/// `latent_dim` is taken from `PsiSlice::latent_dim`; if absent we fall back to
34/// the maximum prefix size, which is the standard Matryoshka convention.
35#[derive(Debug, Clone)]
36pub struct NestedPrefixPenalty {
37    pub target: PsiSlice,
38    pub target_tier: PenaltyTier,
39    /// Sorted strictly-increasing prefix sizes `m_1 < m_2 < ... < m_K`.
40    pub prefix_sizes: Vec<usize>,
41    /// Per-shell base weights `w_k`. The effective strength is
42    /// `λ_k = w_k · exp(ρ_k)`.
43    pub shell_weights: Vec<f64>,
44    /// Smoothing parameter ε > 0 for the smoothed-L¹ surrogate
45    /// `sqrt(x² + ε²)`; the Hessian needs ε > 0 for differentiability at 0.
46    pub eps: f64,
47    /// Local ρ indices for the K per-shell log-strengths.
48    pub rho_indices: Vec<usize>,
49    pub weight_schedule: Option<ScalarWeightSchedule>,
50}
51
52impl NestedPrefixPenalty {
53    /// Build a new nested-prefix penalty.
54    ///
55    /// Errors when:
56    ///  * `prefix_sizes` is empty.
57    ///  * `prefix_sizes` is not strictly increasing.
58    ///  * any prefix exceeds the latent dimension (when known).
59    ///  * `shell_weights.len() != prefix_sizes.len()`.
60    ///  * `eps <= 0` (the smoothed-L¹ gradient `1/sqrt(x²+ε²)` and Hessian
61    ///    `ε²/(x²+ε²)^{3/2}` both need ε > 0).
62    #[must_use = "build error must be handled"]
63    pub fn new(
64        target: PsiSlice,
65        target_tier: PenaltyTier,
66        prefix_sizes: Vec<usize>,
67        shell_weights: Vec<f64>,
68        eps: f64,
69    ) -> Result<Self, String> {
70        if prefix_sizes.is_empty() {
71            return Err("NestedPrefixPenalty requires at least one prefix".into());
72        }
73        if shell_weights.len() != prefix_sizes.len() {
74            return Err(format!(
75                "NestedPrefixPenalty requires shell_weights.len() == prefix_sizes.len(); \
76                 got {} weights for {} prefixes",
77                shell_weights.len(),
78                prefix_sizes.len()
79            ));
80        }
81        for w in &shell_weights {
82            if !w.is_finite() || *w < 0.0 {
83                return Err(format!(
84                    "NestedPrefixPenalty shell weights must be finite and ≥ 0; got {w}"
85                ));
86            }
87        }
88        for i in 0..prefix_sizes.len() {
89            if prefix_sizes[i] == 0 {
90                return Err("NestedPrefixPenalty prefixes must be > 0".into());
91            }
92            if i > 0 && prefix_sizes[i] <= prefix_sizes[i - 1] {
93                return Err(format!(
94                    "NestedPrefixPenalty prefixes must be strictly increasing; got {:?}",
95                    prefix_sizes
96                ));
97            }
98        }
99        if let Some(d) = target.latent_dim {
100            let max_prefix = *prefix_sizes.last().expect("non-empty");
101            if max_prefix > d {
102                return Err(format!(
103                    "NestedPrefixPenalty largest prefix {max_prefix} exceeds latent_dim {d}"
104                ));
105            }
106        }
107        if !(eps.is_finite() && eps > 0.0) {
108            return Err(format!(
109                "NestedPrefixPenalty requires eps > 0 (1/sqrt(x²+ε²) singularity at 0); got {eps}"
110            ));
111        }
112        let rho_indices = (0..prefix_sizes.len()).collect();
113        Ok(Self {
114            target,
115            target_tier,
116            prefix_sizes,
117            shell_weights,
118            eps,
119            rho_indices,
120            weight_schedule: None,
121        })
122    }
123
124    /// Attach a global annealing schedule shared by all shell weights. The
125    /// REML loop still picks per-shell ρ_k on top of this baseline.
126    #[must_use]
127    pub fn with_weight_schedule(mut self, schedule: ScalarWeightSchedule) -> Self {
128        self.weight_schedule = Some(schedule);
129        self
130    }
131
132    /// Latent dimension used to slice rows. Falls back to the largest prefix.
133    fn latent_dim(&self) -> usize {
134        self.target
135            .latent_dim
136            .unwrap_or_else(|| *self.prefix_sizes.last().expect("non-empty"))
137    }
138
139    /// Resolve per-shell effective weights `λ_k = w_k · exp(ρ_k)`.
140    fn lambdas(&self, rho: ArrayView1<'_, f64>) -> Vec<f64> {
141        self.prefix_sizes
142            .iter()
143            .enumerate()
144            .map(|(k, _)| resolve_learnable_weight(self.shell_weights[k], rho[self.rho_indices[k]]))
145            .collect()
146    }
147
148    /// Per-axis cumulative weight `W_i = Σ_{k: m_k > i} λ_k`. Length = F.
149    /// Computed in `O(F + K)` by scanning prefixes from outer to inner.
150    fn per_axis_weights(&self, lambdas: &[f64]) -> Vec<f64> {
151        let f = self.latent_dim();
152        let mut w = vec![0.0_f64; f];
153        // For each shell k, every axis i ∈ [0, m_k) gets +λ_k.
154        // Equivalent reverse-cumulative form, but the direct O(K·F) loop is
155        // K≤8 in practice, so this is O(F) for the use cases we ship.
156        for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
157            let lam = lambdas[k];
158            if lam == 0.0 {
159                continue;
160            }
161            let end = m_k.min(f);
162            for entry in w.iter_mut().take(end) {
163                *entry += lam;
164            }
165        }
166        w
167    }
168}
169
170impl AnalyticPenalty for NestedPrefixPenalty {
171    fn tier(&self) -> PenaltyTier {
172        self.target_tier
173    }
174
175    fn value(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> f64 {
176        let f = self.latent_dim();
177        assert!(
178            target.len().is_multiple_of(f),
179            "target length must be n_rows · F"
180        );
181        let n_rows = target.len() / f;
182        let lambdas = self.lambdas(rho);
183        let eps2 = self.eps * self.eps;
184        // Per-axis L¹ totals s_i = Σ_n sqrt(t_{n,i}² + ε²).
185        let mut s_axis = vec![0.0_f64; f];
186        for n in 0..n_rows {
187            let row = &target.as_slice().expect("contiguous")[n * f..(n + 1) * f];
188            for (i, &x) in row.iter().enumerate() {
189                s_axis[i] += (x * x + eps2).sqrt();
190            }
191        }
192        // Now P = Σ_k λ_k · Σ_{i<m_k} s_i.
193        let mut total = 0.0;
194        for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
195            let end = m_k.min(f);
196            let mut acc = 0.0;
197            for &v in s_axis.iter().take(end) {
198                acc += v;
199            }
200            total += lambdas[k] * acc;
201        }
202        total
203    }
204
205    fn grad_target(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
206        let f = self.latent_dim();
207        let n_rows = target.len() / f;
208        let lambdas = self.lambdas(rho);
209        let w_per_axis = self.per_axis_weights(&lambdas);
210        let eps2 = self.eps * self.eps;
211        let src = target.as_slice().expect("contiguous");
212        let mut g = Array1::<f64>::zeros(target.len());
213        let g_slice = g.as_slice_mut().expect("contiguous");
214        for n in 0..n_rows {
215            for i in 0..f {
216                let x = src[n * f + i];
217                let w = w_per_axis[i];
218                if w == 0.0 {
219                    continue;
220                }
221                g_slice[n * f + i] = w * x / (x * x + eps2).sqrt();
222            }
223        }
224        g
225    }
226
227    fn hessian_diag(
228        &self,
229        target: ArrayView1<'_, f64>,
230        rho: ArrayView1<'_, f64>,
231    ) -> Option<Array1<f64>> {
232        let f = self.latent_dim();
233        let n_rows = target.len() / f;
234        let lambdas = self.lambdas(rho);
235        let w_per_axis = self.per_axis_weights(&lambdas);
236        let eps2 = self.eps * self.eps;
237        let src = target.as_slice().expect("contiguous");
238        let mut d = Array1::<f64>::zeros(target.len());
239        let d_slice = d.as_slice_mut().expect("contiguous");
240        for n in 0..n_rows {
241            for i in 0..f {
242                let w = w_per_axis[i];
243                if w == 0.0 {
244                    continue;
245                }
246                let x = src[n * f + i];
247                let r = (x * x + eps2).sqrt();
248                d_slice[n * f + i] = w * eps2 / (r * r * r);
249            }
250        }
251        Some(d)
252    }
253
254    fn grad_rho(&self, target: ArrayView1<'_, f64>, rho: ArrayView1<'_, f64>) -> Array1<f64> {
255        let f = self.latent_dim();
256        let n_rows = target.len() / f;
257        let lambdas = self.lambdas(rho);
258        let eps2 = self.eps * self.eps;
259        // Same axis-wise reduction as `value`, but we need the per-shell
260        // (not cumulative) sums for the ρ-gradient.
261        let mut s_axis = vec![0.0_f64; f];
262        let src = target.as_slice().expect("contiguous");
263        for n in 0..n_rows {
264            for i in 0..f {
265                let x = src[n * f + i];
266                s_axis[i] += (x * x + eps2).sqrt();
267            }
268        }
269        let n_rho = self.rho_count();
270        let mut out = Array1::<f64>::zeros(n_rho);
271        for (k, &m_k) in self.prefix_sizes.iter().enumerate() {
272            let end = m_k.min(f);
273            let mut shell_sum = 0.0;
274            for &v in s_axis.iter().take(end) {
275                shell_sum += v;
276            }
277            // ∂P/∂ρ_k = λ_k · shell_sum  because λ_k = w_k · exp(ρ_k).
278            out[self.rho_indices[k]] = lambdas[k] * shell_sum;
279        }
280        out
281    }
282
283    fn rho_count(&self) -> usize {
284        self.prefix_sizes.len()
285    }
286
287    fn name(&self) -> &str {
288        "nested_prefix"
289    }
290
291    fn apply_schedule(&mut self, iter: usize) {
292        if let Some(schedule) = self.weight_schedule.as_mut() {
293            let prev = schedule.current_weight(schedule.iter_count);
294            let next = schedule.current_weight(iter);
295            if prev > 0.0 {
296                let ratio = next / prev;
297                for w in &mut self.shell_weights {
298                    *w *= ratio;
299                }
300            }
301            schedule.iter_count = iter + 1;
302        }
303    }
304}