Skip to main content

sublinear_solver/
contrastive.rs

1//! Contrastive search — find the rows whose solution diverged most from a
2//! baseline. ADR-001 roadmap item #6.
3//!
4//! The architectural shape RuView, Cognitum, and Ruflo's inner loops
5//! actually want: not "give me the whole solution vector", but "tell me
6//! which entries crossed a boundary big enough to wake an agent". This
7//! is the change-driven activation primitive the ADR's thesis turns on.
8//!
9//! ## API
10//!
11//! ```rust,no_run
12//! # use sublinear_solver::contrastive::{find_anomalous_rows, AnomalyRow};
13//! # let baseline: Vec<f64> = vec![];
14//! # let current: Vec<f64> = vec![];
15//! let top_k = find_anomalous_rows(&baseline, &current, 5);
16//! for AnomalyRow { row, baseline, current, anomaly } in top_k {
17//!     println!("row {row}: was {baseline}, now {current} (Δ={anomaly})");
18//! }
19//! ```
20//!
21//! ## Complexity
22//!
23//! The current implementation is `O(n log k)` — one pass over the
24//! solution vectors with a `k`-sized min-heap. That's already useful
25//! (avoids `O(n log n)` of a full sort) but not yet the `O(k · log n)`
26//! the ADR promised. The follow-up will land a *direct* path that
27//! computes only the top-k entries of the new solution via the sublinear-
28//! Neumann single-entry primitive, without ever materialising the full
29//! current solution. Tracked as a `// TODO(ADR-001 #6 phase 2):` in the
30//! source.
31
32use crate::complexity::{Complexity, ComplexityClass};
33use crate::types::Precision;
34use alloc::collections::BinaryHeap;
35use alloc::vec::Vec;
36use core::cmp::Ordering;
37
38/// One row's anomaly report.
39#[derive(Debug, Clone, PartialEq)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub struct AnomalyRow {
42    /// Row index in the solution vector.
43    pub row: usize,
44    /// The baseline value at this row.
45    pub baseline: Precision,
46    /// The current value at this row.
47    pub current: Precision,
48    /// `|current - baseline|`. The score used for ranking. Higher = more
49    /// anomalous.
50    pub anomaly: Precision,
51}
52
53// Min-heap helper: we want to keep the k *largest* anomalies, so we use a
54// max-of-min wrapper that orders by inverted anomaly score (smallest at the
55// top), and evict the smallest whenever a new entry beats it.
56#[derive(Debug, Clone)]
57struct MinHeapEntry(AnomalyRow);
58
59impl PartialEq for MinHeapEntry {
60    fn eq(&self, other: &Self) -> bool {
61        self.0.anomaly == other.0.anomaly && self.0.row == other.0.row
62    }
63}
64impl Eq for MinHeapEntry {}
65impl PartialOrd for MinHeapEntry {
66    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
67        Some(self.cmp(other))
68    }
69}
70impl Ord for MinHeapEntry {
71    fn cmp(&self, other: &Self) -> Ordering {
72        // Invert so BinaryHeap (max-heap) acts as a min-heap on anomaly.
73        // Tie-break on row index ascending so the API is deterministic
74        // (same inputs always yield the same top-k ordering).
75        other
76            .0
77            .anomaly
78            .partial_cmp(&self.0.anomaly)
79            .unwrap_or(Ordering::Equal)
80            .then_with(|| other.0.row.cmp(&self.0.row))
81    }
82}
83
84/// Return the `k` rows whose `current` value diverged most from `baseline`.
85///
86/// Result is sorted by descending anomaly score (largest first). Ties are
87/// broken by row index ascending so the API is deterministic.
88///
89/// - `O(n log k)` time, `O(k)` space.
90/// - If `k >= baseline.len()`, returns *all* rows sorted by anomaly.
91/// - If `k == 0`, returns an empty vector.
92///
93/// Panics if `baseline.len() != current.len()`.
94pub fn find_anomalous_rows(
95    baseline: &[Precision],
96    current: &[Precision],
97    k: usize,
98) -> Vec<AnomalyRow> {
99    assert_eq!(
100        baseline.len(),
101        current.len(),
102        "find_anomalous_rows: baseline.len()={} != current.len()={}",
103        baseline.len(),
104        current.len(),
105    );
106
107    if k == 0 || baseline.is_empty() {
108        return Vec::new();
109    }
110
111    // TODO(ADR-001 #6 phase 2): replace the O(n) full scan with a direct
112    // top-k path that computes individual entries of `current` via the
113    // sublinear-Neumann single-entry primitive, giving O(k · log n)
114    // total. Today this is the cheap O(n log k) baseline so RuView /
115    // Cognitum have something callable while phase 2 lands.
116
117    let mut heap: BinaryHeap<MinHeapEntry> = BinaryHeap::with_capacity(k.min(baseline.len()) + 1);
118    for (row, (&b, &c)) in baseline.iter().zip(current.iter()).enumerate() {
119        let anomaly = (c - b).abs();
120        let entry = MinHeapEntry(AnomalyRow {
121            row,
122            baseline: b,
123            current: c,
124            anomaly,
125        });
126
127        if heap.len() < k {
128            heap.push(entry);
129        } else if let Some(smallest) = heap.peek() {
130            // Smallest is at the top because of the inverted Ord.
131            if anomaly > smallest.0.anomaly {
132                heap.pop();
133                heap.push(entry);
134            }
135        }
136    }
137
138    // Drain into a sorted-descending vector.
139    let mut out: Vec<AnomalyRow> = heap.into_iter().map(|e| e.0).collect();
140    out.sort_by(|a, b| {
141        b.anomaly
142            .partial_cmp(&a.anomaly)
143            .unwrap_or(Ordering::Equal)
144            .then_with(|| a.row.cmp(&b.row))
145    });
146    out
147}
148
149/// Return only the rows whose anomaly score exceeds `threshold`. Useful as
150/// the boundary-crossing primitive for change-driven activation: an agent
151/// stays asleep until at least one entry crosses the threshold.
152///
153/// - `O(n)` time, `O(matches)` space.
154///
155/// Panics if `baseline.len() != current.len()`.
156pub fn find_rows_above_threshold(
157    baseline: &[Precision],
158    current: &[Precision],
159    threshold: Precision,
160) -> Vec<AnomalyRow> {
161    assert_eq!(
162        baseline.len(),
163        current.len(),
164        "find_rows_above_threshold: dim mismatch {} vs {}",
165        baseline.len(),
166        current.len(),
167    );
168
169    baseline
170        .iter()
171        .zip(current.iter())
172        .enumerate()
173        .filter_map(|(row, (&b, &c))| {
174            let anomaly = (c - b).abs();
175            if anomaly > threshold {
176                Some(AnomalyRow {
177                    row,
178                    baseline: b,
179                    current: c,
180                    anomaly,
181                })
182            } else {
183                None
184            }
185        })
186        .collect()
187}
188
189// ─────────────────────────────────────────────────────────────────────────
190// Complexity declaration. The current path is Linear; phase-2 will drop
191// to SubLinear (O(k · log n)). Declared as Adaptive { Linear, Linear } for
192// now so the contract is honest about today's bound.
193// ─────────────────────────────────────────────────────────────────────────
194
195/// Marker type with a `Complexity` impl for `find_anomalous_rows`.
196pub struct FindAnomalousRowsOp;
197
198impl Complexity for FindAnomalousRowsOp {
199    const CLASS: ComplexityClass = ComplexityClass::Adaptive {
200        default: &ComplexityClass::Linear,
201        worst: &ComplexityClass::Linear,
202    };
203    const DETAIL: &'static str =
204        "O(n log k) full-scan baseline today; phase-2 lowers to O(k · log n) via the \
205         sublinear-Neumann single-entry primitive.";
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn empty_inputs_return_empty() {
214        let v: Vec<Precision> = alloc::vec![];
215        assert_eq!(find_anomalous_rows(&v, &v, 5), alloc::vec![]);
216        assert_eq!(find_anomalous_rows(&v, &v, 0), alloc::vec![]);
217    }
218
219    #[test]
220    fn k_zero_returns_empty() {
221        let b = alloc::vec![1.0, 2.0, 3.0];
222        let c = alloc::vec![10.0, 20.0, 30.0];
223        assert_eq!(find_anomalous_rows(&b, &c, 0), alloc::vec![]);
224    }
225
226    #[test]
227    fn top_k_is_correct_for_small_case() {
228        let b = alloc::vec![1.0, 1.0, 1.0, 1.0, 1.0];
229        let c = alloc::vec![1.0, 5.0, 1.0, 9.0, 2.0];
230        // anomalies: 0, 4, 0, 8, 1 — sorted desc: 8 (row 3), 4 (row 1), 1 (row 4).
231        let top = find_anomalous_rows(&b, &c, 3);
232        assert_eq!(top.len(), 3);
233        assert_eq!(top[0].row, 3);
234        assert_eq!(top[0].anomaly, 8.0);
235        assert_eq!(top[1].row, 1);
236        assert_eq!(top[1].anomaly, 4.0);
237        assert_eq!(top[2].row, 4);
238        assert_eq!(top[2].anomaly, 1.0);
239    }
240
241    #[test]
242    fn k_larger_than_n_returns_all_sorted() {
243        let b = alloc::vec![0.0, 0.0, 0.0];
244        let c = alloc::vec![3.0, 1.0, 2.0];
245        let top = find_anomalous_rows(&b, &c, 10);
246        assert_eq!(top.len(), 3);
247        // Sorted desc by anomaly.
248        assert!(top[0].anomaly >= top[1].anomaly);
249        assert!(top[1].anomaly >= top[2].anomaly);
250    }
251
252    #[test]
253    fn tie_breaks_on_row_index_ascending() {
254        let b = alloc::vec![0.0, 0.0, 0.0];
255        let c = alloc::vec![5.0, 5.0, 5.0]; // all tied
256        let top = find_anomalous_rows(&b, &c, 2);
257        assert_eq!(top.len(), 2);
258        assert_eq!(top[0].row, 0);
259        assert_eq!(top[1].row, 1);
260    }
261
262    #[test]
263    fn anomaly_is_absolute_value() {
264        let b = alloc::vec![0.0, 10.0];
265        let c = alloc::vec![-7.0, 3.0];
266        // anomalies: 7, 7 — both tied. Tie-break: row 0 before row 1.
267        let top = find_anomalous_rows(&b, &c, 2);
268        assert_eq!(top[0].anomaly, 7.0);
269        assert_eq!(top[1].anomaly, 7.0);
270        assert_eq!(top[0].row, 0);
271    }
272
273    #[test]
274    #[should_panic(expected = "dim mismatch")]
275    fn threshold_panics_on_dim_mismatch() {
276        let b = alloc::vec![1.0, 2.0];
277        let c = alloc::vec![1.0];
278        let _ = find_rows_above_threshold(&b, &c, 0.5);
279    }
280
281    #[test]
282    fn threshold_filters_correctly() {
283        let b = alloc::vec![0.0, 0.0, 0.0, 0.0];
284        let c = alloc::vec![0.1, 0.5, 2.0, 0.05];
285        let above = find_rows_above_threshold(&b, &c, 0.3);
286        // 0.5 and 2.0 pass; 0.1 and 0.05 don't.
287        assert_eq!(above.len(), 2);
288        assert_eq!(above[0].row, 1);
289        assert_eq!(above[1].row, 2);
290    }
291
292    #[test]
293    fn threshold_returns_empty_when_nothing_crosses() {
294        let b = alloc::vec![0.0; 5];
295        let c = alloc::vec![0.01, 0.02, 0.03, 0.04, 0.05];
296        let above = find_rows_above_threshold(&b, &c, 1.0);
297        assert!(above.is_empty(), "no entry above threshold should return empty");
298    }
299}