1mod dataset;
2mod model;
3mod solver;
4mod variant_config;
5mod window;
6
7#[cfg(test)]
8mod tests;
9
10use crate::error::{DDAError, Result};
11use crate::types::{CcdConditioningStrategy, DDARequest, DDAResult, VariantResult};
12use dataset::{AnalysisBounds, MatrixDataset};
13use model::ModelSpec;
14use serde::{Deserialize, Serialize};
15use solver::{
16 bic_like_score, build_channel_regression_window_with_inputs, causal_improvement,
17 circular_shift_series, compute_de_value, conditional_causal_improvement,
18 empirical_significance_confidence, greedy_sparse_unique_improvements,
19 solve_channel_with_inputs, solve_channel_with_surrogate_inputs, solve_channels_parallel,
20 solve_directed_pair, solve_group_block, solve_temporally_regularized_windows,
21 synchronization_value, SolvedBlock,
22};
23use std::time::{Duration, Instant};
24use uuid::Uuid;
25use variant_config::{
26 collect_analysis_channels, flip_pairs, labels_for_channels, labels_for_groups,
27 labels_for_pairs, labels_for_sy, resolve_ccd_candidate_channels,
28 resolve_ccd_conditioning_strategy, resolve_ccd_max_active_sources, resolve_ccd_pairs,
29 resolve_ccd_surrogate_shifts, resolve_ccd_temporal_lambda, resolve_cd_pairs, resolve_ct_groups,
30 resolve_de_groups, resolve_sy_pairs, resolve_variant_selected_channels, VariantMode,
31};
32use window::PreparedWindow;
33
34pub(crate) const PARALLEL_BATCH_MIN_LEN: usize = 4;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum NormalizationMode {
38 ZScore,
39 Raw,
40 MinMax,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "snake_case")]
45pub enum SvdBackend {
46 RobustSvd,
47 NativeCompatSvd,
48}
49
50#[derive(Debug, Clone)]
51pub struct PureRustOptions {
52 pub nr_exclude: usize,
53 pub normalization_mode: NormalizationMode,
54 pub derivative_step: usize,
55 pub svd_backend: SvdBackend,
56}
57
58impl Default for PureRustOptions {
59 fn default() -> Self {
60 Self {
61 nr_exclude: 10,
62 normalization_mode: NormalizationMode::ZScore,
63 derivative_step: 1,
64 svd_backend: SvdBackend::RobustSvd,
65 }
66 }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct PureRustProgress {
71 pub stage_id: String,
72 pub stage_label: String,
73 pub step_index: usize,
74 pub total_steps: usize,
75 pub window_index: usize,
76 pub total_windows: usize,
77 pub item_index: usize,
78 pub total_items: usize,
79 pub item_kind: String,
80 pub item_label: String,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
84pub struct CcdConditioningInspection {
85 pub pairs: Vec<[usize; 2]>,
86 pub conditioning_sets: Vec<Vec<usize>>,
87 pub candidate_channels: Vec<usize>,
88 pub strategy: CcdConditioningStrategy,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
92pub struct CcdConditioningSubsetScore {
93 pub pair: [usize; 2],
94 pub confounds: Vec<usize>,
95 pub bic_like_score: f64,
96 pub mean_rmse: f64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub struct CcdConditioningSubsetProfile {
101 pub pair: [usize; 2],
102 pub confounds: Vec<usize>,
103 pub bic_like_score: f64,
104 pub mean_rmse: f64,
105 pub window_bic_scores: Vec<f64>,
106 pub window_rmses: Vec<f64>,
107}
108
109#[derive(Debug, Clone)]
110pub struct PureRustRunner {
111 options: PureRustOptions,
112}
113
114impl Default for PureRustRunner {
115 fn default() -> Self {
116 Self::new(PureRustOptions::default())
117 }
118}
119
120impl PureRustRunner {
121 pub fn new(options: PureRustOptions) -> Self {
122 Self { options }
123 }
124
125 pub fn run_on_matrix(
126 &self,
127 request: &DDARequest,
128 samples: &[Vec<f64>],
129 channel_labels: Option<&[String]>,
130 ) -> Result<DDAResult> {
131 self.run_on_matrix_internal(request, samples, channel_labels, None)
132 }
133
134 pub fn run_on_matrix_with_progress<F>(
135 &self,
136 request: &DDARequest,
137 samples: &[Vec<f64>],
138 channel_labels: Option<&[String]>,
139 on_progress: F,
140 ) -> Result<DDAResult>
141 where
142 F: FnMut(&PureRustProgress),
143 {
144 let mut callback = on_progress;
145 self.run_on_matrix_internal(request, samples, channel_labels, Some(&mut callback))
146 }
147
148 pub fn inspect_ccd_conditioning_sets_on_matrix(
149 &self,
150 request: &DDARequest,
151 samples: &[Vec<f64>],
152 channel_labels: Option<&[String]>,
153 ) -> Result<CcdConditioningInspection> {
154 let dataset = MatrixDataset::new(samples, channel_labels)?;
155 let model = ModelSpec::from_request(request)?;
156 let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
157 let ccd_pairs = resolve_ccd_pairs(request, dataset.cols);
158 let strategy = resolve_ccd_conditioning_strategy(request);
159 let candidate_channels = resolve_ccd_candidate_channels(request, dataset.cols);
160 let needs_prepared_windows = !matches!(strategy, CcdConditioningStrategy::AllSelected);
161
162 let conditioning_sets = if needs_prepared_windows {
163 let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
164 let required_rows = native_window_marker.saturating_sub(1);
165 if bounds.len < required_rows {
166 return Err(DDAError::InvalidParameter(format!(
167 "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
168 bounds.len, required_rows
169 )));
170 }
171 if model.window_step == 0 {
172 return Err(DDAError::InvalidParameter(
173 "window_step must be greater than zero".to_string(),
174 ));
175 }
176 let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
177 let windows = (0..num_windows)
178 .map(|window_idx| {
179 prepare_window_for_analysis(
180 &dataset,
181 &bounds,
182 &model,
183 window_idx,
184 &self.options,
185 )
186 })
187 .collect::<Result<Vec<_>>>()?;
188 compute_ccd_pair_conditioning_sets(
189 Some(&windows),
190 &ccd_pairs,
191 &candidate_channels,
192 strategy,
193 &model,
194 resolve_ccd_max_active_sources(request).unwrap_or(3),
195 self.options.svd_backend,
196 )
197 } else {
198 compute_ccd_pair_conditioning_sets(
199 None,
200 &ccd_pairs,
201 &candidate_channels,
202 strategy,
203 &model,
204 resolve_ccd_max_active_sources(request).unwrap_or(3),
205 self.options.svd_backend,
206 )
207 };
208
209 Ok(CcdConditioningInspection {
210 pairs: ccd_pairs,
211 conditioning_sets,
212 candidate_channels,
213 strategy,
214 })
215 }
216
217 pub fn score_ccd_conditioning_subsets_on_matrix(
218 &self,
219 request: &DDARequest,
220 samples: &[Vec<f64>],
221 channel_labels: Option<&[String]>,
222 pair: [usize; 2],
223 confound_sets: &[Vec<usize>],
224 ) -> Result<Vec<CcdConditioningSubsetScore>> {
225 let dataset = MatrixDataset::new(samples, channel_labels)?;
226 let model = ModelSpec::from_request(request)?;
227 let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
228 let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
229 let required_rows = native_window_marker.saturating_sub(1);
230 if bounds.len < required_rows {
231 return Err(DDAError::InvalidParameter(format!(
232 "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
233 bounds.len, required_rows
234 )));
235 }
236 if model.window_step == 0 {
237 return Err(DDAError::InvalidParameter(
238 "window_step must be greater than zero".to_string(),
239 ));
240 }
241 let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
242 let windows = (0..num_windows)
243 .map(|window_idx| {
244 prepare_window_for_analysis(&dataset, &bounds, &model, window_idx, &self.options)
245 })
246 .collect::<Result<Vec<_>>>()?;
247
248 Ok(confound_sets
249 .iter()
250 .map(|confounds| CcdConditioningSubsetScore {
251 pair,
252 confounds: confounds.clone(),
253 bic_like_score: average_conditioned_baseline_score(
254 &windows,
255 pair[0],
256 confounds,
257 &model,
258 self.options.svd_backend,
259 ),
260 mean_rmse: average_conditioned_baseline_rmse(
261 &windows,
262 pair[0],
263 confounds,
264 &model,
265 self.options.svd_backend,
266 ),
267 })
268 .collect())
269 }
270
271 pub fn profile_ccd_conditioning_subsets_on_matrix(
272 &self,
273 request: &DDARequest,
274 samples: &[Vec<f64>],
275 channel_labels: Option<&[String]>,
276 pair: [usize; 2],
277 confound_sets: &[Vec<usize>],
278 ) -> Result<Vec<CcdConditioningSubsetProfile>> {
279 let dataset = MatrixDataset::new(samples, channel_labels)?;
280 let model = ModelSpec::from_request(request)?;
281 let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
282 let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
283 let required_rows = native_window_marker.saturating_sub(1);
284 if bounds.len < required_rows {
285 return Err(DDAError::InvalidParameter(format!(
286 "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
287 bounds.len, required_rows
288 )));
289 }
290 if model.window_step == 0 {
291 return Err(DDAError::InvalidParameter(
292 "window_step must be greater than zero".to_string(),
293 ));
294 }
295 let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
296 let windows = (0..num_windows)
297 .map(|window_idx| {
298 prepare_window_for_analysis(&dataset, &bounds, &model, window_idx, &self.options)
299 })
300 .collect::<Result<Vec<_>>>()?;
301
302 Ok(confound_sets
303 .iter()
304 .map(|confounds| {
305 let (window_bic_scores, window_rmses) = conditioned_baseline_window_metrics(
306 &windows,
307 pair[0],
308 confounds,
309 &model,
310 self.options.svd_backend,
311 );
312 let bic_like_score = finite_mean(&window_bic_scores).unwrap_or(f64::INFINITY);
313 let mean_rmse = finite_mean(&window_rmses).unwrap_or(f64::INFINITY);
314 CcdConditioningSubsetProfile {
315 pair,
316 confounds: confounds.clone(),
317 bic_like_score,
318 mean_rmse,
319 window_bic_scores,
320 window_rmses,
321 }
322 })
323 .collect())
324 }
325
326 fn run_on_matrix_internal(
327 &self,
328 request: &DDARequest,
329 samples: &[Vec<f64>],
330 channel_labels: Option<&[String]>,
331 mut on_progress: Option<&mut dyn FnMut(&PureRustProgress)>,
332 ) -> Result<DDAResult> {
333 let dataset = MatrixDataset::new(samples, channel_labels)?;
334 let variant_mode = VariantMode::from_request(request);
335 let model = ModelSpec::from_request(request)?;
336 let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
337 let st_channels = resolve_variant_selected_channels(
338 request,
339 dataset.cols,
340 &["ST", "st", "single_timeseries"],
341 );
342 let de_channels = resolve_variant_selected_channels(
343 request,
344 dataset.cols,
345 &["DE", "de", "dynamical_ergodicity"],
346 );
347 let sy_channels = resolve_variant_selected_channels(
348 request,
349 dataset.cols,
350 &["SY", "sy", "synchronization"],
351 );
352 let ct_groups = resolve_ct_groups(request, dataset.cols);
353 let de_groups = resolve_de_groups(request, dataset.cols, &de_channels);
354 let cd_pairs = resolve_cd_pairs(request, dataset.cols);
355 let ccd_pairs = resolve_ccd_pairs(request, dataset.cols);
356 let ccd_conditioning_strategy = resolve_ccd_conditioning_strategy(request);
357 let ccd_candidate_channels = resolve_ccd_candidate_channels(request, dataset.cols);
358 let ccd_surrogate_shifts = resolve_ccd_surrogate_shifts(request);
359 let ccd_temporal_lambda = resolve_ccd_temporal_lambda(request).unwrap_or(0.25);
360 let ccd_max_active_sources = resolve_ccd_max_active_sources(request).unwrap_or(3);
361 let sy_pairs = resolve_sy_pairs(&sy_channels);
362 let analysis_channels = collect_analysis_channels(
363 &st_channels,
364 &ct_groups,
365 &de_groups,
366 &cd_pairs,
367 &ccd_pairs,
368 &ccd_candidate_channels,
369 );
370
371 let enabled_st = variant_mode.st_enabled;
372 let enabled_ct = variant_mode.ct_enabled;
373 let enabled_cd = variant_mode.cd_enabled;
374 let enabled_ccd_core = (variant_mode.ccd_enabled
375 || variant_mode.ccdsig_enabled
376 || variant_mode.ccdstab_enabled
377 || variant_mode.trccd_enabled
378 || variant_mode.mvccd_enabled)
379 && !ccd_pairs.is_empty();
380 let enabled_ccd = variant_mode.ccd_enabled && !ccd_pairs.is_empty();
381 let enabled_ccdsig = variant_mode.ccdsig_enabled && !ccd_pairs.is_empty();
382 let enabled_ccdstab = variant_mode.ccdstab_enabled && !ccd_pairs.is_empty();
383 let enabled_trccd = variant_mode.trccd_enabled && !ccd_pairs.is_empty();
384 let enabled_mvccd = variant_mode.mvccd_enabled && !ccd_pairs.is_empty();
385 let enabled_de = variant_mode.de_enabled;
386 let enabled_sy = variant_mode.sy_mode > 0 && !sy_pairs.is_empty();
387
388 if !enabled_st
389 && !enabled_ct
390 && !enabled_cd
391 && !enabled_ccd_core
392 && !enabled_de
393 && !enabled_sy
394 {
395 return Err(DDAError::InvalidParameter(
396 "No DDA variants enabled for pure Rust engine".to_string(),
397 ));
398 }
399
400 let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
401 let required_rows = native_window_marker.saturating_sub(1);
402 if bounds.len < required_rows {
403 return Err(DDAError::InvalidParameter(format!(
404 "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
405 bounds.len, required_rows
406 )));
407 }
408 if model.window_step == 0 {
409 return Err(DDAError::InvalidParameter(
410 "window_step must be greater than zero".to_string(),
411 ));
412 }
413
414 let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
415 let needs_prepared_windows = enabled_trccd
416 || !matches!(
417 ccd_conditioning_strategy,
418 CcdConditioningStrategy::AllSelected
419 );
420 let mut prepared_windows = None;
421 let progress_enabled = on_progress.is_some();
422 let analysis_channel_labels = progress_enabled
423 .then(|| labels_for_channels(&dataset.channel_labels, &analysis_channels));
424 let ct_group_labels =
425 progress_enabled.then(|| labels_for_groups(&dataset.channel_labels, &ct_groups, " & "));
426 let de_group_labels =
427 progress_enabled.then(|| labels_for_groups(&dataset.channel_labels, &de_groups, " & "));
428 let cd_pair_labels =
429 progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &cd_pairs, " <- "));
430 let ccd_pair_labels =
431 progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &ccd_pairs, " <- "));
432 let sy_forward_labels =
433 progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &sy_pairs, " -> "));
434 let sy_reverse_labels = progress_enabled.then(|| {
435 let sy_reverse_pairs = flip_pairs(&sy_pairs);
436 labels_for_pairs(&dataset.channel_labels, &sy_reverse_pairs, " -> ")
437 });
438 let shared_block_steps = if enabled_st || enabled_cd || enabled_de {
439 analysis_channels.len()
440 } else {
441 0
442 };
443 let steps_per_window = 1
444 + shared_block_steps
445 + if enabled_ct { ct_groups.len() } else { 0 }
446 + if enabled_de { de_groups.len() } else { 0 }
447 + if enabled_cd { cd_pairs.len() } else { 0 }
448 + if enabled_ccd_core { ccd_pairs.len() } else { 0 }
449 + if enabled_ccdsig { ccd_pairs.len() } else { 0 }
450 + if enabled_mvccd { ccd_pairs.len() } else { 0 }
451 + if enabled_sy { sy_pairs.len() * 2 } else { 0 };
452 let total_steps = num_windows * steps_per_window
453 + if enabled_trccd { ccd_pairs.len() } else { 0 }
454 + if enabled_ccdstab { ccd_pairs.len() } else { 0 };
455 let mut emitted_steps = 0usize;
456 let mut last_progress_emit = Instant::now() - Duration::from_secs(1);
457 let mut report = |stage_id: &str,
458 stage_label: &str,
459 window_number: usize,
460 item_index: usize,
461 total_items: usize,
462 item_kind: &str,
463 item_label: Option<&str>| {
464 emitted_steps += 1;
465 let should_emit = emitted_steps <= 1
466 || emitted_steps >= total_steps
467 || last_progress_emit.elapsed() >= Duration::from_millis(125);
468 if !should_emit {
469 return;
470 }
471 last_progress_emit = Instant::now();
472 if let Some(callback) = on_progress.as_deref_mut() {
473 callback(&PureRustProgress {
474 stage_id: stage_id.to_string(),
475 stage_label: stage_label.to_string(),
476 step_index: emitted_steps,
477 total_steps,
478 window_index: window_number,
479 total_windows: num_windows,
480 item_index,
481 total_items,
482 item_kind: item_kind.to_string(),
483 item_label: item_label.unwrap_or("").to_string(),
484 });
485 }
486 };
487
488 let native_window_markers: Vec<f64> = (0..num_windows)
489 .map(|window_idx| {
490 (bounds.start + window_idx * model.window_step + native_window_marker) as f64
491 })
492 .collect();
493
494 if needs_prepared_windows {
495 let mut windows = Vec::with_capacity(num_windows);
496 for window_idx in 0..num_windows {
497 report(
498 "prepare-window",
499 "Preparing analysis window",
500 window_idx + 1,
501 window_idx + 1,
502 num_windows,
503 "window",
504 None,
505 );
506 windows.push(prepare_window_for_analysis(
507 &dataset,
508 &bounds,
509 &model,
510 window_idx,
511 &self.options,
512 )?);
513 }
514 prepared_windows = Some(windows);
515 }
516
517 let ccd_pair_conditioning_sets = if enabled_ccd_core {
518 compute_ccd_pair_conditioning_sets(
519 prepared_windows.as_deref(),
520 &ccd_pairs,
521 &ccd_candidate_channels,
522 ccd_conditioning_strategy,
523 &model,
524 ccd_max_active_sources,
525 self.options.svd_backend,
526 )
527 } else {
528 Vec::new()
529 };
530 let ccd_target_conditioning_sets =
531 build_target_conditioning_sets(&ccd_pairs, &ccd_pair_conditioning_sets);
532
533 let mut st_matrix =
534 enabled_st.then(|| vec![vec![f64::NAN; num_windows]; st_channels.len()]);
535 let mut ct_matrix = enabled_ct.then(|| vec![vec![f64::NAN; num_windows]; ct_groups.len()]);
536 let mut cd_matrix = enabled_cd.then(|| vec![vec![f64::NAN; num_windows]; cd_pairs.len()]);
537 let mut ccd_matrix =
538 enabled_ccd_core.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
539 let mut ccdsig_matrix =
540 enabled_ccdsig.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
541 let mut mvccd_matrix =
542 enabled_mvccd.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
543 let mut trccd_matrix =
544 enabled_trccd.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
545 let mut ccdstab_matrix =
546 enabled_ccdstab.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
547 let mut de_matrix = enabled_de.then(|| vec![vec![f64::NAN; num_windows]; de_groups.len()]);
548 let mut sy_matrix = enabled_sy.then(|| {
549 let rows = if variant_mode.sy_mode == 2 {
550 sy_pairs.len() * 2
551 } else {
552 sy_pairs.len()
553 };
554 vec![vec![f64::NAN; num_windows]; rows]
555 });
556
557 for window_idx in 0..num_windows {
558 let prepared_storage;
559 let prepared = if let Some(windows) = prepared_windows.as_ref() {
560 &windows[window_idx]
561 } else {
562 report(
563 "prepare-window",
564 "Preparing analysis window",
565 window_idx + 1,
566 window_idx + 1,
567 num_windows,
568 "window",
569 None,
570 );
571 prepared_storage = prepare_window_for_analysis(
572 &dataset,
573 &bounds,
574 &model,
575 window_idx,
576 &self.options,
577 )?;
578 &prepared_storage
579 };
580
581 let mut st_blocks: Vec<Option<SolvedBlock>> = vec![None; dataset.cols];
582 if enabled_st || enabled_cd || enabled_de {
583 let computed_st_blocks = solve_channels_parallel(&analysis_channels, |&channel| {
584 (
585 channel,
586 solve_group_block(
587 &prepared,
588 &[channel],
589 &model.primary_terms,
590 model.window_length,
591 self.options.svd_backend,
592 ),
593 )
594 });
595 for (channel_idx, (channel, block)) in computed_st_blocks.into_iter().enumerate() {
596 report(
597 "st-blocks",
598 "Solving baseline channel dynamics",
599 window_idx + 1,
600 channel_idx + 1,
601 analysis_channels.len(),
602 "channel",
603 analysis_channel_labels
604 .as_ref()
605 .and_then(|labels| labels.get(channel_idx).map(String::as_str)),
606 );
607 if channel < st_blocks.len() {
608 st_blocks[channel] = Some(block);
609 }
610 }
611 }
612
613 if let Some(matrix) = st_matrix.as_mut() {
614 for (row_idx, &channel) in st_channels.iter().enumerate() {
615 if let Some(block) = st_blocks.get(channel).and_then(Option::as_ref) {
616 matrix[row_idx][window_idx] =
617 block.coefficients.first().copied().unwrap_or(f64::NAN);
618 }
619 }
620 }
621
622 let mut ct_blocks = Vec::new();
623 if enabled_ct {
624 ct_blocks = solve_channels_parallel(&ct_groups, |group| {
625 solve_group_block(
626 &prepared,
627 group,
628 &model.primary_terms,
629 model.window_length,
630 self.options.svd_backend,
631 )
632 });
633 for (group_idx, _) in ct_groups.iter().enumerate() {
634 report(
635 "ct",
636 "Computing cross-timeseries groups",
637 window_idx + 1,
638 group_idx + 1,
639 ct_groups.len(),
640 "group",
641 ct_group_labels
642 .as_ref()
643 .and_then(|labels| labels.get(group_idx).map(String::as_str)),
644 );
645 }
646 }
647
648 if let Some(matrix) = ct_matrix.as_mut() {
649 for (row_idx, block) in ct_blocks.iter().enumerate() {
650 matrix[row_idx][window_idx] =
651 block.coefficients.first().copied().unwrap_or(f64::NAN);
652 }
653 }
654
655 let mut de_blocks = Vec::new();
656 if enabled_de {
657 de_blocks = solve_channels_parallel(&de_groups, |group| {
658 solve_group_block(
659 &prepared,
660 group,
661 &model.primary_terms,
662 model.window_length,
663 self.options.svd_backend,
664 )
665 });
666 for (group_idx, _) in de_groups.iter().enumerate() {
667 report(
668 "de",
669 "Computing dynamical ergodicity groups",
670 window_idx + 1,
671 group_idx + 1,
672 de_groups.len(),
673 "group",
674 de_group_labels
675 .as_ref()
676 .and_then(|labels| labels.get(group_idx).map(String::as_str)),
677 );
678 }
679 }
680
681 if let Some(matrix) = de_matrix.as_mut() {
682 for (row_idx, group) in de_groups.iter().enumerate() {
683 let ct_rmse = de_blocks
684 .get(row_idx)
685 .map(|block| block.rmse)
686 .unwrap_or(f64::NAN);
687 let de_value = compute_de_value(group, &st_blocks, ct_rmse);
688 matrix[row_idx][window_idx] = de_value;
689 }
690 }
691
692 if enabled_cd {
693 let cd_values = solve_channels_parallel(&cd_pairs, |pair| {
694 let forward = solve_directed_pair(
695 &prepared,
696 pair[0],
697 pair[1],
698 pair[0],
699 &model.primary_terms,
700 &model.secondary_terms,
701 model.window_length,
702 self.options.svd_backend,
703 );
704 let baseline = st_blocks
705 .get(pair[0])
706 .and_then(Option::as_ref)
707 .map(|block| block.rmse)
708 .unwrap_or(f64::NAN);
709 causal_improvement(baseline, forward.rmse)
710 });
711 for (pair_idx, _) in cd_pairs.iter().enumerate() {
712 report(
713 "cd",
714 "Computing directed causal pairs",
715 window_idx + 1,
716 pair_idx + 1,
717 cd_pairs.len(),
718 "pair",
719 cd_pair_labels
720 .as_ref()
721 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
722 );
723 }
724 if let Some(matrix) = cd_matrix.as_mut() {
725 for (pair_idx, value) in cd_values.into_iter().enumerate() {
726 matrix[pair_idx][window_idx] = value;
727 }
728 }
729 }
730
731 if enabled_ccd_core {
732 let ccd_values = solve_channels_parallel(
733 &ccd_pairs
734 .iter()
735 .zip(ccd_pair_conditioning_sets.iter())
736 .collect::<Vec<_>>(),
737 |(pair, confounds)| {
738 let baseline = solve_channel_with_inputs(
739 prepared,
740 pair[0],
741 confounds,
742 &model.primary_terms,
743 &model.secondary_terms,
744 model.window_length,
745 self.options.svd_backend,
746 );
747 let mut conditioned_inputs = (*confounds).clone();
748 conditioned_inputs.push(pair[1]);
749 let conditioned = solve_channel_with_inputs(
750 prepared,
751 pair[0],
752 &conditioned_inputs,
753 &model.primary_terms,
754 &model.secondary_terms,
755 model.window_length,
756 self.options.svd_backend,
757 );
758 let observed =
759 conditional_causal_improvement(baseline.rmse, conditioned.rmse);
760
761 let significance = if enabled_ccdsig {
762 let surrogate_shifts =
763 ccd_surrogate_shifts.clone().unwrap_or_else(|| {
764 default_surrogate_shifts(prepared.shifted.len())
765 });
766 let surrogate_inputs = confounds
767 .iter()
768 .map(|channel| extract_shifted_channel_series(prepared, *channel))
769 .collect::<Vec<_>>();
770 let source_series = extract_shifted_channel_series(prepared, pair[1]);
771 let null_scores = surrogate_shifts
772 .into_iter()
773 .filter(|shift| *shift > 0)
774 .map(|shift| {
775 let shifted_source =
776 circular_shift_series(&source_series, shift);
777 let mut conditioned_surrogates = surrogate_inputs.clone();
778 conditioned_surrogates.push(shifted_source);
779 let surrogate_block = solve_channel_with_surrogate_inputs(
780 prepared,
781 pair[0],
782 &conditioned_surrogates,
783 &model.primary_terms,
784 &model.secondary_terms,
785 model.window_length,
786 self.options.svd_backend,
787 );
788 conditional_causal_improvement(
789 baseline.rmse,
790 surrogate_block.rmse,
791 )
792 })
793 .collect::<Vec<_>>();
794 empirical_significance_confidence(observed, &null_scores)
795 } else {
796 f64::NAN
797 };
798
799 (observed, significance)
800 },
801 );
802 for (pair_idx, _) in ccd_pairs.iter().enumerate() {
803 report(
804 "ccd",
805 "Computing conditional directed causal pairs",
806 window_idx + 1,
807 pair_idx + 1,
808 ccd_pairs.len(),
809 "pair",
810 ccd_pair_labels
811 .as_ref()
812 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
813 );
814 }
815 if let Some(matrix) = ccd_matrix.as_mut() {
816 for (pair_idx, value) in ccd_values.iter().enumerate() {
817 matrix[pair_idx][window_idx] = value.0;
818 }
819 }
820 if let Some(matrix) = ccdsig_matrix.as_mut() {
821 for (pair_idx, value) in ccd_values.into_iter().enumerate() {
822 matrix[pair_idx][window_idx] = value.1;
823 }
824 }
825 }
826
827 if enabled_mvccd {
828 let mvccd_values = compute_mvccd_window_scores(
829 prepared,
830 &ccd_pairs,
831 &ccd_target_conditioning_sets,
832 &model,
833 ccd_max_active_sources,
834 self.options.svd_backend,
835 );
836 for (pair_idx, _) in ccd_pairs.iter().enumerate() {
837 report(
838 "mvccd",
839 "Computing sparse multivariate conditional pairs",
840 window_idx + 1,
841 pair_idx + 1,
842 ccd_pairs.len(),
843 "pair",
844 ccd_pair_labels
845 .as_ref()
846 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
847 );
848 }
849 if let Some(matrix) = mvccd_matrix.as_mut() {
850 for (pair_idx, value) in mvccd_values.into_iter().enumerate() {
851 matrix[pair_idx][window_idx] = value;
852 }
853 }
854 }
855
856 if let Some(matrix) = sy_matrix.as_mut() {
857 let sy_values = solve_channels_parallel(&sy_pairs, |pair| {
858 let forward = solve_directed_pair(
859 &prepared,
860 pair[0],
861 pair[1],
862 pair[1],
863 &model.primary_terms,
864 &model.secondary_terms,
865 model.window_length,
866 self.options.svd_backend,
867 );
868 let reverse = solve_directed_pair(
869 &prepared,
870 pair[1],
871 pair[0],
872 pair[0],
873 &model.primary_terms,
874 &model.secondary_terms,
875 model.window_length,
876 self.options.svd_backend,
877 );
878 (forward.rmse, reverse.rmse)
879 });
880 for (pair_idx, _) in sy_pairs.iter().enumerate() {
881 report(
882 "sy",
883 "Computing synchronization directions",
884 window_idx + 1,
885 pair_idx * 2 + 1,
886 sy_pairs.len() * 2,
887 "direction",
888 sy_forward_labels
889 .as_ref()
890 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
891 );
892 report(
893 "sy",
894 "Computing synchronization directions",
895 window_idx + 1,
896 pair_idx * 2 + 2,
897 sy_pairs.len() * 2,
898 "direction",
899 sy_reverse_labels
900 .as_ref()
901 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
902 );
903 }
904 for (pair_idx, (forward_rmse, reverse_rmse)) in sy_values.into_iter().enumerate() {
905 if variant_mode.sy_mode == 2 {
906 let row_base = pair_idx * 2;
907 matrix[row_base][window_idx] = forward_rmse;
908 matrix[row_base + 1][window_idx] = reverse_rmse;
909 } else {
910 matrix[pair_idx][window_idx] =
911 synchronization_value(1, forward_rmse, reverse_rmse);
912 }
913 }
914 }
915 }
916
917 if enabled_trccd {
918 if let Some(matrix) = trccd_matrix.as_mut() {
919 let windows = prepared_windows.as_deref().unwrap_or(&[]);
920 let regularized = compute_trccd_matrix(
921 windows,
922 &ccd_pairs,
923 &ccd_pair_conditioning_sets,
924 &model,
925 ccd_temporal_lambda,
926 self.options.svd_backend,
927 );
928 for (pair_idx, row) in regularized.into_iter().enumerate() {
929 report(
930 "trccd",
931 "Computing temporally regularized conditional pairs",
932 num_windows,
933 pair_idx + 1,
934 ccd_pairs.len(),
935 "pair",
936 ccd_pair_labels
937 .as_ref()
938 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
939 );
940 matrix[pair_idx] = row;
941 }
942 }
943 }
944
945 if enabled_ccdstab {
946 if let Some(base_ccd) = ccd_matrix.as_ref() {
947 let stability = self.compute_ccd_stability_matrix(
948 request,
949 samples,
950 channel_labels,
951 &native_window_markers,
952 &ccd_pairs,
953 base_ccd,
954 )?;
955 if let Some(matrix) = ccdstab_matrix.as_mut() {
956 for (pair_idx, row) in stability.into_iter().enumerate() {
957 report(
958 "ccdstab",
959 "Computing conditional-pair stability",
960 num_windows,
961 pair_idx + 1,
962 ccd_pairs.len(),
963 "pair",
964 ccd_pair_labels
965 .as_ref()
966 .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
967 );
968 matrix[pair_idx] = row;
969 }
970 }
971 }
972 }
973
974 let mut variant_results = Vec::new();
975 if let Some(q_matrix) = st_matrix {
976 variant_results.push(VariantResult {
977 variant_id: "ST".to_string(),
978 variant_name: "Single Timeseries (ST)".to_string(),
979 q_matrix,
980 channel_labels: Some(labels_for_channels(&dataset.channel_labels, &st_channels)),
981 error_values: Some(native_window_markers.clone()),
982 });
983 }
984 if let Some(q_matrix) = ct_matrix {
985 variant_results.push(VariantResult {
986 variant_id: "CT".to_string(),
987 variant_name: "Cross-Timeseries (CT)".to_string(),
988 q_matrix,
989 channel_labels: Some(labels_for_groups(&dataset.channel_labels, &ct_groups, "&")),
990 error_values: Some(native_window_markers.clone()),
991 });
992 }
993 if let Some(q_matrix) = cd_matrix {
994 variant_results.push(VariantResult {
995 variant_id: "CD".to_string(),
996 variant_name: "Cross-Dynamical (CD)".to_string(),
997 q_matrix,
998 channel_labels: Some(labels_for_pairs(&dataset.channel_labels, &cd_pairs, " <- ")),
999 error_values: Some(native_window_markers.clone()),
1000 });
1001 }
1002 if enabled_ccd {
1003 if let Some(q_matrix) = ccd_matrix.clone() {
1004 variant_results.push(VariantResult {
1005 variant_id: "CCD".to_string(),
1006 variant_name: "Conditional Cross-Dynamical (CCD)".to_string(),
1007 q_matrix,
1008 channel_labels: Some(labels_for_pairs(
1009 &dataset.channel_labels,
1010 &ccd_pairs,
1011 " <- ",
1012 )),
1013 error_values: Some(native_window_markers.clone()),
1014 });
1015 }
1016 }
1017 if let Some(q_matrix) = ccdsig_matrix {
1018 variant_results.push(VariantResult {
1019 variant_id: "CCDSIG".to_string(),
1020 variant_name: "Conditional Cross-Dynamical Significance (CCDSIG)".to_string(),
1021 q_matrix,
1022 channel_labels: Some(labels_for_pairs(
1023 &dataset.channel_labels,
1024 &ccd_pairs,
1025 " <- ",
1026 )),
1027 error_values: Some(native_window_markers.clone()),
1028 });
1029 }
1030 if let Some(q_matrix) = ccdstab_matrix {
1031 variant_results.push(VariantResult {
1032 variant_id: "CCDSTAB".to_string(),
1033 variant_name: "Conditional Cross-Dynamical Stability (CCDSTAB)".to_string(),
1034 q_matrix,
1035 channel_labels: Some(labels_for_pairs(
1036 &dataset.channel_labels,
1037 &ccd_pairs,
1038 " <- ",
1039 )),
1040 error_values: Some(native_window_markers.clone()),
1041 });
1042 }
1043 if let Some(q_matrix) = trccd_matrix {
1044 variant_results.push(VariantResult {
1045 variant_id: "TRCCD".to_string(),
1046 variant_name: "Temporally Regularized Conditional Cross-Dynamical (TRCCD)"
1047 .to_string(),
1048 q_matrix,
1049 channel_labels: Some(labels_for_pairs(
1050 &dataset.channel_labels,
1051 &ccd_pairs,
1052 " <- ",
1053 )),
1054 error_values: Some(native_window_markers.clone()),
1055 });
1056 }
1057 if let Some(q_matrix) = mvccd_matrix {
1058 variant_results.push(VariantResult {
1059 variant_id: "MVCCD".to_string(),
1060 variant_name: "Sparse Multivariate Conditional Cross-Dynamical (MVCCD)".to_string(),
1061 q_matrix,
1062 channel_labels: Some(labels_for_pairs(
1063 &dataset.channel_labels,
1064 &ccd_pairs,
1065 " <- ",
1066 )),
1067 error_values: Some(native_window_markers.clone()),
1068 });
1069 }
1070 if let Some(q_matrix) = de_matrix {
1071 variant_results.push(VariantResult {
1072 variant_id: "DE".to_string(),
1073 variant_name: "Dynamical Ergodicity (DE)".to_string(),
1074 q_matrix,
1075 channel_labels: Some(labels_for_groups(&dataset.channel_labels, &de_groups, "&")),
1076 error_values: Some(native_window_markers.clone()),
1077 });
1078 }
1079 if let Some(q_matrix) = sy_matrix {
1080 variant_results.push(VariantResult {
1081 variant_id: "SY".to_string(),
1082 variant_name: "Synchronization (SY)".to_string(),
1083 q_matrix,
1084 channel_labels: Some(labels_for_sy(
1085 &dataset.channel_labels,
1086 &sy_pairs,
1087 variant_mode.sy_mode,
1088 )),
1089 error_values: Some(native_window_markers.clone()),
1090 });
1091 }
1092
1093 let primary_q = variant_results
1094 .first()
1095 .map(|variant| variant.q_matrix.clone())
1096 .unwrap_or_default();
1097
1098 Ok(DDAResult {
1099 id: Uuid::new_v4().to_string(),
1100 file_path: request.file_path.clone(),
1101 channels: dataset.channel_labels.clone(),
1102 q_matrix: primary_q,
1103 variant_results: Some(variant_results),
1104 raw_output: None,
1105 window_parameters: request.window_parameters.clone(),
1106 delay_parameters: request.delay_parameters.clone(),
1107 created_at: chrono::Utc::now().to_rfc3339(),
1108 error_values: Some(native_window_markers),
1109 })
1110 }
1111
1112 fn compute_ccd_stability_matrix(
1113 &self,
1114 request: &DDARequest,
1115 samples: &[Vec<f64>],
1116 channel_labels: Option<&[String]>,
1117 base_markers: &[f64],
1118 ccd_pairs: &[[usize; 2]],
1119 base_ccd: &[Vec<f64>],
1120 ) -> Result<Vec<Vec<f64>>> {
1121 let perturbed_requests = build_ccd_stability_requests(request);
1122 let mut runs = Vec::new();
1123 for perturbed in perturbed_requests {
1124 if let Ok(result) =
1125 self.run_on_matrix_internal(&perturbed, samples, channel_labels, None)
1126 {
1127 if let Some(variant) = result.variant_results.as_ref().and_then(|variants| {
1128 variants.iter().find(|variant| variant.variant_id == "CCD")
1129 }) {
1130 let markers = variant
1131 .error_values
1132 .clone()
1133 .or_else(|| result.error_values.clone())
1134 .unwrap_or_default();
1135 runs.push((markers, variant.q_matrix.clone()));
1136 }
1137 }
1138 }
1139
1140 let mut stability = vec![vec![f64::NAN; base_markers.len()]; ccd_pairs.len()];
1141 if runs.is_empty() {
1142 return Ok(stability);
1143 }
1144
1145 for pair_idx in 0..ccd_pairs.len() {
1146 for (window_idx, marker) in base_markers.iter().enumerate() {
1147 let reference = base_ccd
1148 .get(pair_idx)
1149 .and_then(|row| row.get(window_idx))
1150 .copied()
1151 .unwrap_or(f64::NAN);
1152 if !reference.is_finite() {
1153 continue;
1154 }
1155 let threshold = reference.abs().max(1e-9) * 0.5;
1156 let mut valid = 0usize;
1157 let mut support = 0usize;
1158 for (markers, matrix) in &runs {
1159 let aligned = nearest_aligned_value(markers, matrix, pair_idx, *marker);
1160 if let Some(value) = aligned.filter(|value| value.is_finite()) {
1161 valid += 1;
1162 if same_sign(reference, value) && value.abs() >= threshold {
1163 support += 1;
1164 }
1165 }
1166 }
1167 if valid > 0 {
1168 stability[pair_idx][window_idx] = support as f64 / valid as f64;
1169 }
1170 }
1171 }
1172
1173 Ok(stability)
1174 }
1175}
1176
1177fn extract_shifted_channel_series(prepared: &PreparedWindow, channel: usize) -> Vec<f64> {
1178 prepared
1179 .shifted
1180 .iter()
1181 .map(|row| row[channel])
1182 .collect::<Vec<_>>()
1183}
1184
1185fn prepare_window_for_analysis(
1186 dataset: &MatrixDataset<'_>,
1187 bounds: &AnalysisBounds,
1188 model: &ModelSpec,
1189 window_idx: usize,
1190 options: &PureRustOptions,
1191) -> Result<PreparedWindow> {
1192 let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
1193 let slice_start = bounds.start + window_idx * model.window_step;
1194 let slice_end = slice_start + native_window_marker;
1195 let padded_window = if slice_end <= dataset.rows {
1196 None
1197 } else {
1198 let available = dataset.samples[slice_start..dataset.rows].to_vec();
1199 let filler = available
1200 .last()
1201 .and_then(|row| row.last())
1202 .copied()
1203 .unwrap_or(f64::NAN);
1204 let mut padded = available;
1205 while padded.len() < native_window_marker {
1206 padded.push(vec![filler; dataset.cols]);
1207 }
1208 Some(padded)
1209 };
1210 let raw_window = padded_window
1211 .as_deref()
1212 .unwrap_or(&dataset.samples[slice_start..slice_end.min(dataset.rows)]);
1213 PreparedWindow::from_raw(raw_window, model, options)
1214}
1215
1216fn compute_ccd_pair_conditioning_sets(
1217 prepared_windows: Option<&[PreparedWindow]>,
1218 ccd_pairs: &[[usize; 2]],
1219 candidate_channels: &[usize],
1220 strategy: CcdConditioningStrategy,
1221 model: &ModelSpec,
1222 auto_cap: usize,
1223 svd_backend: SvdBackend,
1224) -> Vec<Vec<usize>> {
1225 match strategy {
1226 CcdConditioningStrategy::AllSelected => ccd_pairs
1227 .iter()
1228 .map(|pair| {
1229 candidate_channels
1230 .iter()
1231 .copied()
1232 .filter(|channel| *channel != pair[0] && *channel != pair[1])
1233 .collect::<Vec<_>>()
1234 })
1235 .collect(),
1236 CcdConditioningStrategy::AutoTargetSparse
1237 | CcdConditioningStrategy::AutoSharedParents
1238 | CcdConditioningStrategy::AutoGroupOmp => {
1239 let Some(windows) = prepared_windows else {
1240 return ccd_pairs
1241 .iter()
1242 .map(|pair| {
1243 candidate_channels
1244 .iter()
1245 .copied()
1246 .filter(|channel| *channel != pair[0] && *channel != pair[1])
1247 .collect::<Vec<_>>()
1248 })
1249 .collect();
1250 };
1251 ccd_pairs
1252 .iter()
1253 .map(|pair| {
1254 auto_select_conditioning_channels_for_pair(
1255 windows,
1256 pair[0],
1257 pair[1],
1258 candidate_channels,
1259 strategy,
1260 model,
1261 auto_cap,
1262 svd_backend,
1263 )
1264 })
1265 .collect()
1266 }
1267 }
1268}
1269
1270fn auto_select_conditioning_channels_for_pair(
1271 prepared_windows: &[PreparedWindow],
1272 target: usize,
1273 source: usize,
1274 candidate_channels: &[usize],
1275 strategy: CcdConditioningStrategy,
1276 model: &ModelSpec,
1277 auto_cap: usize,
1278 svd_backend: SvdBackend,
1279) -> Vec<usize> {
1280 let usable_candidates = candidate_channels
1281 .iter()
1282 .copied()
1283 .filter(|channel| *channel != target && *channel != source)
1284 .filter(|channel| channel_is_usable(prepared_windows, *channel))
1285 .collect::<Vec<_>>();
1286 if usable_candidates.is_empty() {
1287 return Vec::new();
1288 }
1289
1290 if matches!(strategy, CcdConditioningStrategy::AutoGroupOmp) {
1291 return omp_select_conditioning_subset(
1292 prepared_windows,
1293 target,
1294 &usable_candidates,
1295 model,
1296 auto_cap.max(1),
1297 svd_backend,
1298 );
1299 }
1300
1301 let target_scores = aggregate_parent_support_scores(
1302 prepared_windows,
1303 target,
1304 &usable_candidates,
1305 model,
1306 auto_cap,
1307 svd_backend,
1308 );
1309 let ranked = match strategy {
1310 CcdConditioningStrategy::AutoTargetSparse => rank_channels_by_scores(&target_scores),
1311 CcdConditioningStrategy::AutoSharedParents => {
1312 let source_scores = aggregate_parent_support_scores(
1313 prepared_windows,
1314 source,
1315 &usable_candidates,
1316 model,
1317 auto_cap,
1318 svd_backend,
1319 );
1320 let mut shared = target_scores
1321 .iter()
1322 .map(|(channel, score)| {
1323 (
1324 *channel,
1325 score.min(source_scores.get(channel).copied().unwrap_or(0.0)),
1326 )
1327 })
1328 .filter(|(_, score)| score.is_finite() && *score > 0.0)
1329 .collect::<Vec<_>>();
1330 shared.sort_by(|left, right| {
1331 right
1332 .1
1333 .partial_cmp(&left.1)
1334 .unwrap_or(std::cmp::Ordering::Equal)
1335 .then_with(|| left.0.cmp(&right.0))
1336 });
1337 shared.into_iter().map(|(channel, _)| channel).collect()
1338 }
1339 CcdConditioningStrategy::AutoGroupOmp => usable_candidates,
1340 CcdConditioningStrategy::AllSelected => usable_candidates,
1341 };
1342
1343 greedy_select_conditioning_subset(
1344 prepared_windows,
1345 target,
1346 &ranked,
1347 model,
1348 auto_cap.max(1),
1349 svd_backend,
1350 )
1351}
1352
1353fn channel_is_usable(prepared_windows: &[PreparedWindow], channel: usize) -> bool {
1354 prepared_windows.iter().any(|prepared| {
1355 prepared
1356 .shifted
1357 .iter()
1358 .any(|row| row.get(channel).copied().unwrap_or(f64::NAN).is_finite())
1359 && prepared
1360 .deriv
1361 .get(channel)
1362 .map(|values| values.iter().any(|value| value.is_finite()))
1363 .unwrap_or(false)
1364 })
1365}
1366
1367fn aggregate_parent_support_scores(
1368 prepared_windows: &[PreparedWindow],
1369 target: usize,
1370 candidate_channels: &[usize],
1371 model: &ModelSpec,
1372 auto_cap: usize,
1373 svd_backend: SvdBackend,
1374) -> std::collections::BTreeMap<usize, f64> {
1375 let mut sums = std::collections::BTreeMap::<usize, f64>::new();
1376 let mut counts = std::collections::BTreeMap::<usize, usize>::new();
1377 for prepared in prepared_windows {
1378 for (channel, improvement) in greedy_sparse_unique_improvements(
1379 prepared,
1380 target,
1381 candidate_channels,
1382 &[],
1383 &model.primary_terms,
1384 &model.secondary_terms,
1385 model.window_length,
1386 auto_cap.max(1),
1387 svd_backend,
1388 ) {
1389 if improvement.is_finite() && improvement > 0.0 {
1390 *sums.entry(channel).or_insert(0.0) += improvement;
1391 *counts.entry(channel).or_insert(0) += 1;
1392 }
1393 }
1394 }
1395 candidate_channels
1396 .iter()
1397 .copied()
1398 .map(|channel| {
1399 let score = match (sums.get(&channel), counts.get(&channel)) {
1400 (Some(sum), Some(count)) if *count > 0 => *sum / (*count as f64),
1401 _ => 0.0,
1402 };
1403 (channel, score)
1404 })
1405 .collect()
1406}
1407
1408fn rank_channels_by_scores(scores: &std::collections::BTreeMap<usize, f64>) -> Vec<usize> {
1409 let mut ranked = scores
1410 .iter()
1411 .filter(|(_, score)| score.is_finite() && **score > 0.0)
1412 .map(|(channel, score)| (*channel, *score))
1413 .collect::<Vec<_>>();
1414 ranked.sort_by(|left, right| {
1415 right
1416 .1
1417 .partial_cmp(&left.1)
1418 .unwrap_or(std::cmp::Ordering::Equal)
1419 .then_with(|| left.0.cmp(&right.0))
1420 });
1421 ranked.into_iter().map(|(channel, _)| channel).collect()
1422}
1423
1424fn greedy_select_conditioning_subset(
1425 prepared_windows: &[PreparedWindow],
1426 target: usize,
1427 ranked_candidates: &[usize],
1428 model: &ModelSpec,
1429 auto_cap: usize,
1430 svd_backend: SvdBackend,
1431) -> Vec<usize> {
1432 let mut selected = Vec::new();
1433 let mut current_score = average_conditioned_baseline_score(
1434 prepared_windows,
1435 target,
1436 &selected,
1437 model,
1438 svd_backend,
1439 );
1440 for &candidate in ranked_candidates.iter().take(auto_cap) {
1441 let mut trial = selected.clone();
1442 trial.push(candidate);
1443 let trial_score = average_conditioned_baseline_score(
1444 prepared_windows,
1445 target,
1446 &trial,
1447 model,
1448 svd_backend,
1449 );
1450 if trial_score + 1e-9 < current_score {
1451 selected = trial;
1452 current_score = trial_score;
1453 }
1454 }
1455 selected
1456}
1457
1458fn omp_select_conditioning_subset(
1459 prepared_windows: &[PreparedWindow],
1460 target: usize,
1461 candidate_channels: &[usize],
1462 model: &ModelSpec,
1463 auto_cap: usize,
1464 svd_backend: SvdBackend,
1465) -> Vec<usize> {
1466 let mut selected = Vec::<usize>::new();
1467 let mut remaining = candidate_channels.to_vec();
1468 let mut current_score = average_conditioned_baseline_score(
1469 prepared_windows,
1470 target,
1471 &selected,
1472 model,
1473 svd_backend,
1474 );
1475
1476 for _ in 0..auto_cap.min(remaining.len()) {
1477 let mut best_candidate = None::<(usize, f64, f64)>;
1478 for &candidate in &remaining {
1479 let mut trial = selected.clone();
1480 trial.push(candidate);
1481 let trial_rmse = average_conditioned_baseline_rmse(
1482 prepared_windows,
1483 target,
1484 &trial,
1485 model,
1486 svd_backend,
1487 );
1488 let trial_score = average_conditioned_baseline_score(
1489 prepared_windows,
1490 target,
1491 &trial,
1492 model,
1493 svd_backend,
1494 );
1495 let required_gain = 1e-4 * current_score.abs().max(1.0);
1496 if current_score - trial_score <= required_gain {
1497 continue;
1498 }
1499 let take = match best_candidate {
1500 None => true,
1501 Some((best_channel, best_rmse, best_score)) => {
1502 trial_rmse < best_rmse - 1e-12
1503 || ((trial_rmse - best_rmse).abs() <= 1e-12
1504 && (trial_score < best_score - 1e-12
1505 || ((trial_score - best_score).abs() <= 1e-12
1506 && candidate < best_channel)))
1507 }
1508 };
1509 if take {
1510 best_candidate = Some((candidate, trial_rmse, trial_score));
1511 }
1512 }
1513
1514 let Some((candidate, _trial_rmse, trial_score)) = best_candidate else {
1515 break;
1516 };
1517 selected.push(candidate);
1518 remaining.retain(|channel| *channel != candidate);
1519 current_score = trial_score;
1520 }
1521
1522 selected.sort_unstable();
1523 selected
1524}
1525
1526fn average_conditioned_baseline_score(
1527 prepared_windows: &[PreparedWindow],
1528 target: usize,
1529 confounds: &[usize],
1530 model: &ModelSpec,
1531 svd_backend: SvdBackend,
1532) -> f64 {
1533 let (window_scores, _) = conditioned_baseline_window_metrics(
1534 prepared_windows,
1535 target,
1536 confounds,
1537 model,
1538 svd_backend,
1539 );
1540 finite_mean(&window_scores).unwrap_or(f64::INFINITY)
1541}
1542
1543fn average_conditioned_baseline_rmse(
1544 prepared_windows: &[PreparedWindow],
1545 target: usize,
1546 confounds: &[usize],
1547 model: &ModelSpec,
1548 svd_backend: SvdBackend,
1549) -> f64 {
1550 let (_, window_rmses) = conditioned_baseline_window_metrics(
1551 prepared_windows,
1552 target,
1553 confounds,
1554 model,
1555 svd_backend,
1556 );
1557 finite_mean(&window_rmses).unwrap_or(f64::INFINITY)
1558}
1559
1560fn conditioned_baseline_window_metrics(
1561 prepared_windows: &[PreparedWindow],
1562 target: usize,
1563 confounds: &[usize],
1564 model: &ModelSpec,
1565 svd_backend: SvdBackend,
1566) -> (Vec<f64>, Vec<f64>) {
1567 let parameter_count = model.primary_terms.len() + confounds.len() * model.secondary_terms.len();
1568 let mut window_scores = Vec::with_capacity(prepared_windows.len());
1569 let mut window_rmses = Vec::with_capacity(prepared_windows.len());
1570 for prepared in prepared_windows {
1571 let block = solve_channel_with_inputs(
1572 prepared,
1573 target,
1574 confounds,
1575 &model.primary_terms,
1576 &model.secondary_terms,
1577 model.window_length,
1578 svd_backend,
1579 );
1580 let rmse = block.rmse;
1581 let score = bic_like_score(rmse, model.window_length, parameter_count);
1582 window_scores.push(score);
1583 window_rmses.push(rmse);
1584 }
1585 (window_scores, window_rmses)
1586}
1587
1588fn finite_mean(values: &[f64]) -> Option<f64> {
1589 let mut total = 0.0;
1590 let mut count = 0usize;
1591 for value in values {
1592 if value.is_finite() {
1593 total += *value;
1594 count += 1;
1595 }
1596 }
1597 (count > 0).then_some(total / (count as f64))
1598}
1599
1600fn build_target_conditioning_sets(
1601 ccd_pairs: &[[usize; 2]],
1602 pair_conditioning_sets: &[Vec<usize>],
1603) -> std::collections::BTreeMap<usize, Vec<usize>> {
1604 use std::collections::{BTreeMap, BTreeSet};
1605
1606 let mut grouped = BTreeMap::<usize, BTreeSet<usize>>::new();
1607 for (pair, confounds) in ccd_pairs.iter().zip(pair_conditioning_sets.iter()) {
1608 let entry = grouped.entry(pair[0]).or_default();
1609 for &channel in confounds {
1610 entry.insert(channel);
1611 }
1612 }
1613 grouped
1614 .into_iter()
1615 .map(|(target, channels)| (target, channels.into_iter().collect()))
1616 .collect()
1617}
1618
1619fn default_surrogate_shifts(series_len: usize) -> Vec<usize> {
1620 if series_len < 8 {
1621 return Vec::new();
1622 }
1623 let mut shifts = vec![
1624 series_len / 6,
1625 series_len / 4,
1626 series_len / 3,
1627 series_len / 2,
1628 (2 * series_len) / 3,
1629 ];
1630 shifts.retain(|shift| *shift > 0 && *shift < series_len);
1631 shifts.sort_unstable();
1632 shifts.dedup();
1633 shifts
1634}
1635
1636fn compute_mvccd_window_scores(
1637 prepared: &PreparedWindow,
1638 ccd_pairs: &[[usize; 2]],
1639 target_conditioning_sets: &std::collections::BTreeMap<usize, Vec<usize>>,
1640 model: &ModelSpec,
1641 max_active_sources: usize,
1642 svd_backend: SvdBackend,
1643) -> Vec<f64> {
1644 use std::collections::{BTreeMap, BTreeSet};
1645
1646 let mut pairs_by_target: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
1647 for (pair_idx, pair) in ccd_pairs.iter().enumerate() {
1648 pairs_by_target.entry(pair[0]).or_default().push(pair_idx);
1649 }
1650
1651 let mut values = vec![0.0; ccd_pairs.len()];
1652 for (target, pair_indices) in pairs_by_target {
1653 let candidate_sources = pair_indices
1654 .iter()
1655 .map(|pair_idx| ccd_pairs[*pair_idx][1])
1656 .collect::<BTreeSet<_>>()
1657 .into_iter()
1658 .collect::<Vec<_>>();
1659 let fixed_inputs = target_conditioning_sets
1660 .get(&target)
1661 .into_iter()
1662 .flat_map(|channels| channels.iter().copied())
1663 .filter(|channel| *channel != target && !candidate_sources.contains(channel))
1664 .collect::<Vec<_>>();
1665 let improvements = greedy_sparse_unique_improvements(
1666 prepared,
1667 target,
1668 &candidate_sources,
1669 &fixed_inputs,
1670 &model.primary_terms,
1671 &model.secondary_terms,
1672 model.window_length,
1673 max_active_sources,
1674 svd_backend,
1675 );
1676 for pair_idx in pair_indices {
1677 let source = ccd_pairs[pair_idx][1];
1678 values[pair_idx] = improvements
1679 .iter()
1680 .find(|(candidate, _)| *candidate == source)
1681 .map(|(_, value)| *value)
1682 .unwrap_or(0.0);
1683 }
1684 }
1685 values
1686}
1687
1688fn compute_trccd_matrix(
1689 prepared_windows: &[PreparedWindow],
1690 ccd_pairs: &[[usize; 2]],
1691 pair_conditioning_sets: &[Vec<usize>],
1692 model: &ModelSpec,
1693 lambda: f64,
1694 svd_backend: SvdBackend,
1695) -> Vec<Vec<f64>> {
1696 solve_channels_parallel(
1697 &ccd_pairs
1698 .iter()
1699 .zip(pair_conditioning_sets.iter())
1700 .collect::<Vec<_>>(),
1701 |(pair, confounds)| {
1702 let conditioned_inputs = {
1703 let mut inputs = (*confounds).clone();
1704 inputs.push(pair[1]);
1705 inputs
1706 };
1707 let baseline_windows = prepared_windows
1708 .iter()
1709 .map(|prepared| {
1710 build_channel_regression_window_with_inputs(
1711 prepared,
1712 pair[0],
1713 &confounds,
1714 &model.primary_terms,
1715 &model.secondary_terms,
1716 model.window_length,
1717 )
1718 })
1719 .collect::<Vec<_>>();
1720 let conditioned_windows = prepared_windows
1721 .iter()
1722 .map(|prepared| {
1723 build_channel_regression_window_with_inputs(
1724 prepared,
1725 pair[0],
1726 &conditioned_inputs,
1727 &model.primary_terms,
1728 &model.secondary_terms,
1729 model.window_length,
1730 )
1731 })
1732 .collect::<Vec<_>>();
1733 let baseline_blocks =
1734 solve_temporally_regularized_windows(&baseline_windows, lambda, svd_backend);
1735 let conditioned_blocks = solve_temporally_regularized_windows(
1736 &conditioned_windows,
1737 lambda,
1738 svd_backend,
1739 );
1740 baseline_blocks
1741 .iter()
1742 .zip(conditioned_blocks.iter())
1743 .map(|(baseline, conditioned)| {
1744 conditional_causal_improvement(baseline.rmse, conditioned.rmse)
1745 })
1746 .collect::<Vec<_>>()
1747 },
1748 )
1749}
1750
1751fn build_ccd_stability_requests(request: &DDARequest) -> Vec<DDARequest> {
1752 let mut requests = Vec::new();
1753 let mut base = request.clone();
1754 base.algorithm_selection.enabled_variants = vec!["CCD".to_string()];
1755 base.algorithm_selection.select_mask = None;
1756
1757 let base_wl = base.window_parameters.window_length.max(32);
1758 let base_ws = base.window_parameters.window_step.max(1);
1759 let mut delays = base.delay_parameters.delays.clone();
1760 if delays.is_empty() {
1761 delays = crate::types::DEFAULT_DELAYS.to_vec();
1762 }
1763
1764 let mut shorter = base.clone();
1765 shorter.window_parameters.window_length = (base_wl.saturating_mul(4) / 5).max(32);
1766 shorter.window_parameters.window_step = (base_ws.saturating_mul(4) / 5).max(1);
1767 requests.push(shorter);
1768
1769 let mut longer = base.clone();
1770 longer.window_parameters.window_length = (base_wl.saturating_mul(6) / 5).max(base_wl + 1);
1771 longer.window_parameters.window_step = (base_ws.saturating_mul(6) / 5).max(base_ws + 1);
1772 requests.push(longer);
1773
1774 if delays.iter().all(|delay| *delay > 0) {
1775 let mut lower_delays = base.clone();
1776 lower_delays.delay_parameters.delays = delays.iter().map(|delay| delay - 1).collect();
1777 requests.push(lower_delays);
1778 }
1779
1780 let mut mixed = base;
1781 mixed.window_parameters.window_length = (base_wl.saturating_mul(4) / 5).max(32);
1782 mixed.window_parameters.window_step = (base_ws.saturating_mul(4) / 5).max(1);
1783 if delays.iter().all(|delay| *delay > 0) {
1784 mixed.delay_parameters.delays = delays.iter().map(|delay| delay - 1).collect();
1785 }
1786 requests.push(mixed);
1787
1788 dedup_stability_requests(requests)
1789}
1790
1791fn dedup_stability_requests(requests: Vec<DDARequest>) -> Vec<DDARequest> {
1792 use std::collections::BTreeSet;
1793
1794 let mut seen = BTreeSet::new();
1795 let mut deduped = Vec::new();
1796 for request in requests {
1797 let key = (
1798 request.window_parameters.window_length,
1799 request.window_parameters.window_step,
1800 request.delay_parameters.delays.clone(),
1801 );
1802 if seen.insert(key) {
1803 deduped.push(request);
1804 }
1805 }
1806 deduped
1807}
1808
1809fn nearest_aligned_value(
1810 markers: &[f64],
1811 matrix: &[Vec<f64>],
1812 row_idx: usize,
1813 target_marker: f64,
1814) -> Option<f64> {
1815 let row = matrix.get(row_idx)?;
1816 let nearest_index = markers
1817 .iter()
1818 .enumerate()
1819 .min_by(|(_, left), (_, right)| {
1820 (*left - target_marker)
1821 .abs()
1822 .partial_cmp(&(*right - target_marker).abs())
1823 .unwrap_or(std::cmp::Ordering::Equal)
1824 })
1825 .map(|(index, _)| index)?;
1826 row.get(nearest_index).copied()
1827}
1828
1829fn same_sign(left: f64, right: f64) -> bool {
1830 (left > 0.0 && right > 0.0) || (left < 0.0 && right < 0.0)
1831}
1832
1833pub fn run_request_on_matrix(
1834 request: &DDARequest,
1835 samples: &[Vec<f64>],
1836 channel_labels: Option<&[String]>,
1837) -> Result<DDAResult> {
1838 PureRustRunner::default().run_on_matrix(request, samples, channel_labels)
1839}
1840
1841pub fn inspect_ccd_conditioning_sets_on_matrix(
1842 request: &DDARequest,
1843 samples: &[Vec<f64>],
1844 channel_labels: Option<&[String]>,
1845) -> Result<CcdConditioningInspection> {
1846 PureRustRunner::default().inspect_ccd_conditioning_sets_on_matrix(
1847 request,
1848 samples,
1849 channel_labels,
1850 )
1851}
1852
1853pub fn score_ccd_conditioning_subsets_on_matrix(
1854 request: &DDARequest,
1855 samples: &[Vec<f64>],
1856 channel_labels: Option<&[String]>,
1857 pair: [usize; 2],
1858 confound_sets: &[Vec<usize>],
1859) -> Result<Vec<CcdConditioningSubsetScore>> {
1860 PureRustRunner::default().score_ccd_conditioning_subsets_on_matrix(
1861 request,
1862 samples,
1863 channel_labels,
1864 pair,
1865 confound_sets,
1866 )
1867}
1868
1869pub fn profile_ccd_conditioning_subsets_on_matrix(
1870 request: &DDARequest,
1871 samples: &[Vec<f64>],
1872 channel_labels: Option<&[String]>,
1873 pair: [usize; 2],
1874 confound_sets: &[Vec<usize>],
1875) -> Result<Vec<CcdConditioningSubsetProfile>> {
1876 PureRustRunner::default().profile_ccd_conditioning_subsets_on_matrix(
1877 request,
1878 samples,
1879 channel_labels,
1880 pair,
1881 confound_sets,
1882 )
1883}
1884
1885pub fn run_request_on_matrix_with_progress<F>(
1886 request: &DDARequest,
1887 samples: &[Vec<f64>],
1888 channel_labels: Option<&[String]>,
1889 on_progress: F,
1890) -> Result<DDAResult>
1891where
1892 F: FnMut(&PureRustProgress),
1893{
1894 PureRustRunner::default().run_on_matrix_with_progress(
1895 request,
1896 samples,
1897 channel_labels,
1898 on_progress,
1899 )
1900}