legato 0.0.9

Legato is a WIP audiograph and DSL for quickly developing audio applications
use crate::{
    context::AudioContext,
    midi::MidiMessageKind,
    node::{Inputs, Node},
    ports::{PortBuilder, Ports},
};

#[derive(Clone)]
pub struct Voice {
    midi_channel: usize,
    ports: Ports,
    cur_freq: f32,
    cur_gate: f32,
    cur_vel: f32,
}

impl Voice {
    pub fn new(midi_channel: usize) -> Self {
        Self {
            midi_channel,
            ports: PortBuilder::default()
                .audio_out_named(&["gate", "freq", "velocity"])
                .build(),
            cur_freq: 0.0,
            cur_gate: 0.0,
            cur_vel: 0.0,
        }
    }
}

impl Node for Voice {
    fn process(&mut self, ctx: &mut AudioContext, _: &Inputs, outputs: &mut [&mut [f32]]) {
        let block_start = ctx.get_instant();

        let cfg = ctx.get_config();
        let block_size = cfg.block_size;
        let fs = cfg.sample_rate as f32;

        let mut last_sample = 0;

        if let Some(store) = ctx.get_midi_store() {
            let res = store.get_channel(self.midi_channel);

            for item in res {
                if item.data == MidiMessageKind::Dummy {
                    continue;
                }

                let offset_duration = item.instant - block_start;

                let idx = (offset_duration.as_secs_f32() * fs) as usize;

                let end_sample = idx.min(block_size);

                // Update state from past to now
                if end_sample > last_sample {
                    outputs[0][last_sample..end_sample].fill(self.cur_gate);
                    outputs[1][last_sample..end_sample].fill(self.cur_freq);
                    outputs[2][last_sample..end_sample].fill(self.cur_vel);
                }

                match item.data {
                    MidiMessageKind::NoteOn { note, velocity } => {
                        self.cur_freq = mtof(note);
                        self.cur_gate = 1.0;
                        self.cur_vel = velocity as f32 / 127.0;
                    }
                    // Keep velocity and frequency here, as there may be a synth with aftertouch logic
                    MidiMessageKind::NoteOff { .. } => {
                        self.cur_gate = 0.0;
                    }
                    // TODO: Pitch bend? Aftertouch logic
                    _ => {}
                }

                last_sample = end_sample;
            }
            if last_sample < block_size {
                outputs[0][last_sample..block_size].fill(self.cur_gate);
                outputs[1][last_sample..block_size].fill(self.cur_freq);
                outputs[2][last_sample..block_size].fill(self.cur_vel);
            }
        }
    }
    fn ports(&self) -> &Ports {
        &self.ports
    }
}

#[inline(always)]
fn mtof(note: u8) -> f32 {
    440.0 * 2.0_f32.powf((note as f32 - 69.0) / 12.0)
}

#[derive(Default, Clone, PartialEq, Debug)]
enum VoiceStateKind {
    #[default]
    Idle,
    Active,
}

#[derive(Default, Clone, PartialEq, Debug)]
struct VoiceState {
    kind: VoiceStateKind,
    note: u8,
    velocity: u8,
    last_used: u64,
}

#[derive(Default, Clone, PartialEq)]
struct VoiceAllocator {
    voices: Box<[VoiceState]>,
    counter: u64,
}

impl VoiceAllocator {
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            voices: vec![VoiceState::default(); capacity].into(),
            counter: 0,
        }
    }

    fn steal_voice(&mut self) -> (usize, &mut VoiceState) {
        self.counter += 1;

        // 1. Try to find an Idle voice first
        let available_idx = self
            .voices
            .iter()
            .enumerate()
            .find(|(_, x)| x.kind == VoiceStateKind::Idle)
            .map(|(i, _)| i);

        let target_idx = match available_idx {
            Some(idx) => idx,
            None => {
                // 2. All voices active: find the one with the smallest counter (Oldest)
                self.voices
                    .iter()
                    .enumerate()
                    .min_by_key(|(_, x)| x.last_used)
                    .map(|(i, _)| i)
                    .unwrap() // Safe because voices is non-empty
            }
        };

        let state = &mut self.voices[target_idx];
        state.last_used = self.counter;
        (target_idx, state)
    }

    pub fn on_note_on(&mut self, note: u8, velocity: u8) -> Option<usize> {
        // If the note is already playing, re-use that voice and update "last_used"
        if let Some((i, voice)) = self
            .voices
            .iter_mut()
            .enumerate()
            .find(|(_, x)| x.note == note && x.kind == VoiceStateKind::Active)
        {
            self.counter += 1;
            voice.velocity = velocity;
            voice.last_used = self.counter;
            return Some(i);
        }

        // Otherwise, steal
        let (i, state) = self.steal_voice();
        state.note = note;
        state.velocity = velocity;
        state.kind = VoiceStateKind::Active;
        Some(i)
    }

    fn on_note_off(&mut self, note: u8, velocity: u8) -> Option<usize> {
        if let Some((i, inner)) = self
            .voices
            .iter_mut()
            .enumerate()
            .find(|(_, x)| x.note == note)
        {
            inner.kind = VoiceStateKind::Idle;
            inner.note = note;
            inner.velocity = velocity;

            return Some(i);
        }

        None
    }
}

