1use std::collections::VecDeque;
4
5use crate::error::{M1ndError, M1ndResult};
6use crate::graph::Graph;
7use crate::types::*;
8
9pub const F_HOT: f32 = 1.0;
15pub const F_COLD: f32 = 3.7;
17pub const SPECTRAL_BANDWIDTH: f32 = 0.8;
19pub const IMMUNITY_HOPS: u8 = 2;
21pub const SIGMOID_STEEPNESS: f32 = 6.0;
23pub const SPECTRAL_BUCKETS: usize = 20;
25pub const DENSITY_FLOOR: f32 = 0.3;
27pub const DENSITY_CAP: f32 = 2.0;
29pub const INHIBITORY_COLD_ATTENUATION: f32 = 0.5;
31
32#[derive(Clone, Copy, Debug)]
40pub struct SpectralPulse {
41 pub node: NodeId,
42 pub amplitude: FiniteF32,
43 pub phase: FiniteF32,
45 pub frequency: PosF32,
47 pub hops: u8,
49 pub prev_node: NodeId,
51 pub recent_path: [NodeId; 3],
53}
54
55#[derive(Clone, Debug, Default)]
62pub struct SpectralWaveBuffer {
63 pub hot_amplitudes: Vec<FiniteF32>,
65 pub hot_frequencies: Vec<FiniteF32>,
67 pub cold_amplitudes: Vec<FiniteF32>,
69 pub cold_frequencies: Vec<FiniteF32>,
71}
72
73pub struct XlrParams {
80 pub num_anti_seeds: usize,
82 pub immunity_hops: u8,
84 pub min_degree_ratio: FiniteF32,
86 pub max_jaccard_similarity: FiniteF32,
88 pub density_clamp_min: FiniteF32,
90 pub density_clamp_max: FiniteF32,
91 pub pulse_budget: u64,
93}
94
95impl Default for XlrParams {
96 fn default() -> Self {
97 Self {
98 num_anti_seeds: 3,
99 immunity_hops: IMMUNITY_HOPS,
100 min_degree_ratio: FiniteF32::new(0.3),
101 max_jaccard_similarity: FiniteF32::new(0.2),
102 density_clamp_min: FiniteF32::new(0.3),
103 density_clamp_max: FiniteF32::new(2.0),
104 pulse_budget: 50_000,
105 }
106 }
107}
108
109#[derive(Clone, Debug)]
116pub struct XlrResult {
117 pub activations: Vec<(NodeId, FiniteF32)>,
119 pub anti_seeds: Vec<NodeId>,
121 pub fallback_to_hot_only: bool,
123 pub pulses_processed: u64,
125}
126
127pub struct AdaptiveXlrEngine {
136 params: XlrParams,
137}
138
139impl AdaptiveXlrEngine {
140 pub fn new(params: XlrParams) -> Self {
141 Self { params }
142 }
143
144 pub fn with_defaults() -> Self {
145 Self::new(XlrParams::default())
146 }
147
148 pub fn query(
153 &self,
154 graph: &Graph,
155 seeds: &[(NodeId, FiniteF32)],
156 config: &PropagationConfig,
157 ) -> M1ndResult<XlrResult> {
158 let n = graph.num_nodes() as usize;
159 if n == 0 || seeds.is_empty() {
160 return Ok(XlrResult {
161 activations: Vec::new(),
162 anti_seeds: Vec::new(),
163 fallback_to_hot_only: false,
164 pulses_processed: 0,
165 });
166 }
167
168 let seed_nodes: Vec<NodeId> = seeds.iter().map(|s| s.0).collect();
169
170 let anti_seeds = self.pick_anti_seeds(graph, &seed_nodes)?;
172
173 let immunity = self.compute_immunity(graph, &seed_nodes)?;
175
176 let hot_freq = PosF32::new(F_HOT).unwrap();
178 let half_budget = self.params.pulse_budget / 2;
179 let hot_pulses = self.propagate_spectral(graph, seeds, hot_freq, config, half_budget)?;
180
181 let cold_freq = PosF32::new(F_COLD).unwrap();
183 let anti_seed_pairs: Vec<(NodeId, FiniteF32)> =
184 anti_seeds.iter().map(|&n| (n, FiniteF32::ONE)).collect();
185 let cold_pulses =
186 self.propagate_spectral(graph, &anti_seed_pairs, cold_freq, config, half_budget)?;
187
188 let total_pulses = hot_pulses.len() as u64 + cold_pulses.len() as u64;
189
190 let mut hot_amp = vec![0.0f32; n];
192 let mut cold_amp = vec![0.0f32; n];
193
194 for p in &hot_pulses {
195 let idx = p.node.as_usize();
196 if idx < n {
197 hot_amp[idx] += p.amplitude.get().abs();
198 }
199 }
200 for p in &cold_pulses {
201 let idx = p.node.as_usize();
202 if idx < n {
203 cold_amp[idx] += p.amplitude.get().abs();
204 }
205 }
206
207 let mut activations = Vec::new();
209 let mut all_zero = true;
210
211 let avg_deg = graph.avg_degree();
213
214 for i in 0..n {
215 let hot = hot_amp[i];
216 if hot <= 0.0 {
217 continue;
218 }
219
220 let immune = if i < immunity.len() {
222 immunity[i]
223 } else {
224 false
225 };
226
227 let effective_cold = if immune { 0.0 } else { cold_amp[i] };
228
229 let raw = hot - effective_cold;
231
232 let out_deg = {
234 let lo = graph.csr.offsets[i] as usize;
235 let hi = if i + 1 < graph.csr.offsets.len() {
236 graph.csr.offsets[i + 1] as usize
237 } else {
238 lo
239 };
240 (hi - lo) as f32
241 };
242 let density = if avg_deg > 0.0 {
243 (out_deg / avg_deg).max(DENSITY_FLOOR).min(DENSITY_CAP)
244 } else {
245 1.0
246 };
247
248 let gated = Self::sigmoid_gate(FiniteF32::new(raw * density));
250 let val = gated.get();
251
252 if val > 0.01 {
253 activations.push((NodeId::new(i as u32), gated));
254 all_zero = false;
255 }
256 }
257
258 let fallback = all_zero && !hot_pulses.is_empty();
260 if fallback {
261 activations.clear();
263 for i in 0..n {
264 if hot_amp[i] > 0.01 {
265 activations.push((NodeId::new(i as u32), FiniteF32::new(hot_amp[i])));
266 }
267 }
268 }
269
270 activations.sort_by(|a, b| b.1.cmp(&a.1));
271
272 Ok(XlrResult {
273 activations,
274 anti_seeds,
275 fallback_to_hot_only: fallback,
276 pulses_processed: total_pulses,
277 })
278 }
279
280 pub fn pick_anti_seeds(&self, graph: &Graph, seeds: &[NodeId]) -> M1ndResult<Vec<NodeId>> {
284 let n = graph.num_nodes() as usize;
285 if n == 0 || seeds.is_empty() {
286 return Ok(Vec::new());
287 }
288
289 let mut seed_set = vec![false; n];
291 let mut seed_neighbors = vec![false; n];
292 for &s in seeds {
293 let idx = s.as_usize();
294 if idx < n {
295 seed_set[idx] = true;
296 seed_neighbors[idx] = true;
297 let range = graph.csr.out_range(s);
298 for j in range {
299 let tgt = graph.csr.targets[j].as_usize();
300 if tgt < n {
301 seed_neighbors[tgt] = true;
302 }
303 }
304 }
305 }
306
307 let avg_seed_degree: f32 = if seeds.is_empty() {
309 0.0
310 } else {
311 let sum: usize = seeds
312 .iter()
313 .map(|s| {
314 let r = graph.csr.out_range(*s);
315 r.end - r.start
316 })
317 .sum();
318 sum as f32 / seeds.len() as f32
319 };
320
321 let mut candidates: Vec<(NodeId, f32)> = Vec::new();
323 for i in 0..n {
324 if seed_set[i] {
325 continue; }
327
328 let range = graph.csr.out_range(NodeId::new(i as u32));
329 let degree = (range.end - range.start) as f32;
330
331 if avg_seed_degree > 0.0 {
333 let ratio = degree / avg_seed_degree;
334 if ratio < self.params.min_degree_ratio.get() {
335 continue;
336 }
337 }
338
339 let mut intersection = 0usize;
341 let mut union_size = 0usize;
342 for j in range.clone() {
343 let tgt = graph.csr.targets[j].as_usize();
344 if tgt < n {
345 union_size += 1;
346 if seed_neighbors[tgt] {
347 intersection += 1;
348 }
349 }
350 }
351 let jaccard = if union_size > 0 {
352 intersection as f32 / union_size as f32
353 } else {
354 0.0
355 };
356
357 if jaccard > self.params.max_jaccard_similarity.get() {
358 continue; }
360
361 let distance_score = if seed_neighbors[i] { 0.0 } else { 1.0 };
363 let score = distance_score + (1.0 - jaccard);
364 candidates.push((NodeId::new(i as u32), score));
365 }
366
367 candidates.sort_by(|a, b| b.1.total_cmp(&a.1));
368 let result: Vec<NodeId> = candidates
369 .iter()
370 .take(self.params.num_anti_seeds)
371 .map(|c| c.0)
372 .collect();
373 Ok(result)
374 }
375
376 pub fn compute_immunity(&self, graph: &Graph, seeds: &[NodeId]) -> M1ndResult<Vec<bool>> {
381 let n = graph.num_nodes() as usize;
382 let mut immune = vec![false; n];
383
384 let mut queue = VecDeque::new();
385 let mut dist = vec![u8::MAX; n];
386
387 for &s in seeds {
388 let idx = s.as_usize();
389 if idx < n {
390 queue.push_back((s, 0u8));
391 dist[idx] = 0;
392 immune[idx] = true;
393 }
394 }
395
396 while let Some((node, d)) = queue.pop_front() {
397 if d >= self.params.immunity_hops {
398 continue;
399 }
400 let range = graph.csr.out_range(node);
401 for j in range {
402 let tgt = graph.csr.targets[j];
403 let tgt_idx = tgt.as_usize();
404 if tgt_idx < n && d + 1 < dist[tgt_idx] {
405 dist[tgt_idx] = d + 1;
406 immune[tgt_idx] = true;
407 queue.push_back((tgt, d + 1));
408 }
409 }
410 }
411
412 Ok(immune)
413 }
414
415 pub fn propagate_spectral(
420 &self,
421 graph: &Graph,
422 origins: &[(NodeId, FiniteF32)],
423 frequency: PosF32,
424 config: &PropagationConfig,
425 budget: u64,
426 ) -> M1ndResult<Vec<SpectralPulse>> {
427 let n = graph.num_nodes() as usize;
428 let decay = config.decay.get();
429 let threshold = config.threshold.get();
430 let mut pulses_out = Vec::new();
431 let mut pulse_count = 0u64;
432
433 let mut queue: VecDeque<SpectralPulse> = VecDeque::new();
434
435 for &(node, amp) in origins {
437 if node.as_usize() >= n {
438 continue;
439 }
440 let pulse = SpectralPulse {
441 node,
442 amplitude: amp,
443 phase: FiniteF32::ZERO,
444 frequency,
445 hops: 0,
446 prev_node: node,
447 recent_path: [node; 3],
448 };
449 queue.push_back(pulse);
450 pulses_out.push(pulse);
451 pulse_count += 1;
452 }
453
454 let max_depth = config.max_depth.min(20);
455
456 while let Some(pulse) = queue.pop_front() {
457 if pulse_count >= budget {
458 break; }
460 if pulse.hops >= max_depth {
461 continue;
462 }
463 if pulse.amplitude.get().abs() < threshold {
464 continue;
465 }
466
467 let range = graph.csr.out_range(pulse.node);
468 for j in range {
469 let tgt = graph.csr.targets[j];
470 if tgt == pulse.prev_node {
471 continue; }
473
474 let w = graph.csr.read_weight(EdgeIdx::new(j as u32)).get();
475 let is_inhib = graph.csr.inhibitory[j];
476
477 let mut new_amp = pulse.amplitude.get() * w * decay;
478
479 if is_inhib {
482 new_amp *= INHIBITORY_COLD_ATTENUATION;
483 }
484
485 if new_amp.abs() < threshold {
486 continue;
487 }
488
489 let phase_advance = 2.0 * std::f32::consts::PI * frequency.get();
491 let new_phase = (pulse.phase.get() + phase_advance) % (2.0 * std::f32::consts::PI);
492
493 let mut rp = pulse.recent_path;
495 rp[2] = rp[1];
496 rp[1] = rp[0];
497 rp[0] = pulse.node;
498
499 let new_pulse = SpectralPulse {
500 node: tgt,
501 amplitude: FiniteF32::new(new_amp),
502 phase: FiniteF32::new(new_phase),
503 frequency,
504 hops: pulse.hops + 1,
505 prev_node: pulse.node,
506 recent_path: rp,
507 };
508
509 pulses_out.push(new_pulse);
510 pulse_count += 1;
511 if pulse_count < budget {
512 queue.push_back(new_pulse);
513 }
514 }
515 }
516
517 Ok(pulses_out)
518 }
519
520 pub fn spectral_overlap(hot_freqs: &[FiniteF32], cold_freqs: &[FiniteF32]) -> FiniteF32 {
524 if hot_freqs.is_empty() || cold_freqs.is_empty() {
525 return FiniteF32::ZERO;
526 }
527
528 let mut hot_buckets = [0.0f32; SPECTRAL_BUCKETS];
530 let mut cold_buckets = [0.0f32; SPECTRAL_BUCKETS];
531
532 let max_freq = 10.0f32; let bucket_width = max_freq / SPECTRAL_BUCKETS as f32;
534
535 for f in hot_freqs {
536 let b = ((f.get() / bucket_width) as usize).min(SPECTRAL_BUCKETS - 1);
537 hot_buckets[b] += 1.0;
538 }
539 for f in cold_freqs {
540 let b = ((f.get() / bucket_width) as usize).min(SPECTRAL_BUCKETS - 1);
541 cold_buckets[b] += 1.0;
542 }
543
544 let mut overlap = 0.0f32;
546 let mut hot_total = 0.0f32;
547 for b in 0..SPECTRAL_BUCKETS {
548 overlap += hot_buckets[b].min(cold_buckets[b]);
549 hot_total += hot_buckets[b];
550 }
551
552 if hot_total > 0.0 {
553 FiniteF32::new(overlap / hot_total)
554 } else {
555 FiniteF32::ZERO
556 }
557 }
558
559 pub fn sigmoid_gate(net_signal: FiniteF32) -> FiniteF32 {
562 let x = net_signal.get() * SIGMOID_STEEPNESS;
563 let clamped = x.max(-20.0).min(20.0);
565 let result = 1.0 / (1.0 + (-clamped).exp());
566 FiniteF32::new(result)
567 }
568}
569
570static_assertions::assert_impl_all!(AdaptiveXlrEngine: Send, Sync);