use cranelift_jit::JITModule;
use std::cell::RefCell;
use std::collections::HashMap;
use std::mem;
use std::rc::Rc;
pub struct DSPNodeContext {
state: *mut DSPState,
persistent_var_index: usize,
persistent_var_map: HashMap<String, usize>,
node_states: HashMap<u64, Box<DSPNodeState>>,
generation: u64,
next_dsp_fun: Option<Box<DSPFunction>>,
}
impl DSPNodeContext {
fn new() -> Self {
Self {
state: Box::into_raw(Box::new(DSPState {
x: 0.0,
y: 0.0,
srate: 44100.0,
israte: 1.0 / 44100.0,
})),
node_states: HashMap::new(),
generation: 0,
next_dsp_fun: None,
persistent_var_map: HashMap::new(),
persistent_var_index: 0,
}
}
pub fn new_ref() -> Rc<RefCell<Self>> {
Rc::new(RefCell::new(Self::new()))
}
pub(crate) fn init_dsp_function(&mut self) {
self.generation += 1;
self.next_dsp_fun = Some(Box::new(DSPFunction::new(self.state, self.generation)));
}
pub(crate) fn get_persistent_variable_index(&mut self, pers_var_name: &str) -> Result<usize, String> {
let index = if let Some(index) = self.persistent_var_map.get(pers_var_name) {
*index
} else {
let index = self.persistent_var_index;
self.persistent_var_index += 1;
self.persistent_var_map.insert(pers_var_name.to_string(), index);
index
};
if let Some(next_dsp_fun) = &mut self.next_dsp_fun {
next_dsp_fun.touch_persistent_var_index(index);
Ok(index)
} else {
Err("No DSPFunction in DSPNodeContext".to_string())
}
}
pub(crate) fn add_dsp_node_instance(
&mut self,
node_type: Rc<dyn DSPNodeType>,
dsp_node_uid: u64,
) -> Result<usize, String> {
if let Some(next_dsp_fun) = &mut self.next_dsp_fun {
if next_dsp_fun.has_dsp_node_state_uid(dsp_node_uid) {
return Err(format!(
"node_state_uid has been used multiple times in same AST: {}",
dsp_node_uid
));
}
if !self.node_states.contains_key(&dsp_node_uid) {
self.node_states.insert(
dsp_node_uid,
Box::new(DSPNodeState::new(dsp_node_uid, node_type.clone())),
);
}
if let Some(state) = self.node_states.get_mut(&dsp_node_uid) {
if state.node_type().name() != node_type.name() {
return Err(format!(
"Different DSPNodeType for uid {}: {} != {}",
dsp_node_uid,
state.node_type().name(),
node_type.name()
));
}
Ok(next_dsp_fun.install(state))
} else {
Err(format!("NodeState does not exist, but it should... bad! {}", dsp_node_uid))
}
} else {
Err("No DSPFunction in DSPNodeContext".to_string())
}
}
pub(crate) fn finalize_dsp_function(
&mut self,
function_ptr: *const u8,
module: JITModule,
) -> Option<Box<DSPFunction>> {
if let Some(mut next_dsp_fun) = self.next_dsp_fun.take() {
next_dsp_fun.set_function_ptr(function_ptr, module);
for (_, node_state) in self.node_states.iter_mut() {
node_state.set_initialized();
}
Some(next_dsp_fun)
} else {
None
}
}
pub fn cleanup_dsp_fun_after_user(&mut self, _fun: Box<DSPFunction>) {
}
pub fn free(&mut self) {
if !self.state.is_null() {
unsafe { Box::from_raw(self.state) };
self.state = std::ptr::null_mut();
}
}
}
impl Drop for DSPNodeContext {
fn drop(&mut self) {
if !self.state.is_null() {
eprintln!("WBlockDSP JIT DSPNodeContext not cleaned up on exit. Forgot to call free() or keep it alive long enough?");
}
}
}
pub struct DSPNodeTypeLibrary {
types: Vec<Rc<dyn DSPNodeType>>,
}
impl DSPNodeTypeLibrary {
pub fn new() -> Self {
Self { types: vec![] }
}
pub fn add(&mut self, typ: Rc<dyn DSPNodeType>) {
self.types.push(typ);
}
pub fn for_each<T, F: FnMut(&Rc<dyn DSPNodeType>) -> Result<(), T>>(
&self,
mut f: F,
) -> Result<(), T> {
for t in self.types.iter() {
f(&t)?;
}
Ok(())
}
}
pub struct DSPFunction {
state: *mut DSPState,
node_state_types: Vec<Rc<dyn DSPNodeType>>,
node_states: Vec<*mut u8>,
node_state_init_reset: Vec<usize>,
node_state_uids: Vec<u64>,
dsp_ctx_generation: u64,
module: Option<JITModule>,
persistent_vars: Vec<f64>,
function: Option<
fn(
f64,
f64,
f64,
f64,
f64,
f64,
f64,
f64,
*mut f64,
*mut f64,
*mut DSPState,
*mut *mut u8,
*mut f64,
*mut f64,
) -> f64,
>,
}
unsafe impl Send for DSPFunction {}
impl DSPFunction {
pub(crate) fn new(state: *mut DSPState, dsp_ctx_generation: u64) -> Self {
Self {
state,
node_state_types: vec![],
node_states: vec![],
node_state_init_reset: vec![],
node_state_uids: vec![],
persistent_vars: vec![],
function: None,
dsp_ctx_generation,
module: None,
}
}
pub(crate) fn set_function_ptr(&mut self, function: *const u8, module: JITModule) {
self.module = Some(module);
self.function = Some(unsafe {
mem::transmute::<
_,
fn(
f64,
f64,
f64,
f64,
f64,
f64,
f64,
f64,
*mut f64,
*mut f64,
*mut DSPState,
*mut *mut u8,
*mut f64,
*mut f64,
) -> f64,
>(function)
});
}
pub fn init(&mut self, srate: f64, previous_function: Option<&DSPFunction>) {
if let Some(previous_function) = previous_function {
let prev_len = previous_function.persistent_vars.len();
self.persistent_vars[0..prev_len]
.copy_from_slice(&previous_function.persistent_vars[0..prev_len])
}
unsafe {
(*self.state).srate = srate;
(*self.state).israte = 1.0 / srate;
}
for idx in self.node_state_init_reset.iter() {
let typ = &self.node_state_types[*idx as usize];
let ptr = self.node_states[*idx as usize];
typ.reset_state(self.state, ptr);
}
}
pub fn set_sample_rate(&mut self, srate: f64) {
unsafe {
(*self.state).srate = srate;
(*self.state).israte = 1.0 / srate;
}
self.reset();
}
pub fn reset(&mut self) {
for (typ, ptr) in self.node_state_types.iter().zip(self.node_states.iter_mut()) {
typ.reset_state(self.state, *ptr);
}
}
pub fn get_dsp_state_ptr(&self) -> *mut DSPState {
self.state
}
pub unsafe fn with_dsp_state<R, F: FnMut(*mut DSPState) -> R>(&mut self, mut f: F) -> R {
f(self.get_dsp_state_ptr())
}
pub unsafe fn with_node_state<T, R, F: FnMut(*mut T) -> R>(
&mut self,
node_state_uid: u64,
mut f: F,
) -> Result<R, ()> {
if let Some(state_ptr) = self.get_node_state_ptr(node_state_uid) {
Ok(f(state_ptr as *mut T))
} else {
Err(())
}
}
pub fn get_node_state_ptr(&self, node_state_uid: u64) -> Option<*mut u8> {
for (i, uid) in self.node_state_uids.iter().enumerate() {
if *uid == node_state_uid {
return Some(self.node_states[i]);
}
}
None
}
pub fn exec_2in_2out(&mut self, in1: f64, in2: f64) -> (f64, f64, f64) {
let mut s1 = 0.0;
let mut s2 = 0.0;
let r = self.exec(in1, in2, 0.0, 0.0, 0.0, 0.0, &mut s1, &mut s2);
(s1, s2, r)
}
pub fn exec(
&mut self,
in1: f64,
in2: f64,
alpha: f64,
beta: f64,
delta: f64,
gamma: f64,
sig1: &mut f64,
sig2: &mut f64,
) -> f64 {
let (srate, israte) = unsafe { ((*self.state).srate, (*self.state).israte) };
let states_ptr: *mut *mut u8 = self.node_states.as_mut_ptr();
let pers_vars_ptr: *mut f64 = self.persistent_vars.as_mut_ptr();
let mut multi_returns = [0.0; 5];
let ret = (unsafe { self.function.unwrap_unchecked() })(
in1,
in2,
alpha,
beta,
delta,
gamma,
srate,
israte,
sig1,
sig2,
self.state,
states_ptr,
pers_vars_ptr,
(&mut multi_returns) as *mut f64
);
ret
}
pub(crate) fn install(&mut self, node_state: &mut DSPNodeState) -> usize {
let idx = self.node_states.len();
node_state.mark(self.dsp_ctx_generation, idx);
self.node_states.push(node_state.ptr());
self.node_state_types.push(node_state.node_type());
self.node_state_uids.push(node_state.uid());
if !node_state.is_initialized() {
self.node_state_init_reset.push(idx);
}
idx
}
pub(crate) fn touch_persistent_var_index(&mut self, idx: usize) {
if idx >= self.persistent_vars.len() {
self.persistent_vars.resize(idx + 1, 0.0);
}
}
pub fn has_dsp_node_state_uid(&self, uid: u64) -> bool {
for i in self.node_state_uids.iter() {
if *i == uid {
return true;
}
}
false
}
}
impl Drop for DSPFunction {
fn drop(&mut self) {
unsafe {
if let Some(module) = self.module.take() {
module.free_memory();
}
};
}
}
pub struct DSPState {
pub x: f64,
pub y: f64,
pub srate: f64,
pub israte: f64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum DSPNodeSigBit {
Value,
DSPStatePtr,
NodeStatePtr,
MultReturnPtr,
}
pub trait DSPNodeType {
fn name(&self) -> &str;
fn function_ptr(&self) -> *const u8;
fn signature(&self, _i: usize) -> Option<DSPNodeSigBit> {
None
}
fn has_return_value(&self) -> bool;
fn reset_state(&self, _dsp_state: *mut DSPState, _state_ptr: *mut u8) {}
fn allocate_state(&self) -> Option<*mut u8> {
None
}
fn deallocate_state(&self, _ptr: *mut u8) {}
}
pub(crate) struct DSPNodeState {
uid: u64,
node_type: Rc<dyn DSPNodeType>,
ptr: *mut u8,
generation: u64,
function_index: usize,
initialized: bool,
}
impl DSPNodeState {
pub(crate) fn new(uid: u64, node_type: Rc<dyn DSPNodeType>) -> Self {
Self {
uid,
node_type: node_type.clone(),
ptr: node_type.allocate_state().expect("DSPNodeState created for stateful node type"),
generation: 0,
function_index: 0,
initialized: false,
}
}
pub(crate) fn uid(&self) -> u64 {
self.uid
}
pub(crate) fn mark(&mut self, gen: u64, index: usize) {
self.generation = gen;
self.function_index = index;
}
pub(crate) fn is_initialized(&self) -> bool {
self.initialized
}
pub(crate) fn set_initialized(&mut self) {
self.initialized = true;
}
pub(crate) fn ptr(&self) -> *mut u8 {
self.ptr
}
pub(crate) fn node_type(&self) -> Rc<dyn DSPNodeType> {
self.node_type.clone()
}
}
impl Drop for DSPNodeState {
fn drop(&mut self) {
self.node_type.deallocate_state(self.ptr);
self.ptr = std::ptr::null_mut();
}
}