use core::array::{from_fn, repeat};
use crate::{
Channels, Inplace, Major, Minor, Parallel, Process, SplitInplace, SplitProcess, Transpose,
};
#[derive(Debug, Copy, Clone, Default)]
pub struct Split<C, S> {
pub config: C,
pub state: S,
}
impl<X: Copy, Y, S, C: SplitProcess<X, Y, S>> Process<X, Y> for Split<C, S> {
fn process(&mut self, x: X) -> Y {
self.config.process(&mut self.state, x)
}
fn block(&mut self, x: &[X], y: &mut [Y]) {
self.config.block(&mut self.state, x, y)
}
}
impl<X: Copy, S, C: SplitInplace<X, S>> Inplace<X> for Split<C, S> {
fn inplace(&mut self, xy: &mut [X]) {
self.config.inplace(&mut self.state, xy);
}
}
impl<C, S> Split<C, S> {
pub const fn new(config: C, state: S) -> Self {
Self { config, state }
}
pub const fn assert_process<X: Copy, Y>(&self)
where
Self: Process<X, Y>,
{
}
}
#[derive(Debug, Copy, Clone, Default)]
#[repr(transparent)]
pub struct Unsplit<P>(pub P);
impl<C> Split<C, ()> {
pub fn stateless(config: C) -> Self {
Self::new(config, ())
}
}
impl<S> Split<(), Unsplit<S>> {
pub fn stateful(state: S) -> Self {
Self::new((), Unsplit(state))
}
}
impl<C0, C1, S0, S1> core::ops::Mul<Split<C1, S1>> for Split<C0, S0> {
type Output = Split<(C0, C1), (S0, S1)>;
fn mul(self, rhs: Split<C1, S1>) -> Self::Output {
Split::from((self, rhs))
}
}
impl<C0, C1, S0, S1> core::ops::Add<Split<C1, S1>> for Split<C0, S0> {
type Output = Split<Parallel<(C0, C1)>, (S0, S1)>;
fn add(self, rhs: Split<C1, S1>) -> Self::Output {
Split::from((self, rhs)).parallel()
}
}
impl<C0, C1, S0, S1> From<(Split<C0, S0>, Split<C1, S1>)> for Split<(C0, C1), (S0, S1)> {
fn from(value: (Split<C0, S0>, Split<C1, S1>)) -> Self {
Split::new(
(value.0.config, value.1.config),
(value.0.state, value.1.state),
)
}
}
impl<C, S, const N: usize> From<[Split<C, S>; N]> for Split<[C; N], [S; N]> {
fn from(splits: [Split<C, S>; N]) -> Self {
let mut splits = splits.map(|s| (Some(s.config), Some(s.state)));
Self::new(
from_fn(|i| splits[i].0.take().unwrap()),
from_fn(|i| splits[i].1.take().unwrap()),
)
}
}
impl<C, S> Split<C, S> {
pub fn minor<U>(self) -> Split<Minor<C, U>, S> {
Split::new(Minor::new(self.config), self.state)
}
pub fn major<U>(self) -> Split<Major<C, U>, S> {
Split::new(Major::new(self.config), self.state)
}
pub fn parallel(self) -> Split<Parallel<C>, S> {
Split::new(Parallel(self.config), self.state)
}
pub fn repeat<const N: usize>(self) -> Split<[C; N], [S; N]>
where
C: Clone,
S: Clone,
{
Split::new(repeat(self.config), repeat(self.state))
}
pub fn channels<const N: usize>(self) -> Split<Channels<C>, [S; N]>
where
S: Clone,
{
Split::new(Channels(self.config), repeat(self.state))
}
pub fn transpose(self) -> Split<Transpose<C>, S> {
Split::new(Transpose(self.config), self.state)
}
}
impl<C, S, U> Split<Minor<C, U>, S> {
pub fn inter(self) -> Split<C, S> {
Split::new(self.config.inner, self.state)
}
}
impl<C, S> Split<Parallel<C>, S> {
pub fn inter(self) -> Split<C, S> {
Split::new(self.config.0, self.state)
}
}
impl<C, S> Split<Transpose<C>, S> {
pub fn inter(self) -> Split<C, S> {
Split::new(self.config.0, self.state)
}
}
impl<C, S, B> Split<Major<C, B>, S> {
pub fn inter(self) -> Split<C, S> {
Split::new(self.config.inner, self.state)
}
}
impl<C0, C1, S0, S1> Split<(C0, C1), (S0, S1)> {
pub fn zip(self) -> (Split<C0, S0>, Split<C1, S1>) {
(
Split::new(self.config.0, self.state.0),
Split::new(self.config.1, self.state.1),
)
}
}
impl<C, S, const N: usize> Split<[C; N], [S; N]> {
pub fn zip(self) -> [Split<C, S>; N] {
let mut it = self.config.into_iter().zip(self.state);
from_fn(|_| {
let (c, s) = it.next().unwrap();
Split::new(c, s)
})
}
}
impl<X: Copy, Y, P: Process<X, Y>> SplitProcess<X, Y, Unsplit<P>> for () {
fn process(&self, state: &mut Unsplit<P>, x: X) -> Y {
state.0.process(x)
}
fn block(&self, state: &mut Unsplit<P>, x: &[X], y: &mut [Y]) {
state.0.block(x, y)
}
}
impl<X: Copy, P: Inplace<X>> SplitInplace<X, Unsplit<P>> for () {
fn inplace(&self, state: &mut Unsplit<P>, xy: &mut [X]) {
state.0.inplace(xy)
}
}