Skip to main content

dda_rs/
engine.rs

1mod dataset;
2mod model;
3mod solver;
4mod variant_config;
5mod window;
6
7#[cfg(test)]
8mod tests;
9
10use crate::error::{DDAError, Result};
11use crate::types::{CcdConditioningStrategy, DDARequest, DDAResult, VariantResult};
12use dataset::{AnalysisBounds, MatrixDataset};
13use model::ModelSpec;
14use serde::{Deserialize, Serialize};
15use solver::{
16    bic_like_score, build_channel_regression_window_with_inputs, causal_improvement,
17    circular_shift_series, compute_de_value, conditional_causal_improvement,
18    empirical_significance_confidence, greedy_sparse_unique_improvements,
19    solve_channel_with_inputs, solve_channel_with_surrogate_inputs, solve_channels_parallel,
20    solve_directed_pair, solve_group_block, solve_temporally_regularized_windows,
21    synchronization_value, SolvedBlock,
22};
23use std::time::{Duration, Instant};
24use uuid::Uuid;
25use variant_config::{
26    collect_analysis_channels, flip_pairs, labels_for_channels, labels_for_groups,
27    labels_for_pairs, labels_for_sy, resolve_ccd_candidate_channels,
28    resolve_ccd_conditioning_strategy, resolve_ccd_max_active_sources, resolve_ccd_pairs,
29    resolve_ccd_surrogate_shifts, resolve_ccd_temporal_lambda, resolve_cd_pairs, resolve_ct_groups,
30    resolve_de_groups, resolve_sy_pairs, resolve_variant_selected_channels, VariantMode,
31};
32use window::PreparedWindow;
33
34pub(crate) const PARALLEL_BATCH_MIN_LEN: usize = 4;
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum NormalizationMode {
38    ZScore,
39    Raw,
40    MinMax,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44#[serde(rename_all = "snake_case")]
45pub enum SvdBackend {
46    RobustSvd,
47    NativeCompatSvd,
48}
49
50#[derive(Debug, Clone)]
51pub struct PureRustOptions {
52    pub nr_exclude: usize,
53    pub normalization_mode: NormalizationMode,
54    pub derivative_step: usize,
55    pub svd_backend: SvdBackend,
56}
57
58impl Default for PureRustOptions {
59    fn default() -> Self {
60        Self {
61            nr_exclude: 10,
62            normalization_mode: NormalizationMode::ZScore,
63            derivative_step: 1,
64            svd_backend: SvdBackend::RobustSvd,
65        }
66    }
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
70pub struct PureRustProgress {
71    pub stage_id: String,
72    pub stage_label: String,
73    pub step_index: usize,
74    pub total_steps: usize,
75    pub window_index: usize,
76    pub total_windows: usize,
77    pub item_index: usize,
78    pub total_items: usize,
79    pub item_kind: String,
80    pub item_label: String,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
84pub struct CcdConditioningInspection {
85    pub pairs: Vec<[usize; 2]>,
86    pub conditioning_sets: Vec<Vec<usize>>,
87    pub candidate_channels: Vec<usize>,
88    pub strategy: CcdConditioningStrategy,
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
92pub struct CcdConditioningSubsetScore {
93    pub pair: [usize; 2],
94    pub confounds: Vec<usize>,
95    pub bic_like_score: f64,
96    pub mean_rmse: f64,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
100pub struct CcdConditioningSubsetProfile {
101    pub pair: [usize; 2],
102    pub confounds: Vec<usize>,
103    pub bic_like_score: f64,
104    pub mean_rmse: f64,
105    pub window_bic_scores: Vec<f64>,
106    pub window_rmses: Vec<f64>,
107}
108
109#[derive(Debug, Clone)]
110pub struct PureRustRunner {
111    options: PureRustOptions,
112}
113
114impl Default for PureRustRunner {
115    fn default() -> Self {
116        Self::new(PureRustOptions::default())
117    }
118}
119
120impl PureRustRunner {
121    pub fn new(options: PureRustOptions) -> Self {
122        Self { options }
123    }
124
125    pub fn run_on_matrix(
126        &self,
127        request: &DDARequest,
128        samples: &[Vec<f64>],
129        channel_labels: Option<&[String]>,
130    ) -> Result<DDAResult> {
131        self.run_on_matrix_internal(request, samples, channel_labels, None)
132    }
133
134    pub fn run_on_matrix_with_progress<F>(
135        &self,
136        request: &DDARequest,
137        samples: &[Vec<f64>],
138        channel_labels: Option<&[String]>,
139        on_progress: F,
140    ) -> Result<DDAResult>
141    where
142        F: FnMut(&PureRustProgress),
143    {
144        let mut callback = on_progress;
145        self.run_on_matrix_internal(request, samples, channel_labels, Some(&mut callback))
146    }
147
148    pub fn inspect_ccd_conditioning_sets_on_matrix(
149        &self,
150        request: &DDARequest,
151        samples: &[Vec<f64>],
152        channel_labels: Option<&[String]>,
153    ) -> Result<CcdConditioningInspection> {
154        let dataset = MatrixDataset::new(samples, channel_labels)?;
155        let model = ModelSpec::from_request(request)?;
156        let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
157        let ccd_pairs = resolve_ccd_pairs(request, dataset.cols);
158        let strategy = resolve_ccd_conditioning_strategy(request);
159        let candidate_channels = resolve_ccd_candidate_channels(request, dataset.cols);
160        let needs_prepared_windows = !matches!(strategy, CcdConditioningStrategy::AllSelected);
161
162        let conditioning_sets = if needs_prepared_windows {
163            let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
164            let required_rows = native_window_marker.saturating_sub(1);
165            if bounds.len < required_rows {
166                return Err(DDAError::InvalidParameter(format!(
167                    "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
168                    bounds.len, required_rows
169                )));
170            }
171            if model.window_step == 0 {
172                return Err(DDAError::InvalidParameter(
173                    "window_step must be greater than zero".to_string(),
174                ));
175            }
176            let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
177            let windows = (0..num_windows)
178                .map(|window_idx| {
179                    prepare_window_for_analysis(
180                        &dataset,
181                        &bounds,
182                        &model,
183                        window_idx,
184                        &self.options,
185                    )
186                })
187                .collect::<Result<Vec<_>>>()?;
188            compute_ccd_pair_conditioning_sets(
189                Some(&windows),
190                &ccd_pairs,
191                &candidate_channels,
192                strategy,
193                &model,
194                resolve_ccd_max_active_sources(request).unwrap_or(3),
195                self.options.svd_backend,
196            )
197        } else {
198            compute_ccd_pair_conditioning_sets(
199                None,
200                &ccd_pairs,
201                &candidate_channels,
202                strategy,
203                &model,
204                resolve_ccd_max_active_sources(request).unwrap_or(3),
205                self.options.svd_backend,
206            )
207        };
208
209        Ok(CcdConditioningInspection {
210            pairs: ccd_pairs,
211            conditioning_sets,
212            candidate_channels,
213            strategy,
214        })
215    }
216
217    pub fn score_ccd_conditioning_subsets_on_matrix(
218        &self,
219        request: &DDARequest,
220        samples: &[Vec<f64>],
221        channel_labels: Option<&[String]>,
222        pair: [usize; 2],
223        confound_sets: &[Vec<usize>],
224    ) -> Result<Vec<CcdConditioningSubsetScore>> {
225        let dataset = MatrixDataset::new(samples, channel_labels)?;
226        let model = ModelSpec::from_request(request)?;
227        let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
228        let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
229        let required_rows = native_window_marker.saturating_sub(1);
230        if bounds.len < required_rows {
231            return Err(DDAError::InvalidParameter(format!(
232                "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
233                bounds.len, required_rows
234            )));
235        }
236        if model.window_step == 0 {
237            return Err(DDAError::InvalidParameter(
238                "window_step must be greater than zero".to_string(),
239            ));
240        }
241        let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
242        let windows = (0..num_windows)
243            .map(|window_idx| {
244                prepare_window_for_analysis(&dataset, &bounds, &model, window_idx, &self.options)
245            })
246            .collect::<Result<Vec<_>>>()?;
247
248        Ok(confound_sets
249            .iter()
250            .map(|confounds| CcdConditioningSubsetScore {
251                pair,
252                confounds: confounds.clone(),
253                bic_like_score: average_conditioned_baseline_score(
254                    &windows,
255                    pair[0],
256                    confounds,
257                    &model,
258                    self.options.svd_backend,
259                ),
260                mean_rmse: average_conditioned_baseline_rmse(
261                    &windows,
262                    pair[0],
263                    confounds,
264                    &model,
265                    self.options.svd_backend,
266                ),
267            })
268            .collect())
269    }
270
271    pub fn profile_ccd_conditioning_subsets_on_matrix(
272        &self,
273        request: &DDARequest,
274        samples: &[Vec<f64>],
275        channel_labels: Option<&[String]>,
276        pair: [usize; 2],
277        confound_sets: &[Vec<usize>],
278    ) -> Result<Vec<CcdConditioningSubsetProfile>> {
279        let dataset = MatrixDataset::new(samples, channel_labels)?;
280        let model = ModelSpec::from_request(request)?;
281        let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
282        let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
283        let required_rows = native_window_marker.saturating_sub(1);
284        if bounds.len < required_rows {
285            return Err(DDAError::InvalidParameter(format!(
286                "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
287                bounds.len, required_rows
288            )));
289        }
290        if model.window_step == 0 {
291            return Err(DDAError::InvalidParameter(
292                "window_step must be greater than zero".to_string(),
293            ));
294        }
295        let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
296        let windows = (0..num_windows)
297            .map(|window_idx| {
298                prepare_window_for_analysis(&dataset, &bounds, &model, window_idx, &self.options)
299            })
300            .collect::<Result<Vec<_>>>()?;
301
302        Ok(confound_sets
303            .iter()
304            .map(|confounds| {
305                let (window_bic_scores, window_rmses) = conditioned_baseline_window_metrics(
306                    &windows,
307                    pair[0],
308                    confounds,
309                    &model,
310                    self.options.svd_backend,
311                );
312                let bic_like_score = finite_mean(&window_bic_scores).unwrap_or(f64::INFINITY);
313                let mean_rmse = finite_mean(&window_rmses).unwrap_or(f64::INFINITY);
314                CcdConditioningSubsetProfile {
315                    pair,
316                    confounds: confounds.clone(),
317                    bic_like_score,
318                    mean_rmse,
319                    window_bic_scores,
320                    window_rmses,
321                }
322            })
323            .collect())
324    }
325
326    fn run_on_matrix_internal(
327        &self,
328        request: &DDARequest,
329        samples: &[Vec<f64>],
330        channel_labels: Option<&[String]>,
331        mut on_progress: Option<&mut dyn FnMut(&PureRustProgress)>,
332    ) -> Result<DDAResult> {
333        let dataset = MatrixDataset::new(samples, channel_labels)?;
334        let variant_mode = VariantMode::from_request(request);
335        let model = ModelSpec::from_request(request)?;
336        let bounds = AnalysisBounds::from_request(request, dataset.rows)?;
337        let st_channels = resolve_variant_selected_channels(
338            request,
339            dataset.cols,
340            &["ST", "st", "single_timeseries"],
341        );
342        let de_channels = resolve_variant_selected_channels(
343            request,
344            dataset.cols,
345            &["DE", "de", "dynamical_ergodicity"],
346        );
347        let sy_channels = resolve_variant_selected_channels(
348            request,
349            dataset.cols,
350            &["SY", "sy", "synchronization"],
351        );
352        let ct_groups = resolve_ct_groups(request, dataset.cols);
353        let de_groups = resolve_de_groups(request, dataset.cols, &de_channels);
354        let cd_pairs = resolve_cd_pairs(request, dataset.cols);
355        let ccd_pairs = resolve_ccd_pairs(request, dataset.cols);
356        let ccd_conditioning_strategy = resolve_ccd_conditioning_strategy(request);
357        let ccd_candidate_channels = resolve_ccd_candidate_channels(request, dataset.cols);
358        let ccd_surrogate_shifts = resolve_ccd_surrogate_shifts(request);
359        let ccd_temporal_lambda = resolve_ccd_temporal_lambda(request).unwrap_or(0.25);
360        let ccd_max_active_sources = resolve_ccd_max_active_sources(request).unwrap_or(3);
361        let sy_pairs = resolve_sy_pairs(&sy_channels);
362        let analysis_channels = collect_analysis_channels(
363            &st_channels,
364            &ct_groups,
365            &de_groups,
366            &cd_pairs,
367            &ccd_pairs,
368            &ccd_candidate_channels,
369        );
370
371        let enabled_st = variant_mode.st_enabled;
372        let enabled_ct = variant_mode.ct_enabled;
373        let enabled_cd = variant_mode.cd_enabled;
374        let enabled_ccd_core = (variant_mode.ccd_enabled
375            || variant_mode.ccdsig_enabled
376            || variant_mode.ccdstab_enabled
377            || variant_mode.trccd_enabled
378            || variant_mode.mvccd_enabled)
379            && !ccd_pairs.is_empty();
380        let enabled_ccd = variant_mode.ccd_enabled && !ccd_pairs.is_empty();
381        let enabled_ccdsig = variant_mode.ccdsig_enabled && !ccd_pairs.is_empty();
382        let enabled_ccdstab = variant_mode.ccdstab_enabled && !ccd_pairs.is_empty();
383        let enabled_trccd = variant_mode.trccd_enabled && !ccd_pairs.is_empty();
384        let enabled_mvccd = variant_mode.mvccd_enabled && !ccd_pairs.is_empty();
385        let enabled_de = variant_mode.de_enabled;
386        let enabled_sy = variant_mode.sy_mode > 0 && !sy_pairs.is_empty();
387
388        if !enabled_st
389            && !enabled_ct
390            && !enabled_cd
391            && !enabled_ccd_core
392            && !enabled_de
393            && !enabled_sy
394        {
395            return Err(DDAError::InvalidParameter(
396                "No DDA variants enabled for pure Rust engine".to_string(),
397            ));
398        }
399
400        let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
401        let required_rows = native_window_marker.saturating_sub(1);
402        if bounds.len < required_rows {
403            return Err(DDAError::InvalidParameter(format!(
404                "Selected range has {} samples but the current DDA contract needs at least {} samples (WL + 2*dm + max(TAU) - 1)",
405                bounds.len, required_rows
406            )));
407        }
408        if model.window_step == 0 {
409            return Err(DDAError::InvalidParameter(
410                "window_step must be greater than zero".to_string(),
411            ));
412        }
413
414        let num_windows = 1 + (bounds.len - required_rows) / model.window_step;
415        let needs_prepared_windows = enabled_trccd
416            || !matches!(
417                ccd_conditioning_strategy,
418                CcdConditioningStrategy::AllSelected
419            );
420        let mut prepared_windows = None;
421        let progress_enabled = on_progress.is_some();
422        let analysis_channel_labels = progress_enabled
423            .then(|| labels_for_channels(&dataset.channel_labels, &analysis_channels));
424        let ct_group_labels =
425            progress_enabled.then(|| labels_for_groups(&dataset.channel_labels, &ct_groups, " & "));
426        let de_group_labels =
427            progress_enabled.then(|| labels_for_groups(&dataset.channel_labels, &de_groups, " & "));
428        let cd_pair_labels =
429            progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &cd_pairs, " <- "));
430        let ccd_pair_labels =
431            progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &ccd_pairs, " <- "));
432        let sy_forward_labels =
433            progress_enabled.then(|| labels_for_pairs(&dataset.channel_labels, &sy_pairs, " -> "));
434        let sy_reverse_labels = progress_enabled.then(|| {
435            let sy_reverse_pairs = flip_pairs(&sy_pairs);
436            labels_for_pairs(&dataset.channel_labels, &sy_reverse_pairs, " -> ")
437        });
438        let shared_block_steps = if enabled_st || enabled_cd || enabled_de {
439            analysis_channels.len()
440        } else {
441            0
442        };
443        let steps_per_window = 1
444            + shared_block_steps
445            + if enabled_ct { ct_groups.len() } else { 0 }
446            + if enabled_de { de_groups.len() } else { 0 }
447            + if enabled_cd { cd_pairs.len() } else { 0 }
448            + if enabled_ccd_core { ccd_pairs.len() } else { 0 }
449            + if enabled_ccdsig { ccd_pairs.len() } else { 0 }
450            + if enabled_mvccd { ccd_pairs.len() } else { 0 }
451            + if enabled_sy { sy_pairs.len() * 2 } else { 0 };
452        let total_steps = num_windows * steps_per_window
453            + if enabled_trccd { ccd_pairs.len() } else { 0 }
454            + if enabled_ccdstab { ccd_pairs.len() } else { 0 };
455        let mut emitted_steps = 0usize;
456        let mut last_progress_emit = Instant::now() - Duration::from_secs(1);
457        let mut report = |stage_id: &str,
458                          stage_label: &str,
459                          window_number: usize,
460                          item_index: usize,
461                          total_items: usize,
462                          item_kind: &str,
463                          item_label: Option<&str>| {
464            emitted_steps += 1;
465            let should_emit = emitted_steps <= 1
466                || emitted_steps >= total_steps
467                || last_progress_emit.elapsed() >= Duration::from_millis(125);
468            if !should_emit {
469                return;
470            }
471            last_progress_emit = Instant::now();
472            if let Some(callback) = on_progress.as_deref_mut() {
473                callback(&PureRustProgress {
474                    stage_id: stage_id.to_string(),
475                    stage_label: stage_label.to_string(),
476                    step_index: emitted_steps,
477                    total_steps,
478                    window_index: window_number,
479                    total_windows: num_windows,
480                    item_index,
481                    total_items,
482                    item_kind: item_kind.to_string(),
483                    item_label: item_label.unwrap_or("").to_string(),
484                });
485            }
486        };
487
488        let native_window_markers: Vec<f64> = (0..num_windows)
489            .map(|window_idx| {
490                (bounds.start + window_idx * model.window_step + native_window_marker) as f64
491            })
492            .collect();
493
494        if needs_prepared_windows {
495            let mut windows = Vec::with_capacity(num_windows);
496            for window_idx in 0..num_windows {
497                report(
498                    "prepare-window",
499                    "Preparing analysis window",
500                    window_idx + 1,
501                    window_idx + 1,
502                    num_windows,
503                    "window",
504                    None,
505                );
506                windows.push(prepare_window_for_analysis(
507                    &dataset,
508                    &bounds,
509                    &model,
510                    window_idx,
511                    &self.options,
512                )?);
513            }
514            prepared_windows = Some(windows);
515        }
516
517        let ccd_pair_conditioning_sets = if enabled_ccd_core {
518            compute_ccd_pair_conditioning_sets(
519                prepared_windows.as_deref(),
520                &ccd_pairs,
521                &ccd_candidate_channels,
522                ccd_conditioning_strategy,
523                &model,
524                ccd_max_active_sources,
525                self.options.svd_backend,
526            )
527        } else {
528            Vec::new()
529        };
530        let ccd_target_conditioning_sets =
531            build_target_conditioning_sets(&ccd_pairs, &ccd_pair_conditioning_sets);
532
533        let mut st_matrix =
534            enabled_st.then(|| vec![vec![f64::NAN; num_windows]; st_channels.len()]);
535        let mut ct_matrix = enabled_ct.then(|| vec![vec![f64::NAN; num_windows]; ct_groups.len()]);
536        let mut cd_matrix = enabled_cd.then(|| vec![vec![f64::NAN; num_windows]; cd_pairs.len()]);
537        let mut ccd_matrix =
538            enabled_ccd_core.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
539        let mut ccdsig_matrix =
540            enabled_ccdsig.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
541        let mut mvccd_matrix =
542            enabled_mvccd.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
543        let mut trccd_matrix =
544            enabled_trccd.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
545        let mut ccdstab_matrix =
546            enabled_ccdstab.then(|| vec![vec![f64::NAN; num_windows]; ccd_pairs.len()]);
547        let mut de_matrix = enabled_de.then(|| vec![vec![f64::NAN; num_windows]; de_groups.len()]);
548        let mut sy_matrix = enabled_sy.then(|| {
549            let rows = if variant_mode.sy_mode == 2 {
550                sy_pairs.len() * 2
551            } else {
552                sy_pairs.len()
553            };
554            vec![vec![f64::NAN; num_windows]; rows]
555        });
556
557        for window_idx in 0..num_windows {
558            let prepared_storage;
559            let prepared = if let Some(windows) = prepared_windows.as_ref() {
560                &windows[window_idx]
561            } else {
562                report(
563                    "prepare-window",
564                    "Preparing analysis window",
565                    window_idx + 1,
566                    window_idx + 1,
567                    num_windows,
568                    "window",
569                    None,
570                );
571                prepared_storage = prepare_window_for_analysis(
572                    &dataset,
573                    &bounds,
574                    &model,
575                    window_idx,
576                    &self.options,
577                )?;
578                &prepared_storage
579            };
580
581            let mut st_blocks: Vec<Option<SolvedBlock>> = vec![None; dataset.cols];
582            if enabled_st || enabled_cd || enabled_de {
583                let computed_st_blocks = solve_channels_parallel(&analysis_channels, |&channel| {
584                    (
585                        channel,
586                        solve_group_block(
587                            &prepared,
588                            &[channel],
589                            &model.primary_terms,
590                            model.window_length,
591                            self.options.svd_backend,
592                        ),
593                    )
594                });
595                for (channel_idx, (channel, block)) in computed_st_blocks.into_iter().enumerate() {
596                    report(
597                        "st-blocks",
598                        "Solving baseline channel dynamics",
599                        window_idx + 1,
600                        channel_idx + 1,
601                        analysis_channels.len(),
602                        "channel",
603                        analysis_channel_labels
604                            .as_ref()
605                            .and_then(|labels| labels.get(channel_idx).map(String::as_str)),
606                    );
607                    if channel < st_blocks.len() {
608                        st_blocks[channel] = Some(block);
609                    }
610                }
611            }
612
613            if let Some(matrix) = st_matrix.as_mut() {
614                for (row_idx, &channel) in st_channels.iter().enumerate() {
615                    if let Some(block) = st_blocks.get(channel).and_then(Option::as_ref) {
616                        matrix[row_idx][window_idx] =
617                            block.coefficients.first().copied().unwrap_or(f64::NAN);
618                    }
619                }
620            }
621
622            let mut ct_blocks = Vec::new();
623            if enabled_ct {
624                ct_blocks = solve_channels_parallel(&ct_groups, |group| {
625                    solve_group_block(
626                        &prepared,
627                        group,
628                        &model.primary_terms,
629                        model.window_length,
630                        self.options.svd_backend,
631                    )
632                });
633                for (group_idx, _) in ct_groups.iter().enumerate() {
634                    report(
635                        "ct",
636                        "Computing cross-timeseries groups",
637                        window_idx + 1,
638                        group_idx + 1,
639                        ct_groups.len(),
640                        "group",
641                        ct_group_labels
642                            .as_ref()
643                            .and_then(|labels| labels.get(group_idx).map(String::as_str)),
644                    );
645                }
646            }
647
648            if let Some(matrix) = ct_matrix.as_mut() {
649                for (row_idx, block) in ct_blocks.iter().enumerate() {
650                    matrix[row_idx][window_idx] =
651                        block.coefficients.first().copied().unwrap_or(f64::NAN);
652                }
653            }
654
655            let mut de_blocks = Vec::new();
656            if enabled_de {
657                de_blocks = solve_channels_parallel(&de_groups, |group| {
658                    solve_group_block(
659                        &prepared,
660                        group,
661                        &model.primary_terms,
662                        model.window_length,
663                        self.options.svd_backend,
664                    )
665                });
666                for (group_idx, _) in de_groups.iter().enumerate() {
667                    report(
668                        "de",
669                        "Computing dynamical ergodicity groups",
670                        window_idx + 1,
671                        group_idx + 1,
672                        de_groups.len(),
673                        "group",
674                        de_group_labels
675                            .as_ref()
676                            .and_then(|labels| labels.get(group_idx).map(String::as_str)),
677                    );
678                }
679            }
680
681            if let Some(matrix) = de_matrix.as_mut() {
682                for (row_idx, group) in de_groups.iter().enumerate() {
683                    let ct_rmse = de_blocks
684                        .get(row_idx)
685                        .map(|block| block.rmse)
686                        .unwrap_or(f64::NAN);
687                    let de_value = compute_de_value(group, &st_blocks, ct_rmse);
688                    matrix[row_idx][window_idx] = de_value;
689                }
690            }
691
692            if enabled_cd {
693                let cd_values = solve_channels_parallel(&cd_pairs, |pair| {
694                    let forward = solve_directed_pair(
695                        &prepared,
696                        pair[0],
697                        pair[1],
698                        pair[0],
699                        &model.primary_terms,
700                        &model.secondary_terms,
701                        model.window_length,
702                        self.options.svd_backend,
703                    );
704                    let baseline = st_blocks
705                        .get(pair[0])
706                        .and_then(Option::as_ref)
707                        .map(|block| block.rmse)
708                        .unwrap_or(f64::NAN);
709                    causal_improvement(baseline, forward.rmse)
710                });
711                for (pair_idx, _) in cd_pairs.iter().enumerate() {
712                    report(
713                        "cd",
714                        "Computing directed causal pairs",
715                        window_idx + 1,
716                        pair_idx + 1,
717                        cd_pairs.len(),
718                        "pair",
719                        cd_pair_labels
720                            .as_ref()
721                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
722                    );
723                }
724                if let Some(matrix) = cd_matrix.as_mut() {
725                    for (pair_idx, value) in cd_values.into_iter().enumerate() {
726                        matrix[pair_idx][window_idx] = value;
727                    }
728                }
729            }
730
731            if enabled_ccd_core {
732                let ccd_values = solve_channels_parallel(
733                    &ccd_pairs
734                        .iter()
735                        .zip(ccd_pair_conditioning_sets.iter())
736                        .collect::<Vec<_>>(),
737                    |(pair, confounds)| {
738                        let baseline = solve_channel_with_inputs(
739                            prepared,
740                            pair[0],
741                            confounds,
742                            &model.primary_terms,
743                            &model.secondary_terms,
744                            model.window_length,
745                            self.options.svd_backend,
746                        );
747                        let mut conditioned_inputs = (*confounds).clone();
748                        conditioned_inputs.push(pair[1]);
749                        let conditioned = solve_channel_with_inputs(
750                            prepared,
751                            pair[0],
752                            &conditioned_inputs,
753                            &model.primary_terms,
754                            &model.secondary_terms,
755                            model.window_length,
756                            self.options.svd_backend,
757                        );
758                        let observed =
759                            conditional_causal_improvement(baseline.rmse, conditioned.rmse);
760
761                        let significance = if enabled_ccdsig {
762                            let surrogate_shifts =
763                                ccd_surrogate_shifts.clone().unwrap_or_else(|| {
764                                    default_surrogate_shifts(prepared.shifted.len())
765                                });
766                            let surrogate_inputs = confounds
767                                .iter()
768                                .map(|channel| extract_shifted_channel_series(prepared, *channel))
769                                .collect::<Vec<_>>();
770                            let source_series = extract_shifted_channel_series(prepared, pair[1]);
771                            let null_scores = surrogate_shifts
772                                .into_iter()
773                                .filter(|shift| *shift > 0)
774                                .map(|shift| {
775                                    let shifted_source =
776                                        circular_shift_series(&source_series, shift);
777                                    let mut conditioned_surrogates = surrogate_inputs.clone();
778                                    conditioned_surrogates.push(shifted_source);
779                                    let surrogate_block = solve_channel_with_surrogate_inputs(
780                                        prepared,
781                                        pair[0],
782                                        &conditioned_surrogates,
783                                        &model.primary_terms,
784                                        &model.secondary_terms,
785                                        model.window_length,
786                                        self.options.svd_backend,
787                                    );
788                                    conditional_causal_improvement(
789                                        baseline.rmse,
790                                        surrogate_block.rmse,
791                                    )
792                                })
793                                .collect::<Vec<_>>();
794                            empirical_significance_confidence(observed, &null_scores)
795                        } else {
796                            f64::NAN
797                        };
798
799                        (observed, significance)
800                    },
801                );
802                for (pair_idx, _) in ccd_pairs.iter().enumerate() {
803                    report(
804                        "ccd",
805                        "Computing conditional directed causal pairs",
806                        window_idx + 1,
807                        pair_idx + 1,
808                        ccd_pairs.len(),
809                        "pair",
810                        ccd_pair_labels
811                            .as_ref()
812                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
813                    );
814                }
815                if let Some(matrix) = ccd_matrix.as_mut() {
816                    for (pair_idx, value) in ccd_values.iter().enumerate() {
817                        matrix[pair_idx][window_idx] = value.0;
818                    }
819                }
820                if let Some(matrix) = ccdsig_matrix.as_mut() {
821                    for (pair_idx, value) in ccd_values.into_iter().enumerate() {
822                        matrix[pair_idx][window_idx] = value.1;
823                    }
824                }
825            }
826
827            if enabled_mvccd {
828                let mvccd_values = compute_mvccd_window_scores(
829                    prepared,
830                    &ccd_pairs,
831                    &ccd_target_conditioning_sets,
832                    &model,
833                    ccd_max_active_sources,
834                    self.options.svd_backend,
835                );
836                for (pair_idx, _) in ccd_pairs.iter().enumerate() {
837                    report(
838                        "mvccd",
839                        "Computing sparse multivariate conditional pairs",
840                        window_idx + 1,
841                        pair_idx + 1,
842                        ccd_pairs.len(),
843                        "pair",
844                        ccd_pair_labels
845                            .as_ref()
846                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
847                    );
848                }
849                if let Some(matrix) = mvccd_matrix.as_mut() {
850                    for (pair_idx, value) in mvccd_values.into_iter().enumerate() {
851                        matrix[pair_idx][window_idx] = value;
852                    }
853                }
854            }
855
856            if let Some(matrix) = sy_matrix.as_mut() {
857                let sy_values = solve_channels_parallel(&sy_pairs, |pair| {
858                    let forward = solve_directed_pair(
859                        &prepared,
860                        pair[0],
861                        pair[1],
862                        pair[1],
863                        &model.primary_terms,
864                        &model.secondary_terms,
865                        model.window_length,
866                        self.options.svd_backend,
867                    );
868                    let reverse = solve_directed_pair(
869                        &prepared,
870                        pair[1],
871                        pair[0],
872                        pair[0],
873                        &model.primary_terms,
874                        &model.secondary_terms,
875                        model.window_length,
876                        self.options.svd_backend,
877                    );
878                    (forward.rmse, reverse.rmse)
879                });
880                for (pair_idx, _) in sy_pairs.iter().enumerate() {
881                    report(
882                        "sy",
883                        "Computing synchronization directions",
884                        window_idx + 1,
885                        pair_idx * 2 + 1,
886                        sy_pairs.len() * 2,
887                        "direction",
888                        sy_forward_labels
889                            .as_ref()
890                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
891                    );
892                    report(
893                        "sy",
894                        "Computing synchronization directions",
895                        window_idx + 1,
896                        pair_idx * 2 + 2,
897                        sy_pairs.len() * 2,
898                        "direction",
899                        sy_reverse_labels
900                            .as_ref()
901                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
902                    );
903                }
904                for (pair_idx, (forward_rmse, reverse_rmse)) in sy_values.into_iter().enumerate() {
905                    if variant_mode.sy_mode == 2 {
906                        let row_base = pair_idx * 2;
907                        matrix[row_base][window_idx] = forward_rmse;
908                        matrix[row_base + 1][window_idx] = reverse_rmse;
909                    } else {
910                        matrix[pair_idx][window_idx] =
911                            synchronization_value(1, forward_rmse, reverse_rmse);
912                    }
913                }
914            }
915        }
916
917        if enabled_trccd {
918            if let Some(matrix) = trccd_matrix.as_mut() {
919                let windows = prepared_windows.as_deref().unwrap_or(&[]);
920                let regularized = compute_trccd_matrix(
921                    windows,
922                    &ccd_pairs,
923                    &ccd_pair_conditioning_sets,
924                    &model,
925                    ccd_temporal_lambda,
926                    self.options.svd_backend,
927                );
928                for (pair_idx, row) in regularized.into_iter().enumerate() {
929                    report(
930                        "trccd",
931                        "Computing temporally regularized conditional pairs",
932                        num_windows,
933                        pair_idx + 1,
934                        ccd_pairs.len(),
935                        "pair",
936                        ccd_pair_labels
937                            .as_ref()
938                            .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
939                    );
940                    matrix[pair_idx] = row;
941                }
942            }
943        }
944
945        if enabled_ccdstab {
946            if let Some(base_ccd) = ccd_matrix.as_ref() {
947                let stability = self.compute_ccd_stability_matrix(
948                    request,
949                    samples,
950                    channel_labels,
951                    &native_window_markers,
952                    &ccd_pairs,
953                    base_ccd,
954                )?;
955                if let Some(matrix) = ccdstab_matrix.as_mut() {
956                    for (pair_idx, row) in stability.into_iter().enumerate() {
957                        report(
958                            "ccdstab",
959                            "Computing conditional-pair stability",
960                            num_windows,
961                            pair_idx + 1,
962                            ccd_pairs.len(),
963                            "pair",
964                            ccd_pair_labels
965                                .as_ref()
966                                .and_then(|labels| labels.get(pair_idx).map(String::as_str)),
967                        );
968                        matrix[pair_idx] = row;
969                    }
970                }
971            }
972        }
973
974        let mut variant_results = Vec::new();
975        if let Some(q_matrix) = st_matrix {
976            variant_results.push(VariantResult {
977                variant_id: "ST".to_string(),
978                variant_name: "Single Timeseries (ST)".to_string(),
979                q_matrix,
980                channel_labels: Some(labels_for_channels(&dataset.channel_labels, &st_channels)),
981                error_values: Some(native_window_markers.clone()),
982            });
983        }
984        if let Some(q_matrix) = ct_matrix {
985            variant_results.push(VariantResult {
986                variant_id: "CT".to_string(),
987                variant_name: "Cross-Timeseries (CT)".to_string(),
988                q_matrix,
989                channel_labels: Some(labels_for_groups(&dataset.channel_labels, &ct_groups, "&")),
990                error_values: Some(native_window_markers.clone()),
991            });
992        }
993        if let Some(q_matrix) = cd_matrix {
994            variant_results.push(VariantResult {
995                variant_id: "CD".to_string(),
996                variant_name: "Cross-Dynamical (CD)".to_string(),
997                q_matrix,
998                channel_labels: Some(labels_for_pairs(&dataset.channel_labels, &cd_pairs, " <- ")),
999                error_values: Some(native_window_markers.clone()),
1000            });
1001        }
1002        if enabled_ccd {
1003            if let Some(q_matrix) = ccd_matrix.clone() {
1004                variant_results.push(VariantResult {
1005                    variant_id: "CCD".to_string(),
1006                    variant_name: "Conditional Cross-Dynamical (CCD)".to_string(),
1007                    q_matrix,
1008                    channel_labels: Some(labels_for_pairs(
1009                        &dataset.channel_labels,
1010                        &ccd_pairs,
1011                        " <- ",
1012                    )),
1013                    error_values: Some(native_window_markers.clone()),
1014                });
1015            }
1016        }
1017        if let Some(q_matrix) = ccdsig_matrix {
1018            variant_results.push(VariantResult {
1019                variant_id: "CCDSIG".to_string(),
1020                variant_name: "Conditional Cross-Dynamical Significance (CCDSIG)".to_string(),
1021                q_matrix,
1022                channel_labels: Some(labels_for_pairs(
1023                    &dataset.channel_labels,
1024                    &ccd_pairs,
1025                    " <- ",
1026                )),
1027                error_values: Some(native_window_markers.clone()),
1028            });
1029        }
1030        if let Some(q_matrix) = ccdstab_matrix {
1031            variant_results.push(VariantResult {
1032                variant_id: "CCDSTAB".to_string(),
1033                variant_name: "Conditional Cross-Dynamical Stability (CCDSTAB)".to_string(),
1034                q_matrix,
1035                channel_labels: Some(labels_for_pairs(
1036                    &dataset.channel_labels,
1037                    &ccd_pairs,
1038                    " <- ",
1039                )),
1040                error_values: Some(native_window_markers.clone()),
1041            });
1042        }
1043        if let Some(q_matrix) = trccd_matrix {
1044            variant_results.push(VariantResult {
1045                variant_id: "TRCCD".to_string(),
1046                variant_name: "Temporally Regularized Conditional Cross-Dynamical (TRCCD)"
1047                    .to_string(),
1048                q_matrix,
1049                channel_labels: Some(labels_for_pairs(
1050                    &dataset.channel_labels,
1051                    &ccd_pairs,
1052                    " <- ",
1053                )),
1054                error_values: Some(native_window_markers.clone()),
1055            });
1056        }
1057        if let Some(q_matrix) = mvccd_matrix {
1058            variant_results.push(VariantResult {
1059                variant_id: "MVCCD".to_string(),
1060                variant_name: "Sparse Multivariate Conditional Cross-Dynamical (MVCCD)".to_string(),
1061                q_matrix,
1062                channel_labels: Some(labels_for_pairs(
1063                    &dataset.channel_labels,
1064                    &ccd_pairs,
1065                    " <- ",
1066                )),
1067                error_values: Some(native_window_markers.clone()),
1068            });
1069        }
1070        if let Some(q_matrix) = de_matrix {
1071            variant_results.push(VariantResult {
1072                variant_id: "DE".to_string(),
1073                variant_name: "Dynamical Ergodicity (DE)".to_string(),
1074                q_matrix,
1075                channel_labels: Some(labels_for_groups(&dataset.channel_labels, &de_groups, "&")),
1076                error_values: Some(native_window_markers.clone()),
1077            });
1078        }
1079        if let Some(q_matrix) = sy_matrix {
1080            variant_results.push(VariantResult {
1081                variant_id: "SY".to_string(),
1082                variant_name: "Synchronization (SY)".to_string(),
1083                q_matrix,
1084                channel_labels: Some(labels_for_sy(
1085                    &dataset.channel_labels,
1086                    &sy_pairs,
1087                    variant_mode.sy_mode,
1088                )),
1089                error_values: Some(native_window_markers.clone()),
1090            });
1091        }
1092
1093        let primary_q = variant_results
1094            .first()
1095            .map(|variant| variant.q_matrix.clone())
1096            .unwrap_or_default();
1097
1098        Ok(DDAResult {
1099            id: Uuid::new_v4().to_string(),
1100            file_path: request.file_path.clone(),
1101            channels: dataset.channel_labels.clone(),
1102            q_matrix: primary_q,
1103            variant_results: Some(variant_results),
1104            raw_output: None,
1105            window_parameters: request.window_parameters.clone(),
1106            delay_parameters: request.delay_parameters.clone(),
1107            created_at: chrono::Utc::now().to_rfc3339(),
1108            error_values: Some(native_window_markers),
1109        })
1110    }
1111
1112    fn compute_ccd_stability_matrix(
1113        &self,
1114        request: &DDARequest,
1115        samples: &[Vec<f64>],
1116        channel_labels: Option<&[String]>,
1117        base_markers: &[f64],
1118        ccd_pairs: &[[usize; 2]],
1119        base_ccd: &[Vec<f64>],
1120    ) -> Result<Vec<Vec<f64>>> {
1121        let perturbed_requests = build_ccd_stability_requests(request);
1122        let mut runs = Vec::new();
1123        for perturbed in perturbed_requests {
1124            if let Ok(result) =
1125                self.run_on_matrix_internal(&perturbed, samples, channel_labels, None)
1126            {
1127                if let Some(variant) = result.variant_results.as_ref().and_then(|variants| {
1128                    variants.iter().find(|variant| variant.variant_id == "CCD")
1129                }) {
1130                    let markers = variant
1131                        .error_values
1132                        .clone()
1133                        .or_else(|| result.error_values.clone())
1134                        .unwrap_or_default();
1135                    runs.push((markers, variant.q_matrix.clone()));
1136                }
1137            }
1138        }
1139
1140        let mut stability = vec![vec![f64::NAN; base_markers.len()]; ccd_pairs.len()];
1141        if runs.is_empty() {
1142            return Ok(stability);
1143        }
1144
1145        for pair_idx in 0..ccd_pairs.len() {
1146            for (window_idx, marker) in base_markers.iter().enumerate() {
1147                let reference = base_ccd
1148                    .get(pair_idx)
1149                    .and_then(|row| row.get(window_idx))
1150                    .copied()
1151                    .unwrap_or(f64::NAN);
1152                if !reference.is_finite() {
1153                    continue;
1154                }
1155                let threshold = reference.abs().max(1e-9) * 0.5;
1156                let mut valid = 0usize;
1157                let mut support = 0usize;
1158                for (markers, matrix) in &runs {
1159                    let aligned = nearest_aligned_value(markers, matrix, pair_idx, *marker);
1160                    if let Some(value) = aligned.filter(|value| value.is_finite()) {
1161                        valid += 1;
1162                        if same_sign(reference, value) && value.abs() >= threshold {
1163                            support += 1;
1164                        }
1165                    }
1166                }
1167                if valid > 0 {
1168                    stability[pair_idx][window_idx] = support as f64 / valid as f64;
1169                }
1170            }
1171        }
1172
1173        Ok(stability)
1174    }
1175}
1176
1177fn extract_shifted_channel_series(prepared: &PreparedWindow, channel: usize) -> Vec<f64> {
1178    prepared
1179        .shifted
1180        .iter()
1181        .map(|row| row[channel])
1182        .collect::<Vec<_>>()
1183}
1184
1185fn prepare_window_for_analysis(
1186    dataset: &MatrixDataset<'_>,
1187    bounds: &AnalysisBounds,
1188    model: &ModelSpec,
1189    window_idx: usize,
1190    options: &PureRustOptions,
1191) -> Result<PreparedWindow> {
1192    let native_window_marker = model.window_length + model.max_delay + 2 * model.dm;
1193    let slice_start = bounds.start + window_idx * model.window_step;
1194    let slice_end = slice_start + native_window_marker;
1195    let padded_window = if slice_end <= dataset.rows {
1196        None
1197    } else {
1198        let available = dataset.samples[slice_start..dataset.rows].to_vec();
1199        let filler = available
1200            .last()
1201            .and_then(|row| row.last())
1202            .copied()
1203            .unwrap_or(f64::NAN);
1204        let mut padded = available;
1205        while padded.len() < native_window_marker {
1206            padded.push(vec![filler; dataset.cols]);
1207        }
1208        Some(padded)
1209    };
1210    let raw_window = padded_window
1211        .as_deref()
1212        .unwrap_or(&dataset.samples[slice_start..slice_end.min(dataset.rows)]);
1213    PreparedWindow::from_raw(raw_window, model, options)
1214}
1215
1216fn compute_ccd_pair_conditioning_sets(
1217    prepared_windows: Option<&[PreparedWindow]>,
1218    ccd_pairs: &[[usize; 2]],
1219    candidate_channels: &[usize],
1220    strategy: CcdConditioningStrategy,
1221    model: &ModelSpec,
1222    auto_cap: usize,
1223    svd_backend: SvdBackend,
1224) -> Vec<Vec<usize>> {
1225    match strategy {
1226        CcdConditioningStrategy::AllSelected => ccd_pairs
1227            .iter()
1228            .map(|pair| {
1229                candidate_channels
1230                    .iter()
1231                    .copied()
1232                    .filter(|channel| *channel != pair[0] && *channel != pair[1])
1233                    .collect::<Vec<_>>()
1234            })
1235            .collect(),
1236        CcdConditioningStrategy::AutoTargetSparse
1237        | CcdConditioningStrategy::AutoSharedParents
1238        | CcdConditioningStrategy::AutoGroupOmp => {
1239            let Some(windows) = prepared_windows else {
1240                return ccd_pairs
1241                    .iter()
1242                    .map(|pair| {
1243                        candidate_channels
1244                            .iter()
1245                            .copied()
1246                            .filter(|channel| *channel != pair[0] && *channel != pair[1])
1247                            .collect::<Vec<_>>()
1248                    })
1249                    .collect();
1250            };
1251            ccd_pairs
1252                .iter()
1253                .map(|pair| {
1254                    auto_select_conditioning_channels_for_pair(
1255                        windows,
1256                        pair[0],
1257                        pair[1],
1258                        candidate_channels,
1259                        strategy,
1260                        model,
1261                        auto_cap,
1262                        svd_backend,
1263                    )
1264                })
1265                .collect()
1266        }
1267    }
1268}
1269
1270fn auto_select_conditioning_channels_for_pair(
1271    prepared_windows: &[PreparedWindow],
1272    target: usize,
1273    source: usize,
1274    candidate_channels: &[usize],
1275    strategy: CcdConditioningStrategy,
1276    model: &ModelSpec,
1277    auto_cap: usize,
1278    svd_backend: SvdBackend,
1279) -> Vec<usize> {
1280    let usable_candidates = candidate_channels
1281        .iter()
1282        .copied()
1283        .filter(|channel| *channel != target && *channel != source)
1284        .filter(|channel| channel_is_usable(prepared_windows, *channel))
1285        .collect::<Vec<_>>();
1286    if usable_candidates.is_empty() {
1287        return Vec::new();
1288    }
1289
1290    if matches!(strategy, CcdConditioningStrategy::AutoGroupOmp) {
1291        return omp_select_conditioning_subset(
1292            prepared_windows,
1293            target,
1294            &usable_candidates,
1295            model,
1296            auto_cap.max(1),
1297            svd_backend,
1298        );
1299    }
1300
1301    let target_scores = aggregate_parent_support_scores(
1302        prepared_windows,
1303        target,
1304        &usable_candidates,
1305        model,
1306        auto_cap,
1307        svd_backend,
1308    );
1309    let ranked = match strategy {
1310        CcdConditioningStrategy::AutoTargetSparse => rank_channels_by_scores(&target_scores),
1311        CcdConditioningStrategy::AutoSharedParents => {
1312            let source_scores = aggregate_parent_support_scores(
1313                prepared_windows,
1314                source,
1315                &usable_candidates,
1316                model,
1317                auto_cap,
1318                svd_backend,
1319            );
1320            let mut shared = target_scores
1321                .iter()
1322                .map(|(channel, score)| {
1323                    (
1324                        *channel,
1325                        score.min(source_scores.get(channel).copied().unwrap_or(0.0)),
1326                    )
1327                })
1328                .filter(|(_, score)| score.is_finite() && *score > 0.0)
1329                .collect::<Vec<_>>();
1330            shared.sort_by(|left, right| {
1331                right
1332                    .1
1333                    .partial_cmp(&left.1)
1334                    .unwrap_or(std::cmp::Ordering::Equal)
1335                    .then_with(|| left.0.cmp(&right.0))
1336            });
1337            shared.into_iter().map(|(channel, _)| channel).collect()
1338        }
1339        CcdConditioningStrategy::AutoGroupOmp => usable_candidates,
1340        CcdConditioningStrategy::AllSelected => usable_candidates,
1341    };
1342
1343    greedy_select_conditioning_subset(
1344        prepared_windows,
1345        target,
1346        &ranked,
1347        model,
1348        auto_cap.max(1),
1349        svd_backend,
1350    )
1351}
1352
1353fn channel_is_usable(prepared_windows: &[PreparedWindow], channel: usize) -> bool {
1354    prepared_windows.iter().any(|prepared| {
1355        prepared
1356            .shifted
1357            .iter()
1358            .any(|row| row.get(channel).copied().unwrap_or(f64::NAN).is_finite())
1359            && prepared
1360                .deriv
1361                .get(channel)
1362                .map(|values| values.iter().any(|value| value.is_finite()))
1363                .unwrap_or(false)
1364    })
1365}
1366
1367fn aggregate_parent_support_scores(
1368    prepared_windows: &[PreparedWindow],
1369    target: usize,
1370    candidate_channels: &[usize],
1371    model: &ModelSpec,
1372    auto_cap: usize,
1373    svd_backend: SvdBackend,
1374) -> std::collections::BTreeMap<usize, f64> {
1375    let mut sums = std::collections::BTreeMap::<usize, f64>::new();
1376    let mut counts = std::collections::BTreeMap::<usize, usize>::new();
1377    for prepared in prepared_windows {
1378        for (channel, improvement) in greedy_sparse_unique_improvements(
1379            prepared,
1380            target,
1381            candidate_channels,
1382            &[],
1383            &model.primary_terms,
1384            &model.secondary_terms,
1385            model.window_length,
1386            auto_cap.max(1),
1387            svd_backend,
1388        ) {
1389            if improvement.is_finite() && improvement > 0.0 {
1390                *sums.entry(channel).or_insert(0.0) += improvement;
1391                *counts.entry(channel).or_insert(0) += 1;
1392            }
1393        }
1394    }
1395    candidate_channels
1396        .iter()
1397        .copied()
1398        .map(|channel| {
1399            let score = match (sums.get(&channel), counts.get(&channel)) {
1400                (Some(sum), Some(count)) if *count > 0 => *sum / (*count as f64),
1401                _ => 0.0,
1402            };
1403            (channel, score)
1404        })
1405        .collect()
1406}
1407
1408fn rank_channels_by_scores(scores: &std::collections::BTreeMap<usize, f64>) -> Vec<usize> {
1409    let mut ranked = scores
1410        .iter()
1411        .filter(|(_, score)| score.is_finite() && **score > 0.0)
1412        .map(|(channel, score)| (*channel, *score))
1413        .collect::<Vec<_>>();
1414    ranked.sort_by(|left, right| {
1415        right
1416            .1
1417            .partial_cmp(&left.1)
1418            .unwrap_or(std::cmp::Ordering::Equal)
1419            .then_with(|| left.0.cmp(&right.0))
1420    });
1421    ranked.into_iter().map(|(channel, _)| channel).collect()
1422}
1423
1424fn greedy_select_conditioning_subset(
1425    prepared_windows: &[PreparedWindow],
1426    target: usize,
1427    ranked_candidates: &[usize],
1428    model: &ModelSpec,
1429    auto_cap: usize,
1430    svd_backend: SvdBackend,
1431) -> Vec<usize> {
1432    let mut selected = Vec::new();
1433    let mut current_score = average_conditioned_baseline_score(
1434        prepared_windows,
1435        target,
1436        &selected,
1437        model,
1438        svd_backend,
1439    );
1440    for &candidate in ranked_candidates.iter().take(auto_cap) {
1441        let mut trial = selected.clone();
1442        trial.push(candidate);
1443        let trial_score = average_conditioned_baseline_score(
1444            prepared_windows,
1445            target,
1446            &trial,
1447            model,
1448            svd_backend,
1449        );
1450        if trial_score + 1e-9 < current_score {
1451            selected = trial;
1452            current_score = trial_score;
1453        }
1454    }
1455    selected
1456}
1457
1458fn omp_select_conditioning_subset(
1459    prepared_windows: &[PreparedWindow],
1460    target: usize,
1461    candidate_channels: &[usize],
1462    model: &ModelSpec,
1463    auto_cap: usize,
1464    svd_backend: SvdBackend,
1465) -> Vec<usize> {
1466    let mut selected = Vec::<usize>::new();
1467    let mut remaining = candidate_channels.to_vec();
1468    let mut current_score = average_conditioned_baseline_score(
1469        prepared_windows,
1470        target,
1471        &selected,
1472        model,
1473        svd_backend,
1474    );
1475
1476    for _ in 0..auto_cap.min(remaining.len()) {
1477        let mut best_candidate = None::<(usize, f64, f64)>;
1478        for &candidate in &remaining {
1479            let mut trial = selected.clone();
1480            trial.push(candidate);
1481            let trial_rmse = average_conditioned_baseline_rmse(
1482                prepared_windows,
1483                target,
1484                &trial,
1485                model,
1486                svd_backend,
1487            );
1488            let trial_score = average_conditioned_baseline_score(
1489                prepared_windows,
1490                target,
1491                &trial,
1492                model,
1493                svd_backend,
1494            );
1495            let required_gain = 1e-4 * current_score.abs().max(1.0);
1496            if current_score - trial_score <= required_gain {
1497                continue;
1498            }
1499            let take = match best_candidate {
1500                None => true,
1501                Some((best_channel, best_rmse, best_score)) => {
1502                    trial_rmse < best_rmse - 1e-12
1503                        || ((trial_rmse - best_rmse).abs() <= 1e-12
1504                            && (trial_score < best_score - 1e-12
1505                                || ((trial_score - best_score).abs() <= 1e-12
1506                                    && candidate < best_channel)))
1507                }
1508            };
1509            if take {
1510                best_candidate = Some((candidate, trial_rmse, trial_score));
1511            }
1512        }
1513
1514        let Some((candidate, _trial_rmse, trial_score)) = best_candidate else {
1515            break;
1516        };
1517        selected.push(candidate);
1518        remaining.retain(|channel| *channel != candidate);
1519        current_score = trial_score;
1520    }
1521
1522    selected.sort_unstable();
1523    selected
1524}
1525
1526fn average_conditioned_baseline_score(
1527    prepared_windows: &[PreparedWindow],
1528    target: usize,
1529    confounds: &[usize],
1530    model: &ModelSpec,
1531    svd_backend: SvdBackend,
1532) -> f64 {
1533    let (window_scores, _) = conditioned_baseline_window_metrics(
1534        prepared_windows,
1535        target,
1536        confounds,
1537        model,
1538        svd_backend,
1539    );
1540    finite_mean(&window_scores).unwrap_or(f64::INFINITY)
1541}
1542
1543fn average_conditioned_baseline_rmse(
1544    prepared_windows: &[PreparedWindow],
1545    target: usize,
1546    confounds: &[usize],
1547    model: &ModelSpec,
1548    svd_backend: SvdBackend,
1549) -> f64 {
1550    let (_, window_rmses) = conditioned_baseline_window_metrics(
1551        prepared_windows,
1552        target,
1553        confounds,
1554        model,
1555        svd_backend,
1556    );
1557    finite_mean(&window_rmses).unwrap_or(f64::INFINITY)
1558}
1559
1560fn conditioned_baseline_window_metrics(
1561    prepared_windows: &[PreparedWindow],
1562    target: usize,
1563    confounds: &[usize],
1564    model: &ModelSpec,
1565    svd_backend: SvdBackend,
1566) -> (Vec<f64>, Vec<f64>) {
1567    let parameter_count = model.primary_terms.len() + confounds.len() * model.secondary_terms.len();
1568    let mut window_scores = Vec::with_capacity(prepared_windows.len());
1569    let mut window_rmses = Vec::with_capacity(prepared_windows.len());
1570    for prepared in prepared_windows {
1571        let block = solve_channel_with_inputs(
1572            prepared,
1573            target,
1574            confounds,
1575            &model.primary_terms,
1576            &model.secondary_terms,
1577            model.window_length,
1578            svd_backend,
1579        );
1580        let rmse = block.rmse;
1581        let score = bic_like_score(rmse, model.window_length, parameter_count);
1582        window_scores.push(score);
1583        window_rmses.push(rmse);
1584    }
1585    (window_scores, window_rmses)
1586}
1587
1588fn finite_mean(values: &[f64]) -> Option<f64> {
1589    let mut total = 0.0;
1590    let mut count = 0usize;
1591    for value in values {
1592        if value.is_finite() {
1593            total += *value;
1594            count += 1;
1595        }
1596    }
1597    (count > 0).then_some(total / (count as f64))
1598}
1599
1600fn build_target_conditioning_sets(
1601    ccd_pairs: &[[usize; 2]],
1602    pair_conditioning_sets: &[Vec<usize>],
1603) -> std::collections::BTreeMap<usize, Vec<usize>> {
1604    use std::collections::{BTreeMap, BTreeSet};
1605
1606    let mut grouped = BTreeMap::<usize, BTreeSet<usize>>::new();
1607    for (pair, confounds) in ccd_pairs.iter().zip(pair_conditioning_sets.iter()) {
1608        let entry = grouped.entry(pair[0]).or_default();
1609        for &channel in confounds {
1610            entry.insert(channel);
1611        }
1612    }
1613    grouped
1614        .into_iter()
1615        .map(|(target, channels)| (target, channels.into_iter().collect()))
1616        .collect()
1617}
1618
1619fn default_surrogate_shifts(series_len: usize) -> Vec<usize> {
1620    if series_len < 8 {
1621        return Vec::new();
1622    }
1623    let mut shifts = vec![
1624        series_len / 6,
1625        series_len / 4,
1626        series_len / 3,
1627        series_len / 2,
1628        (2 * series_len) / 3,
1629    ];
1630    shifts.retain(|shift| *shift > 0 && *shift < series_len);
1631    shifts.sort_unstable();
1632    shifts.dedup();
1633    shifts
1634}
1635
1636fn compute_mvccd_window_scores(
1637    prepared: &PreparedWindow,
1638    ccd_pairs: &[[usize; 2]],
1639    target_conditioning_sets: &std::collections::BTreeMap<usize, Vec<usize>>,
1640    model: &ModelSpec,
1641    max_active_sources: usize,
1642    svd_backend: SvdBackend,
1643) -> Vec<f64> {
1644    use std::collections::{BTreeMap, BTreeSet};
1645
1646    let mut pairs_by_target: BTreeMap<usize, Vec<usize>> = BTreeMap::new();
1647    for (pair_idx, pair) in ccd_pairs.iter().enumerate() {
1648        pairs_by_target.entry(pair[0]).or_default().push(pair_idx);
1649    }
1650
1651    let mut values = vec![0.0; ccd_pairs.len()];
1652    for (target, pair_indices) in pairs_by_target {
1653        let candidate_sources = pair_indices
1654            .iter()
1655            .map(|pair_idx| ccd_pairs[*pair_idx][1])
1656            .collect::<BTreeSet<_>>()
1657            .into_iter()
1658            .collect::<Vec<_>>();
1659        let fixed_inputs = target_conditioning_sets
1660            .get(&target)
1661            .into_iter()
1662            .flat_map(|channels| channels.iter().copied())
1663            .filter(|channel| *channel != target && !candidate_sources.contains(channel))
1664            .collect::<Vec<_>>();
1665        let improvements = greedy_sparse_unique_improvements(
1666            prepared,
1667            target,
1668            &candidate_sources,
1669            &fixed_inputs,
1670            &model.primary_terms,
1671            &model.secondary_terms,
1672            model.window_length,
1673            max_active_sources,
1674            svd_backend,
1675        );
1676        for pair_idx in pair_indices {
1677            let source = ccd_pairs[pair_idx][1];
1678            values[pair_idx] = improvements
1679                .iter()
1680                .find(|(candidate, _)| *candidate == source)
1681                .map(|(_, value)| *value)
1682                .unwrap_or(0.0);
1683        }
1684    }
1685    values
1686}
1687
1688fn compute_trccd_matrix(
1689    prepared_windows: &[PreparedWindow],
1690    ccd_pairs: &[[usize; 2]],
1691    pair_conditioning_sets: &[Vec<usize>],
1692    model: &ModelSpec,
1693    lambda: f64,
1694    svd_backend: SvdBackend,
1695) -> Vec<Vec<f64>> {
1696    solve_channels_parallel(
1697        &ccd_pairs
1698            .iter()
1699            .zip(pair_conditioning_sets.iter())
1700            .collect::<Vec<_>>(),
1701        |(pair, confounds)| {
1702            let conditioned_inputs = {
1703                let mut inputs = (*confounds).clone();
1704                inputs.push(pair[1]);
1705                inputs
1706            };
1707            let baseline_windows = prepared_windows
1708                .iter()
1709                .map(|prepared| {
1710                    build_channel_regression_window_with_inputs(
1711                        prepared,
1712                        pair[0],
1713                        &confounds,
1714                        &model.primary_terms,
1715                        &model.secondary_terms,
1716                        model.window_length,
1717                    )
1718                })
1719                .collect::<Vec<_>>();
1720            let conditioned_windows = prepared_windows
1721                .iter()
1722                .map(|prepared| {
1723                    build_channel_regression_window_with_inputs(
1724                        prepared,
1725                        pair[0],
1726                        &conditioned_inputs,
1727                        &model.primary_terms,
1728                        &model.secondary_terms,
1729                        model.window_length,
1730                    )
1731                })
1732                .collect::<Vec<_>>();
1733            let baseline_blocks =
1734                solve_temporally_regularized_windows(&baseline_windows, lambda, svd_backend);
1735            let conditioned_blocks = solve_temporally_regularized_windows(
1736                &conditioned_windows,
1737                lambda,
1738                svd_backend,
1739            );
1740            baseline_blocks
1741                .iter()
1742                .zip(conditioned_blocks.iter())
1743                .map(|(baseline, conditioned)| {
1744                    conditional_causal_improvement(baseline.rmse, conditioned.rmse)
1745                })
1746                .collect::<Vec<_>>()
1747        },
1748    )
1749}
1750
1751fn build_ccd_stability_requests(request: &DDARequest) -> Vec<DDARequest> {
1752    let mut requests = Vec::new();
1753    let mut base = request.clone();
1754    base.algorithm_selection.enabled_variants = vec!["CCD".to_string()];
1755    base.algorithm_selection.select_mask = None;
1756
1757    let base_wl = base.window_parameters.window_length.max(32);
1758    let base_ws = base.window_parameters.window_step.max(1);
1759    let mut delays = base.delay_parameters.delays.clone();
1760    if delays.is_empty() {
1761        delays = crate::types::DEFAULT_DELAYS.to_vec();
1762    }
1763
1764    let mut shorter = base.clone();
1765    shorter.window_parameters.window_length = (base_wl.saturating_mul(4) / 5).max(32);
1766    shorter.window_parameters.window_step = (base_ws.saturating_mul(4) / 5).max(1);
1767    requests.push(shorter);
1768
1769    let mut longer = base.clone();
1770    longer.window_parameters.window_length = (base_wl.saturating_mul(6) / 5).max(base_wl + 1);
1771    longer.window_parameters.window_step = (base_ws.saturating_mul(6) / 5).max(base_ws + 1);
1772    requests.push(longer);
1773
1774    if delays.iter().all(|delay| *delay > 0) {
1775        let mut lower_delays = base.clone();
1776        lower_delays.delay_parameters.delays = delays.iter().map(|delay| delay - 1).collect();
1777        requests.push(lower_delays);
1778    }
1779
1780    let mut mixed = base;
1781    mixed.window_parameters.window_length = (base_wl.saturating_mul(4) / 5).max(32);
1782    mixed.window_parameters.window_step = (base_ws.saturating_mul(4) / 5).max(1);
1783    if delays.iter().all(|delay| *delay > 0) {
1784        mixed.delay_parameters.delays = delays.iter().map(|delay| delay - 1).collect();
1785    }
1786    requests.push(mixed);
1787
1788    dedup_stability_requests(requests)
1789}
1790
1791fn dedup_stability_requests(requests: Vec<DDARequest>) -> Vec<DDARequest> {
1792    use std::collections::BTreeSet;
1793
1794    let mut seen = BTreeSet::new();
1795    let mut deduped = Vec::new();
1796    for request in requests {
1797        let key = (
1798            request.window_parameters.window_length,
1799            request.window_parameters.window_step,
1800            request.delay_parameters.delays.clone(),
1801        );
1802        if seen.insert(key) {
1803            deduped.push(request);
1804        }
1805    }
1806    deduped
1807}
1808
1809fn nearest_aligned_value(
1810    markers: &[f64],
1811    matrix: &[Vec<f64>],
1812    row_idx: usize,
1813    target_marker: f64,
1814) -> Option<f64> {
1815    let row = matrix.get(row_idx)?;
1816    let nearest_index = markers
1817        .iter()
1818        .enumerate()
1819        .min_by(|(_, left), (_, right)| {
1820            (*left - target_marker)
1821                .abs()
1822                .partial_cmp(&(*right - target_marker).abs())
1823                .unwrap_or(std::cmp::Ordering::Equal)
1824        })
1825        .map(|(index, _)| index)?;
1826    row.get(nearest_index).copied()
1827}
1828
1829fn same_sign(left: f64, right: f64) -> bool {
1830    (left > 0.0 && right > 0.0) || (left < 0.0 && right < 0.0)
1831}
1832
1833pub fn run_request_on_matrix(
1834    request: &DDARequest,
1835    samples: &[Vec<f64>],
1836    channel_labels: Option<&[String]>,
1837) -> Result<DDAResult> {
1838    PureRustRunner::default().run_on_matrix(request, samples, channel_labels)
1839}
1840
1841pub fn inspect_ccd_conditioning_sets_on_matrix(
1842    request: &DDARequest,
1843    samples: &[Vec<f64>],
1844    channel_labels: Option<&[String]>,
1845) -> Result<CcdConditioningInspection> {
1846    PureRustRunner::default().inspect_ccd_conditioning_sets_on_matrix(
1847        request,
1848        samples,
1849        channel_labels,
1850    )
1851}
1852
1853pub fn score_ccd_conditioning_subsets_on_matrix(
1854    request: &DDARequest,
1855    samples: &[Vec<f64>],
1856    channel_labels: Option<&[String]>,
1857    pair: [usize; 2],
1858    confound_sets: &[Vec<usize>],
1859) -> Result<Vec<CcdConditioningSubsetScore>> {
1860    PureRustRunner::default().score_ccd_conditioning_subsets_on_matrix(
1861        request,
1862        samples,
1863        channel_labels,
1864        pair,
1865        confound_sets,
1866    )
1867}
1868
1869pub fn profile_ccd_conditioning_subsets_on_matrix(
1870    request: &DDARequest,
1871    samples: &[Vec<f64>],
1872    channel_labels: Option<&[String]>,
1873    pair: [usize; 2],
1874    confound_sets: &[Vec<usize>],
1875) -> Result<Vec<CcdConditioningSubsetProfile>> {
1876    PureRustRunner::default().profile_ccd_conditioning_subsets_on_matrix(
1877        request,
1878        samples,
1879        channel_labels,
1880        pair,
1881        confound_sets,
1882    )
1883}
1884
1885pub fn run_request_on_matrix_with_progress<F>(
1886    request: &DDARequest,
1887    samples: &[Vec<f64>],
1888    channel_labels: Option<&[String]>,
1889    on_progress: F,
1890) -> Result<DDAResult>
1891where
1892    F: FnMut(&PureRustProgress),
1893{
1894    PureRustRunner::default().run_on_matrix_with_progress(
1895        request,
1896        samples,
1897        channel_labels,
1898        on_progress,
1899    )
1900}