1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
extern crate fann_sys;
use super::{to_filename, Fann};
use error::{FannError, FannErrorType, FannResult};
use fann_sys::*;
use libc::c_uint;
use std::cell::RefCell;
use std::path::Path;
use std::ptr::copy_nonoverlapping;
pub type TrainCallback = dyn Fn(c_uint) -> (Vec<fann_type>, Vec<fann_type>);
// Thread-local container for user-supplied callback functions.
// This is necessary because the raw fann_create_train_from_callback C function takes a function
// pointer and not a closure. So instead of the user-supplied function we pass a function to it
// which will call the content of CALLBACK.
thread_local!(static CALLBACK: RefCell<Option<Box<TrainCallback>>> = RefCell::new(None));
pub struct TrainData {
raw: *mut fann_train_data,
}
impl TrainData {
/// Read a file that stores training data.
///
/// The file must be formatted like:
///
/// ```text
/// num_train_data num_input num_output
/// inputdata separated by space
/// outputdata separated by space
/// .
/// .
/// .
/// inputdata separated by space
/// outputdata separated by space
/// ```
pub fn from_file<P: AsRef<Path>>(path: P) -> FannResult<TrainData> {
let filename = to_filename(path)?;
unsafe {
let raw = fann_read_train_from_file(filename.as_ptr());
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
/// Create training data using the given callback which for each number between `0` (included)
/// and `num_data` (excluded) returns a pair of input and output vectors with `num_input` and
/// `num_output` entries respectively.
pub fn from_callback(
num_data: c_uint,
num_input: c_uint,
num_output: c_uint,
cb: Box<TrainCallback>,
) -> FannResult<TrainData> {
extern "C" fn raw_callback(
num: c_uint,
num_input: c_uint,
num_output: c_uint,
input: *mut fann_type,
output: *mut fann_type,
) {
// Call the callback we stored in the thread-local container.
let (in_vec, out_vec) = CALLBACK.with(|cell| cell.borrow().as_ref().unwrap()(num));
// Make sure it returned data of the correct size, then copy the data.
assert_eq!(in_vec.len(), num_input as usize);
assert_eq!(out_vec.len(), num_output as usize);
unsafe {
copy_nonoverlapping(in_vec.as_ptr(), input, in_vec.len());
copy_nonoverlapping(out_vec.as_ptr(), output, out_vec.len());
}
}
unsafe {
// Put the callback into the thread-local container.
CALLBACK.with(|cell| *cell.borrow_mut() = Some(cb));
let raw = fann_create_train_from_callback(
num_data,
num_input,
num_output,
Some(raw_callback),
);
// Remove it from the thread-local container to free the memory.
CALLBACK.with(|cell| *cell.borrow_mut() = None);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
/// Save the training data to a file.
pub fn save<P: AsRef<Path>>(&self, path: P) -> FannResult<()> {
let filename = to_filename(path)?;
unsafe {
let result = fann_save_train(self.raw, filename.as_ptr());
FannError::check_no_error(self.raw as *mut fann_error)?;
if result == -1 {
Err(FannError {
error_type: FannErrorType::CantSaveFile,
error_str: "Error saving training data".to_owned(),
})
} else {
Ok(())
}
}
}
/// Merge the given data sets into a new one.
pub fn merge(data1: &TrainData, data2: &TrainData) -> FannResult<TrainData> {
unsafe {
let raw = fann_merge_train_data(data1.raw, data2.raw);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
/// Create a subset of the training data, starting at the given positon and consisting of
/// `length` samples.
pub fn subset(&self, pos: c_uint, length: c_uint) -> FannResult<TrainData> {
unsafe {
let raw = fann_subset_train_data(self.raw, pos, length);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
/// Return the number of training patterns in the data.
pub fn length(&self) -> c_uint {
unsafe { fann_length_train_data(self.raw) }
}
/// Return the number of input values in each training pattern.
pub fn num_input(&self) -> c_uint {
unsafe { fann_num_input_train_data(self.raw) }
}
/// Return the number of output values in each training pattern.
pub fn num_output(&self) -> c_uint {
unsafe { fann_num_output_train_data(self.raw) }
}
/// Scale input and output in the training data using the parameters previously calculated for
/// the given network.
pub fn scale_for(&mut self, fann: &Fann) -> FannResult<()> {
unsafe {
fann_scale_train(fann.raw, self.raw);
FannError::check_no_error(fann.raw as *mut fann_error)?;
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Descale input and output in the training data using the parameters previously calculated for
/// the given network.
pub fn descale_for(&mut self, fann: &Fann) -> FannResult<()> {
unsafe {
fann_descale_train(fann.raw, self.raw);
FannError::check_no_error(fann.raw as *mut fann_error)?;
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Scales the inputs in the training data to the specified range.
pub fn scale_input(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_input_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Scales the outputs in the training data to the specified range.
pub fn scale_output(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_output_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Scales the inputs and outputs in the training data to the specified range.
pub fn scale(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
/// Shuffle training data, randomizing the order. This is recommended for incremental training
/// while it does not affect batch training.
pub fn shuffle(&mut self) {
unsafe {
fann_shuffle_train_data(self.raw);
}
}
/// Get a pointer to the underlying raw `fann_train_data` structure.
pub unsafe fn get_raw(&self) -> *mut fann_train_data {
self.raw
}
// TODO: save_to_fixed?
}
impl Clone for TrainData {
fn clone(&self) -> TrainData {
unsafe {
let raw = fann_duplicate_train_data(self.raw);
if FannError::check_no_error(raw as *mut fann_error).is_err() {
panic!("Unable to clone TrainData.");
}
TrainData { raw }
}
}
}
impl Drop for TrainData {
fn drop(&mut self) {
unsafe {
fann_destroy_train(self.raw);
}
}
}