1use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
10use std::sync::Arc;
11
12use parking_lot::RwLock;
13use tokio::sync::mpsc;
14use tracing::{info, trace, warn};
15
16use ringkernel_core::k2k::K2KStats;
17use ringkernel_core::prelude::*;
18
19use crate::error::{AudioFftError, Result};
20use crate::messages::{Complex, FrequencyBin, NeighborData, SeparatedBin};
21use crate::separation::{CoherenceAnalyzer, SeparationConfig};
22
23#[derive(Debug, Clone)]
25pub struct BinActorState {
26 pub bin_index: u32,
28 pub current_frame: u64,
30 pub current_value: Complex,
32 pub prev_value: Option<Complex>,
34 pub left_neighbor: Option<NeighborData>,
36 pub right_neighbor: Option<NeighborData>,
38 pub coherence: f32,
40 pub smoothed_coherence: f32,
42 pub phase_derivative: f32,
44 pub spectral_flux: f32,
46}
47
48impl BinActorState {
49 pub fn new(bin_index: u32) -> Self {
51 Self {
52 bin_index,
53 current_frame: 0,
54 current_value: Complex::default(),
55 prev_value: None,
56 left_neighbor: None,
57 right_neighbor: None,
58 coherence: 0.5,
59 smoothed_coherence: 0.5,
60 phase_derivative: 0.0,
61 spectral_flux: 0.0,
62 }
63 }
64
65 pub fn update(&mut self, bin: &FrequencyBin) {
67 self.prev_value = Some(self.current_value);
68 self.current_value = bin.value;
69 self.current_frame = bin.frame_id;
70
71 if let Some(prev) = self.prev_value {
73 let prev_phase = prev.phase();
74 let curr_phase = self.current_value.phase();
75 let mut phase_diff = curr_phase - prev_phase;
77 while phase_diff > std::f32::consts::PI {
78 phase_diff -= 2.0 * std::f32::consts::PI;
79 }
80 while phase_diff < -std::f32::consts::PI {
81 phase_diff += 2.0 * std::f32::consts::PI;
82 }
83 self.phase_derivative = phase_diff;
84
85 let prev_mag = prev.magnitude();
87 let curr_mag = self.current_value.magnitude();
88 self.spectral_flux = (curr_mag - prev_mag).max(0.0); }
90
91 self.left_neighbor = None;
93 self.right_neighbor = None;
94 }
95
96 pub fn set_neighbor(&mut self, data: NeighborData, is_left: bool) {
98 if is_left {
99 self.left_neighbor = Some(data);
100 } else {
101 self.right_neighbor = Some(data);
102 }
103 }
104
105 pub fn has_all_neighbors(&self, has_left: bool, has_right: bool) -> bool {
107 (!has_left || self.left_neighbor.is_some()) && (!has_right || self.right_neighbor.is_some())
108 }
109
110 pub fn to_neighbor_data(&self) -> NeighborData {
112 NeighborData {
113 source_bin: self.bin_index,
114 frame_id: self.current_frame,
115 value: self.current_value,
116 magnitude: self.current_value.magnitude(),
117 phase: self.current_value.phase(),
118 phase_derivative: self.phase_derivative,
119 spectral_flux: self.spectral_flux,
120 }
121 }
122}
123
124pub struct BinActorHandle {
126 pub bin_index: u32,
128 kernel_id: KernelId,
130 #[allow(dead_code)]
132 endpoint: K2KEndpoint,
133 state: Arc<RwLock<BinActorState>>,
135 input_tx: mpsc::Sender<FrequencyBin>,
137 output_rx: mpsc::Receiver<SeparatedBin>,
139 running: Arc<AtomicBool>,
141}
142
143impl BinActorHandle {
144 pub async fn send_bin(&self, bin: FrequencyBin) -> Result<()> {
146 self.input_tx
147 .send(bin)
148 .await
149 .map_err(|e| AudioFftError::kernel(format!("Failed to send bin data: {}", e)))
150 }
151
152 pub async fn receive_separated(&mut self) -> Option<SeparatedBin> {
154 self.output_rx.recv().await
155 }
156
157 pub fn state(&self) -> BinActorState {
159 self.state.read().clone()
160 }
161
162 pub fn kernel_id(&self) -> &KernelId {
164 &self.kernel_id
165 }
166
167 pub fn is_running(&self) -> bool {
169 self.running.load(Ordering::Relaxed)
170 }
171
172 pub fn stop(&self) {
174 self.running.store(false, Ordering::Relaxed);
175 }
176}
177
178pub struct BinActor {
180 bin_index: u32,
182 #[allow(dead_code)]
184 total_bins: u32,
185 #[allow(dead_code)]
187 kernel_id: KernelId,
188 state: Arc<RwLock<BinActorState>>,
190 endpoint: K2KEndpoint,
192 left_neighbor_id: Option<KernelId>,
194 right_neighbor_id: Option<KernelId>,
196 input_rx: mpsc::Receiver<FrequencyBin>,
198 output_tx: mpsc::Sender<SeparatedBin>,
200 analyzer: CoherenceAnalyzer,
202 config: SeparationConfig,
204 running: Arc<AtomicBool>,
206 frame_counter: AtomicU64,
208}
209
210impl BinActor {
211 pub fn new(
213 bin_index: u32,
214 total_bins: u32,
215 broker: &Arc<K2KBroker>,
216 config: SeparationConfig,
217 ) -> (Self, BinActorHandle) {
218 let kernel_id = KernelId::new(format!("bin_actor_{}", bin_index));
219 let endpoint = broker.register(kernel_id.clone());
220
221 let state = Arc::new(RwLock::new(BinActorState::new(bin_index)));
222 let running = Arc::new(AtomicBool::new(true));
223
224 let (input_tx, input_rx) = mpsc::channel(64);
225 let (output_tx, output_rx) = mpsc::channel(64);
226
227 let handle_endpoint =
229 broker.register(KernelId::new(format!("bin_actor_{}_handle", bin_index)));
230
231 let handle = BinActorHandle {
232 bin_index,
233 kernel_id: kernel_id.clone(),
234 endpoint: handle_endpoint,
235 state: state.clone(),
236 input_tx,
237 output_rx,
238 running: running.clone(),
239 };
240
241 let actor = Self {
242 bin_index,
243 total_bins,
244 kernel_id,
245 state,
246 endpoint,
247 left_neighbor_id: None,
248 right_neighbor_id: None,
249 input_rx,
250 output_tx,
251 analyzer: CoherenceAnalyzer::new(config.clone()),
252 config,
253 running,
254 frame_counter: AtomicU64::new(0),
255 };
256
257 (actor, handle)
258 }
259
260 pub fn set_neighbors(&mut self, left: Option<KernelId>, right: Option<KernelId>) {
262 self.left_neighbor_id = left;
263 self.right_neighbor_id = right;
264 }
265
266 pub async fn run(&mut self) -> Result<()> {
268 info!("Bin actor {} starting", self.bin_index);
269
270 while self.running.load(Ordering::Relaxed) {
271 let bin = match tokio::time::timeout(
273 std::time::Duration::from_millis(100),
274 self.input_rx.recv(),
275 )
276 .await
277 {
278 Ok(Some(bin)) => bin,
279 Ok(None) => {
280 break;
282 }
283 Err(_) => {
284 continue;
286 }
287 };
288
289 trace!("Bin {} processing frame {}", self.bin_index, bin.frame_id);
290
291 {
293 let mut state = self.state.write();
294 state.update(&bin);
295 }
296
297 self.send_neighbor_data().await?;
299
300 self.receive_neighbor_data().await?;
302
303 let separated = self.compute_separation();
305
306 if self.output_tx.send(separated).await.is_err() {
308 warn!("Output channel closed for bin {}", self.bin_index);
309 break;
310 }
311
312 self.frame_counter.fetch_add(1, Ordering::Relaxed);
313 }
314
315 info!("Bin actor {} stopped", self.bin_index);
316 Ok(())
317 }
318
319 async fn send_neighbor_data(&mut self) -> Result<()> {
321 let neighbor_data = self.state.read().to_neighbor_data();
322
323 if let Some(left_id) = &self.left_neighbor_id {
325 let envelope = MessageEnvelope::new(
326 &neighbor_data,
327 self.bin_index as u64,
328 (self.bin_index - 1) as u64,
329 HlcTimestamp::now(self.bin_index as u64),
330 );
331
332 match self.endpoint.send(left_id.clone(), envelope).await {
333 Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
334 trace!("Sent to left neighbor {}", left_id);
335 }
336 Ok(receipt) => {
337 trace!("Left neighbor delivery status: {:?}", receipt.status);
338 }
339 Err(e) => {
340 trace!("Failed to send to left neighbor: {}", e);
341 }
342 }
343 }
344
345 if let Some(right_id) = &self.right_neighbor_id {
347 let envelope = MessageEnvelope::new(
348 &neighbor_data,
349 self.bin_index as u64,
350 (self.bin_index + 1) as u64,
351 HlcTimestamp::now(self.bin_index as u64),
352 );
353
354 match self.endpoint.send(right_id.clone(), envelope).await {
355 Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
356 trace!("Sent to right neighbor {}", right_id);
357 }
358 Ok(receipt) => {
359 trace!("Right neighbor delivery status: {:?}", receipt.status);
360 }
361 Err(e) => {
362 trace!("Failed to send to right neighbor: {}", e);
363 }
364 }
365 }
366
367 Ok(())
368 }
369
370 async fn receive_neighbor_data(&mut self) -> Result<()> {
372 let has_left = self.left_neighbor_id.is_some();
373 let has_right = self.right_neighbor_id.is_some();
374
375 let timeout = std::time::Duration::from_millis(10);
377 let deadline = std::time::Instant::now() + timeout;
378
379 while std::time::Instant::now() < deadline {
380 match self.endpoint.try_receive() {
381 Some(k2k_msg) => {
382 if let Ok(neighbor_data) = NeighborData::deserialize(&k2k_msg.envelope.payload)
384 {
385 let is_left = neighbor_data.source_bin < self.bin_index;
386 let mut state = self.state.write();
387 state.set_neighbor(neighbor_data, is_left);
388
389 if state.has_all_neighbors(has_left, has_right) {
390 break;
391 }
392 }
393 }
394 None => {
395 tokio::task::yield_now().await;
397 }
398 }
399 }
400
401 Ok(())
402 }
403
404 fn compute_separation(&mut self) -> SeparatedBin {
406 let state = self.state.read();
407
408 let (coherence, transient) = self.analyzer.analyze(
410 &state.current_value,
411 state.left_neighbor.as_ref(),
412 state.right_neighbor.as_ref(),
413 state.phase_derivative,
414 state.spectral_flux,
415 );
416
417 drop(state);
419 {
420 let mut state = self.state.write();
421 state.coherence = coherence;
422 state.smoothed_coherence = state.smoothed_coherence * self.config.temporal_smoothing
423 + coherence * (1.0 - self.config.temporal_smoothing);
424 }
425
426 let state = self.state.read();
427 let smoothed = state.smoothed_coherence;
428
429 let direct_ratio = smoothed.powf(self.config.separation_curve);
431 let ambient_ratio = 1.0 - direct_ratio;
432
433 let direct = state.current_value.scale(direct_ratio);
434 let ambience = state.current_value.scale(ambient_ratio);
435
436 SeparatedBin::new(
437 state.current_frame,
438 self.bin_index,
439 direct,
440 ambience,
441 smoothed,
442 transient,
443 )
444 }
445}
446
447pub struct BinNetwork {
449 num_bins: usize,
451 broker: Arc<K2KBroker>,
453 handles: Vec<BinActorHandle>,
455 tasks: Vec<tokio::task::JoinHandle<Result<()>>>,
457 #[allow(dead_code)]
459 config: SeparationConfig,
460 running: Arc<AtomicBool>,
462}
463
464impl BinNetwork {
465 pub async fn new(num_bins: usize, config: SeparationConfig) -> Result<Self> {
467 info!("Creating bin network with {} bins", num_bins);
468
469 let broker = K2KBuilder::new()
470 .max_pending_messages(num_bins * 4)
471 .delivery_timeout_ms(100)
472 .build();
473
474 let mut actors: Vec<BinActor> = Vec::with_capacity(num_bins);
475 let mut handles: Vec<BinActorHandle> = Vec::with_capacity(num_bins);
476
477 for i in 0..num_bins {
479 let (actor, handle) = BinActor::new(i as u32, num_bins as u32, &broker, config.clone());
480 actors.push(actor);
481 handles.push(handle);
482 }
483
484 for (i, actor) in actors.iter_mut().enumerate() {
486 let left = if i > 0 {
487 Some(KernelId::new(format!("bin_actor_{}", i - 1)))
488 } else {
489 None
490 };
491 let right = if i < num_bins - 1 {
492 Some(KernelId::new(format!("bin_actor_{}", i + 1)))
493 } else {
494 None
495 };
496 actor.set_neighbors(left, right);
497 }
498
499 let running = Arc::new(AtomicBool::new(true));
500
501 let mut tasks = Vec::with_capacity(num_bins);
503 for mut actor in actors {
504 let task = tokio::spawn(async move { actor.run().await });
505 tasks.push(task);
506 }
507
508 Ok(Self {
509 num_bins,
510 broker,
511 handles,
512 tasks,
513 config,
514 running,
515 })
516 }
517
518 pub fn num_bins(&self) -> usize {
520 self.num_bins
521 }
522
523 pub fn get_handle(&self, bin_index: usize) -> Option<&BinActorHandle> {
525 self.handles.get(bin_index)
526 }
527
528 pub async fn send_bins(&self, bins: &[FrequencyBin]) -> Result<()> {
530 for (i, bin) in bins.iter().enumerate() {
531 if i < self.handles.len() {
532 self.handles[i].send_bin(bin.clone()).await?;
533 }
534 }
535 Ok(())
536 }
537
538 pub async fn receive_separated(&mut self) -> Result<Vec<SeparatedBin>> {
540 let mut results = Vec::with_capacity(self.num_bins);
541
542 for handle in &mut self.handles {
543 if let Some(separated) = handle.receive_separated().await {
544 results.push(separated);
545 }
546 }
547
548 results.sort_by_key(|b| b.bin_index);
550
551 Ok(results)
552 }
553
554 pub async fn process_frame(
556 &mut self,
557 frame_id: u64,
558 bins: &[Complex],
559 sample_rate: u32,
560 fft_size: usize,
561 ) -> Result<Vec<SeparatedBin>> {
562 let freq_bins: Vec<FrequencyBin> = bins
564 .iter()
565 .enumerate()
566 .map(|(i, &value)| {
567 let frequency_hz = i as f32 * sample_rate as f32 / fft_size as f32;
568 FrequencyBin::new(frame_id, i as u32, bins.len() as u32, value, frequency_hz)
569 })
570 .collect();
571
572 self.send_bins(&freq_bins).await?;
574
575 self.receive_separated().await
577 }
578
579 pub async fn stop(&mut self) -> Result<()> {
581 info!("Stopping bin network");
582 self.running.store(false, Ordering::Relaxed);
583
584 for handle in &self.handles {
585 handle.stop();
586 }
587
588 for task in self.tasks.drain(..) {
590 let _ = task.await;
591 }
592
593 Ok(())
594 }
595
596 pub fn k2k_stats(&self) -> K2KStats {
598 self.broker.stats()
599 }
600}
601
602impl Drop for BinNetwork {
603 fn drop(&mut self) {
604 self.running.store(false, Ordering::Relaxed);
605 for handle in &self.handles {
606 handle.stop();
607 }
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use super::*;
614
615 #[tokio::test]
616 async fn test_bin_network_creation() {
617 let config = SeparationConfig::default();
618 let network = BinNetwork::new(16, config).await.unwrap();
619
620 assert_eq!(network.num_bins(), 16);
621
622 let stats = network.k2k_stats();
623 assert!(stats.registered_endpoints >= 16);
624 }
625
626 #[test]
627 fn test_bin_actor_state() {
628 let mut state = BinActorState::new(5);
629 assert_eq!(state.bin_index, 5);
630 assert_eq!(state.coherence, 0.5);
631
632 let bin = FrequencyBin::new(1, 5, 1024, Complex::new(1.0, 0.0), 440.0);
633 state.update(&bin);
634
635 assert_eq!(state.current_frame, 1);
636 assert!((state.current_value.magnitude() - 1.0).abs() < 1e-6);
637 }
638}