use std::io::{Seek, Write};
use crate::compression_trait::{CompressionAlgo, DataLoaderTrait, Meta};
use crate::Endian;
use crate::{compression_trait::CompressionTrait, CHUNK_BUFF};
use flate2::write::{DeflateEncoder, GzEncoder, ZlibEncoder};
use indicatif::ProgressBar;
fn generate_header_compressed(
tensors: &[Meta],
) -> (Vec<String>, Vec<(String, usize, usize, usize, usize)>) {
let infos: Vec<String> = tensors
.iter()
.map(|meta| {
format!(
r#""{}":{{"begin":FASTTENSOR_PLACEHOLD,"name":"{}","shape":{:?},"strides":{:?},"size":{},"dtype":"{}","compress_algo":{},"endian":{},"indices":place_holder0}}"#,
meta.name,
meta.name,
meta.data_saver.shape().to_vec(),
meta.data_saver.shape().to_strides().to_vec(),
meta.data_saver.size(),
meta.data_saver.dtype(),
serde_json::to_string(&meta.compression_algo).expect("Failed to serialize compression_algo"),
serde_json::to_string(&meta.endian).expect("Failed to serialize endian")
)
})
.collect();
let res: Vec<(String, usize, usize, usize, usize)> = tensors
.iter()
.map(|meta| {
let x = &meta.data_saver;
let outer = x.size() / (*x.shape().last().unwrap() as usize);
let inner = (*x.shape().last().unwrap() as usize) * x.mem_size();
let num_chunks;
let mut num_lines;
let mut remain = 0;
let mut buffer_size;
if x.size() * x.mem_size() < CHUNK_BUFF {
num_chunks = 1;
num_lines = outer;
buffer_size = num_lines * inner;
} else {
buffer_size = ((CHUNK_BUFF - 1) / inner) * inner;
num_lines = buffer_size / inner;
if num_lines == 0 {
num_lines = 1;
buffer_size = inner;
}
remain = outer % num_lines;
num_chunks = outer / num_lines;
}
(
meta.name.clone(),
num_chunks,
num_lines,
remain,
buffer_size,
)
})
.collect::<Vec<(String, usize, usize, usize, usize)>>();
(infos, res)
}
pub fn save(path: std::path::PathBuf, mut to_saves: Vec<Meta>) -> std::io::Result<()> {
let total_size: usize = to_saves
.iter()
.map(|meta| meta.data_saver.size())
.sum::<usize>();
let pb = ProgressBar::new(total_size as u64);
pb.set_style(
indicatif::ProgressStyle::default_bar()
.template("[{elapsed_precise}] {bar:40.cyan/blue} {pos:>7}/{len:7} {msg}")
.unwrap(),
);
let mut header = generate_header_compressed(&to_saves);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).expect("Failed to create directory");
}
let mut file = std::fs::File::create(path)?;
file.write_all("FASTTENSORFASTTENSORHPLACEHOLD".as_bytes())?;
let mut len_so_far: usize = "FASTTENSORFASTTENSORHPLACEHOLD".len();
let mut attrs: Vec<Vec<(usize, usize, usize, usize)>> = vec![];
for (idx, meta) in to_saves.iter_mut().enumerate() {
let last_stride: i64 = *meta.data_saver.strides().last().unwrap() as i64;
let mut prg: Vec<i64> = vec![0; meta.data_saver.shape().len() - 1];
let mut shape: Vec<i64> = meta.data_saver.shape().iter().map(|x| *x as i64).collect();
shape.iter_mut().for_each(|x: &mut i64| {
*x -= 1;
});
let inner_loop_size: usize = *meta.data_saver.shape().last().unwrap() as usize;
let line_num: usize = header.1[idx].2;
let num_chunks: usize = header.1[idx].1;
let buffer_size: usize = header.1[idx].4;
let mut attributes = vec![];
let mut chunk = vec![0u8; buffer_size];
let unpack = get_unpack_closure(meta.endian);
for k in 0..num_chunks {
for j in 0..line_num {
for i in 0..inner_loop_size {
let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
unpack(
&mut meta.data_saver,
((i as i64) * last_stride) as isize,
&mut chunk[start..end],
);
}
pb.inc(inner_loop_size as u64);
for h in (0..shape.len() - 1).rev() {
if prg[h] < shape[h] {
prg[h] += 1;
meta.data_saver
.offset(meta.data_saver.strides()[h] as isize);
break;
} else {
prg[h] = 0;
meta.data_saver
.offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
}
}
}
compress_data(
&meta,
&chunk,
&mut file,
&mut attributes,
&mut len_so_far,
k,
line_num,
)?;
}
let remain_outer: usize = header.1[idx].3;
let mut remain_chunk: Vec<u8> =
vec![0u8; remain_outer * inner_loop_size * meta.data_saver.mem_size()];
for j in 0..remain_outer {
for i in 0..inner_loop_size {
let start = (j * inner_loop_size + i) * meta.data_saver.mem_size();
let end = (j * inner_loop_size + i + 1) * meta.data_saver.mem_size();
unpack(
&mut meta.data_saver,
((i as i64) * last_stride) as isize,
&mut remain_chunk[start..end],
);
}
pb.inc(inner_loop_size as u64);
for h in (0..shape.len() - 1).rev() {
if prg[h] < shape[h] {
prg[h] += 1;
meta.data_saver
.offset(meta.data_saver.strides()[h] as isize);
break;
} else {
prg[h] = 0;
meta.data_saver
.offset(-(meta.data_saver.strides()[h] * (shape[h] as i64)) as isize);
}
}
}
compress_data(
&meta,
&remain_chunk,
&mut file,
&mut attributes,
&mut len_so_far,
num_chunks,
line_num,
)?;
header.0[idx] = header.0[idx].replace("place_holder0", &format!("{:?}", attributes));
attrs.push(attributes);
}
let current_pos = file.seek(std::io::SeekFrom::Current(0))?;
file.seek(std::io::SeekFrom::Start(0))?;
file.write_all(format!("FASTTENSOR{:20}", current_pos).as_bytes())?;
file.seek(std::io::SeekFrom::Start(current_pos))?;
let mut header_len = "FASTTENSORFASTTENSORHPLACEHOLD".len();
let length: Vec<usize> = attrs
.iter()
.map(|x| {
let tmp = header_len;
header_len += x.last().unwrap().1 + x.last().unwrap().2;
tmp
})
.collect();
header.0.iter_mut().zip(length).for_each(|(x, y)| {
*x = x.replace("FASTTENSOR_PLACEHOLD", &format!("{:20}", y));
});
let to_write = format!(r#"{{{}}}"#, header.0.join(","));
let to_write = to_write.replace("(", "[");
let to_write = to_write.replace(")", "]");
let bytes = to_write.as_bytes();
file.write_all(bytes)?;
Ok(())
}
fn compress_data(
meta: &Meta,
chunk: &[u8],
file: &mut std::fs::File,
attributes: &mut Vec<(usize, usize, usize, usize)>,
len_so_far: &mut usize,
k: usize,
line_num: usize,
) -> std::io::Result<()> {
let mut closure = |compressed_data: &[u8]| -> std::io::Result<()> {
file.write_all(compressed_data)?;
attributes.push((
k * line_num,
*len_so_far,
compressed_data.len(),
chunk.len(),
));
*len_so_far += compressed_data.len();
Ok(())
};
match meta.compression_algo {
CompressionAlgo::Gzip => {
let mut encoder =
GzEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
encoder.write_all_data(chunk)?;
encoder.flush_all()?;
let compressed_data = encoder.finish_all()?;
closure(&compressed_data)?
}
CompressionAlgo::Deflate => {
let mut encoder =
DeflateEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
encoder.write_all_data(chunk)?;
encoder.flush_all()?;
let compressed_data = encoder.finish_all()?;
closure(&compressed_data)?
}
CompressionAlgo::Zlib => {
let mut encoder =
ZlibEncoder::new(Vec::new(), flate2::Compression::new(meta.compression_level));
encoder.write_all_data(chunk)?;
encoder.flush_all()?;
let compressed_data = encoder.finish_all()?;
closure(&compressed_data)?
}
CompressionAlgo::NoCompression => closure(chunk)?,
}
Ok(())
}
fn get_unpack_closure(endian: Endian) -> impl Fn(&mut Box<dyn DataLoaderTrait>, isize, &mut [u8]) {
match endian {
Endian::Little => {
|data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
data_saver.fill_le_bytes_slice(offset, data)
}
}
Endian::Big => {
|data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
data_saver.fill_be_bytes_slice(offset, data)
}
}
Endian::Native => {
|data_saver: &mut Box<dyn DataLoaderTrait>, offset: isize, data: &mut [u8]| {
data_saver.fill_ne_bytes_slice(offset, data)
}
}
}
}