#[cfg(all(feature = "mesalock_sgx", not(target_env = "sgx")))]
use std::prelude::v1::*;
use crate::decision_tree::{Data, DataVec, ValueType, VALUE_TYPE_UNKNOWN};
cfg_if! {
if #[cfg(all(feature = "mesalock_sgx", not(target_env = "sgx")))] {
use std::collections::HashMap;
use std::error::Error;
use std::untrusted::fs::File;
use std::io::{BufRead, BufReader, Seek, SeekFrom};
} else {
use std::collections::HashMap;
use std::error::Error;
#[cfg(not(feature = "mesalock_sgx"))]
use std::fs::File;
#[cfg(feature = "mesalock_sgx")]
use std::untrusted::fs::File;
use std::io::{BufRead, BufReader, Seek, SeekFrom};
}
}
use regex::Regex;
use serde_derive::{Deserialize, Serialize};
#[derive(Copy, Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum FileFormat {
CSV,
TXT,
}
#[derive(Copy, Debug, Clone, Serialize, Deserialize)]
pub struct InputFormat {
pub ftype: FileFormat,
pub header: bool,
pub label_idx: usize,
pub enable_unknown_value: bool,
pub delimeter: char,
pub feature_size: usize,
}
impl InputFormat {
pub fn csv_format() -> InputFormat {
InputFormat {
ftype: FileFormat::CSV,
header: false,
label_idx: 0,
enable_unknown_value: false,
delimeter: ',',
feature_size: 0,
}
}
pub fn txt_format() -> InputFormat {
InputFormat {
ftype: FileFormat::TXT,
header: false,
label_idx: 0,
enable_unknown_value: false,
delimeter: '\t',
feature_size: 0,
}
}
pub fn to_string(&self) -> String {
let mut s = String::from("");
s.push_str(&format!(
"File type: {}\n",
match self.ftype {
FileFormat::CSV => "CSV",
FileFormat::TXT => "TXT",
}
));
match self.ftype {
FileFormat::CSV => {
s.push_str(&format!("Has header: {}\n", self.header));
s.push_str(&format!("Label index: {}\n", self.label_idx));
}
FileFormat::TXT => {
s.push_str(&format!("Feature size: {}\n", self.feature_size));
}
}
s.push_str(&format!("Delemeter: [{}]", self.delimeter));
s
}
pub fn set_feature_size(&mut self, size: usize) {
self.feature_size = size;
}
pub fn set_label_index(&mut self, idx: usize) {
self.label_idx = idx;
}
pub fn set_delimeter(&mut self, delim: char) {
self.delimeter = delim;
}
}
fn count(mut hash_map: HashMap<char, u32>, word: char) -> HashMap<char, u32> {
{
let c = hash_map.entry(word).or_insert(0);
*c += 1;
}
hash_map
}
pub fn infer(file_name: &str) -> InputFormat {
let file = File::open(file_name.to_string()).unwrap();
let mut reader = BufReader::new(file);
let mut first_line = String::new();
reader.read_line(&mut first_line).unwrap();
let mut input_format = if first_line.contains(':') {
InputFormat::txt_format()
} else {
InputFormat::csv_format()
};
let reg = match input_format.ftype {
FileFormat::CSV => Regex::new(r"[+-]?\d+(,\d+)*(.\d+(e\d+)?)?").unwrap(),
FileFormat::TXT => Regex::new(r"\d+:[+-]?\d+(,\d+)*(.\d+(e\d+)?)?").unwrap(),
};
let mut second_line = String::new();
reader
.read_line(&mut second_line)
.expect("No second line to read");
let caps = reg.captures(&second_line).unwrap().len();
let second_line_after = reg.replace_all(&second_line, "");
let cnt = second_line_after.chars().fold(HashMap::new(), count);
let default_delim: char = match input_format.ftype {
FileFormat::CSV => ',',
FileFormat::TXT => '\t',
};
let mut flag = false;
if let Some(value) = cnt.get(&default_delim) {
if *value > ((caps as u32) - 2) {
input_format.delimeter = default_delim;
flag = true;
}
}
if !flag {
let mut max_cnt: u32 = 0;
let mut delim = '\t';
for (k, v) in &cnt {
if *v > max_cnt {
max_cnt = *v;
delim = *k;
}
}
input_format.delimeter = delim;
flag = true;
}
assert_eq!(flag, true);
if let FileFormat::CSV = input_format.ftype {
let first_line_after = reg.replace_all(&first_line, "");
let letters = Regex::new(r"[a-zA-Z]").unwrap();
if let Some(letter_caps) = letters.captures(&first_line_after) {
input_format.header = letter_caps.len() > 0;
}
}
input_format
}
pub fn load_csv(file: &mut File, input_format: InputFormat) -> Result<DataVec, Box<Error>> {
file.seek(SeekFrom::Start(0))?;
let mut dv = Vec::new();
let mut reader = BufReader::new(file);
let mut l = String::new();
if input_format.header {
reader.read_line(&mut l).unwrap_or(0);
}
let mut v: Vec<ValueType>;
for line in reader.lines() {
let content = line?;
if input_format.enable_unknown_value {
v = content
.split(input_format.delimeter)
.map(|x| x.parse::<ValueType>().unwrap_or(VALUE_TYPE_UNKNOWN))
.collect();
} else {
v = content
.split(input_format.delimeter)
.map(|x| x.parse::<ValueType>().unwrap())
.collect();
}
dv.push(Data {
label: v.swap_remove(input_format.label_idx),
feature: v,
target: 0.0,
weight: 1.0,
residual: 0.0,
initial_guess: 0.0,
})
}
Ok(dv)
}
pub fn load_txt(file: &mut File, input_format: InputFormat) -> Result<DataVec, Box<Error>> {
file.seek(SeekFrom::Start(0))?;
let mut dv = Vec::new();
let reader = BufReader::new(file);
let mut label: ValueType = 0.0;
let mut idx: usize = 0;
let mut val: ValueType = 0.0;
for line in reader.lines() {
let mut v: Vec<ValueType> = vec![VALUE_TYPE_UNKNOWN; input_format.feature_size];
for token in line.unwrap().split(input_format.delimeter) {
let splited_token: Vec<&str> = token.split(':').collect();
if splited_token.len() == 2 {
let mut err = false;
match splited_token[0 as usize].parse::<usize>() {
Ok(kk) => {
idx = kk;
}
Err(_) => err = true,
}
match splited_token[1 as usize].parse::<ValueType>() {
Ok(vv) => {
val = vv;
}
Err(_) => err = true,
}
if idx >= input_format.feature_size {
err = true;
}
if !err {
v[idx] = val;
}
}
if splited_token.len() == 1 {
label = splited_token[0 as usize].parse::<ValueType>().unwrap();
} else {
}
}
dv.push(Data {
label,
feature: v,
target: 0.0,
weight: 1.0,
residual: 0.0,
initial_guess: 0.0,
});
}
Ok(dv)
}
pub fn load(file_name: &str, input_format: InputFormat) -> Result<DataVec, Box<Error>> {
let mut file = File::open(file_name.to_string())?;
match input_format.ftype {
FileFormat::CSV => load_csv(&mut file, input_format),
FileFormat::TXT => load_txt(&mut file, input_format),
}
}