use libc::{c_uint, c_void};
use parse_code;
use std::fs;
use std::io::{self, Read};
use std::path;
use zstd_sys;
pub fn from_continuous(sample_data: &[u8], sample_sizes: &[usize],
max_size: usize)
-> io::Result<Vec<u8>> {
if sample_sizes.iter().sum::<usize>() != sample_data.len() {
return Err(io::Error::new(io::ErrorKind::Other,
"sample sizes don't add up".to_string()));
}
let mut result = Vec::with_capacity(max_size);
unsafe {
let result_ptr = result.as_mut_ptr() as *mut c_void;
let sample_ptr = sample_data.as_ptr() as *const c_void;
let code = zstd_sys::ZDICT_trainFromBuffer(result_ptr,
result.capacity(),
sample_ptr,
sample_sizes.as_ptr(),
sample_sizes.len() as
c_uint);
let written = try!(parse_code(code));
result.set_len(written);
}
Ok(result)
}
pub fn from_samples<S: AsRef<[u8]>>(samples: &[S], max_size: usize)
-> io::Result<Vec<u8>> {
let data: Vec<_> = samples
.iter()
.flat_map(|s| s.as_ref())
.cloned()
.collect();
let sizes: Vec<_> = samples.iter().map(|s| s.as_ref().len()).collect();
from_continuous(&data, &sizes, max_size)
}
pub fn from_files<I, P>(filenames: I, max_size: usize) -> io::Result<Vec<u8>>
where P: AsRef<path::Path>,
I: IntoIterator<Item = P>
{
let mut buffer = Vec::new();
let mut sizes = Vec::new();
for filename in filenames {
let mut file = try!(fs::File::open(filename));
let len = try!(file.read_to_end(&mut buffer));
sizes.push(len);
}
from_continuous(&buffer, &sizes, max_size)
}
#[cfg(test)]
mod tests {
use std::fs;
use std::io;
use std::io::Read;
#[test]
fn test_dict_training() {
let paths: Vec<_> = fs::read_dir("src")
.unwrap()
.map(|entry| entry.unwrap())
.map(|entry| entry.path())
.filter(|path| path.to_str().unwrap().ends_with(".rs"))
.collect();
let dict = super::from_files(&paths, 4000).unwrap();
for path in paths {
let mut buffer = Vec::new();
let mut file = fs::File::open(path).unwrap();
let mut content = Vec::new();
file.read_to_end(&mut content).unwrap();
io::copy(&mut &content[..],
&mut ::stream::Encoder::with_dictionary(&mut buffer,
1,
&dict)
.unwrap()
.auto_finish())
.unwrap();
let mut result = Vec::new();
io::copy(&mut ::stream::Decoder::with_dictionary(&buffer[..],
&dict[..])
.unwrap(),
&mut result)
.unwrap();
assert_eq!(&content, &result);
}
}
}