#![doc(html_root_url = "https://docs.rs/ocl-convolution/0.3.0")]
#![warn(missing_debug_implementations, missing_docs, bare_trait_objects)]
#![warn(clippy::all, clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::must_use_candidate,
clippy::module_name_repetitions,
clippy::doc_markdown
)]
use ndarray::{Array4, ArrayView4};
use ocl::OclPrm;
use std::{fmt, marker::PhantomData};
mod base;
mod buffers;
mod params;
use crate::{
base::Base,
buffers::{Filters, Pinned},
};
pub use crate::{
base::ConvolutionBuilder,
buffers::{FeatureMap, FeatureMapShape, Layout},
params::{I8Params, Params},
};
const SOURCE: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.cl"));
pub trait ConvElement: OclPrm + Copy + 'static {
type Acc: OclPrm + Copy + 'static;
type Params: Copy + Into<Params> + Into<Self::ClParams>;
type ClParams: OclPrm;
}
impl ConvElement for f32 {
type Acc = f32;
type Params = Params;
type ClParams = params::ClParams;
}
impl ConvElement for i8 {
type Acc = i32;
type Params = I8Params;
type ClParams = params::ClI8Params;
}
impl ConvolutionBuilder<f32> {
pub fn build(&self, params: Params) -> ocl::Result<Convolution<f32>> {
Base::new(self, params).map(Convolution)
}
}
impl ConvolutionBuilder<i8> {
pub fn build(&self, params: I8Params) -> ocl::Result<Convolution<i8>> {
Base::new(self, params).map(Convolution)
}
}
pub struct Convolution<T: ConvElement>(Base<PhantomData<T>>);
impl<T> fmt::Debug for Convolution<T>
where
T: ConvElement,
T::Params: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.debug_tuple("Convolution").field(&self.0).finish()
}
}
impl Convolution<f32> {
pub fn f32(size: u32) -> ocl::Result<ConvolutionBuilder<f32>> {
ConvolutionBuilder::new(size, &[("KERNEL_TYPE", 32)], SOURCE)
}
}
impl Convolution<i8> {
pub fn i8(size: u32) -> ocl::Result<ConvolutionBuilder<i8>> {
ConvolutionBuilder::new(size, &[("KERNEL_TYPE", 8)], SOURCE)
}
}
impl<T: ConvElement> Convolution<T> {
pub fn size(&self) -> u32 {
self.0.size()
}
pub fn params(&self) -> T::Params {
self.0.params()
}
pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
self.0.set_params(params)
}
pub fn with_filters<'a>(
self,
filters: impl Into<ArrayView4<'a, T>>,
) -> ocl::Result<FiltersConvolution<T>> {
self.0
.with_filters(filters.into(), None)
.map(FiltersConvolution)
}
pub fn with_biased_filters<'a>(
self,
filters: impl Into<ArrayView4<'a, T>>,
filter_biases: &[T::Acc],
) -> ocl::Result<FiltersConvolution<T>> {
self.0
.with_filters(filters.into(), Some(filter_biases))
.map(FiltersConvolution)
}
pub fn compute<'a>(
&self,
signal: FeatureMap<'_, T>,
filters: impl Into<ArrayView4<'a, T>>,
) -> ocl::Result<Array4<T>> {
self.0.compute(signal, filters.into(), None)
}
pub fn compute_with_biases<'a>(
&self,
signal: FeatureMap<'_, T>,
filters: impl Into<ArrayView4<'a, T>>,
filter_biases: &[T::Acc],
) -> ocl::Result<Array4<T>> {
self.0.compute(signal, filters.into(), Some(filter_biases))
}
}
pub struct FiltersConvolution<T: ConvElement>(Base<Filters<T>>);
impl<T> fmt::Debug for FiltersConvolution<T>
where
T: ConvElement,
T::Params: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("FiltersConvolution")
.field(&self.0)
.finish()
}
}
impl<T: ConvElement> FiltersConvolution<T> {
pub fn size(&self) -> u32 {
self.0.size()
}
pub fn params(&self) -> T::Params {
self.0.params()
}
pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
self.0.set_params(params)
}
pub fn pin(self, signal_shape: FeatureMapShape) -> ocl::Result<PinnedConvolution<T>> {
self.0.pinned(signal_shape).map(PinnedConvolution)
}
pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
self.0.compute(signal)
}
}
pub struct PinnedConvolution<T: ConvElement>(Base<Pinned<T>>);
impl<T> fmt::Debug for PinnedConvolution<T>
where
T: ConvElement,
T::Params: fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter
.debug_tuple("PinnedConvolution")
.field(&self.0)
.finish()
}
}
impl<T: ConvElement> PinnedConvolution<T> {
pub fn size(&self) -> u32 {
self.0.size()
}
pub fn params(&self) -> T::Params {
self.0.params()
}
pub fn set_params(&mut self, params: T::Params) -> ocl::Result<()> {
self.0.set_params(params)
}
pub fn compute(&self, signal: FeatureMap<'_, T>) -> ocl::Result<Array4<T>> {
self.0.compute(signal)
}
}
#[cfg(doctest)]
doc_comment::doctest!("../README.md");