use super::label::Label;
use crate::observation::{new_observation_vector, Observation, ObservationType};
use crate::Data;
use glob::glob;
use lazy_static::lazy_static;
use regex::Regex;
use std::collections::hash_map::HashMap;
use std::fs;
use std::io::prelude::*;
use std::path::PathBuf;
pub(super) fn get_observation(
path: PathBuf,
n_step: usize,
observation_type: &ObservationType,
) -> Result<Vec<Box<dyn Observation>>, String> {
let data = Data::from_file(path.clone())?;
if let Ok(current_obs) = new_observation_vector(data.data, observation_type, n_step) {
Ok(current_obs)
} else {
return Err(String::from(
"the observation cannot be converted into a int",
));
}
}
pub(super) fn get_label(path: PathBuf) -> Result<Label, String> {
match fs::File::open(path) {
Err(e) => return Err(format!("{}", e)),
Ok(mut file) => {
let mut json_string = String::new();
if let Err(e) = file.read_to_string(&mut json_string) {
return Err(format!("{}", e));
}
let json: Result<Label, serde_json::error::Error> =
serde_json::from_str(json_string.as_str());
match json {
Err(e) => Err(format!("{}", e)),
Ok(res) => Ok(res),
}
}
}
}
pub(super) fn get_all_data_path(path: PathBuf) -> Result<Vec<PathBuf>, String> {
get_all_path(path, String::from("**/dataset*.json"))
}
pub(super) fn get_all_label_data_path(path: PathBuf) -> Result<Vec<PathBuf>, String> {
get_all_path(path, String::from("**/label-*.json"))
}
pub(super) fn associate_training_path(
observation_paths: Vec<PathBuf>,
label_paths: Vec<PathBuf>,
) -> HashMap<String, (PathBuf, PathBuf)> {
let mut map_associate: HashMap<String, (PathBuf, PathBuf)> = HashMap::new();
lazy_static! {
static ref RE: Regex = Regex::new("([0-9]*[.])?[0-9]+").unwrap();
}
for path in observation_paths {
let id = RE
.captures(path.to_str().unwrap())
.unwrap()
.get(0)
.unwrap()
.as_str();
if let Some(associated_label) = label_paths.iter().find(|x| {
let label_id = RE
.captures(x.to_str().unwrap())
.unwrap()
.get(0)
.unwrap()
.as_str();
label_id == id
}) {
map_associate.insert(String::from(id), (path.clone(), associated_label.clone()));
}
}
map_associate.clone()
}
pub(super) fn get_all_path(path: PathBuf, filter: String) -> Result<Vec<PathBuf>, String> {
let mut observation_paths: Vec<PathBuf> = Vec::new();
if let Some(root_path) = path.to_str() {
let expr = String::from(format!("{}{}", root_path, filter));
let iter = match glob(&expr) {
Ok(i) => i,
Err(e) => return Err(format!("{}", e)),
};
for entry in iter {
let json_path = match entry {
Ok(path) => path,
Err(e) => return Err(format!("{}", e)),
};
observation_paths.push(json_path);
}
}
Ok(observation_paths.clone())
}
pub(super) fn expend_states(label: &Label, n_step: usize) -> Vec<usize> {
let mut states: Vec<usize> = Vec::new();
for s in label.data.iter() {
let iter = (s.from..(s.to + n_step)).step_by(n_step);
for _i in iter {
states.push(s.state);
}
}
states.clone()
}