use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use parking_lot::RwLock;
use tokio::sync::mpsc;
use tracing::{info, trace, warn};
use ringkernel_core::k2k::K2KStats;
use ringkernel_core::prelude::*;
use crate::error::{AudioFftError, Result};
use crate::messages::{Complex, FrequencyBin, NeighborData, SeparatedBin};
use crate::separation::{CoherenceAnalyzer, SeparationConfig};
#[derive(Debug, Clone)]
pub struct BinActorState {
pub bin_index: u32,
pub current_frame: u64,
pub current_value: Complex,
pub prev_value: Option<Complex>,
pub left_neighbor: Option<NeighborData>,
pub right_neighbor: Option<NeighborData>,
pub coherence: f32,
pub smoothed_coherence: f32,
pub phase_derivative: f32,
pub spectral_flux: f32,
}
impl BinActorState {
pub fn new(bin_index: u32) -> Self {
Self {
bin_index,
current_frame: 0,
current_value: Complex::default(),
prev_value: None,
left_neighbor: None,
right_neighbor: None,
coherence: 0.5,
smoothed_coherence: 0.5,
phase_derivative: 0.0,
spectral_flux: 0.0,
}
}
pub fn update(&mut self, bin: &FrequencyBin) {
self.prev_value = Some(self.current_value);
self.current_value = bin.value;
self.current_frame = bin.frame_id;
if let Some(prev) = self.prev_value {
let prev_phase = prev.phase();
let curr_phase = self.current_value.phase();
let mut phase_diff = curr_phase - prev_phase;
while phase_diff > std::f32::consts::PI {
phase_diff -= 2.0 * std::f32::consts::PI;
}
while phase_diff < -std::f32::consts::PI {
phase_diff += 2.0 * std::f32::consts::PI;
}
self.phase_derivative = phase_diff;
let prev_mag = prev.magnitude();
let curr_mag = self.current_value.magnitude();
self.spectral_flux = (curr_mag - prev_mag).max(0.0); }
self.left_neighbor = None;
self.right_neighbor = None;
}
pub fn set_neighbor(&mut self, data: NeighborData, is_left: bool) {
if is_left {
self.left_neighbor = Some(data);
} else {
self.right_neighbor = Some(data);
}
}
pub fn has_all_neighbors(&self, has_left: bool, has_right: bool) -> bool {
(!has_left || self.left_neighbor.is_some()) && (!has_right || self.right_neighbor.is_some())
}
pub fn to_neighbor_data(&self) -> NeighborData {
NeighborData {
source_bin: self.bin_index,
frame_id: self.current_frame,
value: self.current_value,
magnitude: self.current_value.magnitude(),
phase: self.current_value.phase(),
phase_derivative: self.phase_derivative,
spectral_flux: self.spectral_flux,
}
}
}
pub struct BinActorHandle {
pub bin_index: u32,
kernel_id: KernelId,
#[allow(dead_code)]
endpoint: K2KEndpoint,
state: Arc<RwLock<BinActorState>>,
input_tx: mpsc::Sender<FrequencyBin>,
output_rx: mpsc::Receiver<SeparatedBin>,
running: Arc<AtomicBool>,
}
impl BinActorHandle {
pub async fn send_bin(&self, bin: FrequencyBin) -> Result<()> {
self.input_tx
.send(bin)
.await
.map_err(|e| AudioFftError::kernel(format!("Failed to send bin data: {}", e)))
}
pub async fn receive_separated(&mut self) -> Option<SeparatedBin> {
self.output_rx.recv().await
}
pub fn state(&self) -> BinActorState {
self.state.read().clone()
}
pub fn kernel_id(&self) -> &KernelId {
&self.kernel_id
}
pub fn is_running(&self) -> bool {
self.running.load(Ordering::Relaxed)
}
pub fn stop(&self) {
self.running.store(false, Ordering::Relaxed);
}
}
pub struct BinActor {
bin_index: u32,
#[allow(dead_code)]
total_bins: u32,
#[allow(dead_code)]
kernel_id: KernelId,
state: Arc<RwLock<BinActorState>>,
endpoint: K2KEndpoint,
left_neighbor_id: Option<KernelId>,
right_neighbor_id: Option<KernelId>,
input_rx: mpsc::Receiver<FrequencyBin>,
output_tx: mpsc::Sender<SeparatedBin>,
analyzer: CoherenceAnalyzer,
config: SeparationConfig,
running: Arc<AtomicBool>,
frame_counter: AtomicU64,
}
impl BinActor {
pub fn new(
bin_index: u32,
total_bins: u32,
broker: &Arc<K2KBroker>,
config: SeparationConfig,
) -> (Self, BinActorHandle) {
let kernel_id = KernelId::new(format!("bin_actor_{}", bin_index));
let endpoint = broker.register(kernel_id.clone());
let state = Arc::new(RwLock::new(BinActorState::new(bin_index)));
let running = Arc::new(AtomicBool::new(true));
let (input_tx, input_rx) = mpsc::channel(64);
let (output_tx, output_rx) = mpsc::channel(64);
let handle_endpoint =
broker.register(KernelId::new(format!("bin_actor_{}_handle", bin_index)));
let handle = BinActorHandle {
bin_index,
kernel_id: kernel_id.clone(),
endpoint: handle_endpoint,
state: state.clone(),
input_tx,
output_rx,
running: running.clone(),
};
let actor = Self {
bin_index,
total_bins,
kernel_id,
state,
endpoint,
left_neighbor_id: None,
right_neighbor_id: None,
input_rx,
output_tx,
analyzer: CoherenceAnalyzer::new(config.clone()),
config,
running,
frame_counter: AtomicU64::new(0),
};
(actor, handle)
}
pub fn set_neighbors(&mut self, left: Option<KernelId>, right: Option<KernelId>) {
self.left_neighbor_id = left;
self.right_neighbor_id = right;
}
pub async fn run(&mut self) -> Result<()> {
info!("Bin actor {} starting", self.bin_index);
while self.running.load(Ordering::Relaxed) {
let bin = match tokio::time::timeout(
std::time::Duration::from_millis(100),
self.input_rx.recv(),
)
.await
{
Ok(Some(bin)) => bin,
Ok(None) => {
break;
}
Err(_) => {
continue;
}
};
trace!("Bin {} processing frame {}", self.bin_index, bin.frame_id);
{
let mut state = self.state.write();
state.update(&bin);
}
self.send_neighbor_data().await?;
self.receive_neighbor_data().await?;
let separated = self.compute_separation();
if self.output_tx.send(separated).await.is_err() {
warn!("Output channel closed for bin {}", self.bin_index);
break;
}
self.frame_counter.fetch_add(1, Ordering::Relaxed);
}
info!("Bin actor {} stopped", self.bin_index);
Ok(())
}
async fn send_neighbor_data(&mut self) -> Result<()> {
let neighbor_data = self.state.read().to_neighbor_data();
if let Some(left_id) = &self.left_neighbor_id {
let envelope = MessageEnvelope::new(
&neighbor_data,
self.bin_index as u64,
(self.bin_index - 1) as u64,
HlcTimestamp::now(self.bin_index as u64),
);
match self.endpoint.send(left_id.clone(), envelope).await {
Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
trace!("Sent to left neighbor {}", left_id);
}
Ok(receipt) => {
trace!("Left neighbor delivery status: {:?}", receipt.status);
}
Err(e) => {
trace!("Failed to send to left neighbor: {}", e);
}
}
}
if let Some(right_id) = &self.right_neighbor_id {
let envelope = MessageEnvelope::new(
&neighbor_data,
self.bin_index as u64,
(self.bin_index + 1) as u64,
HlcTimestamp::now(self.bin_index as u64),
);
match self.endpoint.send(right_id.clone(), envelope).await {
Ok(receipt) if receipt.status == DeliveryStatus::Delivered => {
trace!("Sent to right neighbor {}", right_id);
}
Ok(receipt) => {
trace!("Right neighbor delivery status: {:?}", receipt.status);
}
Err(e) => {
trace!("Failed to send to right neighbor: {}", e);
}
}
}
Ok(())
}
async fn receive_neighbor_data(&mut self) -> Result<()> {
let has_left = self.left_neighbor_id.is_some();
let has_right = self.right_neighbor_id.is_some();
let timeout = std::time::Duration::from_millis(10);
let deadline = std::time::Instant::now() + timeout;
while std::time::Instant::now() < deadline {
match self.endpoint.try_receive() {
Some(k2k_msg) => {
if let Ok(neighbor_data) = NeighborData::deserialize(&k2k_msg.envelope.payload)
{
let is_left = neighbor_data.source_bin < self.bin_index;
let mut state = self.state.write();
state.set_neighbor(neighbor_data, is_left);
if state.has_all_neighbors(has_left, has_right) {
break;
}
}
}
None => {
tokio::task::yield_now().await;
}
}
}
Ok(())
}
fn compute_separation(&mut self) -> SeparatedBin {
let state = self.state.read();
let (coherence, transient) = self.analyzer.analyze(
&state.current_value,
state.left_neighbor.as_ref(),
state.right_neighbor.as_ref(),
state.phase_derivative,
state.spectral_flux,
);
drop(state);
{
let mut state = self.state.write();
state.coherence = coherence;
state.smoothed_coherence = state.smoothed_coherence * self.config.temporal_smoothing
+ coherence * (1.0 - self.config.temporal_smoothing);
}
let state = self.state.read();
let smoothed = state.smoothed_coherence;
let direct_ratio = smoothed.powf(self.config.separation_curve);
let ambient_ratio = 1.0 - direct_ratio;
let direct = state.current_value.scale(direct_ratio);
let ambience = state.current_value.scale(ambient_ratio);
SeparatedBin::new(
state.current_frame,
self.bin_index,
direct,
ambience,
smoothed,
transient,
)
}
}
pub struct BinNetwork {
num_bins: usize,
broker: Arc<K2KBroker>,
handles: Vec<BinActorHandle>,
tasks: Vec<tokio::task::JoinHandle<Result<()>>>,
#[allow(dead_code)]
config: SeparationConfig,
running: Arc<AtomicBool>,
}
impl BinNetwork {
pub async fn new(num_bins: usize, config: SeparationConfig) -> Result<Self> {
info!("Creating bin network with {} bins", num_bins);
let broker = K2KBuilder::new()
.max_pending_messages(num_bins * 4)
.delivery_timeout_ms(100)
.build();
let mut actors: Vec<BinActor> = Vec::with_capacity(num_bins);
let mut handles: Vec<BinActorHandle> = Vec::with_capacity(num_bins);
for i in 0..num_bins {
let (actor, handle) = BinActor::new(i as u32, num_bins as u32, &broker, config.clone());
actors.push(actor);
handles.push(handle);
}
for (i, actor) in actors.iter_mut().enumerate() {
let left = if i > 0 {
Some(KernelId::new(format!("bin_actor_{}", i - 1)))
} else {
None
};
let right = if i < num_bins - 1 {
Some(KernelId::new(format!("bin_actor_{}", i + 1)))
} else {
None
};
actor.set_neighbors(left, right);
}
let running = Arc::new(AtomicBool::new(true));
let mut tasks = Vec::with_capacity(num_bins);
for mut actor in actors {
let task = tokio::spawn(async move { actor.run().await });
tasks.push(task);
}
Ok(Self {
num_bins,
broker,
handles,
tasks,
config,
running,
})
}
pub fn num_bins(&self) -> usize {
self.num_bins
}
pub fn get_handle(&self, bin_index: usize) -> Option<&BinActorHandle> {
self.handles.get(bin_index)
}
pub async fn send_bins(&self, bins: &[FrequencyBin]) -> Result<()> {
for (i, bin) in bins.iter().enumerate() {
if i < self.handles.len() {
self.handles[i].send_bin(bin.clone()).await?;
}
}
Ok(())
}
pub async fn receive_separated(&mut self) -> Result<Vec<SeparatedBin>> {
let mut results = Vec::with_capacity(self.num_bins);
for handle in &mut self.handles {
if let Some(separated) = handle.receive_separated().await {
results.push(separated);
}
}
results.sort_by_key(|b| b.bin_index);
Ok(results)
}
pub async fn process_frame(
&mut self,
frame_id: u64,
bins: &[Complex],
sample_rate: u32,
fft_size: usize,
) -> Result<Vec<SeparatedBin>> {
let freq_bins: Vec<FrequencyBin> = bins
.iter()
.enumerate()
.map(|(i, &value)| {
let frequency_hz = i as f32 * sample_rate as f32 / fft_size as f32;
FrequencyBin::new(frame_id, i as u32, bins.len() as u32, value, frequency_hz)
})
.collect();
self.send_bins(&freq_bins).await?;
self.receive_separated().await
}
pub async fn stop(&mut self) -> Result<()> {
info!("Stopping bin network");
self.running.store(false, Ordering::Relaxed);
for handle in &self.handles {
handle.stop();
}
for task in self.tasks.drain(..) {
let _ = task.await;
}
Ok(())
}
pub fn k2k_stats(&self) -> K2KStats {
self.broker.stats()
}
}
impl Drop for BinNetwork {
fn drop(&mut self) {
self.running.store(false, Ordering::Relaxed);
for handle in &self.handles {
handle.stop();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bin_network_creation() {
let config = SeparationConfig::default();
let network = BinNetwork::new(16, config).await.unwrap();
assert_eq!(network.num_bins(), 16);
let stats = network.k2k_stats();
assert!(stats.registered_endpoints >= 16);
}
#[test]
fn test_bin_actor_state() {
let mut state = BinActorState::new(5);
assert_eq!(state.bin_index, 5);
assert_eq!(state.coherence, 0.5);
let bin = FrequencyBin::new(1, 5, 1024, Complex::new(1.0, 0.0), 440.0);
state.update(&bin);
assert_eq!(state.current_frame, 1);
assert!((state.current_value.magnitude() - 1.0).abs() < 1e-6);
}
}