//! A Rust wrapper for the Fast Artificial Neural Network library.
//!
//! A new neural network with random weights can be created with the `Fann::new` method, or, for
//! different network topologies, with its variants `Fann::new_sparse` and `Fann::new_shortcut`.
//! Existing neural networks can be saved to and loaded from files.
//!
//! Similarly, training data sets can be loaded from and saved to human-readable files, or training
//! data can be provided directly to the network as slices of floating point numbers.
//!
//! Example:
//!
//! ```
//! extern crate fann;
//! use fann::{ActivationFunc, Fann, TrainAlgorithm, QuickpropParams};
//!
//! fn main() {
//! // Create a new network with two input neurons, a hidden layer with three neurons, and one
//! // output neuron.
//! let mut fann = Fann::new(&[2, 3, 1]).unwrap();
//! // Configure the activation functions for the hidden and output neurons.
//! fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
//! fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
//! // Use the Quickprop learning algorithm, with default parameters.
//! // (Otherwise, Rprop would be used.)
//! fann.set_train_algorithm(TrainAlgorithm::Quickprop(Default::default()));
//! // Train for up to 500000 epochs, displaying progress information after intervals of 1000
//! // epochs. Stop when the network's error on the training data drops to 0.001.
//! let max_epochs = 500000;
//! let epochs_between_reports = 1000;
//! let desired_error = 0.001;
//! // Train directly on data loaded from the file "xor.data".
//! fann.on_file("test_files/xor.data")
//! .with_reports(epochs_between_reports)
//! .train(max_epochs, desired_error).unwrap();
//! // The network now approximates the XOR problem:
//! assert!(fann.run(&[-1.0, 1.0]).unwrap()[0] > 0.9);
//! assert!(fann.run(&[ 1.0, -1.0]).unwrap()[0] > 0.9);
//! assert!(fann.run(&[ 1.0, 1.0]).unwrap()[0] < 0.1);
//! assert!(fann.run(&[-1.0, -1.0]).unwrap()[0] < 0.1);
//! }
//! ```
//!
//! FANN also supports cascade training, where the network's topology is changed during training by
//! adding additional neurons:
//!
//! ```
//! extern crate fann;
//! use fann::{ActivationFunc, CascadeParams, Fann};
//!
//! fn main() {
//! // Create a new network with two input neurons and one output neuron.
//! let mut fann = Fann::new_shortcut(&[2, 1]).unwrap();
//! // Use the default cascade training parameters, but a higher weight multiplier:
//! fann.set_cascade_params(&CascadeParams {
//! weight_multiplier: 0.6,
//! ..CascadeParams::default()
//! });
//! // Add up to 50 neurons, displaying progress information after each.
//! // Stop when the network's error on the training data drops to 0.001.
//! let max_neurons = 50;
//! let neurons_between_reports = 1;
//! let desired_error = 0.001;
//! // Train directly on data loaded from the file "xor.data".
//! fann.on_file("test_files/xor.data")
//! .with_reports(neurons_between_reports)
//! .cascade()
//! .train(max_neurons, desired_error).unwrap();
//! // The network now approximates the XOR problem:
//! assert!(fann.run(&[-1.0, 1.0]).unwrap()[0] > 0.9);
//! assert!(fann.run(&[ 1.0, -1.0]).unwrap()[0] > 0.9);
//! assert!(fann.run(&[ 1.0, 1.0]).unwrap()[0] < 0.1);
//! assert!(fann.run(&[-1.0, -1.0]).unwrap()[0] < 0.1);
//! }
//! ```
extern crate fann_sys;
extern crate libc;
use fann_sys::*;
use libc::{c_float, c_int, c_uint};
use std::cell::RefCell;
use std::ffi::CString;
use std::mem::{forget, transmute};
use std::path::Path;
use std::ptr::{copy_nonoverlapping, null_mut};
pub use activation_func::ActivationFunc;
pub use cascade_params::CascadeParams;
pub use error::{FannError, FannErrorType, FannResult};
pub use error_func::ErrorFunc;
pub use net_type::NetType;
pub use stop_func::StopFunc;
pub use train_algorithm::TrainAlgorithm;
pub use train_algorithm::{BatchParams, IncrementalParams, QuickpropParams, RpropParams};
pub use train_data::TrainData;
mod activation_func;
mod cascade_params;
mod error;
mod error_func;
mod net_type;
mod stop_func;
mod train_algorithm;
mod train_data;
/// The type of weights, inputs and outputs in a neural network. It is defined as `c_float` by
/// default, and as `c_double` if the `double` feature is configured.
pub type FannType = fann_type;
pub type Connection = fann_connection;
/// Convert a path to a `CString`.
fn to_filename<P: AsRef<Path>>(path: P) -> Result<CString, FannError> {
match path.as_ref().to_str().map(CString::new) {
None => Err(FannError {
error_type: FannErrorType::CantOpenTdR,
error_str: "File name contains invalid unicode characters".to_owned(),
}),
Some(Err(e)) => Err(FannError {
error_type: FannErrorType::CantOpenTdR,
error_str: format!(
"File name contains a nul byte at position {}",
e.nul_position()
),
}),
Some(Ok(cs)) => Ok(cs),
}
}
/// Either an owned or a borrowed `TrainData`.
enum CurrentTrainData<'a> {
Own(FannResult<TrainData>),
Ref(&'a TrainData),
}
// Thread-local container for a pointer to the current FannTrainer.
// This is necessary because the raw fann_train_on_data_with_callback C function takes a function
// pointer and not a closure. So instead of the user-supplied function we pass a function to it
// which will call the callback stored in the trainer.
// The 'static lifetime is a lie! But the trainer lives longer than the train method runs, and
// afterwards resets this pointer to null again.
thread_local!(static TRAINER: RefCell<*mut FannTrainer<'static>> = RefCell::new(null_mut()));
#[derive(Clone, Copy, Debug)]
pub enum CallbackResult {
Stop,
Continue,
}
impl CallbackResult {
pub fn stop_if(condition: bool) -> CallbackResult {
if condition {
CallbackResult::Stop
} else {
CallbackResult::Continue
}
}
}
/// A training configuration. Create this with `Fann::on_data` or `Fann::on_file` and run the
/// training with `train`.
pub struct FannTrainer<'a> {
fann: &'a mut Fann,
cur_data: CurrentTrainData<'a>,
callback: Option<&'a Fn(&Fann, &TrainData, c_uint) -> CallbackResult>,
interval: c_uint,
cascade: bool,
}
impl<'a> FannTrainer<'a> {
fn with_data<'b>(fann: &'b mut Fann, data: &'b TrainData) -> FannTrainer<'b> {
FannTrainer {
fann,
cur_data: CurrentTrainData::Ref(data),
callback: None,
interval: 0,
cascade: false,
}
}
fn with_file<P: AsRef<Path>>(fann: &mut Fann, path: P) -> FannTrainer {
FannTrainer {
fann,
cur_data: CurrentTrainData::Own(TrainData::from_file(path)),
callback: None,
interval: 0,
cascade: false,
}
}
/// Activates printing reports periodically. Between two reports, `interval` neurons are added
/// (for cascade training) or training goes on for `interval` epochs (otherwise).
pub fn with_reports(self, interval: c_uint) -> FannTrainer<'a> {
FannTrainer { interval, ..self }
}
/// Configures a callback to be called periodically during training. Every `interval` epochs
/// (for regular training) or every time `interval` new neurons have been added (for cascade
/// training), the callback runs. It receives as arguments:
///
/// * a reference to the current `Fann`,
/// * a reference to the training data,
/// * the number of steps (added neurons or epochs) taken so far.
pub fn with_callback(
self,
interval: c_uint,
callback: &'a Fn(&Fann, &TrainData, c_uint) -> CallbackResult,
) -> FannTrainer<'a> {
FannTrainer {
callback: Some(callback),
interval,
..self
}
}
/// Use the Cascade2 algorithm: This adds neurons to the neural network while training, starting
/// with an ANN without any hidden layers. The network should use shortcut connections, so it
/// needs to be created like this:
///
/// ```
/// let td = fann::TrainData::from_file("test_files/xor.data").unwrap();
/// let fann = fann::Fann::new_shortcut(&[td.num_input(), td.num_output()]).unwrap();
/// ```
pub fn cascade(self) -> FannTrainer<'a> {
FannTrainer {
cascade: true,
..self
}
}
extern "C" fn raw_callback(
ann: *mut fann,
td: *mut fann_train_data,
_: c_uint,
_: c_uint,
_: c_float,
steps: c_uint,
) -> c_int {
// TODO: This is an ugly hack - find better ways to solve the following issues:
// * The C callback is not a closure, so it cannot access the user-supplied argument.
// https://aatch.github.io/blog/2015/01/17/unboxed-closures-and-ffi-callbacks doesn't
// work here because the C callback doesn't take a user-defined pointer as an argument.
// Instead, we store a pointer to the FannTrainer, which contains a fat pointer to the
// callback, in a thread-local variable that is accessed by the raw callback.
// * The lifetime isn't known at the point where the thread-local variable is declared, so
// we just use 'static and transmute the pointer!
// * The C callback is only given pointers to the raw structs, not to self and data. We
// read these from the tread-local variable, too, and assert that they correspond to the
// given raw structs.
// * https://github.com/rust-lang/rust/issues/24010 seems to make it impossible to define a
// trait that would act as a shortcut for Fn(...) -> CallbackResult.
match TRAINER.with(|cell| unsafe {
let trainer = *cell.borrow();
let data = (*trainer).get_data().unwrap();
assert_eq!(ann, (*trainer).fann.raw);
assert_eq!(td, data.get_raw());
let callback = (*trainer).callback.unwrap();
callback((*trainer).fann, data, steps)
}) {
CallbackResult::Stop => -1,
CallbackResult::Continue => 0,
}
}
fn get_data(&'a self) -> FannResult<&'a TrainData> {
match self.cur_data {
CurrentTrainData::Ref(data) => Ok(data),
CurrentTrainData::Own(ref result) => result.as_ref().map_err(FannError::clone),
}
}
/// Train the network until either the mean square error drops below the `desired_error`, or
/// the maximum number of steps is reached. If cascade training is activated, `max_steps`
/// refers to the number of neurons that are added, otherwise it is the maximum number of
/// training epochs.
// Clippy's suggestion fails: https://github.com/rust-lang-nursery/rust-clippy/issues/3340
#[cfg_attr(feature = "cargo-clippy", allow(useless_transmute))]
pub fn train(&mut self, max_steps: c_uint, desired_error: c_float) -> FannResult<()> {
unsafe {
let raw_data = try!(self.get_data()).get_raw();
if self.callback.is_some() {
TRAINER.with(|cell| *cell.borrow_mut() = transmute(&mut *self));
fann_set_callback(self.fann.raw, Some(FannTrainer::raw_callback));
}
let raw_train_fn = if self.cascade {
fann_cascadetrain_on_data
} else {
fann_train_on_data
};
raw_train_fn(
self.fann.raw,
raw_data,
max_steps,
self.interval,
desired_error,
);
if self.callback.is_some() {
fann_set_callback(self.fann.raw, None);
TRAINER.with(|cell| *cell.borrow_mut() = null_mut());
}
FannError::check_no_error(self.fann.raw as *mut fann_error)
}
}
}
pub struct Fann {
// We don't consider setting and clearing the error string and number a mutation, and every
// method should leave these fields cleared, either because it succeeded or because it read the
// fields and returned the corresponding error.
// We also don't consider writing the output data a mutation, as we don't provide access to it
// and copy it before returning it.
raw: *mut fann,
}
impl Fann {
unsafe fn from_raw(raw: *mut fann) -> FannResult<Fann> {
try!(FannError::check_no_error(raw as *mut fann_error));
Ok(Fann { raw })
}
/// Create a fully connected neural network.
///
/// There will be a bias neuron in each layer except the output layer,
/// and this bias neuron will be connected to all neurons in the next layer.
/// When running the network, the bias nodes always emit 1.
///
/// # Arguments
///
/// * `layers` - Specifies the number of neurons in each layer, starting with the input and
/// ending with the output layer.
///
/// # Example
///
/// ```
/// // Creating a network with 2 input neurons, 1 output neuron,
/// // and two hidden layers with 8 and 9 neurons.
/// let layers = [2, 8, 9, 1];
/// fann::Fann::new(&layers).unwrap();
/// ```
pub fn new(layers: &[c_uint]) -> FannResult<Fann> {
Fann::new_sparse(1.0, layers)
}
/// Create a neural network that is not necessarily fully connected.
///
/// There will be a bias neuron in each layer except the output layer,
/// and this bias neuron will be connected to all neurons in the next layer.
/// When running the network, the bias nodes always emit 1.
///
/// # Arguments
///
/// * `connection_rate` - The share of pairs of neurons in consecutive layers that will be
/// connected.
/// * `layers` - Specifies the number of neurons in each layer, starting with the input
/// and ending with the output layer.
pub fn new_sparse(connection_rate: c_float, layers: &[c_uint]) -> FannResult<Fann> {
unsafe {
Fann::from_raw(fann_create_sparse_array(
connection_rate,
layers.len() as c_uint,
layers.as_ptr(),
))
}
}
/// Create a neural network which has shortcut connections, i. e. it doesn't connect only each
/// layer to its successor, but every layer with every later layer: Each neuron has connections
/// to all neurons in all subsequent layers.
pub fn new_shortcut(layers: &[c_uint]) -> FannResult<Fann> {
unsafe {
Fann::from_raw(fann_create_shortcut_array(
layers.len() as c_uint,
layers.as_ptr(),
))
}
}
/// Read a neural network from a file.
pub fn from_file<P: AsRef<Path>>(path: P) -> FannResult<Fann> {
let filename = try!(to_filename(path));
unsafe { Fann::from_raw(fann_create_from_file(filename.as_ptr())) }
}
/// Save the network to a configuration file.
///
/// The file will contain all information about the neural network, except parameters generated
/// during training, like mean square error and the bit fail limit.
pub fn save<P: AsRef<Path>>(&self, path: P) -> FannResult<()> {
let filename = try!(to_filename(path));
unsafe {
let result = fann_save(self.raw, filename.as_ptr());
FannError::check_zero(result, self.raw as *mut fann_error, "Error saving network")
}
}
/// Give each connection a random weight between `min_weight` and `max_weight`.
///
/// By default, weights in a new network are random between -0.1 and 0.1.
pub fn randomize_weights(&mut self, min_weight: FannType, max_weight: FannType) {
unsafe { fann_randomize_weights(self.raw, min_weight, max_weight) }
}
/// Initialize the weights using Widrow & Nguyen's algorithm.
///
/// The algorithm developed by Derrick Nguyen and Bernard Widrow sets the weight in a way that
/// can speed up training with the given training data. This technique is not always successful
/// and in some cases can even be less efficient that a purely random initialization.
pub fn init_weights(&mut self, train_data: &TrainData) {
unsafe { fann_init_weights(self.raw, train_data.get_raw()) }
}
/// Print the connections of the network in a compact matrix, for easy viewing of its
/// internals.
///
/// The output on a small (2 2 1) network trained on the xor problem:
///
/// ```text
/// Layer / Neuron 012345
/// L 1 / N 3 BBa...
/// L 1 / N 4 BBA...
/// L 1 / N 5 ......
/// L 2 / N 6 ...BBA
/// L 2 / N 7 ......
/// ```
///
/// This network has five real neurons and two bias neurons. This gives a total of seven
/// neurons named from 0 to 6. The connections between these neurons can be seen in the matrix.
/// "." is a place where there is no connection, while a character tells how strong the
/// connection is on a scale from a-z. The two real neurons in the hidden layer (neuron 3 and 4
/// in layer 1) have connections from the three neurons in the previous layer as is visible in
/// the first two lines. The output neuron 6 has connections from the three neurons in the
/// hidden layer 3 - 5 as is visible in the fourth line.
///
/// To simplify the matrix output neurons are not visible as neurons that connections can come
/// from, and input and bias neurons are not visible as neurons that connections can go to.
pub fn print_connections(&self) {
unsafe { fann_print_connections(self.raw) }
}
/// Print all parameters and options of the network.
pub fn print_parameters(&self) {
unsafe { fann_print_parameters(self.raw) }
}
/// Return an `Err` if the size of the slice does not match the number of input neurons,
/// otherwise `Ok(())`.
fn check_input_size(&self, input: &[FannType]) -> FannResult<()> {
let num_input = self.get_num_input() as usize;
if input.len() == num_input {
Ok(())
} else {
Err(FannError {
error_type: FannErrorType::IndexOutOfBound,
error_str: format!(
"Input has length {}, but there are {} input neurons",
input.len(),
num_input
),
})
}
}
/// Return an `Err` if the size of the slice does not match the number of output neurons,
/// otherwise `Ok(())`.
fn check_output_size(&self, output: &[FannType]) -> FannResult<()> {
let num_output = self.get_num_output() as usize;
if output.len() == num_output {
Ok(())
} else {
Err(FannError {
error_type: FannErrorType::IndexOutOfBound,
error_str: format!(
"Output has length {}, but there are {} output neurons",
output.len(),
num_output
),
})
}
}
/// Train with a single pair of input and output. This is always incremental training (see
/// `TrainAlg`), since only one pattern is presented.
pub fn train(&mut self, input: &[FannType], desired_output: &[FannType]) -> FannResult<()> {
unsafe {
try!(self.check_input_size(input));
try!(self.check_output_size(desired_output));
fann_train(self.raw, input.as_ptr(), desired_output.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
}
Ok(())
}
/// Create a training configuration for the given data set.
pub fn on_data<'a>(&'a mut self, data: &'a TrainData) -> FannTrainer<'a> {
FannTrainer::with_data(self, data)
}
/// Create a training configuration, reading the training data from the given file.
pub fn on_file<P: AsRef<Path>>(&mut self, path: P) -> FannTrainer {
FannTrainer::with_file(self, path)
}
/// Train one epoch with a set of training data, i. e. each sample from the training data is
/// considered exactly once.
///
/// Returns the mean square error as it is calculated either before or during the actual
/// training. This is not the actual MSE after the training epoch, but since calculating this
/// will require to go through the entire training set once more, it is more than adequate to
/// use this value during training.
pub fn train_epoch(&mut self, data: &TrainData) -> FannResult<c_float> {
unsafe {
let mse = fann_train_epoch(self.raw, data.get_raw());
try!(FannError::check_no_error(self.raw as *mut fann_error));
Ok(mse)
}
}
/// Test with a single pair of input and output. This operation updates the mean square error
/// but does not change the network.
///
/// Returns the actual output of the network.
pub fn test(
&mut self,
input: &[FannType],
desired_output: &[FannType],
) -> FannResult<Vec<FannType>> {
try!(self.check_input_size(input));
try!(self.check_output_size(desired_output));
let num_output = self.get_num_output() as usize;
let mut result = Vec::with_capacity(num_output);
unsafe {
let output = fann_test(self.raw, input.as_ptr(), desired_output.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
copy_nonoverlapping(output, result.as_mut_ptr(), num_output);
result.set_len(num_output);
}
Ok(result)
}
/// Test with a training data set and calculate the mean square error.
pub fn test_data(&mut self, data: &TrainData) -> FannResult<c_float> {
unsafe {
let mse = fann_test_data(self.raw, data.get_raw());
try!(FannError::check_no_error(self.raw as *mut fann_error));
Ok(mse)
}
}
/// Get the mean square error.
pub fn get_mse(&self) -> c_float {
unsafe { fann_get_MSE(self.raw) }
}
/// Get the number of fail bits, i. e. the number of neurons which differed from the desired
/// output by more than the bit fail limit since the previous reset.
pub fn get_bit_fail(&self) -> c_uint {
unsafe { fann_get_bit_fail(self.raw) }
}
/// Reset the mean square error and bit fail count.
pub fn reset_mse_and_bit_fail(&mut self) {
unsafe {
fann_reset_MSE(self.raw);
}
}
/// Run the input through the neural network and returns the output. The length of the input
/// must equal the number of input neurons and the length of the output will equal the number
/// of output neurons.
pub fn run(&self, input: &[FannType]) -> FannResult<Vec<FannType>> {
try!(self.check_input_size(input));
let num_output = self.get_num_output() as usize;
let mut result = Vec::with_capacity(num_output);
unsafe {
let output = fann_run(self.raw, input.as_ptr());
try!(FannError::check_no_error(self.raw as *mut fann_error));
copy_nonoverlapping(output, result.as_mut_ptr(), num_output);
result.set_len(num_output);
}
Ok(result)
}
/// Get the number of input neurons.
pub fn get_num_input(&self) -> c_uint {
unsafe { fann_get_num_input(self.raw) }
}
/// Get the number of output neurons.
pub fn get_num_output(&self) -> c_uint {
unsafe { fann_get_num_output(self.raw) }
}
/// Get the total number of neurons, including the bias neurons.
///
/// E. g. a 2-4-2 network has 3 + 5 + 2 = 10 neurons (because two layers have bias neurons).
pub fn get_total_neurons(&self) -> c_uint {
unsafe { fann_get_total_neurons(self.raw) }
}
/// Get the total number of connections.
pub fn get_total_connections(&self) -> c_uint {
unsafe { fann_get_total_connections(self.raw) }
}
/// Get the type of the neural network.
pub fn get_network_type(&self) -> NetType {
let nt_enum = unsafe { fann_get_network_type(self.raw) };
NetType::from_nettype_enum(nt_enum)
}
/// Get the connection rate used when the network was created.
pub fn get_connection_rate(&self) -> c_float {
unsafe { fann_get_connection_rate(self.raw) }
}
/// Get the number of layers in the network.
pub fn get_num_layers(&self) -> c_uint {
unsafe { fann_get_num_layers(self.raw) }
}
/// Get the number of neurons in each layer of the network.
pub fn get_layer_sizes(&self) -> Vec<c_uint> {
let num_layers = self.get_num_layers() as usize;
let mut result = Vec::with_capacity(num_layers);
unsafe {
fann_get_layer_array(self.raw, result.as_mut_ptr());
result.set_len(num_layers);
}
result
}
/// Get the number of bias neurons in each layer of the network.
pub fn get_bias_counts(&self) -> Vec<c_uint> {
let num_layers = self.get_num_layers() as usize;
let mut result = Vec::with_capacity(num_layers);
unsafe {
fann_get_bias_array(self.raw, result.as_mut_ptr());
result.set_len(num_layers);
}
result
}
/// Get a list of all connections in the network.
pub fn get_connections(&self) -> Vec<Connection> {
let total = self.get_total_connections() as usize;
let mut result = Vec::with_capacity(total);
unsafe {
fann_get_connection_array(self.raw, result.as_mut_ptr());
result.set_len(total);
}
result
}
/// Set the weights of all given connections.
///
/// Connections that don't already exist are ignored.
pub fn set_connections<'a, I: IntoIterator<Item = &'a Connection>>(&mut self, connections: I) {
for c in connections {
self.set_weight(c.from_neuron, c.to_neuron, c.weight);
}
}
/// Set the weight of the given connection.
pub fn set_weight(&mut self, from_neuron: c_uint, to_neuron: c_uint, weight: FannType) {
unsafe { fann_set_weight(self.raw, from_neuron, to_neuron, weight) }
}
/// Get the activation function for neuron number `neuron` in layer number `layer`, counting
/// the input layer as number 0. Input layer neurons do not have an activation function, so
/// `layer` must be at least 1.
pub fn get_activation_func(&self, layer: c_int, neuron: c_int) -> FannResult<ActivationFunc> {
let af_enum = unsafe { fann_get_activation_function(self.raw, layer, neuron) };
unsafe { try!(FannError::check_no_error(self.raw as *mut fann_error)) };
ActivationFunc::from_fann_activationfunc_enum(af_enum)
}
/// Set the activation function for neuron number `neuron` in layer number `layer`, counting
/// the input layer as number 0. Input layer neurons do not have an activation function, so
/// `layer` must be at least 1.
pub fn set_activation_func(&mut self, af: ActivationFunc, layer: c_int, neuron: c_int) {
let af_enum = af.to_fann_activationfunc_enum();
unsafe { fann_set_activation_function(self.raw, af_enum, layer, neuron) }
}
/// Set the activation function for all hidden layers.
pub fn set_activation_func_hidden(&mut self, activation_func: ActivationFunc) {
unsafe {
let af_enum = activation_func.to_fann_activationfunc_enum();
fann_set_activation_function_hidden(self.raw, af_enum);
}
}
/// Set the activation function for the output layer.
pub fn set_activation_func_output(&mut self, activation_func: ActivationFunc) {
unsafe {
let af_enum = activation_func.to_fann_activationfunc_enum();
fann_set_activation_function_output(self.raw, af_enum)
}
}
/// Get the activation steepness for neuron number `neuron` in layer number `layer`.
#[cfg_attr(feature = "cargo-clippy", allow(float_cmp))]
pub fn get_activation_steepness(&self, layer: c_int, neuron: c_int) -> Option<FannType> {
let steepness = unsafe { fann_get_activation_steepness(self.raw, layer, neuron) };
// This returns exactly -1 if the neuron is not defined.
if steepness == -1.0 {
return None;
}
Some(steepness)
}
/// Set the activation steepness for neuron number `neuron` in layer number `layer`, counting
/// the input layer as number 0. Input layer neurons do not have an activation steepness, so
/// layer must be at least 1.
///
/// The steepness determines how fast the function goes from minimum to maximum. A higher value
/// will result in more aggressive training.
///
/// A steep activation function is adequate if outputs are binary, e. e. they are supposed to
/// be either almost 0 or almost 1.
///
/// The default value is 0.5.
pub fn set_activation_steepness(&self, steepness: FannType, layer: c_int, neuron: c_int) {
unsafe { fann_set_activation_steepness(self.raw, steepness, layer, neuron) }
}
/// Set the activation steepness for layer number `layer`.
pub fn set_activation_steepness_layer(&self, steepness: FannType, layer: c_int) {
unsafe { fann_set_activation_steepness_layer(self.raw, steepness, layer) }
}
/// Set the activation steepness for all hidden layers.
pub fn set_activation_steepness_hidden(&self, steepness: FannType) {
unsafe { fann_set_activation_steepness_hidden(self.raw, steepness) }
}
/// Set the activation steepness for the output layer.
pub fn set_activation_steepness_output(&self, steepness: FannType) {
unsafe { fann_set_activation_steepness_output(self.raw, steepness) }
}
/// Get the error function used during training.
pub fn get_error_func(&self) -> ErrorFunc {
let ef_enum = unsafe { fann_get_train_error_function(self.raw) };
ErrorFunc::from_errorfunc_enum(ef_enum)
}
/// Set the error function used during training.
///
/// The default is `Tanh`.
pub fn set_error_func(&mut self, ef: ErrorFunc) {
let ef_enum = ef.to_errorfunc_enum();
unsafe { fann_set_train_error_function(self.raw, ef_enum) }
}
/// Get the stop criterion for training.
pub fn get_stop_func(&self) -> StopFunc {
let sf_enum = unsafe { fann_get_train_stop_function(self.raw) };
StopFunc::from_stopfunc_enum(sf_enum)
}
/// Set the stop criterion for training.
///
/// The default is `Mse`.
pub fn set_stop_func(&mut self, sf: StopFunc) {
let sf_enum = sf.to_stopfunc_enum();
unsafe { fann_set_train_stop_function(self.raw, sf_enum) }
}
/// Get the bit fail limit.
pub fn get_bit_fail_limit(&self) -> FannType {
unsafe { fann_get_bit_fail_limit(self.raw) }
}
/// Set the bit fail limit.
///
/// Each output neuron value that differs from the desired output by more than the bit fail
/// limit is counted as a failed bit.
pub fn set_bit_fail_limit(&mut self, bit_fail_limit: FannType) {
unsafe { fann_set_bit_fail_limit(self.raw, bit_fail_limit) }
}
/// Get cascade training parameters.
pub fn get_cascade_params(&self) -> CascadeParams {
unsafe {
let num_af = fann_get_cascade_activation_functions_count(self.raw) as usize;
let af_enum_ptr = fann_get_cascade_activation_functions(self.raw);
let af_enums = Vec::from_raw_parts(af_enum_ptr, num_af, num_af);
let activation_functions = af_enums
.iter()
.map(|&af_enum| ActivationFunc::from_fann_activationfunc_enum(af_enum).unwrap())
.collect();
forget(af_enums);
let num_st = fann_get_cascade_activation_steepnesses_count(self.raw) as usize;
let mut activation_steepnesses = Vec::with_capacity(num_st);
let st_ptr = fann_get_cascade_activation_steepnesses(self.raw);
copy_nonoverlapping(st_ptr, activation_steepnesses.as_mut_ptr(), num_st);
activation_steepnesses.set_len(num_st);
CascadeParams {
output_change_fraction: fann_get_cascade_output_change_fraction(self.raw),
output_stagnation_epochs: fann_get_cascade_output_stagnation_epochs(self.raw),
candidate_change_fraction: fann_get_cascade_candidate_change_fraction(self.raw),
candidate_stagnation_epochs: fann_get_cascade_candidate_stagnation_epochs(self.raw),
candidate_limit: fann_get_cascade_candidate_limit(self.raw),
weight_multiplier: fann_get_cascade_weight_multiplier(self.raw),
max_out_epochs: fann_get_cascade_max_out_epochs(self.raw),
max_cand_epochs: fann_get_cascade_max_cand_epochs(self.raw),
activation_functions,
activation_steepnesses,
num_candidate_groups: fann_get_cascade_num_candidate_groups(self.raw),
}
}
}
/// Set cascade training parameters.
pub fn set_cascade_params(&mut self, params: &CascadeParams) {
let af_enums: Vec<_> = params
.activation_functions
.iter()
.map(|af| af.to_fann_activationfunc_enum())
.collect();
unsafe {
fann_set_cascade_output_change_fraction(self.raw, params.output_change_fraction);
fann_set_cascade_output_stagnation_epochs(self.raw, params.output_stagnation_epochs);
fann_set_cascade_candidate_change_fraction(self.raw, params.candidate_change_fraction);
fann_set_cascade_candidate_stagnation_epochs(
self.raw,
params.candidate_stagnation_epochs,
);
fann_set_cascade_candidate_limit(self.raw, params.candidate_limit);
fann_set_cascade_weight_multiplier(self.raw, params.weight_multiplier);
fann_set_cascade_max_out_epochs(self.raw, params.max_out_epochs);
fann_set_cascade_max_cand_epochs(self.raw, params.max_cand_epochs);
fann_set_cascade_activation_functions(
self.raw,
af_enums.as_ptr(),
af_enums.len() as c_uint,
);
fann_set_cascade_activation_steepnesses(
self.raw,
params.activation_steepnesses.as_ptr(),
params.activation_steepnesses.len() as c_uint,
);
fann_set_cascade_num_candidate_groups(self.raw, params.num_candidate_groups);
}
}
/// Get the currently configured training algorithm.
pub fn get_train_algorithm(&self) -> TrainAlgorithm {
let ft_enum = unsafe { fann_get_training_algorithm(self.raw) };
match ft_enum {
FANN_TRAIN_INCREMENTAL => unsafe {
TrainAlgorithm::Incremental(IncrementalParams {
learning_momentum: fann_get_learning_momentum(self.raw),
learning_rate: fann_get_learning_rate(self.raw),
})
},
FANN_TRAIN_BATCH => unsafe {
TrainAlgorithm::Batch(BatchParams {
learning_rate: fann_get_learning_rate(self.raw),
})
},
FANN_TRAIN_RPROP => unsafe {
TrainAlgorithm::Rprop(RpropParams {
decrease_factor: fann_get_rprop_decrease_factor(self.raw),
increase_factor: fann_get_rprop_increase_factor(self.raw),
delta_min: fann_get_rprop_delta_min(self.raw),
delta_max: fann_get_rprop_delta_max(self.raw),
delta_zero: fann_get_rprop_delta_zero(self.raw),
})
},
FANN_TRAIN_QUICKPROP => unsafe {
TrainAlgorithm::Quickprop(QuickpropParams {
decay: fann_get_quickprop_decay(self.raw),
mu: fann_get_quickprop_mu(self.raw),
learning_rate: fann_get_learning_rate(self.raw),
})
},
}
}
/// Set the algorithm to be used for training.
pub fn set_train_algorithm(&mut self, ta: TrainAlgorithm) {
match ta {
TrainAlgorithm::Incremental(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_INCREMENTAL);
fann_set_learning_momentum(self.raw, params.learning_momentum);
fann_set_learning_rate(self.raw, params.learning_rate);
},
TrainAlgorithm::Batch(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_BATCH);
fann_set_learning_rate(self.raw, params.learning_rate);
},
TrainAlgorithm::Rprop(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_RPROP);
fann_set_rprop_decrease_factor(self.raw, params.decrease_factor);
fann_set_rprop_increase_factor(self.raw, params.increase_factor);
fann_set_rprop_delta_min(self.raw, params.delta_min);
fann_set_rprop_delta_max(self.raw, params.delta_max);
fann_set_rprop_delta_zero(self.raw, params.delta_zero);
},
TrainAlgorithm::Quickprop(params) => unsafe {
fann_set_training_algorithm(self.raw, FANN_TRAIN_QUICKPROP);
fann_set_quickprop_decay(self.raw, params.decay);
fann_set_quickprop_mu(self.raw, params.mu);
fann_set_learning_rate(self.raw, params.learning_rate);
},
}
}
/// Calculate input scaling parameters for future use based on the given training data.
pub fn set_input_scaling_params(
&mut self,
data: &TrainData,
new_input_min: c_float,
new_input_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_input_scaling_params(
self.raw,
data.get_raw(),
new_input_min,
new_input_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
/// Calculate output scaling parameters for future use based on the given training data.
pub fn set_output_scaling_params(
&mut self,
data: &TrainData,
new_output_min: c_float,
new_output_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_output_scaling_params(
self.raw,
data.get_raw(),
new_output_min,
new_output_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
/// Calculate scaling parameters for future use based on the given training data.
pub fn set_scaling_params(
&mut self,
data: &TrainData,
new_input_min: c_float,
new_input_max: c_float,
new_output_min: c_float,
new_output_max: c_float,
) -> FannResult<()> {
unsafe {
let result = fann_set_scaling_params(
self.raw,
data.get_raw(),
new_input_min,
new_input_max,
new_output_min,
new_output_max,
);
FannError::check_zero(
result,
self.raw as *mut fann_error,
"Error calculating scaling parameters",
)
}
}
/// Clear scaling parameters.
pub fn clear_scaling_params(&mut self) -> FannResult<()> {
unsafe {
FannError::check_zero(
fann_clear_scaling_params(self.raw),
self.raw as *mut fann_error,
"Error clearing scaling parameters",
)
}
}
/// Scale data in input vector before feeding it to the network, based on previously calculated
/// parameters.
pub fn scale_input(&self, input: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_scale_input(self.raw, input.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Scale data in output vector before feeding it to the network, based on previously calculated
/// parameters.
pub fn scale_output(&self, output: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_scale_output(self.raw, output.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Descale data in input vector after feeding it to the network, based on previously calculated
/// parameters.
pub fn descale_input(&self, input: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_descale_input(self.raw, input.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Descale data in output vector after getting it from the network, based on previously
/// calculated parameters.
pub fn descale_output(&self, output: &mut [FannType]) -> FannResult<()> {
unsafe {
fann_descale_output(self.raw, output.as_mut_ptr());
FannError::check_no_error(self.raw as *mut fann_error)
}
}
// TODO: set_error_log: Always disable, due to different error handling?
// TODO: save_to_fixed?
// TODO: user_data methods?
}
impl Drop for Fann {
fn drop(&mut self) {
unsafe {
fann_destroy(self.raw);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use fann_sys;
use libc::c_uint;
use std::cell::RefCell;
use std::ptr::null_mut;
const EPSILON: FannType = 0.2;
#[test]
fn test_tutorial() {
let max_epochs = 500_000;
let desired_error = 0.0001;
let mut fann = Fann::new(&[2, 3, 1]).unwrap();
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
fann.on_file("test_files/xor.data")
.train(max_epochs, desired_error)
.unwrap();
assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs());
}
#[test]
fn test_activation_func() {
let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap();
// Don't print the expected errors:
unsafe {
fann_sys::fann_set_error_log(fann.raw as *mut fann_sys::fann_error, null_mut());
}
assert!(fann.get_activation_func(0, 1).is_err());
assert!(fann.get_activation_func(4, 1).is_err());
assert_eq!(
Ok(ActivationFunc::SigmoidStepwise),
fann.get_activation_func(2, 2)
);
fann.set_activation_func(ActivationFunc::Sin, 2, 2);
assert_eq!(Ok(ActivationFunc::Sin), fann.get_activation_func(2, 2));
}
#[test]
fn test_train_algorithm() {
let mut fann = Fann::new(&[4, 3, 3, 1]).unwrap();
assert_eq!(TrainAlgorithm::default(), fann.get_train_algorithm());
let quickprop = TrainAlgorithm::Quickprop(QuickpropParams {
decay: -0.0002,
..Default::default()
});
fann.set_train_algorithm(quickprop);
assert_eq!(quickprop, fann.get_train_algorithm());
}
#[test]
fn test_layer_sizes() {
let fann = Fann::new(&[4, 3, 3, 1]).unwrap();
assert_eq!(vec![4, 3, 3, 1], fann.get_layer_sizes());
assert_eq!(vec![1, 1, 1, 0], fann.get_bias_counts());
}
#[test]
fn test_get_set_connections() {
let mut fann = Fann::new(&[1, 1]).unwrap();
let connection = Connection {
from_neuron: 1,
to_neuron: 2,
weight: 0.123,
};
fann.set_connections(&[connection]);
assert_eq!(2, fann.get_total_connections()); // 2 because of the bias neuron in layer 0.
assert_eq!(connection, fann.get_connections()[1]);
}
#[test]
fn test_cascade_params() {
let fann = Fann::new(&[1, 1]).unwrap();
assert_eq!(CascadeParams::default(), fann.get_cascade_params());
}
#[test]
fn test_train_data_from_callback() {
let mut fann = Fann::new(&[2, 3, 1]).unwrap();
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
let td = TrainData::from_callback(
4,
2,
1,
Box::new(|num| match num {
0 => (vec![-1.0, 1.0], vec![1.0]),
1 => (vec![1.0, -1.0], vec![1.0]),
2 => (vec![-1.0, -1.0], vec![-1.0]),
3 => (vec![1.0, 1.0], vec![-1.0]),
_ => unreachable!(),
}),
).unwrap();
fann.on_data(&td).train(500_000, 0.0001).unwrap();
assert!(EPSILON > (1.0 - fann.run(&[-1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (1.0 - fann.run(&[1.0, -1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[1.0, 1.0]).unwrap()[0]).abs());
assert!(EPSILON > (-1.0 - fann.run(&[-1.0, -1.0]).unwrap()[0]).abs());
}
#[test]
fn test_train_callback() {
// Without a hidden layer, the XOR problem cannot be solved, so the training will only stop
// when the callback says so.
let mut fann = Fann::new(&[2, 1]).unwrap();
fann.set_activation_func_output(ActivationFunc::LinearPiece);
let xor_data = TrainData::from_file("test_files/xor.data").unwrap();
let raw = fann.raw;
let callback_epochs = RefCell::new(Vec::new());
let cb = |fann: &Fann, train_data: &TrainData, epochs: c_uint| {
assert_eq!(raw, fann.raw);
unsafe {
assert_eq!(xor_data.get_raw(), train_data.get_raw());
}
callback_epochs.borrow_mut().push(epochs);
CallbackResult::stop_if(epochs == 40) // Stop after 40 epochs.
};
fann.on_data(&xor_data)
.with_callback(10, &cb)
.train(100, 0.1)
.unwrap();
// The interval was 10 epochs. Also, FANN always runs the callback after the first epoch.
assert_eq!(vec![1, 10, 20, 30, 40], *callback_epochs.borrow());
}
}