const PER_VOICE_CHANS: usize = 3; // Current amount of chans per voice

#[derive(Default, Clone)]
struct NodePortCached {
    gate: f32,
    freq: f32,
    vel: f32,
}

#[derive(Clone)]
pub struct PolyVoice {
    voice_allocator: VoiceAllocator,
    port_caches: Vec<NodePortCached>,
    last_index_buffers: Box<[usize]>,
    midi_channel: usize,
    ports: Ports,
}

impl PolyVoice {
    pub fn new(voices: usize, midi_channel: usize) -> Self {
        Self {
            voice_allocator: VoiceAllocator::with_capacity(voices),
            port_caches: vec![NodePortCached::default(); voices],
            last_index_buffers: vec![0_usize; voices].into(),
            midi_channel,
            ports: PortBuilder::default()
                .audio_out(voices * PER_VOICE_CHANS)
                .build(),
        }
    }
}

impl Node for PolyVoice {
    fn process(&mut self, ctx: &mut AudioContext, _: &Inputs, outputs: &mut [&mut [f32]]) {
        let block_start = ctx.get_instant();

        let cfg = ctx.get_config();
        let block_size = cfg.block_size;
        let fs = cfg.sample_rate as f32;

        // Reset last sample buffer. This buffer helps create the slices.
        for idx in self.last_index_buffers.iter_mut() {
            *idx = 0;
        }

        if let Some(store) = ctx.get_midi_store() {
            let res = store.get_channel(self.midi_channel);

            for item in res {
                if item.data == MidiMessageKind::Dummy {
                    continue;
                }

                // Here, we use the voice allocator to figure out which voice we are going to write to.
                // You can think of voices in the same way that tracks are used in the mixer.
                // If we have 3 midi channels here, and 3 voice, we end up with 9 total channels.
                let chan_option = match item.data {
                    MidiMessageKind::NoteOn { note, velocity } => {
                        self.voice_allocator.on_note_on(note, velocity)
                    }
                    MidiMessageKind::NoteOff { note, velocity } => {
                        self.voice_allocator.on_note_off(note, velocity)
                    }
                    _ => None,
                };

                if let Some(chan_idx) = chan_option {
                    let offset_duration = item.instant - block_start;

                    let idx =
                        (offset_duration.as_secs_f32() * fs).clamp(0.0, block_size as f32) as usize;

                    let start = chan_idx * PER_VOICE_CHANS;

                    let last_index = &mut self.last_index_buffers[chan_idx];

                    let state = &mut self.port_caches[chan_idx];

                    // Update state from past to now
                    if idx > *last_index {
                        outputs[start][*last_index..idx].fill(state.gate);
                        outputs[start + 1][*last_index..idx].fill(state.freq);
                        outputs[start + 2][*last_index..idx].fill(state.vel);
                    }

                    match item.data {
                        MidiMessageKind::NoteOn { note, velocity } => {
                            state.freq = mtof(note);
                            state.gate = 1.0;
                            state.vel = velocity as f32 / 127.0;
                        }
                        // TODO: Keep velocity and frequency here, as there may be a synth with aftertouch logic
                        MidiMessageKind::NoteOff { note: _, velocity } => {
                            state.gate = 0.0;
                            state.vel = velocity as f32 / 127.0;
                        }
                        // TODO: Pitch bend? Aftertouch logic
                        _ => {}
                    }

                    *last_index = idx;
                }
            }

            // Finish the slices to the end of the buffer with the current state

            for (i, state) in self.port_caches.iter().enumerate() {
                let last_sample = self.last_index_buffers[i];

                let start = i * PER_VOICE_CHANS;

                if last_sample < block_size {
                    outputs[start][last_sample..block_size].fill(state.gate);
                    outputs[start + 1][last_sample..block_size].fill(state.freq);
                    outputs[start + 2][last_sample..block_size].fill(state.vel);
                }
            }
        }
    }
    fn ports(&self) -> &Ports {
        &self.ports
    }
}