use std::path::PathBuf;
use std::time::Duration;
use truce_core::buffer::RawBufferScratch;
#[cfg(feature = "wav")]
use truce_core::cast::sample_rate_u32;
use truce_core::cast::{len_u32, sample_count_usize};
use truce_core::events::{Event, EventBody, EventList, TransportInfo};
use truce_core::export::PluginExport;
use truce_core::info::PluginCategory;
use truce_core::plugin::Plugin;
use truce_core::process::ProcessContext;
use truce_params::Params;
#[derive(Default)]
pub enum InputSource {
#[default]
Silence,
Constant(f32),
Buffer(Vec<Vec<f32>>),
Generator(Box<dyn FnMut(usize, f64) -> f32>),
}
#[derive(Clone)]
pub struct TransportSpec {
pub bpm: f64,
pub playing: bool,
pub position_beats: f64,
pub time_signature: (u8, u8),
}
impl Default for TransportSpec {
fn default() -> Self {
Self {
bpm: 120.0,
playing: false,
position_beats: 0.0,
time_signature: (4, 4),
}
}
}
#[derive(Clone, Copy, Default)]
pub enum MeterCapture {
None,
#[default]
Final,
PerBlock,
}
#[derive(Clone, Copy)]
pub struct CaptureSpec {
pub audio: bool,
pub meters: MeterCapture,
pub output_events: bool,
pub block_snapshots: bool,
}
impl Default for CaptureSpec {
fn default() -> Self {
Self {
audio: true,
meters: MeterCapture::Final,
output_events: false,
block_snapshots: false,
}
}
}
#[derive(Default, Clone)]
pub struct Script {
events: Vec<(usize, EventBody)>,
cursor_samples: usize,
sample_rate: f64,
}
impl Script {
pub fn note_on(&mut self, note: u8, velocity: f32) {
self.push(EventBody::NoteOn {
group: 0,
channel: 0,
note,
velocity: truce_core::midi::denorm_7bit(velocity),
});
}
pub fn note_off(&mut self, note: u8) {
self.push(EventBody::NoteOff {
group: 0,
channel: 0,
note,
velocity: 0,
});
}
pub fn cc(&mut self, cc: u8, value: f32) {
self.push(EventBody::ControlChange {
group: 0,
channel: 0,
cc,
value: truce_core::midi::denorm_7bit(value),
});
}
pub fn pitch_bend(&mut self, normalized: f32) {
self.push(EventBody::PitchBend {
group: 0,
channel: 0,
value: truce_core::midi::denorm_pitch_bend(normalized),
});
}
pub fn channel_pressure(&mut self, value: f32) {
self.push(EventBody::ChannelPressure {
group: 0,
channel: 0,
pressure: truce_core::midi::denorm_7bit(value),
});
}
pub fn set_param(&mut self, id: impl Into<u32>, normalized: f64) {
self.push(EventBody::ParamChange {
id: id.into(),
value: normalized,
});
}
pub fn raw(&mut self, body: EventBody) {
self.push(body);
}
#[allow(clippy::cast_precision_loss)]
pub fn wait_ms(&mut self, ms: u64) {
debug_assert!(
ms != 0,
"wait_ms(0) is a no-op - drop the call, or use wait_samples(0) if you mean it"
);
let sr = if self.sample_rate > 0.0 {
self.sample_rate
} else {
44_100.0
};
let samples_f = (sr * ms as f64) / 1000.0;
let samples = sample_count_usize(samples_f);
self.cursor_samples = self.cursor_samples.saturating_add(samples);
}
pub fn wait_samples(&mut self, n: usize) {
self.cursor_samples += n;
}
fn push(&mut self, body: EventBody) {
self.events.push((self.cursor_samples, body));
}
}
pub struct DriverResult<P: PluginExport> {
pub output: Vec<Vec<f32>>,
pub sample_rate: f64,
pub block_size: usize,
pub total_frames: usize,
pub meters: MeterReadings,
pub output_events: Vec<Event>,
pub block_snapshots: Vec<Vec<(u32, f64)>>,
pub plugin: P,
}
#[derive(Default)]
pub enum MeterReadings {
#[default]
None,
Final(Vec<(u32, f32)>),
PerBlock(Vec<Vec<(u32, f32)>>),
}
#[cfg(feature = "wav")]
impl<P: PluginExport> DriverResult<P> {
pub fn write_wav(&self, path: impl AsRef<std::path::Path>) -> std::io::Result<()> {
if self.output.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"no audio captured (CaptureSpec::audio was false)",
));
}
#[allow(clippy::cast_possible_truncation)]
let spec = hound::WavSpec {
channels: self.output.len() as u16,
sample_rate: sample_rate_u32(self.sample_rate),
bits_per_sample: 32,
sample_format: hound::SampleFormat::Float,
};
let mut wav = hound::WavWriter::create(path, spec).map_err(io_err)?;
for frame in 0..self.total_frames {
for ch in &self.output {
wav.write_sample(ch[frame]).map_err(io_err)?;
}
}
wav.finalize().map_err(io_err)?;
Ok(())
}
}
#[cfg(feature = "wav")]
fn io_err(e: hound::Error) -> std::io::Error {
std::io::Error::other(e)
}
type SetupFn<P> = Box<dyn FnOnce(&mut P, &SetupContext)>;
#[derive(Clone, Copy, Debug)]
pub struct SetupContext {
pub channels: usize,
pub sample_rate: f64,
pub block_size: usize,
}
enum StateSource {
Blob(Vec<u8>),
File(PathBuf),
}
pub struct PluginDriver<P: PluginExport> {
sample_rate: f64,
channels: Option<usize>,
block_size: usize,
duration: Duration,
transport: TransportSpec,
input: InputSource,
script: Script,
state_source: Option<StateSource>,
manifest_dir: PathBuf,
param_overrides: Vec<(u32, f64)>,
setup: Option<SetupFn<P>>,
capture: CaptureSpec,
}
impl<P: PluginExport> Default for PluginDriver<P> {
fn default() -> Self {
Self::new()
}
}
impl<P: PluginExport> PluginDriver<P> {
#[must_use]
pub fn new() -> Self {
Self {
sample_rate: 44_100.0,
channels: None,
block_size: 512,
duration: Duration::from_secs(1),
transport: TransportSpec::default(),
input: InputSource::Silence,
script: Script::default(),
state_source: None,
manifest_dir: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
param_overrides: Vec::new(),
setup: None,
capture: CaptureSpec::default(),
}
}
#[must_use]
pub fn sample_rate(mut self, sr: f64) -> Self {
self.sample_rate = sr;
self
}
#[must_use]
pub fn channels(mut self, n: usize) -> Self {
self.channels = Some(n);
self
}
#[must_use]
pub fn block_size(mut self, n: usize) -> Self {
self.block_size = n;
self
}
#[must_use]
pub fn duration(mut self, d: Duration) -> Self {
self.duration = d;
self
}
#[must_use]
pub fn transport(mut self, t: TransportSpec) -> Self {
self.transport = t;
self
}
#[must_use]
pub fn bpm(mut self, bpm: f64) -> Self {
self.transport.bpm = bpm;
self
}
#[must_use]
pub fn playing(mut self, playing: bool) -> Self {
self.transport.playing = playing;
self
}
#[must_use]
pub fn input(mut self, source: InputSource) -> Self {
self.input = source;
self
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn script(mut self, f: impl FnOnce(&mut Script)) -> Self {
let old_sr = self.script.sample_rate;
let new_sr = self.sample_rate;
if old_sr > 0.0 && (old_sr - new_sr).abs() > f64::EPSILON {
let scale = new_sr / old_sr;
self.script.cursor_samples =
sample_count_usize(((self.script.cursor_samples as f64) * scale).round());
for (off, _) in &mut self.script.events {
*off = sample_count_usize(((*off as f64) * scale).round());
}
}
self.script.sample_rate = new_sr;
f(&mut self.script);
self
}
#[must_use]
pub fn set_param(mut self, id: impl Into<u32>, normalized: f64) -> Self {
self.param_overrides.push((id.into(), normalized));
self
}
#[must_use]
pub fn manifest_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.manifest_dir = dir.into();
self
}
#[must_use]
pub fn setup<F: FnOnce(&mut P, &SetupContext) + 'static>(mut self, f: F) -> Self {
self.setup = Some(Box::new(f));
self
}
#[must_use]
pub fn state_blob(mut self, bytes: Vec<u8>) -> Self {
self.state_source = Some(StateSource::Blob(bytes));
self
}
#[must_use]
pub fn state_file(mut self, path: impl Into<PathBuf>) -> Self {
let raw = path.into();
let resolved = if raw.is_absolute() {
raw
} else {
self.manifest_dir.join(&raw)
};
self.state_source = Some(StateSource::File(resolved));
self
}
#[must_use]
pub fn capture_audio(mut self, on: bool) -> Self {
self.capture.audio = on;
self
}
#[must_use]
pub fn capture_meters(mut self, m: MeterCapture) -> Self {
self.capture.meters = m;
self
}
#[must_use]
pub fn capture_output_events(mut self, on: bool) -> Self {
self.capture.output_events = on;
self
}
#[must_use]
pub fn capture_block_snapshots(mut self, on: bool) -> Self {
self.capture.block_snapshots = on;
self
}
#[allow(clippy::cast_precision_loss, clippy::too_many_lines)]
#[must_use]
pub fn run(mut self) -> DriverResult<P> {
let mut plugin = P::create();
plugin.init();
plugin.reset(self.sample_rate, self.block_size);
plugin.params().set_sample_rate(self.sample_rate);
plugin.params().snap_smoothers();
let state_bytes =
match self.state_source.take() {
Some(StateSource::Blob(b)) => Some(b),
Some(StateSource::File(path)) => Some(std::fs::read(&path).unwrap_or_else(|e| {
panic!("state_file: failed to read {}: {e}", path.display())
})),
None => None,
};
if let Some(bytes) = state_bytes.as_deref()
&& let Err(e) = plugin.load_state(bytes)
{
eprintln!("truce-driver: load_state failed: {e}");
}
for (id, value) in &self.param_overrides {
plugin.params().set_normalized(*id, *value);
}
plugin.params().snap_smoothers();
let channels = self.channels.unwrap_or_else(|| {
let layouts = P::bus_layouts();
let layout = &layouts[0];
let outs = layout.total_output_channels() as usize;
if outs > 0 { outs } else { 2 }
});
if let Some(f) = self.setup.take() {
let ctx = SetupContext {
channels,
sample_rate: self.sample_rate,
block_size: self.block_size,
};
f(&mut plugin, &ctx);
}
let is_effect = P::info().category == PluginCategory::Effect;
let total_frames = sample_count_usize(self.duration.as_secs_f64() * self.sample_rate);
let mut output: Vec<Vec<f32>> = if self.capture.audio {
(0..channels)
.map(|_| Vec::with_capacity(total_frames))
.collect()
} else {
Vec::new()
};
let mut output_events_capture: Vec<Event> = Vec::new();
let mut per_block_meters: Vec<Vec<(u32, f32)>> = Vec::new();
let mut block_snapshots: Vec<Vec<(u32, f64)>> = Vec::new();
let constant_value: Option<f32> = match &self.input {
InputSource::Constant(v) => Some(*v),
InputSource::Silence => Some(0.0),
_ => None,
};
let script_events = prepare_script_events(&mut self.script, self.sample_rate, total_frames);
let mut transport_pos_beats = self.transport.position_beats;
let beats_per_second = self.transport.bpm / 60.0;
let meter_ids: Vec<u32> = plugin.params().meter_ids();
if let InputSource::Buffer(bufs) = &self.input {
assert_eq!(
bufs.len(),
channels,
"InputSource::Buffer channel count {} doesn't match driver channels {channels}",
bufs.len(),
);
}
let mut out_bufs: Vec<Vec<f32>> = (0..channels)
.map(|_| vec![0.0f32; self.block_size])
.collect();
let mut in_bufs: Vec<Vec<f32>> = if is_effect {
(0..channels)
.map(|_| vec![0.0f32; self.block_size])
.collect()
} else {
Vec::new()
};
let mut cursor = 0usize;
let mut event_list = EventList::with_capacity(script_events.len().min(256));
let mut output_events_block = EventList::default();
let mut scratch: RawBufferScratch<<P as Plugin>::Sample> = RawBufferScratch::default();
scratch.ensure_capacity(in_bufs.len(), out_bufs.len(), self.block_size);
let mut in_ptrs: Vec<*const f32> = Vec::with_capacity(in_bufs.len());
let mut out_ptrs: Vec<*mut f32> = Vec::with_capacity(out_bufs.len());
while cursor < total_frames {
let block_len = self.block_size.min(total_frames - cursor);
for b in &mut out_bufs {
b.clear();
b.resize(block_len, 0.0);
}
event_list.clear();
for (off, body) in &script_events {
if *off >= cursor && *off < cursor + block_len {
event_list.push(Event {
sample_offset: len_u32(*off - cursor),
body: *body,
});
}
}
if is_effect {
fill_input_block(
&mut in_bufs,
&mut self.input,
constant_value,
cursor,
block_len,
self.sample_rate,
);
}
in_ptrs.clear();
out_ptrs.clear();
for b in &in_bufs {
in_ptrs.push(b.as_ptr());
}
for b in &mut out_bufs {
out_ptrs.push(b.as_mut_ptr());
}
let block_u32 = len_u32(block_len);
let num_in_u32 = len_u32(in_ptrs.len());
let num_out_u32 = len_u32(out_ptrs.len());
let mut audio = unsafe {
scratch.build(
in_ptrs.as_ptr(),
out_ptrs.as_mut_ptr(),
num_in_u32,
num_out_u32,
block_u32,
P::supports_in_place(),
)
};
let transport_info = TransportInfo {
playing: self.transport.playing,
tempo: self.transport.bpm,
time_sig_num: self.transport.time_signature.0,
time_sig_den: self.transport.time_signature.1,
position_seconds: cursor as f64 / self.sample_rate,
position_beats: transport_pos_beats,
bar_start_beats: 0.0,
..Default::default()
};
output_events_block.clear();
let mut ctx = ProcessContext::new(
&transport_info,
self.sample_rate,
block_len,
&mut output_events_block,
);
plugin.process(&mut audio, &event_list, &mut ctx);
let _ = audio;
unsafe {
scratch.finish_widening_f32(out_ptrs.as_mut_ptr(), num_out_u32, block_u32);
}
if self.capture.audio {
for (ch, buf) in out_bufs.iter().enumerate() {
output[ch].extend_from_slice(buf);
}
}
if self.capture.output_events {
let cursor_u32 = u32::try_from(cursor).unwrap_or(u32::MAX);
for ev in output_events_block.iter() {
let mut e = *ev;
e.sample_offset = e.sample_offset.saturating_add(cursor_u32);
output_events_capture.push(e);
}
}
if matches!(self.capture.meters, MeterCapture::PerBlock) {
per_block_meters.push(
meter_ids
.iter()
.map(|id| (*id, plugin.get_meter(*id)))
.collect(),
);
}
if self.capture.block_snapshots {
let infos = plugin.params().param_infos();
block_snapshots.push(
infos
.iter()
.map(|pi| (pi.id, plugin.params().get_plain(pi.id).unwrap_or(0.0)))
.collect(),
);
}
if self.transport.playing {
let block_seconds = block_len as f64 / self.sample_rate;
transport_pos_beats += block_seconds * beats_per_second;
}
cursor += block_len;
}
let meters = match self.capture.meters {
MeterCapture::None => MeterReadings::None,
MeterCapture::Final => MeterReadings::Final(
meter_ids
.iter()
.map(|id| (*id, plugin.get_meter(*id)))
.collect(),
),
MeterCapture::PerBlock => MeterReadings::PerBlock(per_block_meters),
};
DriverResult {
output,
sample_rate: self.sample_rate,
block_size: self.block_size,
total_frames,
meters,
output_events: output_events_capture,
block_snapshots,
plugin,
}
}
}
#[allow(clippy::cast_precision_loss)]
fn prepare_script_events(
script: &mut Script,
sample_rate: f64,
total_frames: usize,
) -> Vec<(usize, EventBody)> {
let build_sr = script.sample_rate;
if build_sr > 0.0 && (build_sr - sample_rate).abs() > f64::EPSILON {
let scale = sample_rate / build_sr;
for (off, _) in &mut script.events {
*off = sample_count_usize(((*off as f64) * scale).round());
}
}
script.sample_rate = sample_rate;
script.events.sort_by_key(|(off, _)| *off);
let dropped = script
.events
.iter()
.filter(|(off, _)| *off >= total_frames)
.count();
if dropped > 0 {
eprintln!(
"[truce-driver] warning: {dropped} script event(s) scheduled past \
total_frames ({total_frames}) - they will not be delivered. Check \
`.duration(...)` vs `wait_ms`/`wait_samples` calls in your script."
);
}
std::mem::take(&mut script.events)
}
fn fill_input_block(
in_bufs: &mut [Vec<f32>],
input: &mut InputSource,
constant_value: Option<f32>,
cursor: usize,
block_len: usize,
sample_rate: f64,
) {
for b in in_bufs.iter_mut() {
b.resize(block_len, 0.0);
}
if let Some(v) = constant_value {
for b in in_bufs {
b.fill(v);
}
return;
}
match input {
InputSource::Buffer(bufs) => {
for (dst, src) in in_bufs.iter_mut().zip(bufs.iter()) {
let start = cursor.min(src.len());
let end = (cursor + block_len).min(src.len());
let copied = end - start;
dst[..copied].copy_from_slice(&src[start..end]);
for s in &mut dst[copied..] {
*s = 0.0;
}
}
}
InputSource::Generator(g) => {
if let Some((first, rest)) = in_bufs.split_first_mut() {
for (i, slot) in first.iter_mut().enumerate() {
*slot = g(cursor + i, sample_rate);
}
for ch in rest {
ch.copy_from_slice(first);
}
}
}
InputSource::Silence | InputSource::Constant(_) => {
for b in in_bufs {
b.fill(0.0);
}
}
}
}