use std::time::Instant;
use clap::{crate_authors, crate_description, crate_name, crate_version, Arg, Command};
use cleora::configuration;
use cleora::configuration::Configuration;
use cleora::configuration::OutputFormat;
use cleora::persistence::entity::InMemoryEntityMappingPersistor;
use cleora::pipeline::{build_graphs, train};
use env_logger::Env;
use std::fs;
use std::sync::Arc;
#[macro_use]
extern crate log;
fn main() {
let env = Env::default()
.filter_or("MY_LOG_LEVEL", "info")
.write_style_or("MY_LOG_STYLE", "always");
env_logger::init_from_env(env);
let now = Instant::now();
let matches = Command::new(crate_name!())
.version(crate_version!())
.author(crate_authors!())
.about(crate_description!())
.arg(
Arg::new("inputs")
.multiple_values(true)
.help("Input files paths")
.takes_value(true),
)
.arg(
Arg::new("input")
.short('i')
.long("input")
.help("Deprecated. Use positional args for input files")
.takes_value(true),
)
.arg(
Arg::new("file-type")
.short('t')
.long("type")
.possible_values(&["tsv", "json"])
.help("Input file type")
.takes_value(true),
)
.arg(
Arg::new("output-dir")
.short('o')
.long("output-dir")
.help("Output directory for files with embeddings")
.takes_value(true),
)
.arg(
Arg::new("dimension")
.short('d')
.long("dimension")
.required(true)
.help("Embedding dimension size")
.takes_value(true),
)
.arg(
Arg::new("number-of-iterations")
.short('n')
.long("number-of-iterations")
.required(true)
.help("Max number of iterations")
.takes_value(true),
)
.arg(
Arg::new("seed")
.short('s')
.long("seed")
.help("Seed (integer) for embedding initialization")
.takes_value(true),
)
.arg(
Arg::new("columns")
.short('c')
.long("columns")
.required(true)
.help(
"Column names (max 12), with modifiers: [transient::, reflexive::, complex::]",
)
.takes_value(true),
)
.arg(
Arg::new("relation-name")
.short('r')
.long("relation-name")
.default_value("emb")
.help("Name of the relation, for output filename generation")
.takes_value(true),
)
.arg(
Arg::new("prepend-field-name")
.short('p')
.long("prepend-field-name")
.possible_values(&["0", "1"])
.default_value("0")
.help("Prepend field name to entity in output")
.takes_value(true),
)
.arg(
Arg::new("log-every-n")
.short('l')
.long("log-every-n")
.default_value("10000")
.help("Log output every N lines")
.takes_value(true),
)
.arg(
Arg::new("in-memory-embedding-calculation")
.short('e')
.long("in-memory-embedding-calculation")
.possible_values(&["0", "1"])
.default_value("1")
.help("Calculate embeddings in memory or with memory-mapped files")
.takes_value(true),
)
.arg(
Arg::new("output-format")
.short('f')
.help("Output format. One of: textfile|numpy")
.possible_values(&["textfile", "numpy"])
.default_value("textfile")
.takes_value(true),
)
.get_matches();
info!("Reading args...");
let input: Vec<String> = {
let named_arg = matches.value_of("input");
let position_args = match matches.values_of("inputs") {
None => vec![],
Some(values) => values.into_iter().collect(),
};
position_args
.into_iter()
.chain(named_arg.into_iter())
.map(|s| s.to_string())
.collect()
};
if input.is_empty() {
panic!("Missing input files")
}
let file_type = match matches.value_of("file-type") {
Some(type_name) => match type_name {
"tsv" => configuration::FileType::Tsv,
"json" => configuration::FileType::Json,
_ => panic!("Invalid file type {}", type_name),
},
None => configuration::FileType::Tsv,
};
let output_dir = matches.value_of("output-dir").map(|s| s.to_string());
if let Some(output_dir) = output_dir.as_ref() {
fs::create_dir_all(output_dir).expect("Can't create output directory");
}
let dimension: u16 = matches.value_of("dimension").unwrap().parse().unwrap();
let max_iter: u8 = matches
.value_of("number-of-iterations")
.unwrap()
.parse()
.unwrap();
let seed: Option<i64> = matches.value_of("seed").map(|s| s.parse().unwrap());
let relation_name = matches.value_of("relation-name").unwrap();
let prepend_field_name = {
let value: u8 = matches
.value_of("prepend-field-name")
.unwrap()
.parse()
.unwrap();
value == 1
};
let log_every: u32 = matches.value_of("log-every-n").unwrap().parse().unwrap();
let in_memory_embedding_calculation = {
let value: u8 = matches
.value_of("in-memory-embedding-calculation")
.unwrap()
.parse()
.unwrap();
value == 1
};
let columns = {
let cols_str = matches.value_of("columns").unwrap();
let cols_str_separated: Vec<&str> = cols_str.split(' ').collect();
match configuration::extract_fields(cols_str_separated) {
Ok(cols) => match configuration::validate_fields(cols) {
Ok(validated_cols) => validated_cols,
Err(msg) => panic!("Invalid column fields. Message: {}", msg),
},
Err(msg) => panic!("Parsing problem. Message: {}", msg),
}
};
let output_format = match matches.value_of("output-format").unwrap() {
"textfile" => OutputFormat::TextFile,
"numpy" => OutputFormat::Numpy,
_ => panic!("unsupported output format"),
};
let config = Configuration {
produce_entity_occurrence_count: true,
embeddings_dimension: dimension,
max_number_of_iteration: max_iter,
seed,
prepend_field: prepend_field_name,
log_every_n: log_every,
in_memory_embedding_calculation,
input,
file_type,
output_dir,
output_format,
relation_name: relation_name.to_string(),
columns,
};
dbg!(&config);
info!("Starting calculation...");
let in_memory_entity_mapping_persistor = InMemoryEntityMappingPersistor::default();
let in_memory_entity_mapping_persistor = Arc::new(in_memory_entity_mapping_persistor);
let sparse_matrices = build_graphs(&config, in_memory_entity_mapping_persistor.clone());
info!(
"Finished Sparse Matrices calculation in {} sec",
now.elapsed().as_secs()
);
train(config, in_memory_entity_mapping_persistor, sparse_matrices);
info!("Finished in {} sec", now.elapsed().as_secs());
}