ha-ndarray 0.5.0

A hardware-accelerated n-dimensional array
Documentation
use std::fmt;

use crate::Error;

use super::{CLElement, OpenCL, TILE_SIZE, WG_SIZE};

pub mod constructors;
pub mod elementwise;
pub mod gather;
pub mod linalg;
pub mod reduce;
pub mod slice;
pub mod view;

#[derive(Clone, Eq, PartialEq, Hash, fmt::Debug)]
pub(super) enum CLExpr {
    Static(&'static str),
    String(String),
}

impl From<&'static str> for CLExpr {
    fn from(s: &'static str) -> Self {
        CLExpr::Static(s)
    }
}

impl From<String> for CLExpr {
    fn from(s: String) -> Self {
        CLExpr::String(s)
    }
}

impl fmt::Display for CLExpr {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        match self {
            Self::Static(s) => f.write_str(s),
            Self::String(s) => f.write_str(s),
        }
    }
}

pub trait Builder {
    fn build(self) -> String;
}

#[derive(Clone, Eq, PartialEq, Hash, fmt::Debug)]
pub struct ElementDual {
    pub(super) i_type: &'static str,
    pub(super) o_type: &'static str,
    pub(super) name: &'static str,
    pub(super) op: CLExpr,
}

impl ElementDual {
    pub(super) fn new<I, O, Op>(name: &'static str, op: Op) -> Self
    where
        I: CLElement,
        O: CLElement,
        Op: Into<CLExpr>,
    {
        Self {
            i_type: I::TYPE,
            o_type: O::TYPE,
            name,
            op: op.into(),
        }
    }
}

impl Builder for ElementDual {
    fn build(self) -> String {
        format!(
            r#"
            inline {o_type} {name}(const {i_type} lhs, const {i_type} rhs) {{
                {op}
            }}
            "#,
            i_type = self.i_type,
            o_type = self.o_type,
            name = self.name,
            op = self.op
        )
    }
}

#[derive(Clone, Eq, PartialEq, Hash, fmt::Debug)]
pub struct ElementUnary {
    pub(super) i_type: &'static str,
    pub(super) o_type: &'static str,
    pub(super) name: &'static str,
    pub(super) op: CLExpr,
}

impl ElementUnary {
    pub(super) fn new<I, O, Op>(name: &'static str, op: Op) -> Self
    where
        I: CLElement,
        O: CLElement,
        Op: Into<CLExpr>,
    {
        Self {
            i_type: I::TYPE,
            o_type: O::TYPE,
            name,
            op: op.into(),
        }
    }
}

impl Builder for ElementUnary {
    fn build(self) -> String {
        format!(
            "
            inline {o_type} {name}(const {i_type} n) {{
                {op}
            }}
            ",
            i_type = self.i_type,
            o_type = self.o_type,
            name = self.name,
            op = self.op,
        )
    }
}

struct ArrayFormat<'a, T> {
    arr: &'a [T],
}

impl<'a, T> From<&'a [T]> for ArrayFormat<'a, T> {
    fn from(arr: &'a [T]) -> Self {
        Self { arr }
    }
}

impl<'a, T: fmt::Display> fmt::Display for ArrayFormat<'a, T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        f.write_str("{ ")?;

        for item in self.arr {
            write!(f, "{item}, ")?;
        }

        f.write_str(" }")
    }
}

#[inline]
fn build(src: &str) -> Result<ocl::Program, Error> {
    ocl::Program::builder()
        .source(src)
        .build(OpenCL::context())
        .map_err(Error::from)
}