use crate::transform::{Mapper, MapperDataset};
use crate::{Dataset, InMemDataset};
use encoding_rs::{GB18030, GBK, UTF_8, UTF_16BE, UTF_16LE};
use globwalk::{self, DirEntry};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io::Read;
use std::path::{Path, PathBuf};
use thiserror::Error;
const SUPPORTED_FILES: [&str; 1] = ["txt"];
#[derive(Debug, Clone, PartialEq)]
pub struct TextData {
pub text: String,
pub text_path: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct TextDatasetItem {
pub text: TextData,
pub label: usize,
}
#[derive(Debug, Clone)]
struct TextDatasetItemRaw {
text_path: PathBuf,
label: String,
}
impl TextDatasetItemRaw {
fn new<P: AsRef<Path>>(text_path: P, label: String) -> TextDatasetItemRaw {
TextDatasetItemRaw {
text_path: text_path.as_ref().to_path_buf(),
label,
}
}
}
struct PathToTextDatasetItem {
classes: HashMap<String, usize>,
}
fn parse_text_content(text_path: &PathBuf) -> String {
let mut file = fs::File::open(text_path).unwrap();
let mut bytes = Vec::new();
file.read_to_end(&mut bytes).unwrap();
if bytes.starts_with(&[0xEF, 0xBB, 0xBF]) && bytes.len() >= 3 {
let (result, _, had_errors) = UTF_8.decode(&bytes[3..]);
if !had_errors {
return result.into_owned();
}
}
let (result, _, had_errors) = UTF_8.decode(&bytes);
if !had_errors {
return result.into_owned();
}
if bytes.starts_with(&[0xFF, 0xFE]) && bytes.len() >= 2 {
let (result, had_errors) = UTF_16LE.decode_with_bom_removal(&bytes[2..]);
if !had_errors {
return result.into_owned();
}
}
if bytes.starts_with(&[0xFE, 0xFF]) && bytes.len() >= 2 {
let (result, had_errors) = UTF_16BE.decode_with_bom_removal(&bytes[2..]);
if !had_errors {
return result.into_owned();
}
}
let (result, _, had_errors) = GB18030.decode(&bytes);
if !had_errors {
return result.into_owned();
}
let (result, _, had_errors) = GBK.decode(&bytes);
if !had_errors {
return result.into_owned();
}
String::from_utf8_lossy(&bytes).to_string()
}
impl Mapper<TextDatasetItemRaw, TextDatasetItem> for PathToTextDatasetItem {
fn map(&self, item: &TextDatasetItemRaw) -> TextDatasetItem {
let label = *self.classes.get(&item.label).unwrap();
let text_content = parse_text_content(&item.text_path);
let text_data = TextData {
text: text_content,
text_path: item.text_path.display().to_string(),
};
TextDatasetItem {
text: text_data,
label,
}
}
}
#[derive(Error, Debug)]
pub enum TextLoaderError {
#[error("unknown: `{0}`")]
Unknown(String),
#[error("I/O error: `{0}`")]
IOError(String),
#[error("Invalid file extension: `{0}`")]
InvalidFileExtensionError(String),
#[error("Encoding error: `{0}`")]
EncodingError(String),
}
type TextDatasetMapper =
MapperDataset<InMemDataset<TextDatasetItemRaw>, PathToTextDatasetItem, TextDatasetItemRaw>;
pub struct TextFolderDataset {
dataset: TextDatasetMapper,
}
impl Dataset<TextDatasetItem> for TextFolderDataset {
fn get(&self, index: usize) -> Option<TextDatasetItem> {
self.dataset.get(index)
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl TextFolderDataset {
pub fn new_classification<P: AsRef<Path>>(root: P) -> Result<Self, TextLoaderError> {
TextFolderDataset::new_classification_with(root, &SUPPORTED_FILES)
}
pub fn new_classification_with<P, S>(root: P, extensions: &[S]) -> Result<Self, TextLoaderError>
where
P: AsRef<Path>,
S: AsRef<str>,
{
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
&[format!(
"*.{{{}}}", extensions
.iter()
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
)
.follow_links(true)
.sort_by(|p1: &DirEntry, p2: &DirEntry| p1.path().cmp(p2.path())) .build()
.map_err(|err| TextLoaderError::Unknown(format!("{err:?}")))?
.filter_map(Result::ok);
let mut items = Vec::new();
let mut classes = HashSet::new();
for text in walker {
let text_path = text.path();
let label = text_path
.parent()
.ok_or_else(|| {
TextLoaderError::IOError("Could not resolve text parent folder".to_string())
})?
.file_name()
.ok_or_else(|| {
TextLoaderError::IOError(
"Could not resolve text parent folder name".to_string(),
)
})?
.to_string_lossy()
.into_owned();
classes.insert(label.clone());
items.push(TextDatasetItemRaw::new(text_path, label))
}
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
Self::with_items(items, &classes)
}
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, TextLoaderError> {
let items = items
.into_iter()
.map(|(path, label)| {
let path = path.as_ref();
let label = label;
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
Ok(TextDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;
Self::with_items(items, classes)
}
fn with_items<S: AsRef<str>>(
items: Vec<TextDatasetItemRaw>,
classes: &[S],
) -> Result<Self, TextLoaderError> {
let dataset = InMemDataset::new(items);
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();
let mapper = PathToTextDatasetItem {
classes: classes_map,
};
let dataset = MapperDataset::new(dataset, mapper);
Ok(Self { dataset })
}
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, TextLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(TextLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
const TEXT_ROOT: &str = "tests/data/text_folder";
#[test]
fn test_text_folder_dataset() {
let dataset = TextFolderDataset::new_classification(TEXT_ROOT).unwrap();
assert_eq!(dataset.len(), 4);
assert_eq!(dataset.get(4), None);
let mut found_positive = false;
let mut found_negative = false;
for i in 0..dataset.len() {
let item = dataset.get(i).unwrap();
if item.label == 0 {
found_negative = true;
assert!(!item.text.text.is_empty());
assert!(item.text.text_path.contains("negative"));
} else if item.label == 1 {
found_positive = true;
assert!(!item.text.text.is_empty());
assert!(item.text.text_path.contains("positive"));
}
}
assert!(found_positive);
assert!(found_negative);
}
#[test]
fn test_text_folder_dataset_with_invalid_extension() {
let result = TextFolderDataset::new_classification_with(TEXT_ROOT, &["invalid"]);
assert!(result.is_err());
}
#[test]
fn test_text_folder_dataset_with_items() {
let root = Path::new(TEXT_ROOT);
let items = vec![
(
root.join("positive").join("sample1.txt"),
"positive".to_string(),
),
(
root.join("negative").join("sample2.txt"),
"negative".to_string(),
),
];
let classes = vec!["positive", "negative"];
let dataset = TextFolderDataset::new_classification_with_items(items, &classes).unwrap();
assert_eq!(dataset.len(), 2);
assert_eq!(dataset.get(2), None);
let item0 = dataset.get(0).unwrap();
let item1 = dataset.get(1).unwrap();
assert!(compare_item(
&item0,
&(
"This is a positive text sample for testing the text folder dataset functionality."
.to_string(),
0
)
));
assert_eq!(item1.label, 1);
assert!(item1.text.text_path.contains("negative"));
assert!(compare_item(
&item1,
&(
"另一个负面文本样本,用以确保数据集能够处理同一类别中的多个文件。".to_string(),
1
)
));
}
fn compare_item(item: &TextDatasetItem, target: &(String, usize)) -> bool {
item.text.text == target.0 && item.label == target.1
}
}