use crate::{Actor, ActorBehavior, Message, Port};
use anyhow::{Error, Result};
use reflow_actor::{
message::EncodableValue,
stream::{spawn_stream_task, stream_collect, StreamFrame},
ActorContext,
};
use reflow_actor_macro::actor;
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
#[actor(
CorrelatorActor,
inports::<100>(stream_a, stream_b),
outports::<50>(stream, stats, error),
state(MemoryState)
)]
pub async fn correlator_actor(context: ActorContext) -> Result<HashMap<String, Message>, Error> {
let config = context.get_config_hashmap();
let sample_rate = config
.get("sampleRate")
.and_then(|v| v.as_f64())
.unwrap_or(44100.0);
let max_lag_ms = config
.get("maxLagMs")
.and_then(|v| v.as_f64())
.unwrap_or(100.0);
let rx_a = match context.take_stream_receiver("stream_a") {
Some(rx) => rx,
None => return Ok(error_output("No StreamHandle on stream_a port")),
};
let rx_b = match context.take_stream_receiver("stream_b") {
Some(rx) => rx,
None => return Ok(error_output("No StreamHandle on stream_b port")),
};
let (tx, handle) =
context.create_stream("stream", Some("audio/correlation".to_string()), None, None);
let (stats_tx, stats_rx) = flume::bounded::<serde_json::Value>(1);
spawn_stream_task(async move {
let result_a = stream_collect(rx_a).await;
let result_b = stream_collect(rx_b).await;
let (samples_a, samples_b) = match (result_a, result_b) {
(Ok((_, _, bytes_a)), Ok((_, _, bytes_b))) => {
let a: Vec<f32> = bytes_a
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
let b: Vec<f32> = bytes_b
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
(a, b)
}
_ => {
let _ = tx
.send_async(StreamFrame::Error("Failed to collect streams".to_string()))
.await;
return;
}
};
let max_lag = (max_lag_ms * sample_rate / 1000.0) as usize;
let len = samples_a.len().min(samples_b.len());
if len == 0 {
let _ = tx
.send_async(StreamFrame::Error("Empty input streams".to_string()))
.await;
return;
}
let rms_a: f32 = (samples_a[..len].iter().map(|s| s * s).sum::<f32>() / len as f32).sqrt();
let rms_b: f32 = (samples_b[..len].iter().map(|s| s * s).sum::<f32>() / len as f32).sqrt();
let norm = rms_a * rms_b * len as f32;
let mut correlations: Vec<f32> = Vec::with_capacity(max_lag * 2 + 1);
let mut best_lag: i32 = 0;
let mut best_corr: f32 = f32::NEG_INFINITY;
for lag in -(max_lag as i32)..=(max_lag as i32) {
let mut sum: f32 = 0.0;
let mut count = 0;
for i in 0..len {
let j = i as i32 + lag;
if j >= 0 && (j as usize) < len {
sum += samples_a[i] * samples_b[j as usize];
count += 1;
}
}
let corr = if norm > 0.0 && count > 0 {
sum / norm
} else {
0.0
};
correlations.push(corr);
if corr > best_corr {
best_corr = corr;
best_lag = lag;
}
}
let _ = tx
.send_async(StreamFrame::Begin {
content_type: Some("audio/correlation".to_string()),
size_hint: None,
metadata: Some(json!({
"maxLag": max_lag,
"sampleRate": sample_rate,
})),
})
.await;
let bytes: Vec<u8> = correlations.iter().flat_map(|c| c.to_le_bytes()).collect();
let _ = tx.send_async(StreamFrame::Data(Arc::new(bytes))).await;
let _ = tx.send_async(StreamFrame::End).await;
let delay_ms = best_lag as f64 / sample_rate * 1000.0;
let _ = stats_tx.send(json!({
"bestLagSamples": best_lag,
"bestLagMs": delay_ms,
"peakCorrelation": best_corr,
"signalLength": len,
}));
});
let stats_value = stats_rx.recv_async().await.unwrap_or(json!({}));
let mut results = HashMap::new();
results.insert("stream".to_string(), Message::stream_handle(handle));
results.insert(
"stats".to_string(),
Message::object(EncodableValue::from(stats_value)),
);
Ok(results)
}
fn error_output(msg: &str) -> HashMap<String, Message> {
let mut out = HashMap::new();
out.insert("error".to_string(), Message::Error(msg.to_string().into()));
out
}