use std::{path::PathBuf, sync::Mutex};
use flate2::read::GzDecoder;
use serde::{Deserialize, Serialize};
use tar::Archive;
use crate::InMemDataset;
use crate::network::downloader;
const AG_NEWS_URL: &str = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz";
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct AgNewsItem {
pub label: String,
pub title: String,
pub content: String,
}
pub struct AgNewsDataset {
agnews_dir: PathBuf,
}
static DOWNLOAD_LOCK: Mutex<()> = Mutex::new(());
impl AgNewsDataset {
pub fn new() -> Self {
Self {
agnews_dir: Self::download(),
}
}
fn download() -> PathBuf {
let _lock = DOWNLOAD_LOCK.lock().unwrap();
let cache_dir = dirs::cache_dir()
.expect("Could not get cache directory")
.join("burn-dataset");
let agnews_dir = cache_dir.join("ag_news_csv");
let url = AG_NEWS_URL;
let filename = "ag_news_csv.tgz";
if !agnews_dir.exists() {
let bytes = downloader::download_file_as_bytes(url, filename);
let gz_buffer = GzDecoder::new(&bytes[..]);
let mut archive = Archive::new(gz_buffer);
archive.unpack(cache_dir).unwrap();
}
agnews_dir
}
fn parse_csv(file_path: &str) -> InMemDataset<AgNewsItem> {
let mut rdr = csv::ReaderBuilder::new();
let rdr = rdr.has_headers(false);
InMemDataset::from_csv(file_path, &rdr).expect("Failed to parse CSV file")
}
pub fn train(&self) -> InMemDataset<AgNewsItem> {
let file_path = self.agnews_dir.join("train.csv");
Self::parse_csv(file_path.to_str().unwrap())
}
pub fn test(&self) -> InMemDataset<AgNewsItem> {
let file_path = self.agnews_dir.join("test.csv");
Self::parse_csv(file_path.to_str().unwrap())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Dataset;
const TRAIN_DATASET_LEN: usize = 120000;
const TEST_DATASET_LEN: usize = 7600;
#[test]
fn test_agnews_download() {
let agnews_dir = AgNewsDataset::download();
assert!(agnews_dir.exists());
}
#[test]
fn test_agnews_len() {
let agnews = AgNewsDataset::new();
let train_dataset = agnews.train();
let test_dataset = agnews.test();
assert_eq!(train_dataset.len(), TRAIN_DATASET_LEN);
assert_eq!(test_dataset.len(), TEST_DATASET_LEN);
}
#[test]
fn test_agnews_first_and_last_item() {
let agnews = AgNewsDataset::new();
let train_dataset = agnews.train();
let first_item = train_dataset.get(0).unwrap();
let last_item = train_dataset.get(train_dataset.len() - 1).unwrap();
assert!(compare_item(&first_item, &("3".to_string(), "Wall St. Bears Claw Back Into the Black (Reuters)".to_string(), "Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.".to_string())));
assert!(compare_item(
&last_item,
&(
"2".to_string(),
"Nets get Carter from Raptors".to_string(),
"INDIANAPOLIS -- All-Star Vince Carter was traded by the Toronto Raptors to the New Jersey Nets for Alonzo Mourning, Eric Williams, Aaron Williams, and a pair of first-round draft picks yesterday.".to_string()
)
));
let test_dataset = agnews.test();
let first_item = test_dataset.get(0).unwrap();
let last_item = test_dataset.get(test_dataset.len() - 1).unwrap();
assert!(compare_item(
&first_item,
&(
"3".to_string(),
"Fears for T N pension after talks".to_string(),
"Unions representing workers at Turner Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.".to_string()
)
));
assert!(compare_item(
&last_item,
&(
"3".to_string(),
"EBay gets into rentals".to_string(),
"EBay plans to buy the apartment and home rental service Rent.com for \\$415 million, adding to its already exhaustive breadth of offerings.".to_string()
)
));
}
fn compare_item(item: &AgNewsItem, target: &(String, String, String)) -> bool {
item.label == target.0 && item.title == target.1 && item.content == target.2
}
}