1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
//! # zuna-rs — ZUNA EEG Foundation Model inference in Rust
//!
//! Pure-Rust inference for the [ZUNA](https://huggingface.co/Zyphra/ZUNA)
//! EEG foundation model, built on [Burn 0.20](https://burn.dev) and
//! [exg](https://github.com/eugenehp/exg) for FIF preprocessing.
//!
//! ## Three entry points
//!
//! | Type | Loads | Use case |
//! |---|---|---|
//! | [`ZunaInference`] | encoder + decoder | full encode → diffuse → decode pipeline |
//! | [`ZunaEncoder`] | encoder only | produce latent embeddings, save memory |
//! | [`ZunaDecoder`] | decoder only | reconstruct from stored embeddings |
//!
//! ## Quick start — full pipeline
//!
//! ```rust,ignore
//! use zuna_rs::{ZunaInference, InferenceResult};
//!
//! let (model, _ms) = ZunaInference::<B>::load(
//! Path::new("config.json"),
//! Path::new("model.safetensors"),
//! device,
//! )?;
//! let result: InferenceResult = model.run_fif(Path::new("recording.fif"), 50, 1.0, 10.0)?;
//! result.save_safetensors("output.safetensors")?;
//! ```
//!
//! ## Quick start — encode only
//!
//! ```rust,ignore
//! use zuna_rs::{ZunaEncoder, EncodingResult};
//!
//! let (enc, _ms) = ZunaEncoder::<B>::load(
//! Path::new("config.json"),
//! Path::new("model.safetensors"),
//! device,
//! )?;
//! let result: EncodingResult = enc.encode_fif(Path::new("recording.fif"), 10.0)?;
//! result.save_safetensors("data/embeddings.safetensors")?;
//! ```
//!
//! ## Quick start — decode from stored embeddings
//!
//! ```rust,ignore
//! use zuna_rs::{ZunaDecoder, encoder::EncodingResult};
//!
//! let embeddings = EncodingResult::load_safetensors("data/embeddings.safetensors")?;
//! let (dec, _ms) = ZunaDecoder::<B>::load(
//! Path::new("config.json"),
//! Path::new("model.safetensors"),
//! device,
//! )?;
//! let result = dec.decode_embeddings(&embeddings, 50, 1.0, 10.0)?;
//! result.save_safetensors("output.safetensors")?;
//! ```
//!
//! ## Embedding regularisation
//!
//! The encoder uses an **MMD (Maximum Mean Discrepancy) bottleneck**: during
//! training an MMD loss constrains the embedding distribution toward **N(0, I)**.
//! At inference the bottleneck is a pure passthrough — no reparameterisation is
//! applied. Embeddings from [`ZunaEncoder`] or [`ZunaInference::encode_fif`]
//! are therefore already in the regularised latent space and can be used
//! directly for downstream tasks.
// ── Thread configuration ─────────────────────────────────────────────────────
/// Configure the global Rayon thread pool used by the NdArray backend.
///
/// Call this **once**, before any model operations. If `n` is `None` or `0`,
/// Rayon uses all logical CPUs (its default).
///
/// Returns the actual number of threads in the pool.
///
/// # Example
/// ```rust,ignore
/// let n = zuna_rs::init_threads(Some(4));
/// println!("Using {n} threads");
/// ```
// ── Internal modules ─────────────────────────────────────────────────────────
// ── Flat re-exports ───────────────────────────────────────────────────────────
//
// Everything a downstream user needs is available as `zuna_rs::Foo` without
// knowing the internal module layout.
// Full pipeline
pub use ;
// Encoder-only
pub use ;
// Decoder-only
pub use ZunaDecoder;
// Configs
pub use ;
// Data types needed for the lower-level API
pub use ;
// Channel position lookup
pub use ;
// CSV / tensor data loading
pub use ;
// CSV export from FIF
pub use fif_to_csv;