use std::io::{Read, Write};
use xxhash_rust::xxh3::xxh3_128;
use crate::{ConfigError, ConfigLocation, ConfigType};
const HASH_BYTE_LENGTH: usize = 16;
pub use bitcode::{Decode, DecodeOwned, Encode};
pub fn load_bin<T>(
app_name: impl AsRef<str>,
config_name: Option<&str>,
location: impl AsRef<ConfigLocation>,
reset_conf_on_err: bool,
) -> Result<T, ConfigError>
where
T: Default + Encode,
for<'de> T: Decode<'de>,
{
load_bin_internal(
app_name.as_ref(),
config_name,
location.as_ref(),
reset_conf_on_err,
false,
)
}
pub fn load_bin_skip_check<T>(
app_name: impl AsRef<str>,
config_name: Option<&str>,
location: impl AsRef<ConfigLocation>,
reset_conf_on_err: bool,
) -> Result<T, ConfigError>
where
T: Default + Encode,
for<'de> T: Decode<'de>,
{
load_bin_internal(
app_name.as_ref(),
config_name,
location.as_ref(),
reset_conf_on_err,
true,
)
}
fn load_bin_internal<T>(
app_name: &str,
config_name: Option<&str>,
location: &ConfigLocation,
reset_conf_on_err: bool,
skip_hash_check: bool,
) -> Result<T, ConfigError>
where
T: Default + Encode,
for<'de> T: Decode<'de>,
{
let config_file_path =
crate::config_location(app_name, config_name, ConfigType::Bin.as_str(), location)?;
let save_default_conf = || {
let default_config = T::default();
let mut file = std::io::BufWriter::new(std::fs::File::create(&config_file_path)?);
let full_data = prepare_serialized_data(&default_config);
file.write_all(&full_data)?;
Ok(default_config)
};
if !config_file_path.try_exists()? {
return save_default_conf();
}
let file = std::fs::File::open(&config_file_path)?;
let mut reader = std::io::BufReader::new(file);
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
if data.len() < HASH_BYTE_LENGTH {
if reset_conf_on_err {
return save_default_conf();
}
return Err(ConfigError::CorruptedHashSector);
}
if !skip_hash_check {
let (binary_hash_from_file, binary_hash_from_data) = get_hash_from_file_and_data(&data);
if binary_hash_from_file != binary_hash_from_data {
if reset_conf_on_err {
return save_default_conf();
}
return Err(ConfigError::HashMismatch);
}
}
let binary_data_without_hash = &data[HASH_BYTE_LENGTH..];
let config: T = match bitcode::decode(binary_data_without_hash) {
Ok(config) => config,
Err(err) => {
if reset_conf_on_err {
save_default_conf()?
} else {
return Err(ConfigError::Bitcode(err));
}
}
};
Ok(config)
}
pub fn store_bin<T>(
app_name: impl AsRef<str>,
config_name: Option<&str>,
location: impl AsRef<ConfigLocation>,
data: &T,
) -> Result<(), ConfigError>
where
T: Encode,
{
let config_file_path = crate::config_location(
app_name.as_ref(),
config_name.as_ref().map(AsRef::as_ref),
ConfigType::Bin.as_str(),
location.as_ref(),
)?;
let mut file = std::io::BufWriter::new(std::fs::File::create(config_file_path)?);
let full_data = prepare_serialized_data(data);
file.write_all(&full_data[..])?;
Ok(())
}
fn get_hash_from_file_and_data(data: &[u8]) -> (&[u8], Vec<u8>) {
let binary_hash_from_file = &data[..HASH_BYTE_LENGTH];
let binary_data_without_hash = &data[HASH_BYTE_LENGTH..];
let binary_hash_from_data = &xxh3_128(binary_data_without_hash).to_le_bytes()[..];
assert!(binary_hash_from_data.len() == HASH_BYTE_LENGTH);
(binary_hash_from_file, binary_hash_from_data.to_vec())
}
fn prepare_serialized_data<T>(data: &T) -> Vec<u8>
where
T: bitcode::Encode,
{
let mut full_data = [vec![0; HASH_BYTE_LENGTH], bitcode::encode(data)].concat();
let hash = &xxh3_128(&full_data[HASH_BYTE_LENGTH..]).to_le_bytes()[..];
full_data[..HASH_BYTE_LENGTH].clone_from_slice(hash);
full_data
}
#[cfg(test)]
mod tests {
use std::io::Seek;
use super::*;
use crate::get_configuration_path;
use ConfigLocation::{Cache, Config, Cwd, LocalData};
#[derive(Default, PartialEq, Debug, Clone, Encode, Decode)]
struct TestConfig {
test: String,
test_vec: Vec<u8>,
}
#[derive(Default, Clone, Debug, Decode, Encode)]
struct TestConfig2 {
strings: String,
vecs: Vec<u8>,
num_1: i32,
num_2: i32,
}
#[test]
fn read_default_config_bin() {
let config = load_bin::<String>(
"test-binconf-read_default_config-string-bin",
None,
Config,
false,
)
.unwrap();
assert_eq!(config, String::from(""));
let test_config = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
let config: TestConfig = load_bin(
"test-binconf-read_default_config-struct-bin",
None,
Config,
false,
)
.unwrap();
assert_eq!(config, TestConfig::default());
store_bin(
"test-binconf-read_default_config-struct-bin",
None::<&str>,
Config,
&test_config,
)
.unwrap();
let config: TestConfig = load_bin(
"test-binconf-read_default_config-struct-bin",
None,
Config,
false,
)
.unwrap();
assert_eq!(config, test_config);
}
#[test]
fn config_with_name_bin() {
let config = load_bin::<String>(
"test-binconf-config_with_name-string-bin",
Some("test-config.bin"),
Config,
false,
)
.unwrap();
assert_eq!(config, String::from(""));
let test_config = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
let config: TestConfig = load_bin(
"test-binconf-config_with_name-struct-bin",
Some("test-config.bin"),
Config,
false,
)
.unwrap();
assert_eq!(config, TestConfig::default());
store_bin(
"test-binconf-config_with_name-struct-bin",
Some("test-config.bin"),
Config,
&test_config,
)
.unwrap();
let config: TestConfig = load_bin(
"test-binconf-config_with_name-struct-bin",
Some("test-config.bin"),
Config,
false,
)
.unwrap();
assert_eq!(config, test_config);
}
#[test]
fn returns_error_on_invalid_config_bin() {
let data = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2],
};
store_bin(
"test-binconf-returns_error_on_invalid_config-bin",
None,
Config,
&data,
)
.unwrap();
let config = load_bin::<TestConfig2>(
"test-binconf-returns_error_on_invalid_config-bin",
None,
Config,
false,
);
assert!(config.is_err());
}
#[test]
fn save_config_user_config_bin() {
let data = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
store_bin(
"test-binconf-save_config_user_config-bin",
None,
Config,
&data,
)
.unwrap();
let config: TestConfig = load_bin(
"test-binconf-save_config_user_config-bin",
None,
Config,
false,
)
.unwrap();
assert_eq!(config, data);
}
#[test]
fn save_config_user_cache_bin() {
let data = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
store_bin(
"test-binconf-save_config_user_cache-bin",
None,
Cache,
&data,
)
.unwrap();
let config: TestConfig = load_bin(
"test-binconf-save_config_user_cache-bin",
None,
Cache,
false,
)
.unwrap();
assert_eq!(config, data);
}
#[test]
fn save_config_user_local_data_bin() {
let data = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
store_bin(
"test-binconf-save_config_user_local_data-bin",
None,
LocalData,
&data,
)
.unwrap();
let config: TestConfig = load_bin(
"test-binconf-save_config_user_local_data-bin",
None,
LocalData,
false,
)
.unwrap();
assert_eq!(config, data);
}
#[test]
fn save_config_user_cwd_bin() {
let data = TestConfig {
test: String::from("test"),
test_vec: vec![1, 2, 3, 4, 5],
};
store_bin("test-binconf-save_config_user_cwd-bin", None, Cwd, &data).unwrap();
let config: TestConfig =
load_bin("test-binconf-save_config_user_cwd-bin", None, Cwd, false).unwrap();
assert_eq!(config, data);
}
#[test]
fn load_config_fallback() {
let data = String::from("test of corrupted data");
store_bin("test-binconf-load_config_fallback-bin", None, Config, &data).unwrap();
assert_eq!(
load_bin::<String>("test-binconf-load_config_fallback-bin", None, Config, false)
.unwrap(),
data
);
let mut file = std::fs::OpenOptions::new()
.write(true)
.read(true)
.open(
get_configuration_path(
"test-binconf-load_config_fallback-bin",
None,
ConfigType::Bin,
Config,
)
.unwrap(),
)
.unwrap();
let mut new_data = Vec::new();
file.read_to_end(&mut new_data).unwrap();
if let Some(last) = new_data.last_mut() {
*last = 0x6F;
}
file.seek(std::io::SeekFrom::Start(0)).unwrap();
file.write_all(&new_data[..]).unwrap();
assert!(
load_bin::<String>("test-binconf-load_config_fallback-bin", None, Config, false)
.is_err()
);
let corrupted_data = load_bin_skip_check::<String>(
"test-binconf-load_config_fallback-bin",
None,
Config,
true,
)
.unwrap();
assert_eq!(corrupted_data, "test of corrupted dato");
}
}