Skip to main content

ringkernel_audio_fft/
bin_actor.rs

1//! FFT bin actors with K2K neighbor messaging.
2//!
3//! Each frequency bin is represented as an independent GPU actor that:
4//! 1. Receives FFT bin data from the host
5//! 2. Exchanges information with neighboring bins via K2K messaging
6//! 3. Performs coherence analysis for direct/ambience separation
7//! 4. Outputs separated bin data
8
9use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use tokio::sync::mpsc;
14use tracing::{info, trace, warn};
15
16use ringkernel_core::k2k::K2KStats;
17use ringkernel_core::prelude::*;
18
19use crate::error::{AudioFftError, Result};
20use crate::messages::{Complex, FrequencyBin, NeighborData, SeparatedBin};
21use crate::separation::{CoherenceAnalyzer, SeparationConfig};
22
23/// State maintained by each bin actor.
24#[derive(Debug, Clone)]
25pub struct BinActorState {
26    /// Bin index.
27    pub bin_index: u32,
28    /// Current frame ID.
29    pub current_frame: u64,
30    /// Current bin value.
31    pub current_value: Complex,
32    /// Previous frame value.
33    pub prev_value: Option<Complex>,
34    /// Left neighbor data (lower frequency).
35    pub left_neighbor: Option<NeighborData>,
36    /// Right neighbor data (higher frequency).
37    pub right_neighbor: Option<NeighborData>,
38    /// Computed coherence with neighbors.
39    pub coherence: f32,
40    /// Temporal smoothing state.
41    pub smoothed_coherence: f32,
42    /// Phase derivative (for transient detection).
43    pub phase_derivative: f32,
44    /// Spectral flux.
45    pub spectral_flux: f32,
46}
47
48impl BinActorState {
49    /// Create a new bin actor state.
50    pub fn new(bin_index: u32) -> Self {
51        Self {
52            bin_index,
53            current_frame: 0,
54            current_value: Complex::default(),
55            prev_value: None,
56            left_neighbor: None,
57            right_neighbor: None,
58            coherence: 0.5,
59            smoothed_coherence: 0.5,
60            phase_derivative: 0.0,
61            spectral_flux: 0.0,
62        }
63    }
64
65    /// Update with new bin data.
66    pub fn update(&mut self, bin: &FrequencyBin) {
67        self.prev_value = Some(self.current_value);
68        self.current_value = bin.value;
69        self.current_frame = bin.frame_id;
70
71        // Calculate phase derivative
72        if let Some(prev) = self.prev_value {
73            let prev_phase = prev.phase();
74            let curr_phase = self.current_value.phase();
75            // Unwrap phase difference
76            let mut phase_diff = curr_phase - prev_phase;
77            while phase_diff > std::f32::consts::PI {
78                phase_diff -= 2.0 * std::f32::consts::PI;
79            }
80            while phase_diff < -std::f32::consts::PI {
81                phase_diff += 2.0 * std::f32::consts::PI;
82            }
83            self.phase_derivative = phase_diff;
84
85            // Calculate spectral flux
86            let prev_mag = prev.magnitude();
87            let curr_mag = self.current_value.magnitude();
88            self.spectral_flux = (curr_mag - prev_mag).max(0.0); // Only onset (positive flux)
89        }
90
91        // Clear neighbor data for new frame
92        self.left_neighbor = None;
93        self.right_neighbor = None;
94    }
95
96    /// Set neighbor data.
97    pub fn set_neighbor(&mut self, data: NeighborData, is_left: bool) {
98        if is_left {
99            self.left_neighbor = Some(data);
100        } else {
101            self.right_neighbor = Some(data);
102        }
103    }
104
105    /// Check if we have all neighbor data.
106    pub fn has_all_neighbors(&self, has_left: bool, has_right: bool) -> bool {
107        (!has_left || self.left_neighbor.is_some()) && (!has_right || self.right_neighbor.is_some())
108    }
109
110    /// Create neighbor data to send to adjacent bins.
111    pub fn to_neighbor_data(&self) -> NeighborData {
112        NeighborData {
113            source_bin: self.bin_index,
114            frame_id: self.current_frame,
115            value: self.current_value,
116            magnitude: self.current_value.magnitude(),
117            phase: self.current_value.phase(),
118            phase_derivative: self.phase_derivative,
119            spectral_flux: self.spectral_flux,
120        }
121    }
122}
123
124/// Handle to a single bin actor.
125pub struct BinActorHandle {
126    /// Bin index.
127    pub bin_index: u32,
128    /// Kernel ID.
129    kernel_id: KernelId,
130    /// K2K endpoint for this actor (reserved for future direct communication).
131    #[allow(dead_code)]
132    endpoint: K2KEndpoint,
133    /// State (shared for monitoring).
134    state: Arc<RwLock<BinActorState>>,
135    /// Input channel for bin data.
136    input_tx: mpsc::Sender<FrequencyBin>,
137    /// Output channel for separated data.
138    output_rx: mpsc::Receiver<SeparatedBin>,
139    /// Running flag.
140    running: Arc<AtomicBool>,
141}
142
143impl BinActorHandle {
144    /// Send bin data to the actor.
145    pub async fn send_bin(&self, bin: FrequencyBin) -> Result<()> {
146        self.input_tx
147            .send(bin)
148            .await
149            .map_err(|e| AudioFftError::kernel(format!("Failed to send bin data: {}", e)))
150    }
151
152    /// Receive separated bin data.
153    pub async fn receive_separated(&mut self) -> Option<SeparatedBin> {
154        self.output_rx.recv().await
155    }
156
157    /// Get the current state.
158    pub fn state(&self) -> BinActorState {
159        self.state.read().clone()
160    }
161
162    /// Get the kernel ID.
163    pub fn kernel_id(&self) -> &KernelId {
164        &self.kernel_id
165    }
166
167    /// Check if running.
168    pub fn is_running(&self) -> bool {
169        self.running.load(Ordering::Relaxed)
170    }
171
172    /// Stop the actor.
173    pub fn stop(&self) {
174        self.running.store(false, Ordering::Relaxed);
175    }
176}
177
178/// A single bin actor that processes one frequency bin.
179pub struct BinActor {
180    /// Bin index.
181    bin_index: u32,
182    /// Total number of bins (reserved for frequency-dependent processing).
183    #[allow(dead_code)]
184    total_bins: u32,
185    /// Kernel ID (reserved for multi-actor coordination).
186    #[allow(dead_code)]
187    kernel_id: KernelId,
188    /// State.
189    state: Arc<RwLock<BinActorState>>,
190    /// K2K endpoint.
191    endpoint: K2KEndpoint,
192    /// Left neighbor kernel ID (if any).
193    left_neighbor_id: Option<KernelId>,
194    /// Right neighbor kernel ID (if any).
195    right_neighbor_id: Option<KernelId>,
196    /// Input channel.
197    input_rx: mpsc::Receiver<FrequencyBin>,
198    /// Output channel.
199    output_tx: mpsc::Sender<SeparatedBin>,
200    /// Coherence analyzer.
201    analyzer: CoherenceAnalyzer,
202    /// Separation config.
203    config: SeparationConfig,
204    /// Running flag.
205    running: Arc<AtomicBool>,
206    /// Frame counter.
207    frame_counter: AtomicU64,
208}
209
210impl BinActor {
211    /// Create a new bin actor.
212    pub fn new(
213        bin_index: u32,
214        total_bins: u32,
215        broker: &Arc<K2KBroker>,
216        config: SeparationConfig,
217    ) -> (Self, BinActorHandle) {
218        let kernel_id = KernelId::new(format!("bin_actor_{}", bin_index));
219        let endpoint = broker.register(kernel_id.clone());
220
221        let state = Arc::new(RwLock::new(BinActorState::new(bin_index)));
222        let running = Arc::new(AtomicBool::new(true));
223
224        let (input_tx, input_rx) = mpsc::channel(64);
225        let (output_tx, output_rx) = mpsc::channel(64);
226
227        // Create handle's endpoint separately
228        let handle_endpoint =
229            broker.register(KernelId::new(format!("bin_actor_{}_handle", bin_index)));
230
231        let handle = BinActorHandle {
232            bin_index,
233            kernel_id: kernel_id.clone(),
234            endpoint: handle_endpoint,
235            state: state.clone(),
236            input_tx,
237            output_rx,
238            running: running.clone(),
239        };
240
241        let actor = Self {
242            bin_index,
243            total_bins,
244            kernel_id,
245            state,
246            endpoint,
247            left_neighbor_id: None,
248            right_neighbor_id: None,
249            input_rx,
250            output_tx,
251            analyzer: CoherenceAnalyzer::new(config.clone()),
252            config,
253            running,
254            frame_counter: AtomicU64::new(0),
255        };
256
257        (actor, handle)
258    }
259
260    /// Set neighbor kernel IDs.
261    pub fn set_neighbors(&mut self, left: Option<KernelId>, right: Option<KernelId>) {
262        self.left_neighbor_id = left;
263        self.right_neighbor_id = right;
264    }
265
266    /// Run the actor processing loop.
267    pub async fn run(&mut self) -> Result<()> {
268        info!("Bin actor {} starting", self.bin_index);
269
270        while self.running.load(Ordering::Relaxed) {
271            // Wait for input bin data
272            let bin = match tokio::time::timeout(
273                std::time::Duration::from_millis(100),
274                self.input_rx.recv(),
275            )
276            .await
277            {
278                Ok(Some(bin)) => bin,
279                Ok(None) => {
280                    // Channel closed
281                    break;
282                }
283                Err(_) => {
284                    // Timeout, check if still running
285                    continue;
286                }
287            };
288
289            trace!("Bin {} processing frame {}", self.bin_index, bin.frame_id);
290
291            // Update state
292            {
293                let mut state = self.state.write();
294                state.update(&bin);
295            }
296
297            // Send neighbor data via K2K
298            self.send_neighbor_data().await?;
299
300            // Receive neighbor data via K2K
301            self.receive_neighbor_data().await?;
302
303            // Perform separation
304            let separated = self.compute_separation();
305
306            // Send output
307            if self.output_tx.send(separated).await.is_err() {
308                warn!("Output channel closed for bin {}", self.bin_index);
309                break;
310            }
311
312            self.frame_counter.fetch_add(1, Ordering::Relaxed);
313        }
314
315        info!("Bin actor {} stopped", self.bin_index);
316        Ok(())
317    }
318
319    /// Send neighbor data to adjacent bins.
320    async fn send_neighbor_data(&mut self) -> Result<()> {
321        let neighbor_data = self.state.read().to_neighbor_data();
322
323        // Send to left neighbor
324        if let Some(left_id) = &self.left_neighbor_id {
325            let envelope = MessageEnvelope::new(
326                &neighbor_data,
327                self.bin_index as u64,
328                (self.bin_index - 1) as u64,
329                HlcTimestamp::now(self.bin_index as u64),
330            );
331
332            match self.endpoint.send(left_id.clone(), envelope).await {
333                Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
334                    trace!("Sent to left neighbor {}", left_id);
335                }
336                Ok(receipt) => {
337                    trace!("Left neighbor delivery status: {:?}", receipt.status);
338                }
339                Err(e) => {
340                    trace!("Failed to send to left neighbor: {}", e);
341                }
342            }
343        }
344
345        // Send to right neighbor
346        if let Some(right_id) = &self.right_neighbor_id {
347            let envelope = MessageEnvelope::new(
348                &neighbor_data,
349                self.bin_index as u64,
350                (self.bin_index + 1) as u64,
351                HlcTimestamp::now(self.bin_index as u64),
352            );
353
354            match self.endpoint.send(right_id.clone(), envelope).await {
355                Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
356                    trace!("Sent to right neighbor {}", right_id);
357                }
358                Ok(receipt) => {
359                    trace!("Right neighbor delivery status: {:?}", receipt.status);
360                }
361                Err(e) => {
362                    trace!("Failed to send to right neighbor: {}", e);
363                }
364            }
365        }
366
367        Ok(())
368    }
369
370    /// Receive neighbor data from adjacent bins.
371    async fn receive_neighbor_data(&mut self) -> Result<()> {
372        let has_left = self.left_neighbor_id.is_some();
373        let has_right = self.right_neighbor_id.is_some();
374
375        // Try to receive with a short timeout
376        let timeout = std::time::Duration::from_millis(10);
377        let deadline = std::time::Instant::now() + timeout;
378
379        while std::time::Instant::now() < deadline {
380            match self.endpoint.try_receive() {
381                Some(k2k_msg) => {
382                    // Deserialize neighbor data
383                    if let Ok(neighbor_data) = NeighborData::deserialize(&k2k_msg.envelope.payload)
384                    {
385                        let is_left = neighbor_data.source_bin < self.bin_index;
386                        let mut state = self.state.write();
387                        state.set_neighbor(neighbor_data, is_left);
388
389                        if state.has_all_neighbors(has_left, has_right) {
390                            break;
391                        }
392                    }
393                }
394                None => {
395                    // No message available, brief yield
396                    tokio::task::yield_now().await;
397                }
398            }
399        }
400
401        Ok(())
402    }
403
404    /// Compute the separation of direct and ambient signals.
405    fn compute_separation(&mut self) -> SeparatedBin {
406        let state = self.state.read();
407
408        // Compute coherence based on neighbor data
409        let (coherence, transient) = self.analyzer.analyze(
410            &state.current_value,
411            state.left_neighbor.as_ref(),
412            state.right_neighbor.as_ref(),
413            state.phase_derivative,
414            state.spectral_flux,
415        );
416
417        // Update smoothed coherence
418        drop(state);
419        {
420            let mut state = self.state.write();
421            state.coherence = coherence;
422            state.smoothed_coherence = state.smoothed_coherence * self.config.temporal_smoothing
423                + coherence * (1.0 - self.config.temporal_smoothing);
424        }
425
426        let state = self.state.read();
427        let smoothed = state.smoothed_coherence;
428
429        // Separate direct and ambient components
430        let direct_ratio = smoothed.powf(self.config.separation_curve);
431        let ambient_ratio = 1.0 - direct_ratio;
432
433        let direct = state.current_value.scale(direct_ratio);
434        let ambience = state.current_value.scale(ambient_ratio);
435
436        SeparatedBin::new(
437            state.current_frame,
438            self.bin_index,
439            direct,
440            ambience,
441            smoothed,
442            transient,
443        )
444    }
445}
446
447/// Network of bin actors with K2K messaging.
448pub struct BinNetwork {
449    /// Number of bins.
450    num_bins: usize,
451    /// K2K broker.
452    broker: Arc<K2KBroker>,
453    /// Actor handles.
454    handles: Vec<BinActorHandle>,
455    /// Actor tasks.
456    tasks: Vec<tokio::task::JoinHandle<Result<()>>>,
457    /// Configuration (reserved for runtime reconfiguration).
458    #[allow(dead_code)]
459    config: SeparationConfig,
460    /// Running flag.
461    running: Arc<AtomicBool>,
462}
463
464impl BinNetwork {
465    /// Create a new bin network.
466    pub async fn new(num_bins: usize, config: SeparationConfig) -> Result<Self> {
467        info!("Creating bin network with {} bins", num_bins);
468
469        let broker = K2KBuilder::new()
470            .max_pending_messages(num_bins * 4)
471            .delivery_timeout_ms(100)
472            .build();
473
474        let mut actors: Vec<BinActor> = Vec::with_capacity(num_bins);
475        let mut handles: Vec<BinActorHandle> = Vec::with_capacity(num_bins);
476
477        // Create all actors
478        for i in 0..num_bins {
479            let (actor, handle) = BinActor::new(i as u32, num_bins as u32, &broker, config.clone());
480            actors.push(actor);
481            handles.push(handle);
482        }
483
484        // Set up neighbor relationships
485        for (i, actor) in actors.iter_mut().enumerate() {
486            let left = if i > 0 {
487                Some(KernelId::new(format!("bin_actor_{}", i - 1)))
488            } else {
489                None
490            };
491            let right = if i < num_bins - 1 {
492                Some(KernelId::new(format!("bin_actor_{}", i + 1)))
493            } else {
494                None
495            };
496            actor.set_neighbors(left, right);
497        }
498
499        let running = Arc::new(AtomicBool::new(true));
500
501        // Spawn actor tasks
502        let mut tasks = Vec::with_capacity(num_bins);
503        for mut actor in actors {
504            let task = tokio::spawn(async move { actor.run().await });
505            tasks.push(task);
506        }
507
508        Ok(Self {
509            num_bins,
510            broker,
511            handles,
512            tasks,
513            config,
514            running,
515        })
516    }
517
518    /// Get the number of bins.
519    pub fn num_bins(&self) -> usize {
520        self.num_bins
521    }
522
523    /// Get a handle to a specific bin.
524    pub fn get_handle(&self, bin_index: usize) -> Option<&BinActorHandle> {
525        self.handles.get(bin_index)
526    }
527
528    /// Send bin data to all actors.
529    pub async fn send_bins(&self, bins: &[FrequencyBin]) -> Result<()> {
530        for (i, bin) in bins.iter().enumerate() {
531            if i < self.handles.len() {
532                self.handles[i].send_bin(bin.clone()).await?;
533            }
534        }
535        Ok(())
536    }
537
538    /// Receive separated bins from all actors.
539    pub async fn receive_separated(&mut self) -> Result<Vec<SeparatedBin>> {
540        let mut results = Vec::with_capacity(self.num_bins);
541
542        for handle in &mut self.handles {
543            if let Some(separated) = handle.receive_separated().await {
544                results.push(separated);
545            }
546        }
547
548        // Sort by bin index
549        results.sort_by_key(|b| b.bin_index);
550
551        Ok(results)
552    }
553
554    /// Process a frame of FFT bins and return separated bins.
555    pub async fn process_frame(
556        &mut self,
557        frame_id: u64,
558        bins: &[Complex],
559        sample_rate: u32,
560        fft_size: usize,
561    ) -> Result<Vec<SeparatedBin>> {
562        // Convert to FrequencyBin messages
563        let freq_bins: Vec<FrequencyBin> = bins
564            .iter()
565            .enumerate()
566            .map(|(i, &value)| {
567                let frequency_hz = i as f32 * sample_rate as f32 / fft_size as f32;
568                FrequencyBin::new(frame_id, i as u32, bins.len() as u32, value, frequency_hz)
569            })
570            .collect();
571
572        // Send to actors
573        self.send_bins(&freq_bins).await?;
574
575        // Receive separated results
576        self.receive_separated().await
577    }
578
579    /// Stop all actors.
580    pub async fn stop(&mut self) -> Result<()> {
581        info!("Stopping bin network");
582        self.running.store(false, Ordering::Relaxed);
583
584        for handle in &self.handles {
585            handle.stop();
586        }
587
588        // Wait for tasks to complete
589        for task in self.tasks.drain(..) {
590            let _ = task.await;
591        }
592
593        Ok(())
594    }
595
596    /// Get K2K broker statistics.
597    pub fn k2k_stats(&self) -> K2KStats {
598        self.broker.stats()
599    }
600}
601
602impl Drop for BinNetwork {
603    fn drop(&mut self) {
604        self.running.store(false, Ordering::Relaxed);
605        for handle in &self.handles {
606            handle.stop();
607        }
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use super::*;
614
615    #[tokio::test]
616    async fn test_bin_network_creation() {
617        let config = SeparationConfig::default();
618        let network = BinNetwork::new(16, config).await.unwrap();
619
620        assert_eq!(network.num_bins(), 16);
621
622        let stats = network.k2k_stats();
623        assert!(stats.registered_endpoints >= 16);
624    }
625
626    #[test]
627    fn test_bin_actor_state() {
628        let mut state = BinActorState::new(5);
629        assert_eq!(state.bin_index, 5);
630        assert_eq!(state.coherence, 0.5);
631
632        let bin = FrequencyBin::new(1, 5, 1024, Complex::new(1.0, 0.0), 440.0);
633        state.update(&bin);
634
635        assert_eq!(state.current_frame, 1);
636        assert!((state.current_value.magnitude() - 1.0).abs() < 1e-6);
637    }
638}