use itertools::iproduct;
use nalgebra::{DMatrix, DVector, Matrix2xX, SVector, Vector2};
use rustfft::{
num_complex::{Complex, ComplexFloat},
Fft, FftPlanner,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub const SAMPLE_RATE: f32 = 44100.;
pub const FPS: f32 = 100.;
const HOP_SIZE: usize = 441;
const FRAME_SIZE: usize = 2048;
const SPECTROGRAM_SIZE: usize = FRAME_SIZE / 2;
const FILTER_MIN_NOTE: i32 = 23; const FILTER_MAX_NOTE: i32 = 132; pub const N_FILTERS: usize = 81;
const LOG_OFFSET: f32 = 1.;
const MIN_BPM: f32 = 55.;
const MAX_BPM: f32 = 215.;
const MIN_INTERVAL: usize = 28;
const MAX_INTERVAL: usize = 109;
const N_STATES: usize = ((MAX_INTERVAL + 1) * MAX_INTERVAL - (MIN_INTERVAL - 1) * MIN_INTERVAL) / 2;
const OBSERVATION_LAMBDA: f32 = 16.;
const TRANSITION_LAMBDA: f32 = 100.;
const PROBABILITY_EPSILON: f32 = f64::EPSILON as f32;
const MIN_CONFIDENCE_RADIUS: f32 = 0.8;
const MIN_TRACKED_BEATS: usize = 3;
const FILTER_LOW_CUTOFF: usize = 7;
const FILTER_HIGH_CUTOFF: usize = 70;
const SPECTRUM_GAIN: f32 = 0.2;
const SPECTRUM_UP_ALPHA: f32 = 0.8;
const SPECTRUM_DOWN_ALPHA: f32 = 0.1;
const LEVEL_FREQ_UP_ALPHA: f32 = 0.9;
const LEVEL_FREQ_DOWN_ALPHA: f32 = 0.1;
const LEVEL_OVERALL_UP_ALPHA: f32 = 0.9;
const LEVEL_OVERALL_DOWN_ALPHA: f32 = 0.01;
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct AudioLevels {
pub low: f32,
pub mid: f32,
pub high: f32,
pub level: f32,
}
mod ml_models;
use ml_models::*;
struct FramedSignalProcessor {
buffer: Vec<i16>,
write_pointer: usize,
hop_counter: i32,
}
impl FramedSignalProcessor {
const INITIAL_HOP_COUNTER: i32 = (HOP_SIZE as i32) - ((FRAME_SIZE / 2) as i32);
pub fn new() -> Self {
assert_eq!(
HOP_SIZE,
(SAMPLE_RATE / FPS).round() as usize,
"Incorrect HOP_SIZE setting"
);
Self {
buffer: vec![0_i16; FRAME_SIZE * 2],
write_pointer: 0,
hop_counter: Self::INITIAL_HOP_COUNTER,
}
}
pub fn reset(&mut self) {
self.write_pointer = 0;
self.hop_counter = Self::INITIAL_HOP_COUNTER;
}
pub fn process(
&mut self,
samples: &[i16], ) -> Vec<DVector<i16>> {
let mut result = Vec::<DVector<i16>>::new();
for sample in samples {
self.buffer[self.write_pointer] = *sample;
self.buffer[self.write_pointer + FRAME_SIZE] = *sample;
self.write_pointer += 1;
assert!(self.write_pointer <= FRAME_SIZE);
if self.write_pointer == FRAME_SIZE {
self.write_pointer = 0;
}
self.hop_counter += 1;
assert!(self.hop_counter <= HOP_SIZE as i32);
if self.hop_counter == HOP_SIZE as i32 {
self.hop_counter = 0;
result.push(DVector::from_column_slice(
&self.buffer[self.write_pointer..self.write_pointer + FRAME_SIZE],
));
}
}
result
}
}
struct ShortTimeFourierTransformProcessor {
window: DVector<f32>, fft: Arc<dyn Fft<f32>>,
}
fn hann(n: usize, m: usize) -> f32 {
0.5 - 0.5 * (std::f32::consts::TAU * n as f32 / (m as f32 - 1.)).cos()
}
impl ShortTimeFourierTransformProcessor {
pub fn new() -> Self {
let window = DVector::from_fn(FRAME_SIZE, |i, _| {
hann(i, FRAME_SIZE) * (1_f32 / (i16::MAX as f32))
});
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(FRAME_SIZE);
Self { window, fft }
}
pub fn process(
&mut self,
frame: &DVector<i16>, ) -> DVector<f32> {
let mut buffer = vec![
Complex {
re: 0_f32,
im: 0_f32
};
FRAME_SIZE
];
for i in 0..FRAME_SIZE {
buffer[i].re = (frame[i] as f32) * self.window[i];
}
self.fft.process(&mut buffer);
DVector::from_fn(SPECTROGRAM_SIZE, |i, _| buffer[i].abs())
}
}
fn triangle_filter(start: i32, center: i32, stop: i32) -> DVector<f32> {
assert!(start < center);
assert!(center < stop);
assert!(stop <= SPECTROGRAM_SIZE as i32);
let mut result = DVector::<f32>::zeros(SPECTROGRAM_SIZE);
let mut sum = 0_f32;
for i in start + 1..=center {
let x = (i as f32 - start as f32) / (center as f32 - start as f32);
if i >= 0 && i < SPECTROGRAM_SIZE as i32 {
result[i as usize] = x;
}
sum += x;
}
for i in center + 1..stop {
let x = (i as f32 - stop as f32) / (center as f32 - stop as f32);
if i >= 0 && i < SPECTROGRAM_SIZE as i32 {
result[i as usize] = x;
}
sum += x;
}
for i in start + 1..stop {
if i >= 0 && i < SPECTROGRAM_SIZE as i32 {
result[i as usize] /= sum;
}
}
result
}
fn note2freq(note: i32) -> f32 {
2_f32.powf((note as f32 - 69.) / 12.) * 440.
}
fn spectrogram_frequencies() -> DVector<f32> {
DVector::from_fn(SPECTROGRAM_SIZE, |i, _| {
i as f32 * SAMPLE_RATE / (SPECTROGRAM_SIZE * 2) as f32
})
}
fn freq2bin(
spectrogram_frequencies: &DVector<f32>, freq: f32,
) -> usize {
let mut index = SPECTROGRAM_SIZE - 1;
for i in 1..SPECTROGRAM_SIZE {
if freq < spectrogram_frequencies[i] {
let left = spectrogram_frequencies[i - 1];
let right = spectrogram_frequencies[i];
index = if (freq - left).abs() < (freq - right).abs() {
i - 1
} else {
i
};
break;
}
}
index
}
pub fn gen_filterbank() -> DMatrix<f32> {
let freqs = spectrogram_frequencies();
let mut filterbank = DMatrix::<f32>::zeros(N_FILTERS, SPECTROGRAM_SIZE);
let mut filter_index = 0_usize;
let mut previous_center = -1_i32;
for note in (FILTER_MIN_NOTE + 1)..=(FILTER_MAX_NOTE - 1) {
let center = freq2bin(&freqs, note2freq(note)) as i32;
if center == previous_center {
continue;
}
let mut start = freq2bin(&freqs, note2freq(note - 1)) as i32;
let mut stop = freq2bin(&freqs, note2freq(note + 1)) as i32;
if stop - start < 2 {
start = center - 1;
stop = center + 1;
}
filterbank.set_row(
filter_index,
&triangle_filter(start, center, stop).transpose(),
);
filter_index += 1;
previous_center = center;
}
assert_eq!(filter_index, N_FILTERS, "N_FILTERS set incorrectly");
filterbank
}
struct FilteredSpectrogramProcessor {
filterbank: DMatrix<f32>, }
impl FilteredSpectrogramProcessor {
pub fn new() -> Self {
Self {
filterbank: gen_filterbank(),
}
}
pub fn process(
&mut self,
spectrogram: &DVector<f32>, ) -> DVector<f32> {
let filter_output = &self.filterbank * spectrogram;
filter_output.map(|x| (x + LOG_OFFSET).log10())
}
}
struct SpectrogramDifferenceProcessor {
prev: Option<DVector<f32>>, }
impl SpectrogramDifferenceProcessor {
pub fn new() -> Self {
Self { prev: None }
}
pub fn reset(&mut self) {
self.prev = None;
}
pub fn process(
&mut self,
filtered_data: &DVector<f32>, ) -> DVector<f32> {
let prev = match &self.prev {
None => filtered_data,
Some(prev) => prev,
};
let diff = (filtered_data - prev).map(|x| 0_f32.max(x));
self.prev = Some(filtered_data.clone());
let mut result = DVector::<f32>::zeros(N_FILTERS * 2);
result
.rows_range_mut(0..N_FILTERS)
.copy_from_slice(filtered_data.as_slice());
result
.rows_range_mut(N_FILTERS..N_FILTERS * 2)
.copy_from_slice(diff.as_slice());
result
}
}
struct BeatStateSpace {
state_positions: DVector<f32>, state_intervals: DVector<usize>, first_states: Vec<usize>,
last_states: Vec<usize>,
}
impl BeatStateSpace {
fn new() -> Self {
assert_eq!(
MIN_INTERVAL,
(60. * FPS / MAX_BPM).round() as usize,
"MIN_INTERVAL set incorrectly"
);
assert_eq!(
MAX_INTERVAL,
(60. * FPS / MIN_BPM).round() as usize,
"MAX_INTERVAL set incorrectly"
);
let intervals: Vec<usize> = (MIN_INTERVAL..=MAX_INTERVAL).collect();
let first_states: Vec<usize> = (MIN_INTERVAL..=MAX_INTERVAL)
.map(|i| ((i - 1) * i - (MIN_INTERVAL - 1) * MIN_INTERVAL) / 2)
.collect();
let last_states: Vec<usize> = (MIN_INTERVAL..=MAX_INTERVAL)
.map(|i| ((i + 1) * i - (MIN_INTERVAL - 1) * MIN_INTERVAL) / 2 - 1)
.collect();
let mut state_positions = DVector::<f32>::zeros(N_STATES);
let mut state_intervals = DVector::<usize>::zeros(N_STATES);
for (&interval, (&first, &last)) in intervals
.iter()
.zip(first_states.iter().zip(last_states.iter()))
{
let inverse_size = 1. / (last + 1 - first) as f32;
for i in first..=last {
state_positions[i] = (i - first) as f32 * inverse_size;
state_intervals[i] = interval;
}
}
Self {
state_positions,
state_intervals,
first_states,
last_states,
}
}
}
fn transition_model(
ss: &BeatStateSpace,
) -> (
Vec<usize>, // pointers (indptr)
Vec<usize>, // states (indices)
Vec<f32>, // probabilities (data)
) {
let same_tempo_states = (0..N_STATES).filter(|s| !ss.first_states.contains(s));
let same_tempo_entries = same_tempo_states.map(|s| (s - 1, s, 1.));
let to_states = &ss.first_states;
let from_states = &ss.last_states;
let from_intervals: Vec<(usize, usize)> = from_states
.iter()
.map(|&s| (s, ss.state_intervals[s]))
.collect();
let to_intervals: Vec<(usize, usize)> = to_states
.iter()
.map(|&s| (s, ss.state_intervals[s]))
.collect();
let transition_tempo_entries: Vec<(usize, usize, f32)> =
iproduct!(from_intervals, to_intervals)
.map(|((from_state, from_interval), (to_state, to_interval))| {
(
from_state,
to_state,
(-TRANSITION_LAMBDA
* ((to_interval as f32) / (from_interval as f32) - 1.).abs())
.exp(),
)
})
.filter(|&(_, _, val)| val > PROBABILITY_EPSILON)
.collect();
let mut row_factor = vec![0_f32; N_STATES];
for &(from, _, prob) in transition_tempo_entries.iter() {
row_factor[from] += prob;
}
row_factor.iter_mut().for_each(|x| *x = 1. / *x);
let transition_tempo_entries = transition_tempo_entries.iter().map(|&(from, to, prob)| {
let prob = prob * row_factor[from];
assert!((0. ..=1.).contains(&prob));
(from, to, prob)
});
let entries = same_tempo_entries.chain(transition_tempo_entries);
{
let mut rows: Vec<Vec<(usize, f32)>> = (0..N_STATES).map(|_| Vec::new()).collect();
for (col, row, value) in entries {
rows[row].push((col, value));
}
rows.iter_mut()
.for_each(|row| row.sort_by_key(|&(col, _)| col));
let mut indptr = Vec::<usize>::new();
let mut indices = Vec::<usize>::new();
let mut data = Vec::<f32>::new();
let mut i: usize = 0;
for row in rows {
indptr.push(i);
for (col, value) in row.into_iter() {
indices.push(col);
data.push(value);
i += 1;
}
}
indptr.push(i);
(indptr, indices, data)
}
}
fn observation_model(ss: &BeatStateSpace) -> DVector<usize> {
let border = 1. / OBSERVATION_LAMBDA;
DVector::from_fn(N_STATES, |i, _| {
usize::from((ss.state_positions[i] as f32) < border)
})
}
fn result_model(ss: &BeatStateSpace) -> Matrix2xX<f32> {
let theta = std::f32::consts::TAU * &ss.state_positions;
let x = theta.map(|x| x.cos());
let y = theta.map(|x| x.sin());
Matrix2xX::<f32>::from_rows(&[x.transpose(), y.transpose()])
}
fn om_densities(observation: f32) -> SVector<f32, 2> {
let p_no_beat = (1. - observation) / (OBSERVATION_LAMBDA - 1.);
let p_beat = observation;
SVector::<f32, 2>::from([p_no_beat, p_beat])
}
fn hmm_initial_distribution() -> DVector<f32> {
let fill_value = 1. / N_STATES as f32;
DVector::from_element(N_STATES, fill_value)
}
struct HMMBeatTrackingProcessor {
_ss: BeatStateSpace,
tm_pointers: Vec<usize>, tm_states: Vec<usize>, tm_probabilities: Vec<f32>,
om_pointers: DVector<usize>,
hmm_fwd_prev: DVector<f32>,
result_model: Matrix2xX<f32>,
last_theta: f32,
tracking: usize,
}
impl HMMBeatTrackingProcessor {
pub fn new() -> Self {
let ss = BeatStateSpace::new();
let (tm_pointers, tm_states, tm_probabilities) = transition_model(&ss);
let om_pointers = observation_model(&ss);
let result_model = result_model(&ss);
Self {
_ss: ss,
tm_pointers,
tm_states,
tm_probabilities,
om_pointers,
hmm_fwd_prev: hmm_initial_distribution(),
result_model,
last_theta: 0.,
tracking: 0,
}
}
pub fn reset(&mut self) {
self.hmm_fwd_prev = hmm_initial_distribution();
self.last_theta = 0.;
self.tracking = 0;
}
fn hmm_forward(&mut self, observation: f32) -> DVector<f32> {
let densities = om_densities(observation);
let mut fwd = DVector::<f32>::zeros(N_STATES);
let mut prob_sum: f32 = 0.;
for state in 0..N_STATES {
let mut fwd_state: f32 = 0.;
for prev_pointer in self.tm_pointers[state]..self.tm_pointers[state + 1] {
fwd_state += self.hmm_fwd_prev[self.tm_states[prev_pointer]]
* self.tm_probabilities[prev_pointer];
}
fwd_state *= densities[self.om_pointers[state]];
fwd[state] = fwd_state;
prob_sum += fwd_state;
}
fwd *= 1. / prob_sum;
self.hmm_fwd_prev.copy_from(&fwd);
fwd
}
pub fn process(&mut self, activation: f32) -> bool {
let state_probabilities = self.hmm_forward(activation);
let result: Vector2<f32> = &self.result_model * &state_probabilities;
let r = result.norm();
let theta = result.y.atan2(result.x).rem_euclid(std::f32::consts::TAU);
let mut beat = false;
if r < MIN_CONFIDENCE_RADIUS {
self.tracking = 0;
}
if self.last_theta - theta > 0.5 * std::f32::consts::TAU {
if self.tracking >= MIN_TRACKED_BEATS {
beat = true;
} else {
self.tracking += 1;
}
}
self.last_theta = theta;
beat
}
}
pub struct LevelsProcessor {
spectrum: DVector<f32>, audio: AudioLevels,
}
impl LevelsProcessor {
pub fn new() -> Self {
Self {
spectrum: DVector::<f32>::zeros(N_FILTERS),
audio: Default::default(),
}
}
pub fn reset(&mut self) {
self.spectrum = DVector::<f32>::zeros(N_FILTERS);
self.audio = Default::default();
}
pub fn process(
&mut self,
filtered: &DVector<f32>, ) -> (DVector<f32>, AudioLevels) {
for i in 0..N_FILTERS {
let f = filtered[i] * SPECTRUM_GAIN;
if f > self.spectrum[i] {
self.spectrum[i] =
f * SPECTRUM_UP_ALPHA + self.spectrum[i] * (1. - SPECTRUM_UP_ALPHA);
} else {
self.spectrum[i] =
f * SPECTRUM_DOWN_ALPHA + self.spectrum[i] * (1. - SPECTRUM_DOWN_ALPHA);
}
}
let mut low = 0_f32;
let mut mid = 0_f32;
let mut high = 0_f32;
let mut n_low = 0_usize;
let mut n_mid = 0_usize;
let mut n_high = 0_usize;
for i in 0..N_FILTERS {
if i < FILTER_LOW_CUTOFF {
low += self.spectrum[i];
n_low += 1;
} else if i > FILTER_HIGH_CUTOFF {
high += self.spectrum[i];
n_high += 1;
} else {
mid += self.spectrum[i];
n_mid += 1;
}
}
assert!(n_low > 0);
assert!(n_mid > 0);
assert!(n_high > 0);
low /= n_low as f32;
mid /= n_mid as f32;
high /= n_high as f32;
let diode_lpf = |stored_val: &mut f32, val| {
if val > *stored_val {
*stored_val = val * LEVEL_FREQ_UP_ALPHA + *stored_val * (1. - LEVEL_FREQ_UP_ALPHA);
} else {
*stored_val =
val * LEVEL_FREQ_DOWN_ALPHA + *stored_val * (1. - LEVEL_FREQ_DOWN_ALPHA);
}
};
diode_lpf(&mut self.audio.low, low);
diode_lpf(&mut self.audio.mid, mid);
diode_lpf(&mut self.audio.high, high);
let level = self.audio.low.max(self.audio.mid.max(self.audio.high));
if level > self.audio.level {
self.audio.level =
level * LEVEL_OVERALL_UP_ALPHA + self.audio.level * (1. - LEVEL_OVERALL_UP_ALPHA);
} else {
self.audio.level = level * LEVEL_OVERALL_DOWN_ALPHA
+ self.audio.level * (1. - LEVEL_OVERALL_DOWN_ALPHA);
}
(self.spectrum.clone(), self.audio.clone())
}
}
pub struct BeatTracker {
framed_processor: FramedSignalProcessor,
stft_processor: ShortTimeFourierTransformProcessor,
filter_processor: FilteredSpectrogramProcessor,
difference_processor: SpectrogramDifferenceProcessor,
neural_networks: Vec<NeuralNetwork>,
hmm: HMMBeatTrackingProcessor,
levels: LevelsProcessor,
}
impl Default for BeatTracker {
fn default() -> Self {
Self::new()
}
}
impl BeatTracker {
pub fn new() -> Self {
Self {
framed_processor: FramedSignalProcessor::new(),
stft_processor: ShortTimeFourierTransformProcessor::new(),
filter_processor: FilteredSpectrogramProcessor::new(),
difference_processor: SpectrogramDifferenceProcessor::new(),
neural_networks: models(),
hmm: HMMBeatTrackingProcessor::new(),
levels: LevelsProcessor::new(),
}
}
pub fn reset(&mut self) {
self.framed_processor.reset();
self.difference_processor.reset();
for nn in &mut self.neural_networks {
nn.reset();
}
self.hmm.reset();
self.levels.reset();
}
pub fn process(
&mut self,
samples: &[i16],
) -> Vec<(
AudioLevels,
DVector<f32>, // Spectrogram: size = SPECTROGRAM_SIZE
f32, // Activation
bool, // Beat
)> {
let frames = self.framed_processor.process(samples);
frames
.iter()
.map(|frame| {
let spectrogram = self.stft_processor.process(frame);
let filtered = self.filter_processor.process(&spectrogram);
let diff = self.difference_processor.process(&filtered);
let ensemble_activations =
self.neural_networks.iter_mut().map(|nn| nn.process(&diff));
let activation =
ensemble_activations.sum::<f32>() / self.neural_networks.len() as f32;
let beat = self.hmm.process(activation);
let (spectrum, audio) = self.levels.process(&filtered);
(audio, spectrum, activation, beat)
})
.collect()
}
}
mod layers {
use nalgebra::{DMatrix, DVector};
pub fn sigmoid(x: f32) -> f32 {
0.5_f32 * (1_f32 + (0.5_f32 * x).tanh())
}
pub fn tanh(x: f32) -> f32 {
x.tanh()
}
pub struct FeedForwardLayer {
weights: DMatrix<f32>, bias: DVector<f32>, }
impl FeedForwardLayer {
pub fn new(weights: DMatrix<f32>, bias: DVector<f32>) -> Self {
Self { weights, bias }
}
pub fn process(
&self,
data: &DVector<f32>, ) -> DVector<f32> {
(&self.weights * data + &self.bias).map(sigmoid)
}
}
pub struct Gate {
weights: DMatrix<f32>, bias: DVector<f32>, recurrent_weights: DMatrix<f32>, peephole_weights: DVector<f32>, }
impl Gate {
pub fn new(
weights: DMatrix<f32>,
bias: DVector<f32>,
recurrent_weights: DMatrix<f32>,
peephole_weights: DVector<f32>,
) -> Self {
Self {
weights,
bias,
recurrent_weights,
peephole_weights,
}
}
pub fn process(
&self,
data: &DVector<f32>, prev: &DVector<f32>, state: &DVector<f32>, ) -> DVector<f32> {
(&self.weights * data
+ state.component_mul(&self.peephole_weights)
+ &self.recurrent_weights * prev
+ &self.bias)
.map(sigmoid)
}
}
pub struct Cell {
weights: DMatrix<f32>, bias: DVector<f32>, recurrent_weights: DMatrix<f32>, }
impl Cell {
pub fn new(
weights: DMatrix<f32>,
bias: DVector<f32>,
recurrent_weights: DMatrix<f32>,
) -> Self {
Self {
weights,
bias,
recurrent_weights,
}
}
pub fn process(
&self,
data: &DVector<f32>, prev: &DVector<f32>, ) -> DVector<f32> {
(&self.weights * data + &self.recurrent_weights * prev + &self.bias).map(tanh)
}
}
pub struct LSTMLayer {
prev: DVector<f32>, state: DVector<f32>,
input_gate: Gate,
forget_gate: Gate,
cell: Cell,
output_gate: Gate,
}
impl LSTMLayer {
pub fn new(input_gate: Gate, forget_gate: Gate, cell: Cell, output_gate: Gate) -> Self {
let output_size = input_gate.weights.nrows();
Self {
prev: DVector::zeros(output_size),
state: DVector::zeros(output_size),
input_gate,
forget_gate,
cell,
output_gate,
}
}
pub fn reset(&mut self) {
self.prev.fill(0.);
self.state.fill(0.);
}
pub fn process(
&mut self,
data: &DVector<f32>, ) -> DVector<f32> {
let ig = self.input_gate.process(data, &self.prev, &self.state);
let fg = self.forget_gate.process(data, &self.prev, &self.state);
let cell = self.cell.process(data, &self.prev);
self.state
.copy_from(&(cell.component_mul(&ig) + self.state.component_mul(&fg)));
let og = self.output_gate.process(data, &self.prev, &self.state);
let out = self.state.map(tanh).component_mul(&og);
self.prev.copy_from(&out);
out
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{dmatrix, dvector};
#[test]
fn test_sigmoid() {
assert_eq!(sigmoid(0.), 0.5);
assert_eq!(sigmoid(2.), 0.8807971);
assert_eq!(sigmoid(-4.), 0.017986208);
}
#[test]
fn test_feed_forward_layer() {
let weights = dmatrix!(1_f32, 1., 1.; 0., 1., 0.);
let bias = dvector!(0_f32, 5.);
let layer = FeedForwardLayer::new(weights, bias);
let out = layer.process(&dvector!(0.5_f32, 0.6, 0.7));
assert_eq!(out, dvector!(sigmoid(1.8), sigmoid(5.6)));
}
#[test]
fn test_lstm_layer() {
let weights = dmatrix!(2_f32, 3., -1.; 0., 2., 0.);
let bias = dvector!(1_f32, 3.);
let recurrent_weights = dmatrix!(0_f32, -1.; 1., 0.);
let peephole_weights = dvector!(2_f32, 3.);
let ig = Gate::new(weights, bias, recurrent_weights, peephole_weights);
let weights = dmatrix!(-1_f32, 1., 0.; -1., 0., 2.);
let bias = dvector!(-2_f32, -1.);
let recurrent_weights = dmatrix!(2_f32, 0.; 0., 2.);
let peephole_weights = dvector!(3_f32, -1.);
let fg = Gate::new(weights, bias, recurrent_weights, peephole_weights);
let weights = dmatrix!(1_f32, -1., 3.; 2., -3., -2.);
let bias = dvector!(1_f32, 2.);
let recurrent_weights = dmatrix!(2_f32, 0.; 0., 2.);
let c = Cell::new(weights, bias, recurrent_weights);
let weights = dmatrix!(1_f32, -1., 2.; 1., -3., 4.);
let bias = dvector!(1_f32, -2.);
let recurrent_weights = dmatrix!(2_f32, -2.; 3., -1.);
let peephole_weights = dvector!(1_f32, 3.);
let og = Gate::new(weights, bias, recurrent_weights, peephole_weights);
let mut l = LSTMLayer::new(ig, fg, c, og);
let input1 = dvector!(2_f32, 3., 4.);
let input2 = dvector!(6_f32, 7., 6.);
let input3 = dvector!(4_f32, 3., 2.);
let output1 = l.process(&input1);
let output2 = l.process(&input2);
let output3 = l.process(&input3);
assert_eq!(output1, dvector!(0.76148117, -0.74785006));
assert_eq!(output2, dvector!(0.9619418, -0.94699216));
assert_eq!(output3, dvector!(0.99459594, -0.4957872));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_framed_signal_processor() {
let audio_signal: Vec<i16> = (0_i16..2048).collect();
let mut framed_signal_processor = FramedSignalProcessor::new();
let frames = framed_signal_processor.process(&audio_signal[0..512]);
assert_eq!(frames.len(), 0); let frames = framed_signal_processor.process(&audio_signal[512..1024]);
assert_eq!(frames.len(), 1);
assert_eq!(frames[0][0], 0);
assert_eq!(frames[0][1023], 0);
assert_eq!(frames[0][1024], 0);
assert_eq!(frames[0][1025], 1);
assert_eq!(frames[0][2047], 1023);
let frames = framed_signal_processor.process(&audio_signal[1024..2048]);
assert_eq!(frames.len(), 2);
assert_eq!(frames[0][584], 1);
assert_eq!(frames[0][1024], 441);
assert_eq!(frames[0][2047], 1464);
assert_eq!(frames[1][143], 1);
assert_eq!(frames[1][1024], 882);
assert_eq!(frames[1][2047], 1905);
}
#[test]
fn test_hann() {
assert_eq!(hann(0, 7), 0.);
assert_eq!(hann(1, 7), 0.25);
assert_eq!(hann(2, 7), 0.75);
assert_eq!(hann(3, 7), 1.);
assert_eq!(hann(4, 7), 0.74999994);
assert_eq!(hann(5, 7), 0.24999982);
assert_eq!(hann(6, 7), 0.);
}
#[test]
fn test_stft_processor() {
let audio_frame = DVector::from_fn(2048, |i, _| i as i16);
let mut stft_processor = ShortTimeFourierTransformProcessor::new();
let result = stft_processor.process(&audio_frame);
assert_eq!(result[0], 31.96973);
assert_eq!(result[1], 17.724249);
assert_eq!(result[2], 1.7021317);
assert_eq!(result[1023], 0.);
}
#[test]
fn test_triangle_filter() {
let filt = triangle_filter(5, 7, 15);
assert_eq!(filt[5], 0.);
assert_eq!(filt[6], 0.1);
assert_eq!(filt[7], 0.2);
assert_eq!(filt[8], 0.175);
assert_eq!(filt[9], 0.15);
assert_eq!(filt[10], 0.125);
assert_eq!(filt[11], 0.1);
assert_eq!(filt[12], 0.075);
assert_eq!(filt[13], 0.05);
assert_eq!(filt[14], 0.025);
assert_eq!(filt[15], 0.);
}
#[test]
fn test_note2freq() {
assert_eq!(note2freq(69), 440.);
assert_eq!(note2freq(57), 220.);
assert_eq!(note2freq(60), 261.62555);
}
#[test]
fn test_spectrogram_frequencies() {
let freqs = spectrogram_frequencies();
assert_eq!(freqs[0], 0.);
assert_eq!(freqs[1], 21.533203);
assert_eq!(freqs[1022], 22006.934);
assert_eq!(freqs[1023], 22028.467);
}
#[test]
fn test_freq2bin() {
let freqs = spectrogram_frequencies();
assert_eq!(freq2bin(&freqs, 0.), 0);
assert_eq!(freq2bin(&freqs, 24000.), 1023);
assert_eq!(freq2bin(&freqs, 440.), 20);
}
#[test]
fn test_gen_filterbank() {
let filterbank = gen_filterbank();
assert_eq!(filterbank[(0, 2)], 1.);
assert_eq!(filterbank[(1, 3)], 1.);
assert_eq!(filterbank[(2, 4)], 1.);
assert_eq!(filterbank[(78, 654)], 0.02631579);
assert_eq!(filterbank[(79, 693)], 0.025);
assert_eq!(filterbank[(80, 734)], 0.023529412);
}
#[test]
fn test_spectrogram_difference_processor() {
let mut data = DVector::<f32>::zeros(N_FILTERS);
let mut proc = SpectrogramDifferenceProcessor::new();
data[0] = 1.;
let r1 = proc.process(&data);
data[0] = 2.;
let r2 = proc.process(&data);
data[0] = 1.;
let r3 = proc.process(&data);
assert_eq!(r1[0], 1.);
assert_eq!(r2[0], 2.);
assert_eq!(r3[0], 1.);
assert_eq!(r1[N_FILTERS], 0.);
assert_eq!(r2[N_FILTERS], 1.);
assert_eq!(r3[N_FILTERS], 0.);
}
#[test]
fn test_beat_state_space() {
let ss = BeatStateSpace::new();
assert_eq!(ss.first_states[0], 0);
assert_eq!(ss.first_states[10], 325);
assert_eq!(ss.first_states[81], 5508);
assert_eq!(ss.last_states[0], 27);
assert_eq!(ss.last_states[10], 362);
assert_eq!(ss.last_states[81], 5616);
assert_eq!(ss.state_positions[0], 0.);
assert_eq!(ss.state_positions[2000], 0.46376812);
assert_eq!(ss.state_positions[5616], 0.99082565);
assert_eq!(ss.state_intervals[0], 28);
assert_eq!(ss.state_intervals[2000], 69);
assert_eq!(ss.state_intervals[5616], 109);
}
#[test]
fn test_transition_model() {
let ss = BeatStateSpace::new();
let (pointers, states, probabilities) = transition_model(&ss);
assert_eq!(pointers.len(), 5618);
assert_eq!(states.len(), 8934);
assert_eq!(probabilities.len(), 8934);
assert_eq!(pointers[0], 0);
assert_eq!(pointers[2000], 3609);
assert_eq!(pointers[5617], 8934);
assert_eq!(states[0], 27);
assert_eq!(states[1000], 702);
assert_eq!(states[8933], 5615);
assert_eq!(probabilities[0], 0.97188437);
assert_eq!(probabilities[1000], 0.010293146);
assert_eq!(probabilities[8933], 1.0);
}
#[test]
fn test_observation_model() {
}
#[test]
fn test_om_density() {
assert_eq!(om_densities(0.), SVector::<f32, 2>::from([0.06666667, 0.]));
assert_eq!(
om_densities(0.5),
SVector::<f32, 2>::from([0.033333335, 0.5])
);
assert_eq!(om_densities(1.), SVector::<f32, 2>::from([0., 1.]));
}
#[test]
fn test_hmm_initial_distribution() {
let dist = hmm_initial_distribution();
assert_eq!(dist[0], 0.00017803098);
assert_eq!(dist[2000], 0.00017803098);
assert_eq!(dist[5616], 0.00017803098);
}
#[test]
fn test_hmm_beat_tracking_processor() {
let mut hmm = HMMBeatTrackingProcessor::new();
hmm.process(0.1);
hmm.process(0.2);
hmm.process(0.9);
hmm.process(0.7);
hmm.process(0.3);
let probs = &hmm.hmm_fwd_prev;
assert_eq!(probs[0], 2.3026321e-7);
assert_eq!(probs[5404], 0.0068785422);
assert_eq!(probs[5191], 0.0068932814);
assert_eq!(probs[5297], 0.0069466503); assert_eq!(probs[5616], 3.571157e-8);
}
#[ignore]
#[test]
fn test_music() {
use std::fs::File;
use std::path::Path;
let mut inp_file = File::open(Path::new("src/lib/test/frontier.wav")).unwrap();
let (header, data) = wav::read(&mut inp_file).unwrap();
assert_eq!(header.audio_format, wav::WAV_FORMAT_PCM);
assert_eq!(header.channel_count, 1);
assert_eq!(header.sampling_rate, 44100);
assert_eq!(header.bits_per_sample, 16);
let data = data.try_into_sixteen().unwrap();
let mut bt = BeatTracker::new();
let beats = bt.process(&data);
let expected_beat_indices = vec![
294, 366, 679, 729, 777, 825, 873, 922, 971, 1021, 1068, 1116, 1164, 1212, 1260, 1308,
1358, 1409, 1459, 1507, 1552, 1599, 1647, 1696, 1745, 1795, 1844, 1892, 1939, 1986,
2035, 2084, 2132, 2180,
];
for (i, &(_, _, _, beat)) in beats.iter().enumerate() {
assert_eq!(beat, expected_beat_indices.contains(&i));
}
}
}