use curl::easy::Easy;
use flate2::read::GzDecoder;
use linfa_preprocessing::tf_idf_vectorization::TfIdfVectorizer;
use linfa_preprocessing::CountVectorizer;
use std::path::Path;
use tar::Archive;
fn download_20news_bydate() {
let mut data = Vec::new();
let mut easy = Easy::new();
easy.url("http://qwone.com/~jason/20Newsgroups/20news-bydate.tar.gz")
.unwrap();
{
let mut transfer = easy.transfer();
transfer
.write_function(|new_data| {
data.extend_from_slice(new_data);
Ok(new_data.len())
})
.unwrap();
transfer.perform().unwrap();
}
let tar = GzDecoder::new(data.as_slice());
let mut archive = Archive::new(tar);
archive.unpack("./20news/").unwrap();
}
fn load_20news_bydate() -> Vec<std::path::PathBuf> {
let desired_targets = [
"alt.atheism",
"talk.religion.misc",
"comp.graphics",
"sci.space",
];
let file_paths = load_test_set(&desired_targets);
if let Ok(paths) = file_paths {
paths
} else {
download_20news_bydate();
load_test_set(&desired_targets).unwrap()
}
}
fn load_set(
path: &'static str,
desired_targets: &[&str],
) -> Result<Vec<std::path::PathBuf>, std::io::Error> {
let mut file_paths = Vec::new();
let desired_targets: std::collections::HashSet<String> =
desired_targets.iter().map(|s| s.to_string()).collect();
let path = Path::new(path);
let dir_content = std::fs::read_dir(path)?;
for sub_dir in dir_content {
let sub_dir = sub_dir?;
if sub_dir.file_type().unwrap().is_dir() {
let dir_name = sub_dir.file_name().into_string().unwrap();
if !desired_targets.is_empty() && !desired_targets.contains(&dir_name) {
continue;
}
for sub_file in std::fs::read_dir(sub_dir.path())? {
let sub_file = sub_file?;
if sub_file.file_type().unwrap().is_file() {
file_paths.push(sub_file.path());
}
}
}
}
Ok(file_paths)
}
fn _load_train_set(desired_targets: &[&str]) -> Result<Vec<std::path::PathBuf>, std::io::Error> {
load_set("./20news/20news-bydate-train", desired_targets)
}
fn load_test_set(desired_targets: &[&str]) -> Result<Vec<std::path::PathBuf>, std::io::Error> {
load_set("./20news/20news-bydate-test", desired_targets)
}
fn iai_fit_vectorizer() {
let file_names = load_20news_bydate();
CountVectorizer::params()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap();
}
fn iai_fit_tf_idf() {
let file_names = load_20news_bydate();
TfIdfVectorizer::default()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap();
}
fn iai_fit_transform_vectorizer() {
let file_names = load_20news_bydate();
CountVectorizer::params()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap()
.transform_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
);
}
fn iai_fit_transform_tf_idf() {
let file_names = load_20news_bydate();
TfIdfVectorizer::default()
.document_frequency(0.05, 0.75)
.n_gram_range(1, 2)
.fit_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
)
.unwrap()
.transform_files(
&file_names,
encoding::all::ISO_8859_1,
encoding::DecoderTrap::Strict,
);
}
macro_rules! iai_main {
( $( $func_name:ident ),+ $(,)* ) => {
mod iai_wrappers {
$(
pub fn $func_name() {
let _ = iai::black_box(super::$func_name());
}
)+
}
fn main() {
load_20news_bydate();
let benchmarks : &[&(&'static str, fn())]= &[
$(
&(stringify!($func_name), iai_wrappers::$func_name),
)+
];
iai::runner(benchmarks);
std::fs::remove_dir_all("./20news").unwrap_or(());
}
}
}
iai_main!(
iai_fit_vectorizer,
iai_fit_transform_vectorizer,
iai_fit_tf_idf,
iai_fit_transform_tf_idf
);