use opencl3::error_codes::ClError;
use savefile_derive::Savefile;
use intricate_macros::{EnumLayer, FromForAllUnnamedVariants};
use crate::{
layers::{
activations::{ReLU, Sigmoid, SoftMax, TanH},
Dense, conv2d::Conv2D,
},
loss_functions::LossFn,
optimizers::Optimizer, utils::opencl::BufferConversionError,
};
#[derive(Debug)]
pub struct ProgramNotFoundError(pub String);
#[derive(Debug, FromForAllUnnamedVariants)]
pub enum SyncDataError {
OpenCL(ClError),
NotInitialized,
NotAllocatedInDevice {
field_name: String,
},
BufferConversion(BufferConversionError),
NoCommandQueue,
}
impl From<String> for ProgramNotFoundError {
fn from(program: String) -> Self {
ProgramNotFoundError(program)
}
}
#[derive(Debug)]
pub struct KernelNotFoundError(pub String);
impl From<String> for KernelNotFoundError {
fn from(kernel: String) -> Self {
KernelNotFoundError(kernel)
}
}
#[derive(Debug, Savefile, EnumLayer, FromForAllUnnamedVariants)]
#[allow(missing_docs)]
pub enum ModelLayer<'a> {
Dense(Dense<'a>),
Conv2D(Conv2D<'a>),
TanH(TanH<'a>),
SoftMax(SoftMax<'a>),
ReLU(ReLU<'a>),
Sigmoid(Sigmoid<'a>),
}
#[derive(Debug)]
pub struct TrainingVerbosity {
pub(crate) show_current_epoch: bool,
pub(crate) show_epoch_progress: bool,
pub(crate) show_epoch_elapsed: bool,
pub(crate) print_loss: bool,
pub(crate) print_accuracy: bool,
pub(crate) halting_condition_warning: bool,
}
impl Default for TrainingVerbosity {
fn default() -> Self {
TrainingVerbosity {
show_current_epoch: true,
show_epoch_progress: true,
show_epoch_elapsed: true,
print_loss: true,
print_accuracy: false,
halting_condition_warning: false,
}
}
}
#[derive(Debug)]
pub enum HaltingCondition {
MinLossReached(f32),
MinAccuracyReached(f32),
}
#[derive(Debug)]
pub struct TrainingOptions<'a> {
pub(crate) loss_fn: &'a mut LossFn<'a>,
pub(crate) batch_size: usize,
pub(crate) optimizer: &'a mut dyn Optimizer<'a>,
pub(crate) verbosity: TrainingVerbosity,
pub(crate) halting_condition: Option<HaltingCondition>,
pub(crate) compute_loss: bool,
pub(crate) compute_accuracy: bool,
pub(crate) epochs: usize,
}
#[derive(Debug)]
pub struct InvalidTrainingOptionError<T> {
pub value_trying_to_be_set: T,
pub parameter_name: &'static str,
pub error_message: String,
}
impl<'a> TrainingOptions<'a> {
pub fn new(
loss_fn: &'a mut LossFn<'a>,
optimizer: &'a mut dyn Optimizer<'a>,
) -> Self {
TrainingOptions {
loss_fn,
batch_size: 0,
optimizer,
verbosity: TrainingVerbosity::default(),
halting_condition: None,
compute_loss: true,
compute_accuracy: false,
epochs: 0
}
}
pub fn set_batch_size(mut self, new_batch_size: usize) -> Self {
self.batch_size = new_batch_size;
self
}
pub fn set_halting_condition(
mut self,
halting_condition: HaltingCondition
) -> Result<Self, InvalidTrainingOptionError<HaltingCondition>> {
match halting_condition {
HaltingCondition::MinLossReached(_) => {
if !self.compute_loss {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: halting_condition,
parameter_name: "halting_condition",
error_message: format!("Unable to set the halting condition to MinLossReached since the loss is not set to be calculated!")
});
}
},
HaltingCondition::MinAccuracyReached(_) => {
if !self.compute_accuracy {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: halting_condition,
parameter_name: "halting_condition",
error_message: format!("Unable to set the halting condition to MinAccuracyReached since the accuracy is not set to be calculated!")
});
}
}
};
self.halting_condition = Some(halting_condition);
Ok(self)
}
pub fn should_compute_loss(
mut self,
should: bool
) -> Result<Self, InvalidTrainingOptionError<bool>> {
if self.verbosity.print_loss && !should {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: should,
parameter_name: "compute_loss",
error_message: format!("Could not set 'compute_loss' = false since there will be a need for the loss because of print_loss being enabled in the TrainingVerbosity!"),
});
}
self.compute_loss = should;
Ok(self)
}
pub fn should_compute_accuracy(
mut self,
should: bool
) -> Result<Self, InvalidTrainingOptionError<bool>> {
if self.verbosity.print_loss && !should {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: should,
parameter_name: "compute_loss",
error_message: format!("Could not set 'compute_loss' = false since there will be a need for the loss because of print_loss being enabled in the TrainingVerbosity!"),
});
}
self.compute_accuracy = should;
Ok(self)
}
pub fn should_print_loss(mut self, should: bool) -> Result<Self, InvalidTrainingOptionError<bool>> {
if should && !self.compute_loss {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: should,
parameter_name: "print_loss",
error_message: format!("Cannot print the loss without having first calculated it! Please use should_compute_loss first with 'true' as parameter!")
});
}
self.verbosity.print_loss = should;
Ok(self)
}
pub fn should_print_accuracy(mut self, should: bool) -> Result<Self, InvalidTrainingOptionError<bool>> {
if should && !self.compute_loss {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: should,
parameter_name: "print_accuracy ",
error_message: format!("Cannot print the accuracy without having first calculated it! Please use should_compute_accuracy first with 'true' as parameter!")
});
}
self.verbosity.print_accuracy = should;
Ok(self)
}
pub fn should_show_epoch_progress(mut self, should: bool) -> Self {
self.verbosity.show_epoch_progress = should;
self
}
pub fn should_show_current_epoch_message(mut self, should: bool) -> Self {
self.verbosity.show_current_epoch = should;
self
}
pub fn should_show_halting_condition_warning(mut self, should: bool) -> Result<Self, InvalidTrainingOptionError<bool>> {
if should && self.halting_condition.is_none() {
return Err(InvalidTrainingOptionError {
value_trying_to_be_set: should,
parameter_name: "halting_condition_warning",
error_message: format!("Cannot have a halting condition warning without a Halting Condition being defined!"),
});
}
self.verbosity.halting_condition_warning = should;
Ok(self)
}
pub fn set_epochs(mut self, epochs: usize) -> Self {
self.epochs = epochs;
self
}
}
#[derive(Debug)]
pub struct TrainingResults {
pub loss_per_training_steps: Vec<f32>,
pub accuracy_per_training_steps: Vec<f32>,
}