extern crate fann_sys;
use super::{to_filename, Fann};
use error::{FannError, FannErrorType, FannResult};
use fann_sys::*;
use libc::c_uint;
use std::cell::RefCell;
use std::path::Path;
use std::ptr::copy_nonoverlapping;
pub type TrainCallback = dyn Fn(c_uint) -> (Vec<fann_type>, Vec<fann_type>);
thread_local!(static CALLBACK: RefCell<Option<Box<TrainCallback>>> = RefCell::new(None));
pub struct TrainData {
raw: *mut fann_train_data,
}
impl TrainData {
pub fn from_file<P: AsRef<Path>>(path: P) -> FannResult<TrainData> {
let filename = to_filename(path)?;
unsafe {
let raw = fann_read_train_from_file(filename.as_ptr());
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
pub fn from_callback(
num_data: c_uint,
num_input: c_uint,
num_output: c_uint,
cb: Box<TrainCallback>,
) -> FannResult<TrainData> {
extern "C" fn raw_callback(
num: c_uint,
num_input: c_uint,
num_output: c_uint,
input: *mut fann_type,
output: *mut fann_type,
) {
let (in_vec, out_vec) = CALLBACK.with(|cell| cell.borrow().as_ref().unwrap()(num));
assert_eq!(in_vec.len(), num_input as usize);
assert_eq!(out_vec.len(), num_output as usize);
unsafe {
copy_nonoverlapping(in_vec.as_ptr(), input, in_vec.len());
copy_nonoverlapping(out_vec.as_ptr(), output, out_vec.len());
}
}
unsafe {
CALLBACK.with(|cell| *cell.borrow_mut() = Some(cb));
let raw = fann_create_train_from_callback(
num_data,
num_input,
num_output,
Some(raw_callback),
);
CALLBACK.with(|cell| *cell.borrow_mut() = None);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> FannResult<()> {
let filename = to_filename(path)?;
unsafe {
let result = fann_save_train(self.raw, filename.as_ptr());
FannError::check_no_error(self.raw as *mut fann_error)?;
if result == -1 {
Err(FannError {
error_type: FannErrorType::CantSaveFile,
error_str: "Error saving training data".to_owned(),
})
} else {
Ok(())
}
}
}
pub fn merge(data1: &TrainData, data2: &TrainData) -> FannResult<TrainData> {
unsafe {
let raw = fann_merge_train_data(data1.raw, data2.raw);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
pub fn subset(&self, pos: c_uint, length: c_uint) -> FannResult<TrainData> {
unsafe {
let raw = fann_subset_train_data(self.raw, pos, length);
FannError::check_no_error(raw as *mut fann_error)?;
Ok(TrainData { raw })
}
}
pub fn length(&self) -> c_uint {
unsafe { fann_length_train_data(self.raw) }
}
pub fn num_input(&self) -> c_uint {
unsafe { fann_num_input_train_data(self.raw) }
}
pub fn num_output(&self) -> c_uint {
unsafe { fann_num_output_train_data(self.raw) }
}
pub fn scale_for(&mut self, fann: &Fann) -> FannResult<()> {
unsafe {
fann_scale_train(fann.raw, self.raw);
FannError::check_no_error(fann.raw as *mut fann_error)?;
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn descale_for(&mut self, fann: &Fann) -> FannResult<()> {
unsafe {
fann_descale_train(fann.raw, self.raw);
FannError::check_no_error(fann.raw as *mut fann_error)?;
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn scale_input(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_input_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn scale_output(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_output_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn scale(&mut self, new_min: fann_type, new_max: fann_type) -> FannResult<()> {
unsafe {
fann_scale_train_data(self.raw, new_min, new_max);
FannError::check_no_error(self.raw as *mut fann_error)
}
}
pub fn shuffle(&mut self) {
unsafe {
fann_shuffle_train_data(self.raw);
}
}
pub unsafe fn get_raw(&self) -> *mut fann_train_data {
self.raw
}
}
impl Clone for TrainData {
fn clone(&self) -> TrainData {
unsafe {
let raw = fann_duplicate_train_data(self.raw);
if FannError::check_no_error(raw as *mut fann_error).is_err() {
panic!("Unable to clone TrainData.");
}
TrainData { raw }
}
}
}
impl Drop for TrainData {
fn drop(&mut self) {
unsafe {
fann_destroy_train(self.raw);
}
}
}