use std::{collections::HashMap, marker::PhantomData};
use autd3_driver::{
derive::prelude::*,
geometry::{Device, Geometry},
};
pub struct Transform<T: Transducer, G: Gain<T>, F: Fn(&Device<T>, &T, &Drive) -> Drive> {
gain: G,
f: F,
phantom: std::marker::PhantomData<T>,
}
pub trait IntoTransform<T: Transducer, G: Gain<T>> {
fn with_transform<F: Fn(&Device<T>, &T, &Drive) -> Drive>(self, f: F) -> Transform<T, G, F>;
}
impl<T: Transducer, G: Gain<T>> IntoTransform<T, G> for G {
fn with_transform<F: Fn(&Device<T>, &T, &Drive) -> Drive>(self, f: F) -> Transform<T, G, F> {
Transform {
gain: self,
f,
phantom: PhantomData,
}
}
}
impl<
T: Transducer + 'static,
G: Gain<T> + 'static,
F: Fn(&Device<T>, &T, &Drive) -> Drive + 'static,
> autd3_driver::datagram::Datagram<T> for Transform<T, G, F>
where
autd3_driver::operation::GainOp<T, Self>: autd3_driver::operation::Operation<T>,
{
type O1 = autd3_driver::operation::GainOp<T, Self>;
type O2 = autd3_driver::operation::NullOp;
fn operation(self) -> Result<(Self::O1, Self::O2), autd3_driver::error::AUTDInternalError> {
Ok((Self::O1::new(self), Self::O2::default()))
}
}
impl<
T: Transducer + 'static,
G: Gain<T> + 'static,
F: Fn(&Device<T>, &T, &Drive) -> Drive + 'static,
> autd3_driver::datagram::GainAsAny for Transform<T, G, F>
{
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
impl<
T: Transducer + 'static,
G: Gain<T> + 'static,
F: Fn(&Device<T>, &T, &Drive) -> Drive + 'static,
> Gain<T> for Transform<T, G, F>
{
fn calc(
&self,
geometry: &Geometry<T>,
filter: GainFilter,
) -> Result<HashMap<usize, Vec<Drive>>, AUTDInternalError> {
let mut g = self.gain.calc(geometry, filter)?;
g.iter_mut().for_each(|(&dev_idx, d)| {
d.iter_mut().enumerate().for_each(|(i, d)| {
*d = (self.f)(&geometry[dev_idx], &geometry[dev_idx][i], d);
});
});
Ok(g)
}
}