1use serde::{Deserialize, Serialize};
7
8pub const SELECT_MASK_SIZE: usize = 6;
10
11pub const BINARY_NAME: &str = "run_DDA_AsciiEdf";
13
14pub const REQUIRES_SHELL_WRAPPER: bool = true;
16
17pub const SHELL_COMMAND: &str = "sh";
19
20pub const SUPPORTED_PLATFORMS: &[&str] = &[
22 "linux",
23 "macos",
24 "windows",
25];
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29#[serde(rename_all = "snake_case")]
30pub enum ChannelFormat {
31 Individual,
32 Pairs,
33 DirectedPairs,
34}
35
36impl ChannelFormat {
37 pub fn from_str(s: &str) -> Option<Self> {
38 match s {
39 "individual" => Some(Self::Individual),
40 "pairs" => Some(Self::Pairs),
41 "directed_pairs" => Some(Self::DirectedPairs),
42 _ => None,
43 }
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
49pub struct OutputColumns {
50 pub coefficients: u8,
51 pub has_error: bool,
52}
53
54#[derive(Debug, Clone, Serialize)]
57pub struct VariantMetadata {
58 pub abbreviation: &'static str,
59 pub name: &'static str,
60 pub position: u8,
61 pub output_suffix: &'static str,
62 pub stride: u8,
63 pub reserved: bool,
64 #[serde(skip)]
65 pub required_params: &'static [&'static str],
66 pub channel_format: ChannelFormat,
67 pub output_columns: OutputColumns,
68 pub documentation: &'static str,
69}
70
71impl VariantMetadata {
72 pub fn from_abbrev(abbrev: &str) -> Option<&'static VariantMetadata> {
74 VARIANT_REGISTRY.iter().find(|v| v.abbreviation == abbrev)
75 }
76
77 pub fn from_suffix(suffix: &str) -> Option<&'static VariantMetadata> {
79 VARIANT_REGISTRY.iter().find(|v| v.output_suffix == suffix)
80 }
81
82 pub fn from_position(pos: u8) -> Option<&'static VariantMetadata> {
84 VARIANT_REGISTRY.iter().find(|v| v.position == pos)
85 }
86
87 pub fn active_variants() -> impl Iterator<Item = &'static VariantMetadata> {
89 VARIANT_REGISTRY.iter().filter(|v| !v.reserved)
90 }
91}
92
93pub const ST: VariantMetadata = VariantMetadata {
102 abbreviation: "ST",
103 name: "Single Timeseries",
104 position: 0,
105 output_suffix: "_ST",
106 stride: 4,
107 reserved: false,
108 required_params: &[],
109 channel_format: ChannelFormat::Individual,
110 output_columns: OutputColumns {
111 coefficients: 3,
112 has_error: true,
113 },
114 documentation: "Analyzes individual channels independently. Most basic variant. One result row per channel.",
115};
116
117
118pub const CT: VariantMetadata = VariantMetadata {
122 abbreviation: "CT",
123 name: "Cross-Timeseries",
124 position: 1,
125 output_suffix: "_CT",
126 stride: 4,
127 reserved: false,
128 required_params: &["-WL_CT", "-WS_CT"],
129 channel_format: ChannelFormat::Pairs,
130 output_columns: OutputColumns {
131 coefficients: 3,
132 has_error: true,
133 },
134 documentation: "Analyzes relationships between channel pairs. Symmetric: pair (1,2) equals (2,1). When enabled with ST, wrapper must run CT pairs separately.",
135};
136
137
138pub const CD: VariantMetadata = VariantMetadata {
142 abbreviation: "CD",
143 name: "Cross-Dynamical",
144 position: 2,
145 output_suffix: "_CD_DDA_ST",
146 stride: 2,
147 reserved: false,
148 required_params: &["-WL_CT", "-WS_CT"],
149 channel_format: ChannelFormat::DirectedPairs,
150 output_columns: OutputColumns {
151 coefficients: 1,
152 has_error: true,
153 },
154 documentation: "Analyzes directed causal relationships. Asymmetric: (1->2) differs from (2->1). CD is independent (no longer requires ST+CT).",
155};
156
157
158pub const RESERVED: VariantMetadata = VariantMetadata {
162 abbreviation: "RESERVED",
163 name: "Reserved",
164 position: 3,
165 output_suffix: "_RESERVED",
166 stride: 1,
167 reserved: true,
168 required_params: &[],
169 channel_format: ChannelFormat::Individual,
170 output_columns: OutputColumns {
171 coefficients: 0,
172 has_error: false,
173 },
174 documentation: "Internal development function. Should always be set to 0 in production.",
175};
176
177
178pub const DE: VariantMetadata = VariantMetadata {
182 abbreviation: "DE",
183 name: "Delay Embedding",
184 position: 4,
185 output_suffix: "_DE",
186 stride: 1,
187 reserved: false,
188 required_params: &["-WL_CT", "-WS_CT"],
189 channel_format: ChannelFormat::Individual,
190 output_columns: OutputColumns {
191 coefficients: 0,
192 has_error: false,
193 },
194 documentation: "Tests for ergodic behavior in dynamical systems. Produces single aggregate measure per time window (not per-channel).",
195};
196
197
198pub const SY: VariantMetadata = VariantMetadata {
202 abbreviation: "SY",
203 name: "Synchronization",
204 position: 5,
205 output_suffix: "_SY",
206 stride: 1,
207 reserved: false,
208 required_params: &[],
209 channel_format: ChannelFormat::Individual,
210 output_columns: OutputColumns {
211 coefficients: 0,
212 has_error: false,
213 },
214 documentation: "Detects synchronized behavior between signals. Produces one value per channel/measure per time window.",
215};
216
217
218
219pub const VARIANT_REGISTRY: &[VariantMetadata] = &[
221 ST,
222 CT,
223 CD,
224 RESERVED,
225 DE,
226 SY,
227];
228
229pub const VARIANT_ORDER: &[&str] = &[
231 "ST",
232 "CT",
233 "CD",
234 "RESERVED",
235 "DE",
236 "SY",
237];
238
239pub fn generate_select_mask(variants: &[&str]) -> [u8; SELECT_MASK_SIZE] {
245 let mut mask = [0u8; SELECT_MASK_SIZE];
246 for abbrev in variants {
247 if let Some(variant) = VariantMetadata::from_abbrev(abbrev) {
248 mask[variant.position as usize] = 1;
249 }
250 }
251 mask
252}
253
254pub fn parse_select_mask(mask: &[u8]) -> Vec<&'static str> {
256 mask.iter()
257 .enumerate()
258 .filter(|(_, &bit)| bit == 1)
259 .filter_map(|(pos, _)| VariantMetadata::from_position(pos as u8))
260 .filter(|v| !v.reserved)
261 .map(|v| v.abbreviation)
262 .collect()
263}
264
265pub fn format_select_mask(mask: &[u8; SELECT_MASK_SIZE]) -> String {
267 mask.iter()
268 .map(|b| b.to_string())
269 .collect::<Vec<_>>()
270 .join(" ")
271}
272
273pub mod select_mask_positions {
278 pub const ST: usize = 0;
280 pub const CT: usize = 1;
282 pub const CD: usize = 2;
284 pub const RESERVED: usize = 3;
286 pub const DE: usize = 4;
288 pub const SY: usize = 5;
290}
291
292#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
297pub enum FileType {
298 EDF,
299 ASCII,
300}
301
302impl FileType {
303 pub fn flag(&self) -> &'static str {
304 match self {
305 Self::EDF => "-EDF",
306 Self::ASCII => "-ASCII",
307 }
308 }
309
310 pub fn from_extension(ext: &str) -> Option<Self> {
311 match ext.to_lowercase().as_str() {
312 "edf" => Some(Self::EDF),
313 "ascii" => Some(Self::ASCII),
314 "txt" => Some(Self::ASCII),
315 "csv" => Some(Self::ASCII),
316 _ => None,
317 }
318 }
319}
320
321pub const BINARY_ENV_VAR: &str = "DDA_BINARY_PATH";
327
328pub const BINARY_HOME_ENV_VAR: &str = "DDA_HOME";
330
331pub const DEFAULT_BINARY_PATHS: &[&str] = &[
333 "~/.local/bin",
334 "~/bin",
335 "/usr/local/bin",
336 "/opt/dda/bin",
337];
338
339pub fn find_binary(explicit_path: Option<&str>) -> Option<std::path::PathBuf> {
353 use std::path::PathBuf;
354
355 fn expand_path(path: &str) -> PathBuf {
357 if path.starts_with("~/") {
358 if let Some(home) = std::env::var_os("HOME") {
359 return PathBuf::from(home).join(&path[2..]);
360 }
361 }
362 PathBuf::from(path)
363 }
364
365 if let Some(path) = explicit_path {
367 let p = expand_path(path);
368 if p.exists() {
369 return Some(p);
370 }
371 return None;
372 }
373
374 if let Ok(env_path) = std::env::var(BINARY_ENV_VAR) {
376 let p = expand_path(&env_path);
377 if p.exists() {
378 return Some(p);
379 }
380 }
381
382 if let Ok(home_path) = std::env::var(BINARY_HOME_ENV_VAR) {
384 let p = expand_path(&home_path).join("bin").join(BINARY_NAME);
385 if p.exists() {
386 return Some(p);
387 }
388 }
389
390 for search_path in DEFAULT_BINARY_PATHS {
392 let p = expand_path(search_path).join(BINARY_NAME);
393 if p.exists() {
394 return Some(p);
395 }
396 }
397
398 None
399}
400
401pub fn require_binary(explicit_path: Option<&str>) -> Result<std::path::PathBuf, String> {
405 find_binary(explicit_path).ok_or_else(|| {
406 format!(
407 "DDA binary '{}' not found. Set ${} or ${}, or install to one of: {:?}",
408 BINARY_NAME, BINARY_ENV_VAR, BINARY_HOME_ENV_VAR, DEFAULT_BINARY_PATHS
409 )
410 })
411}
412
413#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_variant_registry_size() {
423 assert_eq!(VARIANT_REGISTRY.len(), 6);
424 }
425
426 #[test]
427 fn test_variant_lookup_by_abbrev() {
428 assert!(VariantMetadata::from_abbrev("ST").is_some());
429 assert!(VariantMetadata::from_abbrev("CT").is_some());
430 assert!(VariantMetadata::from_abbrev("CD").is_some());
431 assert!(VariantMetadata::from_abbrev("DE").is_some());
432 assert!(VariantMetadata::from_abbrev("SY").is_some());
433 assert!(VariantMetadata::from_abbrev("INVALID").is_none());
434 }
435
436 #[test]
437 fn test_variant_lookup_by_suffix() {
438 assert!(VariantMetadata::from_suffix("_ST").is_some());
439 assert!(VariantMetadata::from_suffix("_CT").is_some());
440 assert!(VariantMetadata::from_suffix("_CD_DDA_ST").is_some());
441 assert!(VariantMetadata::from_suffix("_DE").is_some());
442 assert!(VariantMetadata::from_suffix("_SY").is_some());
443 }
444
445 #[test]
446 fn test_select_mask_generation() {
447 let mask = generate_select_mask(&["ST", "SY"]);
448 assert_eq!(mask[select_mask_positions::ST], 1);
449 assert_eq!(mask[select_mask_positions::SY], 1);
450 assert_eq!(mask[select_mask_positions::CT], 0);
451 }
452
453 #[test]
454 fn test_select_mask_parsing() {
455 let mask = [1, 1, 0, 0, 1, 0];
456 let variants = parse_select_mask(&mask);
457 assert!(variants.contains(&"ST"));
458 assert!(variants.contains(&"CT"));
459 assert!(variants.contains(&"DE"));
460 assert!(!variants.contains(&"CD"));
461 }
462
463 #[test]
464 fn test_file_type_flags() {
465 assert_eq!(FileType::EDF.flag(), "-EDF");
466 assert_eq!(FileType::ASCII.flag(), "-ASCII");
467 }
468
469 #[test]
470 fn test_file_type_detection() {
471 assert_eq!(FileType::from_extension("edf"), Some(FileType::EDF));
472 assert_eq!(FileType::from_extension("ascii"), Some(FileType::ASCII));
473 assert_eq!(FileType::from_extension("txt"), Some(FileType::ASCII));
474 assert_eq!(FileType::from_extension("csv"), Some(FileType::ASCII));
475 assert!(FileType::from_extension("unknown").is_none());
476 }
477
478 #[test]
479 fn test_binary_name() {
480 assert_eq!(BINARY_NAME, "run_DDA_AsciiEdf");
481 }
482
483 #[test]
484 fn test_stride_values() {
485 assert_eq!(ST.stride, 4);
486 assert_eq!(CT.stride, 4);
487 assert_eq!(CD.stride, 2);
488 assert_eq!(DE.stride, 1);
489 assert_eq!(SY.stride, 1);
490 }
491}