use std::fs::File;
use std::io::{Read, Write};
use crate::errors::error::{SurrealError, SurrealErrorStatus};
use crate::storage::header::Header;
use crate::{safe_eject, safe_eject_internal};
pub struct SurMlFile {
pub header: Header,
pub model: Vec<u8>,
}
impl SurMlFile {
pub fn fresh(model: Vec<u8>) -> Self {
Self {
header: Header::fresh(),
model,
}
}
pub fn new(header: Header, model: Vec<u8>) -> Self {
Self {
header,
model,
}
}
pub fn from_bytes(bytes: Vec<u8>) -> Result<Self, SurrealError> {
if bytes.len() < 4 {
return Err(SurrealError::new(
"Not enough bytes to read".to_string(),
SurrealErrorStatus::BadRequest,
));
}
let mut header_bytes = Vec::new();
let mut model_bytes = Vec::new();
let mut buffer = [0u8; 4];
buffer.copy_from_slice(&bytes[0..4]);
let integer_value = u32::from_be_bytes(buffer);
if bytes.len() < (4 + integer_value as usize) {
return Err(SurrealError::new(
"Not enough bytes to read for header, maybe the file format is not correct"
.to_string(),
SurrealErrorStatus::BadRequest,
));
}
header_bytes.extend_from_slice(&bytes[4..(4 + integer_value as usize)]);
model_bytes.extend_from_slice(&bytes[(4 + integer_value as usize)..]);
let header = Header::from_bytes(header_bytes)?;
let model = model_bytes;
Ok(Self {
header,
model,
})
}
pub fn from_file(file_path: &str) -> Result<Self, SurrealError> {
let mut file = safe_eject!(File::open(file_path), SurrealErrorStatus::NotFound);
let mut buffer = [0u8; 4];
safe_eject!(file.read_exact(&mut buffer), SurrealErrorStatus::BadRequest);
let integer_value = u32::from_be_bytes(buffer);
let mut header_buffer = vec![0u8; integer_value as usize];
safe_eject!(file.read_exact(&mut header_buffer), SurrealErrorStatus::BadRequest);
let mut model_buffer = Vec::new();
safe_eject!(
file.take(usize::MAX as u64).read_to_end(&mut model_buffer),
SurrealErrorStatus::BadRequest
);
let header = Header::from_bytes(header_buffer)?;
Ok(Self {
header,
model: model_buffer,
})
}
pub fn to_bytes(&self) -> Vec<u8> {
let (num, header_bytes) = self.header.to_bytes();
let num_bytes = i32::to_be_bytes(num).to_vec();
let mut combined_vec: Vec<u8> = Vec::new();
combined_vec.extend(num_bytes);
combined_vec.extend(header_bytes);
combined_vec.extend(self.model.clone());
combined_vec
}
pub fn write(&self, file_path: &str) -> Result<(), SurrealError> {
let combined_vec = self.to_bytes();
let mut file = safe_eject_internal!(File::create(file_path));
safe_eject_internal!(file.write(&combined_vec));
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write() {
let mut header = Header::fresh();
header.add_column(String::from("squarefoot"));
header.add_column(String::from("num_floors"));
header.add_output(String::from("house_price"), None);
let mut file = File::open("./stash/linear_test.onnx").unwrap();
let mut model_bytes = Vec::new();
file.read_to_end(&mut model_bytes).unwrap();
let surml_file = SurMlFile::new(header, model_bytes);
surml_file.write("./stash/test.surml").unwrap();
let _ = SurMlFile::from_file("./stash/test.surml").unwrap();
}
#[test]
fn test_write_forrest() {
let header = Header::fresh();
let mut file = File::open("./stash/forrest_test.onnx").unwrap();
let mut model_bytes = Vec::new();
file.read_to_end(&mut model_bytes).unwrap();
let surml_file = SurMlFile::new(header, model_bytes);
surml_file.write("./stash/forrest.surml").unwrap();
let _ = SurMlFile::from_file("./stash/forrest.surml").unwrap();
}
#[test]
fn test_empty_buffer() {
let bytes = vec![0u8; 0];
match SurMlFile::from_bytes(bytes) {
Ok(_) => panic!("should have error with loading an empty buffer"),
Err(error) => {
assert_eq!(error.status, SurrealErrorStatus::BadRequest);
assert_eq!(error.to_string(), "Not enough bytes to read");
}
}
}
}