use parse_code;
use std::fs;
use std::io::{self, Read};
use std::path;
use zstd_safe;
pub fn from_continuous(
sample_data: &[u8],
sample_sizes: &[usize],
max_size: usize,
) -> io::Result<Vec<u8>> {
if sample_sizes.iter().sum::<usize>() != sample_data.len() {
return Err(io::Error::new(
io::ErrorKind::Other,
"sample sizes don't add up".to_string(),
));
}
let mut result = Vec::with_capacity(max_size);
unsafe {
result.set_len(max_size);
let written = parse_code(zstd_safe::train_from_buffer(
&mut result,
sample_data,
sample_sizes,
))?;
result.set_len(written);
}
Ok(result)
}
pub fn from_samples<S: AsRef<[u8]>>(
samples: &[S],
max_size: usize,
) -> io::Result<Vec<u8>> {
let data: Vec<_> =
samples.iter().flat_map(|s| s.as_ref()).cloned().collect();
let sizes: Vec<_> = samples.iter().map(|s| s.as_ref().len()).collect();
from_continuous(&data, &sizes, max_size)
}
pub fn from_files<I, P>(filenames: I, max_size: usize) -> io::Result<Vec<u8>>
where
P: AsRef<path::Path>,
I: IntoIterator<Item = P>,
{
let mut buffer = Vec::new();
let mut sizes = Vec::new();
for filename in filenames {
let mut file = fs::File::open(filename)?;
let len = file.read_to_end(&mut buffer)?;
sizes.push(len);
}
from_continuous(&buffer, &sizes, max_size)
}
#[cfg(test)]
mod tests {
use std::fs;
use std::io;
use std::io::Read;
#[test]
fn test_dict_training() {
let paths: Vec<_> = fs::read_dir("src")
.unwrap()
.map(|entry| entry.unwrap())
.map(|entry| entry.path())
.filter(|path| path.to_str().unwrap().ends_with(".rs"))
.collect();
let dict = super::from_files(&paths, 4000).unwrap();
for path in paths {
let mut buffer = Vec::new();
let mut file = fs::File::open(path).unwrap();
let mut content = Vec::new();
file.read_to_end(&mut content).unwrap();
io::copy(
&mut &content[..],
&mut ::stream::Encoder::with_dictionary(&mut buffer, 1, &dict)
.unwrap()
.auto_finish(),
).unwrap();
let mut result = Vec::new();
io::copy(
&mut ::stream::Decoder::with_dictionary(
&buffer[..],
&dict[..],
).unwrap(),
&mut result,
).unwrap();
assert_eq!(&content, &result);
}
}
}