Skip to main content

plato_audio_jepa/
lib.rs

1//! # plato-audio-jepa
2//!
3//! Audio / vibration JEPA for the PLATO nervous system. Processes microphone
4//! input into structured room state vectors suitable for downstream tiles.
5
6use serde::{Deserialize, Serialize};
7use uuid::Uuid;
8
9// ---------------------------------------------------------------------------
10// Types
11// ---------------------------------------------------------------------------
12
13/// Structured room state produced by the audio JEPA from a chunk of audio.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AudioTile {
16    pub id: Uuid,
17    pub volume: f32,
18    pub dominant_frequency: f32,
19    pub spectral_centroid: f32,
20    pub anomaly: f32,
21    pub timestamp: u64,
22}
23
24/// Spectral deadband filter — only process chunks when frequency content changes.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct AudioDeadband {
27    pub threshold: f64,
28    pub last_spectrum: Option<Vec<f32>>,
29}
30
31impl Default for AudioDeadband {
32    fn default() -> Self {
33        Self {
34            threshold: 0.05,
35            last_spectrum: None,
36        }
37    }
38}
39
40impl AudioDeadband {
41    pub fn new(threshold: f64) -> Self {
42        Self {
43            threshold,
44            last_spectrum: None,
45        }
46    }
47
48    /// Returns `true` if the new spectrum represents a significant change.
49    pub fn should_process(&mut self, spectrum: &[f32]) -> bool {
50        let significant = match self.last_spectrum {
51            None => true,
52            Some(ref prev) => {
53                let n = prev.len().min(spectrum.len());
54                if n == 0 {
55                    return true;
56                }
57                let diff: f64 = (0..n)
58                    .map(|i| {
59                        let d = (prev[i] - spectrum[i]) as f64;
60                        d * d
61                    })
62                    .sum::<f64>()
63                    / n as f64;
64                diff.sqrt() > self.threshold
65            }
66        };
67        if significant {
68            self.last_spectrum = Some(spectrum.to_vec());
69        }
70        significant
71    }
72}
73
74/// 16-dimensional audio state vector for a room.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct RoomAudioState {
77    /// [0] volume (0-1 normalized)
78    pub volume: f32,
79    /// [1] dominant_frequency (Hz normalized to 0-1, where 1 = Nyquist)
80    pub dominant_frequency: f32,
81    /// [2] spectral_centroid (brightness of sound)
82    pub spectral_centroid: f32,
83    /// [3] anomaly_score (0-1)
84    pub anomaly_score: f32,
85    /// [4-7] frequency band energies (sub-bass, bass, mid, high)
86    pub band_energies: [f32; 4],
87    /// [8-11] temporal patterns (volume trend, frequency drift, onset rate, rhythm)
88    pub temporal_patterns: [f32; 4],
89    /// [12-15] reserved
90    pub reserved: [f32; 4],
91}
92
93impl Default for RoomAudioState {
94    fn default() -> Self {
95        Self {
96            volume: 0.0,
97            dominant_frequency: 0.0,
98            spectral_centroid: 0.0,
99            anomaly_score: 0.0,
100            band_energies: [0.0; 4],
101            temporal_patterns: [0.0; 4],
102            reserved: [0.0; 4],
103        }
104    }
105}
106
107impl RoomAudioState {
108    /// Convert to a flat 16-element f32 array.
109    pub fn to_vector(&self) -> [f32; 16] {
110        let mut v = [0.0f32; 16];
111        v[0] = self.volume;
112        v[1] = self.dominant_frequency;
113        v[2] = self.spectral_centroid;
114        v[3] = self.anomaly_score;
115        v[4..8].copy_from_slice(&self.band_energies);
116        v[8..12].copy_from_slice(&self.temporal_patterns);
117        v[12..16].copy_from_slice(&self.reserved);
118        v
119    }
120
121    /// Reconstruct from a flat 16-element f32 array.
122    pub fn from_vector(v: &[f32; 16]) -> Self {
123        let mut bands = [0.0f32; 4];
124        let mut temporal = [0.0f32; 4];
125        let mut reserved = [0.0f32; 4];
126        bands.copy_from_slice(&v[4..8]);
127        temporal.copy_from_slice(&v[8..12]);
128        reserved.copy_from_slice(&v[12..16]);
129        Self {
130            volume: v[0],
131            dominant_frequency: v[1],
132            spectral_centroid: v[2],
133            anomaly_score: v[3],
134            band_energies: bands,
135            temporal_patterns: temporal,
136            reserved,
137        }
138    }
139}
140
141// ---------------------------------------------------------------------------
142// Functions
143// ---------------------------------------------------------------------------
144
145/// Compute a simple DFT returning the magnitude of the first 16 frequency bins.
146/// Input is a window of time-domain samples (ideally power-of-2 length).
147pub fn compute_spectrum(samples: &[f32]) -> Vec<f32> {
148    let n_bins = 16usize;
149    let n = samples.len().max(1);
150    let mut magnitudes = vec![0.0f32; n_bins];
151
152    for (k, mag) in magnitudes.iter_mut().enumerate() {
153        let mut re = 0.0f32;
154        let mut im = 0.0f32;
155        for (i, &s) in samples.iter().enumerate() {
156            let angle = -2.0 * std::f32::consts::PI * (k as f32) * (i as f32) / (n as f32);
157            re += s * angle.cos();
158            im += s * angle.sin();
159        }
160        *mag = (re * re + im * im).sqrt() / (n as f32);
161    }
162
163    magnitudes
164}
165
166/// Compute the spectral centroid (weighted mean of frequency bins).
167/// Returns a value in the range of the spectrum magnitudes.
168pub fn compute_spectral_centroid(spectrum: &[f32]) -> f32 {
169    if spectrum.is_empty() {
170        return 0.0;
171    }
172
173    let mut weighted_sum = 0.0f32;
174    let mut weight_total = 0.0f32;
175
176    for (i, &mag) in spectrum.iter().enumerate() {
177        let freq_bin = i as f32;
178        weighted_sum += freq_bin * mag;
179        weight_total += mag;
180    }
181
182    if weight_total == 0.0 {
183        0.0
184    } else {
185        weighted_sum / weight_total
186    }
187}
188
189/// Compute the energy (sum of squares) in a frequency range.
190pub fn compute_band_energy(spectrum: &[f32], low: usize, high: usize) -> f32 {
191    let low = low.min(spectrum.len());
192    let high = high.min(spectrum.len());
193    if low >= high {
194        return 0.0;
195    }
196    spectrum[low..high].iter().map(|&m| m * m).sum()
197}
198
199/// Estimate the onset rate (sudden volume increases per unit time) from a
200/// volume history. Returns the rate of onsets.
201pub fn compute_onset_rate(volume_history: &[f32]) -> f32 {
202    if volume_history.len() < 2 {
203        return 0.0;
204    }
205
206    let threshold = 0.1;
207    let mut onsets = 0usize;
208
209    for i in 1..volume_history.len() {
210        let delta = volume_history[i] - volume_history[i - 1];
211        if delta > threshold {
212            onsets += 1;
213        }
214    }
215
216    onsets as f32 / (volume_history.len() - 1) as f32
217}
218
219/// Estimate BPM from a volume envelope by counting peaks.
220/// `sample_rate` is the rate of the volume-history samples (not audio samples).
221pub fn detect_rhythm(volume_history: &[f32], sample_rate: f32) -> f32 {
222    if volume_history.len() < 4 || sample_rate <= 0.0 {
223        return 0.0;
224    }
225
226    // Simple peak detection: a sample is a peak if it's greater than its neighbors
227    let mut peaks: Vec<usize> = Vec::new();
228    for i in 1..volume_history.len() - 1 {
229        if volume_history[i] > volume_history[i - 1]
230            && volume_history[i] > volume_history[i + 1]
231            && volume_history[i] > 0.05
232        {
233            peaks.push(i);
234        }
235    }
236
237    if peaks.len() < 2 {
238        return 0.0;
239    }
240
241    // Average period between peaks → BPM
242    let mut total_interval = 0.0f32;
243    let mut count = 0usize;
244    for w in peaks.windows(2) {
245        let interval_samples = (w[1] - w[0]) as f32;
246        let interval_seconds = interval_samples / sample_rate;
247        if interval_seconds > 0.1 && interval_seconds < 5.0 {
248            total_interval += interval_seconds;
249            count += 1;
250        }
251    }
252
253    if count == 0 {
254        return 0.0;
255    }
256
257    let avg_period = total_interval / count as f32;
258    if avg_period > 0.0 {
259        60.0 / avg_period
260    } else {
261        0.0
262    }
263}
264
265/// Convert a RoomAudioState into an AudioTile (the output artifact).
266pub fn audio_state_to_tile(state: &RoomAudioState) -> AudioTile {
267    AudioTile {
268        id: Uuid::new_v4(),
269        volume: state.volume,
270        dominant_frequency: state.dominant_frequency,
271        spectral_centroid: state.spectral_centroid,
272        anomaly: state.anomaly_score,
273        timestamp: 0,
274    }
275}