use super::GpuTrainer;
use crate::error::{Result, TuneError};
use crate::train::config::TrainingConfig;
use lattice_fann::{Activation, NetworkBuilder};
pub struct GpuTrainerBuilder {
input_size: usize,
hidden_layers: Vec<(usize, Activation)>,
output_size: usize,
output_activation: Activation,
config: TrainingConfig,
}
impl GpuTrainerBuilder {
pub fn new(input_size: usize, output_size: usize) -> Self {
Self {
input_size,
hidden_layers: Vec::new(),
output_size,
output_activation: Activation::Softmax,
config: TrainingConfig::default(),
}
}
pub fn hidden(mut self, size: usize, activation: Activation) -> Self {
self.hidden_layers.push((size, activation));
self
}
pub fn output_activation(mut self, activation: Activation) -> Self {
self.output_activation = activation;
self
}
pub fn config(mut self, config: TrainingConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> Result<GpuTrainer> {
let mut builder = NetworkBuilder::new().input(self.input_size);
for (size, activation) in self.hidden_layers {
builder = builder.hidden(size, activation);
}
let network = builder
.output(self.output_size, self.output_activation)
.build()
.map_err(|e| TuneError::Training(format!("Network build failed: {e}")))?;
GpuTrainer::new(network, self.config)
}
}