Skip to main content

m1nd_core/
xlr.rs

1// === crates/m1nd-core/src/xlr.rs ===
2
3use std::collections::VecDeque;
4
5use crate::error::{M1ndError, M1ndResult};
6use crate::graph::Graph;
7use crate::types::*;
8
9// ---------------------------------------------------------------------------
10// Constants from xlr_v2.py
11// ---------------------------------------------------------------------------
12
13/// Hot signal frequency (xlr_v2.py F_HOT = 1.0).
14pub const F_HOT: f32 = 1.0;
15/// Cold signal frequency (xlr_v2.py F_COLD = 3.7).
16pub const F_COLD: f32 = 3.7;
17/// Spectral overlap Gaussian kernel bandwidth (xlr_v2.py bw=0.8).
18pub const SPECTRAL_BANDWIDTH: f32 = 0.8;
19/// Default immunity distance in hops (xlr_v2.py 2-hop BFS).
20pub const IMMUNITY_HOPS: u8 = 2;
21/// Sigmoid steepness for gating (xlr_v2.py * 6.0).
22pub const SIGMOID_STEEPNESS: f32 = 6.0;
23/// Number of spectral buckets for overlap (DEC-003).
24pub const SPECTRAL_BUCKETS: usize = 20;
25/// Density clamp floor.
26pub const DENSITY_FLOOR: f32 = 0.3;
27/// Density clamp cap.
28pub const DENSITY_CAP: f32 = 2.0;
29/// Inhibitory cold attenuation (DEC-010).
30pub const INHIBITORY_COLD_ATTENUATION: f32 = 0.5;
31
32// ---------------------------------------------------------------------------
33// SpectralPulse — per-node pulse (xlr_v2.py SpectralPulse)
34// ---------------------------------------------------------------------------
35
36/// A spectral pulse carrying amplitude, phase, and frequency.
37/// 48 bytes — fits in one cache line.
38/// Replaces: xlr_v2.py SpectralPulse dataclass
39#[derive(Clone, Copy, Debug)]
40pub struct SpectralPulse {
41    pub node: NodeId,
42    pub amplitude: FiniteF32,
43    /// Phase in [0, 2*pi).
44    pub phase: FiniteF32,
45    /// Frequency: F_HOT for seeds, F_COLD for anti-seeds.
46    pub frequency: PosF32,
47    /// Hops from origin (for immunity check).
48    pub hops: u8,
49    /// Previous node (for path tracking, replaces unbounded Vec — FM-RES-007).
50    pub prev_node: NodeId,
51    /// Recent path (last 3 nodes) — bounded, replaces full path Vec.
52    pub recent_path: [NodeId; 3],
53}
54
55// ---------------------------------------------------------------------------
56// SpectralWaveBuffer — per-node accumulation (xlr_v2.py SpectralWaveBuffer)
57// ---------------------------------------------------------------------------
58
59/// Accumulated spectral energy at a node.
60/// Replaces: xlr_v2.py SpectralWaveBuffer
61#[derive(Clone, Debug, Default)]
62pub struct SpectralWaveBuffer {
63    /// Hot signal accumulated amplitudes.
64    pub hot_amplitudes: Vec<FiniteF32>,
65    /// Hot signal accumulated frequencies.
66    pub hot_frequencies: Vec<FiniteF32>,
67    /// Cold signal accumulated amplitudes.
68    pub cold_amplitudes: Vec<FiniteF32>,
69    /// Cold signal accumulated frequencies.
70    pub cold_frequencies: Vec<FiniteF32>,
71}
72
73// ---------------------------------------------------------------------------
74// XlrParams — configuration
75// ---------------------------------------------------------------------------
76
77/// XLR engine configuration.
78/// Replaces: xlr_v2.py AdaptiveXLREngine.__init__ parameters
79pub struct XlrParams {
80    /// Number of anti-seeds to pick. Default: 3.
81    pub num_anti_seeds: usize,
82    /// Immunity hop distance from seeds. Default: 2 (FM-XLR-008 fix: BFS-based, not count-based).
83    pub immunity_hops: u8,
84    /// Minimum degree ratio for anti-seed candidates. Default: 0.3.
85    pub min_degree_ratio: FiniteF32,
86    /// Maximum Jaccard similarity between seed and anti-seed neighborhoods. Default: 0.2.
87    pub max_jaccard_similarity: FiniteF32,
88    /// Density adaptive clamp range. Default: [0.3, 2.0].
89    pub density_clamp_min: FiniteF32,
90    pub density_clamp_max: FiniteF32,
91    /// Pulse propagation budget (FM-RES-004). Default: 50_000.
92    pub pulse_budget: u64,
93}
94
95impl Default for XlrParams {
96    fn default() -> Self {
97        Self {
98            num_anti_seeds: 3,
99            immunity_hops: IMMUNITY_HOPS,
100            min_degree_ratio: FiniteF32::new(0.3),
101            max_jaccard_similarity: FiniteF32::new(0.2),
102            density_clamp_min: FiniteF32::new(0.3),
103            density_clamp_max: FiniteF32::new(2.0),
104            pulse_budget: 50_000,
105        }
106    }
107}
108
109// ---------------------------------------------------------------------------
110// XlrResult — output of XLR pipeline
111// ---------------------------------------------------------------------------
112
113/// Result of XLR adaptive noise cancellation.
114/// Replaces: xlr_v2.py AdaptiveXLREngine.query() return
115#[derive(Clone, Debug)]
116pub struct XlrResult {
117    /// Per-node activation after spectral cancellation + sigmoid gating.
118    pub activations: Vec<(NodeId, FiniteF32)>,
119    /// Anti-seed nodes that were selected.
120    pub anti_seeds: Vec<NodeId>,
121    /// Whether over-cancellation fallback was triggered (FM-XLR-010).
122    pub fallback_to_hot_only: bool,
123    /// Pulses processed (for budget monitoring).
124    pub pulses_processed: u64,
125}
126
127// ---------------------------------------------------------------------------
128// AdaptiveXlrEngine — main engine (xlr_v2.py AdaptiveXLREngine)
129// ---------------------------------------------------------------------------
130
131/// Adaptive XLR noise cancellation engine.
132/// Dual propagation: hot from seeds, cold from anti-seeds.
133/// Spectral overlap modulation, density-adaptive strength, sigmoid gating.
134/// Replaces: xlr_v2.py AdaptiveXLREngine
135pub struct AdaptiveXlrEngine {
136    params: XlrParams,
137}
138
139impl AdaptiveXlrEngine {
140    pub fn new(params: XlrParams) -> Self {
141        Self { params }
142    }
143
144    pub fn with_defaults() -> Self {
145        Self::new(XlrParams::default())
146    }
147
148    /// Run full XLR pipeline on a set of seed nodes.
149    /// Steps: pick anti-seeds -> compute immunity -> propagate hot -> propagate cold
150    ///        -> spectral overlap -> density modulation -> sigmoid gating -> rescale.
151    /// Replaces: xlr_v2.py AdaptiveXLREngine.query()
152    pub fn query(
153        &self,
154        graph: &Graph,
155        seeds: &[(NodeId, FiniteF32)],
156        config: &PropagationConfig,
157    ) -> M1ndResult<XlrResult> {
158        let n = graph.num_nodes() as usize;
159        if n == 0 || seeds.is_empty() {
160            return Ok(XlrResult {
161                activations: Vec::new(),
162                anti_seeds: Vec::new(),
163                fallback_to_hot_only: false,
164                pulses_processed: 0,
165            });
166        }
167
168        let seed_nodes: Vec<NodeId> = seeds.iter().map(|s| s.0).collect();
169
170        // Step 1: Pick anti-seeds
171        let anti_seeds = self.pick_anti_seeds(graph, &seed_nodes)?;
172
173        // Step 2: Compute immunity
174        let immunity = self.compute_immunity(graph, &seed_nodes)?;
175
176        // Step 3: Propagate hot pulses from seeds
177        let hot_freq = PosF32::new(F_HOT).unwrap();
178        let half_budget = self.params.pulse_budget / 2;
179        let hot_pulses = self.propagate_spectral(graph, seeds, hot_freq, config, half_budget)?;
180
181        // Step 4: Propagate cold pulses from anti-seeds
182        let cold_freq = PosF32::new(F_COLD).unwrap();
183        let anti_seed_pairs: Vec<(NodeId, FiniteF32)> =
184            anti_seeds.iter().map(|&n| (n, FiniteF32::ONE)).collect();
185        let cold_pulses =
186            self.propagate_spectral(graph, &anti_seed_pairs, cold_freq, config, half_budget)?;
187
188        let total_pulses = hot_pulses.len() as u64 + cold_pulses.len() as u64;
189
190        // Step 5: Accumulate per-node hot/cold amplitudes
191        let mut hot_amp = vec![0.0f32; n];
192        let mut cold_amp = vec![0.0f32; n];
193
194        for p in &hot_pulses {
195            let idx = p.node.as_usize();
196            if idx < n {
197                hot_amp[idx] += p.amplitude.get().abs();
198            }
199        }
200        for p in &cold_pulses {
201            let idx = p.node.as_usize();
202            if idx < n {
203                cold_amp[idx] += p.amplitude.get().abs();
204            }
205        }
206
207        // Step 6: Adaptive differential with immunity, density, and sigmoid gating
208        let mut activations = Vec::new();
209        let mut all_zero = true;
210
211        // Compute average degree for density modulation
212        let avg_deg = graph.avg_degree();
213
214        for i in 0..n {
215            let hot = hot_amp[i];
216            if hot <= 0.0 {
217                continue;
218            }
219
220            // Immunity factor: immune nodes get full hot signal, no cold cancellation
221            let immune = if i < immunity.len() {
222                immunity[i]
223            } else {
224                false
225            };
226
227            let effective_cold = if immune { 0.0 } else { cold_amp[i] };
228
229            // Raw differential
230            let raw = hot - effective_cold;
231
232            // Density modulation: nodes with degree near avg get density=1.0
233            let out_deg = {
234                let lo = graph.csr.offsets[i] as usize;
235                let hi = if i + 1 < graph.csr.offsets.len() {
236                    graph.csr.offsets[i + 1] as usize
237                } else {
238                    lo
239                };
240                (hi - lo) as f32
241            };
242            let density = if avg_deg > 0.0 {
243                (out_deg / avg_deg).max(DENSITY_FLOOR).min(DENSITY_CAP)
244            } else {
245                1.0
246            };
247
248            // Sigmoid gate
249            let gated = Self::sigmoid_gate(FiniteF32::new(raw * density));
250            let val = gated.get();
251
252            if val > 0.01 {
253                activations.push((NodeId::new(i as u32), gated));
254                all_zero = false;
255            }
256        }
257
258        // FM-XLR-010: over-cancellation fallback
259        let fallback = all_zero && !hot_pulses.is_empty();
260        if fallback {
261            // Return hot-only
262            activations.clear();
263            for i in 0..n {
264                if hot_amp[i] > 0.01 {
265                    activations.push((NodeId::new(i as u32), FiniteF32::new(hot_amp[i])));
266                }
267            }
268        }
269
270        activations.sort_by(|a, b| b.1.cmp(&a.1));
271
272        Ok(XlrResult {
273            activations,
274            anti_seeds,
275            fallback_to_hot_only: fallback,
276            pulses_processed: total_pulses,
277        })
278    }
279
280    /// Pick anti-seeds: structurally similar (degree), semantically different (Jaccard).
281    /// Replaces: xlr_v2.py pick_anti_seeds()
282    /// FM-XLR-008 fix: immunity computed from BFS reach, not seed count.
283    pub fn pick_anti_seeds(&self, graph: &Graph, seeds: &[NodeId]) -> M1ndResult<Vec<NodeId>> {
284        let n = graph.num_nodes() as usize;
285        if n == 0 || seeds.is_empty() {
286            return Ok(Vec::new());
287        }
288
289        // BFS to find seed neighborhood
290        let mut seed_set = vec![false; n];
291        let mut seed_neighbors = vec![false; n];
292        for &s in seeds {
293            let idx = s.as_usize();
294            if idx < n {
295                seed_set[idx] = true;
296                seed_neighbors[idx] = true;
297                let range = graph.csr.out_range(s);
298                for j in range {
299                    let tgt = graph.csr.targets[j].as_usize();
300                    if tgt < n {
301                        seed_neighbors[tgt] = true;
302                    }
303                }
304            }
305        }
306
307        // Compute average seed degree
308        let avg_seed_degree: f32 = if seeds.is_empty() {
309            0.0
310        } else {
311            let sum: usize = seeds
312                .iter()
313                .map(|s| {
314                    let r = graph.csr.out_range(*s);
315                    r.end - r.start
316                })
317                .sum();
318            sum as f32 / seeds.len() as f32
319        };
320
321        // Candidate scoring: structurally distant + similar degree
322        let mut candidates: Vec<(NodeId, f32)> = Vec::new();
323        for i in 0..n {
324            if seed_set[i] {
325                continue; // Skip seeds
326            }
327
328            let range = graph.csr.out_range(NodeId::new(i as u32));
329            let degree = (range.end - range.start) as f32;
330
331            // Degree ratio filter
332            if avg_seed_degree > 0.0 {
333                let ratio = degree / avg_seed_degree;
334                if ratio < self.params.min_degree_ratio.get() {
335                    continue;
336                }
337            }
338
339            // Jaccard similarity with seed neighborhood (lower = better anti-seed)
340            let mut intersection = 0usize;
341            let mut union_size = 0usize;
342            for j in range.clone() {
343                let tgt = graph.csr.targets[j].as_usize();
344                if tgt < n {
345                    union_size += 1;
346                    if seed_neighbors[tgt] {
347                        intersection += 1;
348                    }
349                }
350            }
351            let jaccard = if union_size > 0 {
352                intersection as f32 / union_size as f32
353            } else {
354                0.0
355            };
356
357            if jaccard > self.params.max_jaccard_similarity.get() {
358                continue; // Too similar to seeds
359            }
360
361            // Score: higher = better anti-seed (distant + adequate degree)
362            let distance_score = if seed_neighbors[i] { 0.0 } else { 1.0 };
363            let score = distance_score + (1.0 - jaccard);
364            candidates.push((NodeId::new(i as u32), score));
365        }
366
367        candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
368        let result: Vec<NodeId> = candidates
369            .iter()
370            .take(self.params.num_anti_seeds)
371            .map(|c| c.0)
372            .collect();
373        Ok(result)
374    }
375
376    /// Compute seed neighborhood immunity set via BFS.
377    /// Returns bitset of immune nodes (within immunity_hops of any seed).
378    /// Replaces: xlr_v2.py compute_seed_neighborhood()
379    /// FM-XLR-008 fix: BFS-based distance, not seed count threshold.
380    pub fn compute_immunity(&self, graph: &Graph, seeds: &[NodeId]) -> M1ndResult<Vec<bool>> {
381        let n = graph.num_nodes() as usize;
382        let mut immune = vec![false; n];
383
384        let mut queue = VecDeque::new();
385        let mut dist = vec![u8::MAX; n];
386
387        for &s in seeds {
388            let idx = s.as_usize();
389            if idx < n {
390                queue.push_back((s, 0u8));
391                dist[idx] = 0;
392                immune[idx] = true;
393            }
394        }
395
396        while let Some((node, d)) = queue.pop_front() {
397            if d >= self.params.immunity_hops {
398                continue;
399            }
400            let range = graph.csr.out_range(node);
401            for j in range {
402                let tgt = graph.csr.targets[j];
403                let tgt_idx = tgt.as_usize();
404                if tgt_idx < n && d + 1 < dist[tgt_idx] {
405                    dist[tgt_idx] = d + 1;
406                    immune[tgt_idx] = true;
407                    queue.push_back((tgt, d + 1));
408                }
409            }
410        }
411
412        Ok(immune)
413    }
414
415    /// Propagate spectral pulses (hot or cold) from origins.
416    /// Budget-limited (FM-RES-004).
417    /// Replaces: xlr_v2.py SpectralPropagator.propagate()
418    /// FM-XLR-014 fix: inhibitory edges do NOT flip cold phase.
419    pub fn propagate_spectral(
420        &self,
421        graph: &Graph,
422        origins: &[(NodeId, FiniteF32)],
423        frequency: PosF32,
424        config: &PropagationConfig,
425        budget: u64,
426    ) -> M1ndResult<Vec<SpectralPulse>> {
427        let n = graph.num_nodes() as usize;
428        let decay = config.decay.get();
429        let threshold = config.threshold.get();
430        let mut pulses_out = Vec::new();
431        let mut pulse_count = 0u64;
432
433        let mut queue: VecDeque<SpectralPulse> = VecDeque::new();
434
435        // Init from origins
436        for &(node, amp) in origins {
437            if node.as_usize() >= n {
438                continue;
439            }
440            let pulse = SpectralPulse {
441                node,
442                amplitude: amp,
443                phase: FiniteF32::ZERO,
444                frequency,
445                hops: 0,
446                prev_node: node,
447                recent_path: [node; 3],
448            };
449            queue.push_back(pulse);
450            pulses_out.push(pulse);
451            pulse_count += 1;
452        }
453
454        let max_depth = config.max_depth.min(20);
455
456        while let Some(pulse) = queue.pop_front() {
457            if pulse_count >= budget {
458                break; // FM-RES-004: budget exhausted
459            }
460            if pulse.hops >= max_depth {
461                continue;
462            }
463            if pulse.amplitude.get().abs() < threshold {
464                continue;
465            }
466
467            let range = graph.csr.out_range(pulse.node);
468            for j in range {
469                let tgt = graph.csr.targets[j];
470                if tgt == pulse.prev_node {
471                    continue; // Don't backtrack to immediate predecessor
472                }
473
474                let w = graph.csr.read_weight(EdgeIdx::new(j as u32)).get();
475                let is_inhib = graph.csr.inhibitory[j];
476
477                let mut new_amp = pulse.amplitude.get() * w * decay;
478
479                // FM-XLR-014 FIX: inhibitory + cold does NOT flip phase.
480                // Just attenuate by INHIBITORY_COLD_ATTENUATION (DEC-010).
481                if is_inhib {
482                    new_amp *= INHIBITORY_COLD_ATTENUATION;
483                }
484
485                if new_amp.abs() < threshold {
486                    continue;
487                }
488
489                // Phase advance
490                let phase_advance = 2.0 * std::f32::consts::PI * frequency.get();
491                let new_phase = (pulse.phase.get() + phase_advance) % (2.0 * std::f32::consts::PI);
492
493                // Update recent path (shift)
494                let mut rp = pulse.recent_path;
495                rp[2] = rp[1];
496                rp[1] = rp[0];
497                rp[0] = pulse.node;
498
499                let new_pulse = SpectralPulse {
500                    node: tgt,
501                    amplitude: FiniteF32::new(new_amp),
502                    phase: FiniteF32::new(new_phase),
503                    frequency,
504                    hops: pulse.hops + 1,
505                    prev_node: pulse.node,
506                    recent_path: rp,
507                };
508
509                pulses_out.push(new_pulse);
510                pulse_count += 1;
511                if pulse_count < budget {
512                    queue.push_back(new_pulse);
513                }
514            }
515        }
516
517        Ok(pulses_out)
518    }
519
520    /// Compute spectral overlap between hot and cold signals at each node.
521    /// DEC-003: bucket-based overlap for O(B) per node.
522    /// Replaces: xlr_v2.py adaptive_differential() spectral overlap section
523    pub fn spectral_overlap(hot_freqs: &[FiniteF32], cold_freqs: &[FiniteF32]) -> FiniteF32 {
524        if hot_freqs.is_empty() || cold_freqs.is_empty() {
525            return FiniteF32::ZERO;
526        }
527
528        // Bucket both signals
529        let mut hot_buckets = [0.0f32; SPECTRAL_BUCKETS];
530        let mut cold_buckets = [0.0f32; SPECTRAL_BUCKETS];
531
532        let max_freq = 10.0f32; // Reasonable max for bucketing
533        let bucket_width = max_freq / SPECTRAL_BUCKETS as f32;
534
535        for f in hot_freqs {
536            let b = ((f.get() / bucket_width) as usize).min(SPECTRAL_BUCKETS - 1);
537            hot_buckets[b] += 1.0;
538        }
539        for f in cold_freqs {
540            let b = ((f.get() / bucket_width) as usize).min(SPECTRAL_BUCKETS - 1);
541            cold_buckets[b] += 1.0;
542        }
543
544        // Overlap = sum(min(hot, cold)) / sum(hot)
545        let mut overlap = 0.0f32;
546        let mut hot_total = 0.0f32;
547        for b in 0..SPECTRAL_BUCKETS {
548            overlap += hot_buckets[b].min(cold_buckets[b]);
549            hot_total += hot_buckets[b];
550        }
551
552        if hot_total > 0.0 {
553            FiniteF32::new(overlap / hot_total)
554        } else {
555            FiniteF32::ZERO
556        }
557    }
558
559    /// Sigmoid gating: activation = sigmoid(x * SIGMOID_STEEPNESS).
560    /// Replaces: xlr_v2.py sigmoid gating in adaptive_differential()
561    pub fn sigmoid_gate(net_signal: FiniteF32) -> FiniteF32 {
562        let x = net_signal.get() * SIGMOID_STEEPNESS;
563        // Clamp to avoid overflow in exp
564        let clamped = x.max(-20.0).min(20.0);
565        let result = 1.0 / (1.0 + (-clamped).exp());
566        FiniteF32::new(result)
567    }
568}
569
570// Ensure Send + Sync for concurrent query serving.
571static_assertions::assert_impl_all!(AdaptiveXlrEngine: Send, Sync);