use std::fmt::Debug;
pub mod categorical_cross_entropy;
pub mod mean_absolute;
pub mod mean_bias;
pub mod mean_squared;
pub use categorical_cross_entropy::CategoricalCrossEntropy;
pub use mean_absolute::MeanAbsolute;
pub use mean_bias::MeanBias;
pub use mean_squared::MeanSquared;
use crate::{
types::{KernelNotFoundError, ProgramNotFoundError},
utils::{
opencl::{BufferOperationError, EnsureKernelsAndProgramError},
OpenCLState,
},
};
use intricate_macros::FromForAllUnnamedVariants;
use opencl3::{device::cl_float, error_codes::ClError, memory::Buffer};
use self::{
categorical_cross_entropy::{compile_categorical_cross_entropy, ReduceOutputsPerSampleError},
mean_absolute::compile_mean_absolute,
mean_bias::compile_mean_bias,
mean_squared::compile_mean_squared,
};
pub(crate) fn compile_losses(
opencl_state: &mut OpenCLState,
) -> Result<(), EnsureKernelsAndProgramError> {
compile_mean_squared(opencl_state)?;
compile_categorical_cross_entropy(opencl_state)?;
compile_mean_absolute(opencl_state)?;
compile_mean_bias(opencl_state)?;
Ok(())
}
#[derive(Debug, FromForAllUnnamedVariants)]
pub enum LossComputationError {
NotInitialized,
NoCommandQueue,
OpenCL(ClError),
SumOutputsPerSmaple(ReduceOutputsPerSampleError),
OutputsAndExpectedOutputsDoNotMatch,
TrainingDataDoesNotHaveExpectedSamplesAmount,
KernelNotFound(KernelNotFoundError),
ProgramNotFound(ProgramNotFoundError),
BufferOperation(BufferOperationError),
}
#[derive(Debug, FromForAllUnnamedVariants)]
pub enum LossToModelOutputsDerivativesComputationError {
NotInitialized,
NoCommandQueue,
OpenCL(ClError),
SumOutputsPerSmaple(ReduceOutputsPerSampleError),
OutputsAndExpectedOutputsDoNotMatch,
TrainingDataDoesNotHaveExpectedSamplesAmount,
KernelNotFound(KernelNotFoundError),
ProgramNotFound(ProgramNotFoundError),
BufferOperation(BufferOperationError),
}
pub trait LossFunction<'a>
where
Self: Debug,
{
fn compute_loss(
&self,
output_samples: &Buffer<cl_float>,
expected_outputs: &Buffer<cl_float>,
samples_amount: usize,
) -> Result<f32, LossComputationError>;
fn init(&mut self, opencl_state: &'a OpenCLState) -> Result<(), ClError>;
fn compute_loss_derivative_with_respect_to_output_samples(
&self,
output_samples: &Buffer<cl_float>,
expected_outputs: &Buffer<cl_float>,
samples_amount: usize,
) -> Result<Buffer<cl_float>, LossToModelOutputsDerivativesComputationError>;
}
#[derive(Debug, FromForAllUnnamedVariants)]
pub enum LossFn<'a> {
MeanSquared(MeanSquared<'a>),
MeanBias(MeanBias<'a>),
MeanAbsolute(MeanAbsolute<'a>),
CategoricalCrossEntropy(CategoricalCrossEntropy<'a>),
}
impl<'a> LossFunction<'a> for LossFn<'a> {
fn init(&mut self, opencl_state: &'a OpenCLState) -> Result<(), ClError> {
match self {
LossFn::MeanSquared(loss) => loss.init(opencl_state),
LossFn::MeanBias(loss) => loss.init(opencl_state),
LossFn::MeanAbsolute(loss) => loss.init(opencl_state),
LossFn::CategoricalCrossEntropy(loss) => loss.init(opencl_state),
}
}
fn compute_loss(
&self,
output_samples: &Buffer<cl_float>,
expected_outputs: &Buffer<cl_float>,
samples_amount: usize,
) -> Result<f32, LossComputationError> {
match self {
LossFn::MeanSquared(loss) => {
loss.compute_loss(output_samples, expected_outputs, samples_amount)
}
LossFn::MeanBias(loss) => {
loss.compute_loss(output_samples, expected_outputs, samples_amount)
}
LossFn::MeanAbsolute(loss) => {
loss.compute_loss(output_samples, expected_outputs, samples_amount)
}
LossFn::CategoricalCrossEntropy(loss) => {
loss.compute_loss(output_samples, expected_outputs, samples_amount)
}
}
}
fn compute_loss_derivative_with_respect_to_output_samples(
&self,
output_samples: &Buffer<cl_float>,
expected_outputs: &Buffer<cl_float>,
samples_amount: usize,
) -> Result<Buffer<cl_float>, LossToModelOutputsDerivativesComputationError> {
match self {
LossFn::MeanSquared(loss) => loss
.compute_loss_derivative_with_respect_to_output_samples(
output_samples,
expected_outputs,
samples_amount,
),
LossFn::MeanBias(loss) => loss.compute_loss_derivative_with_respect_to_output_samples(
output_samples,
expected_outputs,
samples_amount,
),
LossFn::MeanAbsolute(loss) => loss
.compute_loss_derivative_with_respect_to_output_samples(
output_samples,
expected_outputs,
samples_amount,
),
LossFn::CategoricalCrossEntropy(loss) => loss
.compute_loss_derivative_with_respect_to_output_samples(
output_samples,
expected_outputs,
samples_amount,
),
}
}
}