extern crate fann_sys;
extern crate libc;
use fann_sys::*;
use libc::{c_float, c_int, c_uint};
use std::cell::RefCell;
use std::ffi::CString;
use std::mem::{forget, transmute};
use std::path::Path;
use std::ptr::{copy_nonoverlapping, null_mut};
pub use activation_func::ActivationFunc;
pub use cascade_params::CascadeParams;
pub use error::{FannError, FannErrorType, FannResult};
pub use error_func::ErrorFunc;
pub use net_type::NetType;
pub use stop_func::StopFunc;
pub use train_algorithm::TrainAlgorithm;
pub use train_algorithm::{BatchParams, IncrementalParams, QuickpropParams, RpropParams};
pub use train_data::TrainData;
mod activation_func;
mod cascade_params;
mod error;
mod error_func;
mod net_type;
mod stop_func;
mod train_algorithm;
mod train_data;
pub type FannType = fann_type;
pub type Connection = fann_connection;
fn to_filename<P: AsRef<Path>>(path: P) -> Result<CString, FannError> {
match path.as_ref().to_str().map(CString::new) {
None => Err(FannError {
error_type: FannErrorType::CantOpenTdR,
error_str: "File name contains invalid unicode characters".to_owned(),
}),
Some(Err(e)) => Err(FannError {
error_type: FannErrorType::CantOpenTdR,
error_str: format!(
"File name contains a nul byte at position {}",
e.nul_position()
),
}),
Some(Ok(cs)) => Ok(cs),
}
}
enum CurrentTrainData<'a> {
Own(FannResult<TrainData>),
Ref(&'a TrainData),
}
thread_local!(static TRAINER: RefCell<*mut FannTrainer<'static>> = RefCell::new(null_mut()));
#[derive(Clone, Copy, Debug)]
pub enum CallbackResult {
Stop,
Continue,
}
impl CallbackResult {
pub fn stop_if(condition: bool) -> CallbackResult {
if condition {
CallbackResult::Stop
} else {
CallbackResult::Continue
}
}
}
pub struct FannTrainer<'a> {
fann: &'a mut Fann,
cur_data: CurrentTrainData<'a>,
callback: Option<&'a Fn(&Fann, &TrainData, c_uint) -> CallbackResult>,
interval: c_uint,
cascade: bool,
}
impl<'a> FannTrainer<'a> {
fn with_data<'b>(fann: &'b mut Fann, data: &'b TrainData) -> FannTrainer<'b> {
FannTrainer {
fann,
cur_data: CurrentTrainData::Ref(data),
callback: None,
interval: 0,
cascade: false,
}
}
fn with_file<P: AsRef<Path>>(fann: &mut Fann, path: P) -> FannTrainer {
FannTrainer {
fann,
cur_data: CurrentTrainData::Own(TrainData::from_file(path)),
callback: None,
interval: 0,
cascade: false,
}
}
pub fn with_reports(self, interval: c_uint) -> FannTrainer<'a> {
FannTrainer { interval, ..self }
}
pub fn with_callback(
self,
interval: c_uint,
callback: &'a Fn(&Fann, &TrainData, c_uint) -> CallbackResult,
) -> FannTrainer<'a> {
FannTrainer {
callback: Some(callback),
interval,
..self
}
}
pub fn cascade(self) -> FannTrainer<'a> {
FannTrainer {
cascade: true,
..self
}
}
extern "C" fn raw_callback(
ann: *mut fann,
td: *mut fann_train_data,
_: c_uint,
_: c_uint,
_: c_float,
steps: c_uint,
) -> c_int {
match TRAINER.with(|cell| unsafe {
let trainer = *cell.borrow();
let data = (*trainer).get_data().unwrap();
assert_eq!(ann, (*trainer).fann.raw);
assert_eq!(td, data.get_raw());
let callback = (*trainer).callback.unwrap();
callback((*trainer).fann, data, steps)
}) {
CallbackResult::Stop => -1,
CallbackResult::Continue => 0,
}
}
fn get_data(&'a self) -> FannResult<&'a TrainData> {
match self.cur_data {
CurrentTrainData::Ref(data) => Ok(data),
CurrentTrainData::Own(ref result) => result.as_ref().map_err(FannError::clone),
}
}
#[cfg_attr(feature = "cargo-clippy", allow(useless_transmute))]
pub fn train(&mut self, max_steps: c_uint, desired_error: c_float) -> FannResult<()> {
unsafe {
let raw_data = try!(self.get_data()).get_raw();
if self.callback.is_some() {
TRAINER.with(|cell| *cell.borrow_mut() = transmute(&mut *self));
fann_set_callback(self.fann.raw, Some(FannTrainer::raw_callback));
}
let raw_train_fn = if self.cascade {
fann_cascadetrain_on_data
} else {
fann_train_on_data
};
raw_train_fn(
self.fann.raw,
raw_data,
max_steps,
self.interval,
desired_error,
);
if self.callback.is_some() {
fann_set_callback(self.fann.raw, None);
TRAINER.with(|cell| *cell.borrow_mut() = null_mut());
}
FannError::check_no_error(self.fann.raw as *mut fann_error)
}
}
}
pub struct Fann {
raw: *mut fann,
}
impl Fann {
unsafe fn from_raw(raw: *mut fann) -> FannResult<Fann> {
try!(FannError::check_no_error(raw as *mut fann_error));
Ok(Fann { raw })
}
pub fn new(layers: &[c_uint]) -> FannResult<Fann> {
Fann::new_sparse(1.0, layers)
}
pub fn new_sparse(connection_rate: c_float, layers: &[c_uint]) -> FannResult<Fann> {
unsafe {
Fann::from_raw(fann_create_sparse_array(
connection_rate,
layers.len() as c_uint,
layers.as_ptr(),
))
}
}
pub fn new_shortcut(layers: &[c_uint]) -> FannResult<Fann> {
unsafe {
Fann::from_raw(fann_create_shortcut_array(
layers.len() as c_uint,
layers.as_ptr(),
))
}
}
pub fn from_file<P: AsRef<Path>>(path: P) -> FannResult<Fann> {
let filename = try!(to_filename(path));
unsafe { Fann::from_raw(fann_create_from_file(filename.as_ptr())) }
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> FannResult<()> {
let filename = try!(to_filename(path));
unsafe {
let result = fann_save(self.raw, filename.as_ptr());
FannError::check_zero(result, self.raw as *mut fann_error, "Error saving network")
}
}
pub fn randomize_weights(&mut self, min_weight: FannType, max_weight: FannType) {
unsafe { fann_randomize_weights(self.raw, min_weight, max_weight) }
}
pub fn init_weights(&mut self, train_data: &TrainData) {
unsafe { fann_init_weights(self.raw, train_data.get_raw()) }
}
pub fn print_connections(&self) {
unsafe { fann_print_connections(self.raw) }
}
pub fn print_parameters(&self) {
unsafe { fann_print_parameters(self.raw) }
}
fn check_input_size(&self, input: &[FannType]) -> FannResult<()> {
let num_input = self.get_num_input() as usize;
if input.len() == num_input {
Ok(())
} else {
Err(FannError {
error_type: FannErrorType::IndexOutOfBound,
error_str: format!(
"Input has length {}, but there are {} input neurons",
input.len(),
num_input
),
})
}
}
fn check_output_size(&self, output: &[FannType]) -> FannResult<()> {
let num_output = self.get_num_output() as usize;
if output.len() == num_output {
Ok(())
} else {
Err(FannError {
error_type: FannErrorType::IndexOutOfBound,
error_str: format!(
"Output has length {}, but there are {} output neurons",
output.len(),
num_output
),
})
}
}
pub fn train(&mut self, input: &[FannType], desired_output: &[FannType]) -> FannResult<()> {
unsafe {
try!(self.check_input_size(input));
try!(self.check_output_size(desired_output));
fann_train(self.raw, input.as_ptr(), desired_output.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
}
Ok(())
}
pub fn on_data<'a>(&'a mut self, data: &'a TrainData) -> FannTrainer<'a> {
FannTrainer::with_data(self, data)
}
pub fn on_file<P: AsRef<Path>>(&mut self, path: P) -> FannTrainer {
FannTrainer::with_file(self, path)
}
pub fn train_epoch(&mut self, data: &TrainData) -> FannResult<c_float> {
unsafe {
let mse = fann_train_epoch(self.raw, data.get_raw());
try!(FannError::check_no_error(self.raw as *mut fann_error));
Ok(mse)
}
}
pub fn test(
&mut self,
input: &[FannType],
desired_output: &[FannType],
) -> FannResult<Vec<FannType>> {
try!(self.check_input_size(input));
try!(self.check_output_size(desired_output));
let num_output = self.get_num_output() as usize;
let mut result = Vec::with_capacity(num_output);
unsafe {
let output = fann_test(self.raw, input.as_ptr(), desired_output.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
copy_nonoverlapping(output, result.as_mut_ptr(), num_output);
result.set_len(num_output);
}
Ok(result)
}
pub fn test_data(&mut self, data: &TrainData) -> FannResult<c_float> {
unsafe {
let mse = fann_test_data(self.raw, data.get_raw());
try!(FannError::check_no_error(self.raw as *mut fann_error));
Ok(mse)
}
}
pub fn get_mse(&self) -> c_float {
unsafe { fann_get_MSE(self.raw) }
}
pub fn get_bit_fail(&self) -> c_uint {
unsafe { fann_get_bit_fail(self.raw) }
}
pub fn reset_mse_and_bit_fail(&mut self) {
unsafe {
fann_reset_MSE(self.raw);
}
}
pub fn run(&self, input: &[FannType]) -> FannResult<Vec<FannType>> {
try!(self.check_input_size(input));
let num_output = self.get_num_output() as usize;
let mut result = Vec::with_capacity(num_output);
unsafe {
let output = fann_run(self.raw, input.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
copy_nonoverlapping(output, result.as_mut_ptr(), num_output);
result.set_len(num_output);
}
Ok(result)
}
pub fn get_num_input(&self) -> c_uint {
unsafe { fann_get_num_input(self.raw) }
}
pub fn get_num_output(&self) -> c_uint {
unsafe { fann_get_num_output(self.raw) }
}
pub fn get_total_neurons(&self) -> c_uint {
unsafe { fann_get_total_neurons(self.raw) }
}
pub fn get_total_connections(&self) -> c_uint {
unsafe { fann_get_total_connections(self.raw) }
}
pub fn get_network_type(&self) -> NetType {
let nt_enum = unsafe { fann_get_network_type(self.raw) };
NetType::from_nettype_enum(nt_enum)
}
pub fn get_connection_rate(&self) -> c_float {
unsafe { fann_get_connection_rate(self.raw) }
}
pub fn get_num_layers(&self) -> c_uint {
unsafe { fann_get_num_layers(self.raw) }
}
pub fn get_layer_sizes(&self) -> Vec<c_uint> {
let num_layers = self.get_num_layers() as usize;
let mut result = Vec::with_capacity(num_layers);
unsafe {
fann_get_layer_array(self.raw, result.as_mut_ptr());
result.set_len(num_layers);
}
result
}
pub fn get_bias_counts(&self) -> Vec<c_uint> {
let num_layers = self.get_num_layers() as usize;
let mut result = Vec::with_capacity(num_layers);
unsafe {
fann_get_bias_array(self.raw, result.as_mut_ptr());
result.set_len(num_layers);
}
result
}
pub fn get_connections(&self) -> Vec<Connection> {
let total = self.get_total_connections() as usize;
let mut result = Vec::with_capacity(total);
unsafe {
fann_get_connection_array(self.raw, result.as_mut_ptr());
result.set_len(total);
}
result
}
pub fn set_connections<'a, I: IntoIterator<Item = &'a Connection>>(&mut self, connections: I) {
for c in connections {
self.set_weight(c.from_neuron, c.to_neuron, c.weight);
}
}
pub fn set_weight(&mut self, from_neuron: c_uint, to_neuron: c_uint, weight: FannType) {
unsafe { fann_set_weight(self.raw, from_neuron, to_neuron, weight) }
}
pub fn get_activation_func(&self, layer: c_int, neuron: c_int) -> FannResult<ActivationFunc> {
let af_enum = unsafe { fann_get_activation_function(self.raw, layer, neuron) };
unsafe { try!(FannError::check_no_error(self.raw as *mut fann_error)) };
ActivationFunc::from_fann_activationfunc_enum(af_enum)
}
pub fn set_activation_func(&mut self, af: ActivationFunc, layer: c_int, neuron: c_int) {
let af_enum = af.to_fann_activationfunc_enum();
unsafe { fann_set_activation_function(self.raw, af_enum, layer, neuron) }
}
pub fn set_activation_func_hidden(&mut self, activation_func: ActivationFunc) {
unsafe {
let af_enum = activation_func.to_fann_activationfunc_enum();
fann_set_activation_function_hidden(self.raw, af_enum);
}
}
pub fn set_activation_func_output(&mut self, activation_func: ActivationFunc) {
unsafe {
let af_enum = activation_func.to_fann_activationfunc_enum();
fann_set_activation_function_output(self.raw, af_enum)
}
}
#[cfg_attr(feature = "cargo-clippy", allow(float_cmp))]
pub fn get_activation_steepness(&self, layer: c_int, neuron: c_int) -> Option<FannType> {
let steepness = unsafe { fann_get_activation_steepness(self.raw, layer, neuron) };
if steepness == -1.0 {
return None;
}
Some(steepness)
}
pub fn set_activation_steepness(&self, steepness: FannType, layer: c_int, neuron: c_int) {
unsafe { fann_set_activation_steepness(self.raw, steepness, layer, neuron) }
}
pub fn set_activation_steepness_layer(&self, steepness: FannType, layer: c_int) {
unsafe { fann_set_activation_steepness_layer(self.raw, steepness, layer) }
}
pub fn set_activation_steepness_hidden(&self, steepness: FannType) {
unsafe { fann_set_activation_steepness_hidden(self.raw, steepness) }
}
pub fn set_activation_steepness_output(&self, steepness: FannType) {
unsafe { fann_set_activation_steepness_output(self.raw, steepness) }
}
pub fn get_error_func(&self) -> ErrorFunc {
let ef_enum = unsafe { fann_get_train_error_function(self.raw) };
ErrorFunc::from_errorfunc_enum(ef_enum)
}
pub fn set_error_func(&mut self, ef: ErrorFunc) {
let ef_enum = ef.to_errorfunc_enum();
unsafe { fann_set_train_error_function(self.raw, ef_enum) }
}
pub fn get_stop_func(&self) -> StopFunc {
let sf_enum = unsafe { fann_get_train_stop_function(self.raw) };
StopFunc::from_stopfunc_enum(sf_enum)
}
pub fn set_stop_func(&mut self, sf: StopFunc) {
let sf_enum = sf.to_stopfunc_enum();
unsafe { fann_set_train_stop_function(self.raw, sf_enum) }
}
pub fn get_bit_fail_limit(&self) -> FannType {
unsafe { fann_get_bit_fail_limit(self.raw) }
}
pub fn set_bit_fail_limit(&mut self, bit_fail_limit: FannType) {
unsafe { fann_set_bit_fail_limit(self.raw, bit_fail_limit) }
}
pub fn get_cascade_params(&self) -> CascadeParams {
unsafe {
let num_af = fann_get_cascade_activation_functions_count(self.raw) as usize;
let af_enum_ptr = fann_get_cascade_activation_functions(self.raw);
let af_enums = Vec::from_raw_parts(af_enum_ptr, num_af, num_af);
let activation_functions = af_enums
.iter()
.map(|&af_enum| ActivationFunc::from_fann_activationfunc_enum(af_enum).unwrap())
.collect();
forget(af_enums);
let num_st = fann_get_cascade_activation_steepnesses_count(self.raw) as usize;
let mut activation_steepnesses = Vec::with_capacity(num_st);
let st_ptr = fann_get_cascade_activation_steepnesses(self.raw);
copy_nonoverlapping(st_ptr, activation_steepnesses.as_mut_ptr(), num_st);
activation_steepnesses.set_len(num_st);
CascadeParams {
output_change_fraction: fann_get_cascade_output_change_fraction(self.raw),
output_stagnation_epochs: fann_get_cascade_output_stagnation_epochs(self.raw),
candidate_change_fraction: fann_get_cascade_candidate_change_fraction(self.raw),
candidate_stagnation_epochs: fann_get_cascade_candidate_stagnation_epochs(self.raw),
candidate_limit: fann_get_cascade_candidate_limit(self.raw),
weight_multiplier: fann_get_cascade_weight_multiplier(self.raw),
max_out_epochs: fann_get_cascade_max_out_epochs(self.raw),
max_cand_epochs: fann_get_cascade_max_cand_epochs(self.raw),
activation_functions,
activation_steepnesses,
num_candidate_groups: fann_get_cascade_num_candidate_groups(self.raw),
}
}
}
pub fn set_cascade_params(&mut self, params: &CascadeParams) {
let af_enums: Vec<_> = params
.activation_functions
.iter()
.map(|af| af.to_fann_activationfunc_enum())
.collect();
unsafe {
fann_set_cascade_output_change_fraction(self.raw, params.output_change_fraction);
fann_set_cascade_output_stagnation_epochs(self.raw, params.output_stagnation_epochs);
fann_set_cascade_candidate_change_fraction(self.raw, params.candidate_change_fraction);
fann_set_cascade_candidate_stagnation_epochs(
self.raw,
params.candidate_stagnation_epochs,
);
fann_set_cascade_candidate_limit(self.raw, params.candidate_limit);
fann_set_cascade_weight_multiplier(self.raw, params.weight_multiplier);
fann_set_cascade_max_out_epochs(self.raw, params.max_out_epochs);
fann_set_cascade_max_cand_epochs(self.raw, params.max_cand_epochs);
fann_set_cascade_activation_functions(
self.raw,
af_enums.as_ptr(),
af_enums.len() as c_uint,
);
fann_set_cascade_activation_steepnesses(
self.raw,
params.activation_steepnesses.as_ptr(),
params.activation_steepnesses.len() as c_uint,
);
fann_set_cascade_num_candidate_groups(self.raw, params.num_candidate_groups);
}
}
pub fn get_train_algorithm(&self) -> TrainAlgorithm {
let ft_enum = unsafe { fann_get_training_algorithm(self.raw) };
match ft_enum {
FANN_TRAIN_INCREMENTAL => unsafe {
TrainAlgorithm::Incremental(IncrementalParams {
learning_momentum: fann_get_learning_momentum(self.raw),
learning_rate: fann_get_learning_rate(self.raw),
})
},
FANN_TRAIN_BATCH => unsafe {
TrainAlgorithm::Batch(BatchParams {
learning_rate: fann_get_learning_rate(self.raw),
})
},
FANN_TRAIN_RPROP => unsafe {
TrainAlgorithm::Rprop(RpropParams {
decrease_factor: fann_get_rprop_decrease_factor(self.raw),
increase_factor: fann_get_rprop_increase_factor(self.raw),
delta_min: fann_get_rprop_delta_min(self.raw),
delta_max: fann_get_rprop_delta_max(self.raw),
delta_zero: fann_get_rprop_delta_zero(self.raw),
})
},
FANN_TRAIN_QUICKPROP => unsafe {
TrainAlgorithm::Quickprop(QuickpropParams {
decay: fann_get_quickprop_decay(self.raw),
mu: fann_get_quickprop_mu(self.raw),
learning_rate: fann_get_learning_rate(self.raw),
})
},
}
}
pub fn set_train_algorithm(&mut self, ta: TrainAlgorithm) {
match ta {
TrainAlgorithm::Incremental(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_INCREMENTAL);
fann_set_learning_momentum(self.raw, params.learning_momentum);
fann_set_learning_rate(self.raw, params.learning_rate);
},
TrainAlgorithm::Batch(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_BATCH);
fann_set_learning_rate(self.raw, params.learning_rate);
},
TrainAlgorithm::Rprop(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_RPROP);
fann_set_rprop_decrease_factor(self.raw, params.decrease_factor);
fann_set_rprop_increase_factor(self.raw, params.increase_factor);
fann_set_rprop_delta_min(self.raw, params.delta_min);
fann_set_rprop_delta_max(self.raw, params.delta_max);
fann_set_rprop_delta_zero(self.raw, params.delta_zero);
},
TrainAlgorithm::Quickprop(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_QUICKPROP);
fann_set_quickprop_decay(self.raw, params.decay);
fann_set_quickprop_mu(self.raw, params.mu);
fann_set_learning_rate(self.raw, params.learning_rate);
},
}
}
pub fn set_input_scaling_params(
&mut self,
data: &TrainData,
new_input_min: c_float,
new_input_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_input_scaling_params(
self.raw,
data.get_raw(),
new_input_min,
new_input_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
pub fn set_output_scaling_params(
&mut self,
data: &TrainData,
new_output_min: c_float,
new_output_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_output_scaling_params(
self.raw,
data.get_raw(),
new_output_min,
new_output_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
pub fn set_scaling_params(
&mut self,
data: &TrainData,
new_input_min: c_float,
new_input_max: c_float,
new_output_min: c_float,
new_output_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_scaling_params(
self.raw,
data.get_raw(),
new_input_min,
new_input_max,
new_output_min,
new_output_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
pub fn clear_scaling_params(&mut self) -> FannResult<()> {
unsafe {
FannError::check_zero(
fann_clear_scaling_params(self.raw),
self.raw as *mut fann_error,
"Error clearing scaling parameters",
)
}
}
pub fn scale_input(&self, input: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_scale_input(self.raw, input.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn scale_output(&self, output: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_scale_output(self.raw, output.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn descale_input(&self, input: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_descale_input(self.raw, input.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn descale_output(&self, output: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_descale_output(self.raw, output.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
}
impl Drop for Fann {
fn drop(&mut self) {
unsafe {
fann_destroy(self.raw);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use fann_sys;
use libc::c_uint;
use std::cell::RefCell;
use std::ptr::null_mut;
const EPSILON: FannType = 0.2;
#[test]
fn test_tutorial() {
let max_epochs = 500_000;
let desired_error = 0.0001;
let mut fann = Fann::new(&[2, 3, 1]).unwrap();
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
fann.on_file("test_files/xor.data")
.train(max_epochs, desired_error)
.unwrap();
assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs());
}
#[test]
fn test_activation_func() {
let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap();
unsafe {
fann_sys::fann_set_error_log(fann.raw as *mut fann_sys::fann_error, null_mut());
}
assert!(fann.get_activation_func(0, 1).is_err());
assert!(fann.get_activation_func(4, 1).is_err());
assert_eq!(
Ok(ActivationFunc::SigmoidStepwise),
fann.get_activation_func(2, 2)
);
fann.set_activation_func(ActivationFunc::Sin, 2, 2);
assert_eq!(Ok(ActivationFunc::Sin), fann.get_activation_func(2, 2));
}
#[test]
fn test_train_algorithm() {
let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap();
assert_eq!(TrainAlgorithm::default(), fann.get_train_algorithm());
let quickprop = TrainAlgorithm::Quickprop(QuickpropParams {
decay: -0.0002,
..Default::default()
});
fann.set_train_algorithm(quickprop);
assert_eq!(quickprop, fann.get_train_algorithm());
}
#[test]
fn test_layer_sizes() {
let fann = Fann::new(&[4, 3, 3, 1]).unwrap();
assert_eq!(vec![4, 3, 3, 1], fann.get_layer_sizes());
assert_eq!(vec![1, 1, 1, 0], fann.get_bias_counts());
}
#[test]
fn test_get_set_connections() {
let mut fann = Fann::new(&[1, 1]).unwrap();
let connection = Connection {
from_neuron: 1,
to_neuron: 2,
weight: 0.123,
};
fann.set_connections(&[connection]);
assert_eq!(2, fann.get_total_connections()); assert_eq!(connection, fann.get_connections()[1]);
}
#[test]
fn test_cascade_params() {
let fann = Fann::new(&[1, 1]).unwrap();
assert_eq!(CascadeParams::default(), fann.get_cascade_params());
}
#[test]
fn test_train_data_from_callback() {
let mut fann = Fann::new(&[2, 3, 1]).unwrap();
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
let td = TrainData::from_callback(
4,
2,
1,
Box::new(|num| match num {
0 => (vec![-1.0, 1.0], vec![1.0]),
1 => (vec![1.0, -1.0], vec![1.0]),
2 => (vec![-1.0, -1.0], vec![-1.0]),
3 => (vec![1.0, 1.0], vec![-1.0]),
_ => unreachable!(),
}),
).unwrap();
fann.on_data(&td).train(500_000, 0.0001).unwrap();
assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs());
}
#[test]
fn test_train_callback() {
let mut fann = Fann::new(&[2, 1]).unwrap();
fann.set_activation_func_output(ActivationFunc::LinearPiece);
let xor_data = TrainData::from_file("test_files/xor.data").unwrap();
let raw = fann.raw;
let callback_epochs = RefCell::new(Vec::new());
let cb = |fann: &Fann, train_data: &TrainData, epochs: c_uint| {
assert_eq!(raw, fann.raw);
unsafe {
assert_eq!(xor_data.get_raw(), train_data.get_raw());
}
callback_epochs.borrow_mut().push(epochs);
CallbackResult::stop_if(epochs == 40) };
fann.on_data(&xor_data)
.with_callback(10, &cb)
.train(100, 0.1)
.unwrap();
assert_eq!(vec![1, 10, 20, 30, 40], *callback_epochs.borrow());
}
}