#![forbid(unsafe_code)]
#![deny(missing_docs)]
#![deny(unused_must_use)]
#![deny(unused_mut)]
use async_std::fs::{File, OpenOptions};
use async_std::io::prelude::SeekExt;
use async_std::io::{ReadExt, SeekFrom, WriteExt};
use async_std::prelude::Future;
use std::collections::HashMap;
use std::path::PathBuf;
use std::pin::Pin;
use std::str::from_utf8;
use anyhow::{bail, Context, Error, Result};
pub type UpgradeFunc =
fn(file: VersionedFile, initial_version: u8, upgraded_version: u8) -> Result<(), Error>;
pub type WrappedUpgradeFunc =
Box<dyn Fn(VersionedFile, u8, u8) -> Pin<Box<dyn Future<Output = Result<(), Error>>>>>;
pub fn wrap_upgrade_process<T>(f: fn(VersionedFile, u8, u8) -> T) -> WrappedUpgradeFunc
where
T: Future<Output = Result<(), Error>> + 'static,
{
Box::new(move |x, y, z| Box::pin(f(x, y, z)))
}
pub struct Upgrade {
pub initial_version: u8,
pub updated_version: u8,
pub process: WrappedUpgradeFunc,
}
#[derive(Debug)]
pub struct VersionedFile {
file: File,
cursor: u64,
needs_seek: bool,
}
impl VersionedFile {
async fn fix_seek(&mut self) -> Result<(), Error> {
if self.needs_seek {
match self.file.seek(SeekFrom::Start(self.cursor)).await {
Ok(_) => self.needs_seek = false,
Err(e) => bail!(format!(
"unable to set file cursor to correct position: {}",
e
)),
};
}
Ok(())
}
pub async fn len(&mut self) -> Result<u64, Error> {
let md = self
.file
.metadata()
.await
.context("unable to get metadata for file")?;
Ok(md.len() - 4096)
}
pub async fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), Error> {
self.fix_seek().await?;
self.needs_seek = true;
match self.file.read_exact(buf).await {
Ok(_) => {}
Err(e) => {
match self.file.seek(SeekFrom::Start(self.cursor)).await {
Ok(_) => self.needs_seek = false,
Err(_) => {}
};
bail!(format!("{}", e));
}
};
self.cursor += buf.len() as u64;
self.needs_seek = false;
Ok(())
}
pub async fn seek(&mut self, pos: SeekFrom) -> Result<u64, Error> {
self.fix_seek().await?;
self.needs_seek = true;
match pos {
SeekFrom::Start(x) => {
let new_pos = self
.file
.seek(SeekFrom::Start(x + 4096))
.await
.context("versioned file seek failed")?;
self.needs_seek = false;
self.cursor = new_pos;
return Ok(new_pos - 4096);
}
SeekFrom::End(x) => {
let size = self.len().await.context("unable to get file len")?;
if x + (size as i64) < 0 {
self.needs_seek = false;
bail!("cannot seek to a position before the start of the file");
}
let new_pos = self.file.seek(pos).await.context("seek failed")?;
self.needs_seek = false;
self.cursor = new_pos;
return Ok(new_pos - 4096);
}
SeekFrom::Current(x) => {
if x + (self.cursor as i64) < 4096 {
self.needs_seek = false;
bail!("cannot seek to a position before the start of the file");
}
let new_pos = self.file.seek(pos).await.context("seek failed")?;
self.needs_seek = false;
self.cursor = new_pos;
return Ok(new_pos - 4096);
}
}
}
pub async fn set_len(&mut self, new_len: u64) -> Result<(), Error> {
self.file
.set_len(new_len + 4096)
.await
.context("unable to adjust file length")?;
self.seek(SeekFrom::End(0))
.await
.context("unable to seek to new end of file")?;
Ok(())
}
pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), Error> {
self.fix_seek().await?;
self.needs_seek = true;
match self.file.write_all(buf).await {
Ok(_) => {}
Err(e) => {
match self.file.seek(SeekFrom::Start(self.cursor)).await {
Ok(_) => self.needs_seek = false,
Err(_) => {}
};
bail!(format!("{}", e));
}
};
self.file.flush().await.context("unable to flush file")?;
self.cursor += buf.len() as u64;
self.needs_seek = false;
Ok(())
}
}
fn version_to_str(version: u8) -> Result<String, Error> {
if version == 0 {
bail!("version is not allowed to be 0");
}
let mut version_string = format!("{}", version);
if version_string.len() == 1 {
version_string = format!("00{}", version);
} else if version_string.len() == 2 {
version_string = format!("0{}", version);
}
Ok(version_string)
}
async fn new_file_header(
file: &mut File,
expected_identifier: &str,
latest_version: u8,
) -> Result<(), Error> {
let version_string =
version_to_str(latest_version).context("unable to convert version to ascii string")?;
let header_str = format!("{}\n{}\n", version_string, expected_identifier);
let header_bytes = header_str.as_bytes();
if header_bytes.len() > 256 {
panic!("developer error: metadata_bytes should be guaranteed to have len below 256");
}
let mut full_header = [0u8; 4096];
full_header[..header_bytes.len()].copy_from_slice(header_bytes);
file.write_all(&full_header)
.await
.context("unable to write initial metadata")?;
file.flush()
.await
.context("unable to flush file after writing header")?;
let new_metadata = file
.metadata()
.await
.context("unable to get updated file metadata")?;
if new_metadata.len() != 4096 {
panic!(
"developer error: file did not initialize with 4096 bytes: {}",
new_metadata.len()
);
}
file.seek(SeekFrom::Start(0))
.await
.context("unable to seek back to beginning of file")?;
Ok(())
}
fn verify_upgrade_paths(upgrade_paths: &Vec<Upgrade>, latest_version: u8) -> Result<(), Error> {
if latest_version == 0 {
bail!("version 0 is not allowed for a VersionedFile");
}
let mut version_routes = HashMap::new();
for path in upgrade_paths {
if path.initial_version >= path.updated_version {
bail!("upgrade paths must always lead to a higher version number");
}
if version_routes.contains_key(&path.initial_version) {
bail!("upgrade paths can only have one upgrade for each version");
}
if path.updated_version > latest_version {
bail!("upgrade paths lead beyond the latest version");
}
if path.initial_version == 0 {
bail!("version 0 is not allowed for a VersionedFile");
}
version_routes.insert(path.initial_version, path.updated_version);
}
let mut complete_paths = HashMap::new();
complete_paths.insert(latest_version, {});
loop {
let mut progress = false;
let mut finished = true;
for (key, value) in &version_routes {
if complete_paths.contains_key(key) {
continue;
}
if complete_paths.contains_key(value) {
progress = true;
complete_paths.insert(*key, {});
} else {
finished = false;
}
}
if finished {
break;
}
if progress == false {
bail!("update graph is incomplete, not all nodes lead to the latest version");
}
}
Ok(())
}
async fn perform_file_upgrade(filepath: &PathBuf, u: &Upgrade) -> Result<(), Error> {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(filepath)
.await
.context("unable to open versioned file for update")?;
let mut versioned_file = VersionedFile {
file,
cursor: 4096,
needs_seek: false,
};
versioned_file
.seek(SeekFrom::Start(0))
.await
.context("unable to seek in file after upgrade")?;
(u.process)(versioned_file, u.initial_version, u.updated_version)
.await
.context(format!(
"unable to complete file upgrade from version {} to {}",
u.initial_version, u.updated_version
))?;
let file = OpenOptions::new()
.read(true)
.write(true)
.open(filepath)
.await
.context("unable to open versioned file for update")?;
let mut versioned_file = VersionedFile {
file,
cursor: 4096,
needs_seek: false,
};
let updated_version_str =
version_to_str(u.updated_version).context("upgrade path has bad version")?;
versioned_file
.file
.seek(SeekFrom::Start(0))
.await
.context("unable to seek to beginning of file")?;
versioned_file
.file
.write_all(updated_version_str.as_bytes())
.await
.context("unable to write updated version to file header")?;
Ok(())
}
pub async fn open_file(
filepath: &PathBuf,
expected_identifier: &str,
latest_version: u8,
upgrades: &Vec<Upgrade>,
) -> Result<VersionedFile, Error> {
let path_str = filepath.to_str().context("could not stringify path")?;
if !path_str.is_ascii() {
bail!("path should be valid ascii");
}
if expected_identifier.len() > 251 {
bail!("the identifier of a versioned file cannot exceed 251 bytes");
}
if !expected_identifier.is_ascii() {
bail!("the identifier must be ascii");
}
let mut file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(filepath)
.await
.context("unable to open versioned file")?;
let file_metadata = file
.metadata()
.await
.context("unable to read versioned file metadata")?;
if file_metadata.len() == 0 {
new_file_header(&mut file, expected_identifier, latest_version)
.await
.context("unable to write new file header")?;
}
let mut header = vec![0; 4096];
file.read_exact(&mut header)
.await
.context("unable to read file header")?;
let header_identifier = from_utf8(&header[4..4 + expected_identifier.len()])
.context("the on-disk file identifier could not be parsed")?;
if header_identifier != expected_identifier {
bail!("the file does not have the correct identifier");
}
verify_upgrade_paths(&upgrades, latest_version).context("upgrade paths are invalid")?;
let version_str = from_utf8(&header[..3]).context("the on-disk version could not be parsed")?;
let mut version: u8 = version_str
.parse()
.context("unable to parse on-disk version")?;
drop(file);
while version != latest_version {
let mut found = false;
for upgrade in upgrades {
if upgrade.initial_version == version {
perform_file_upgrade(filepath, upgrade)
.await
.context("unable to complete file upgrade")?;
version = upgrade.updated_version;
found = true;
break;
}
}
if !found {
bail!("no viable upgrade path exists for file");
}
}
let mut file = OpenOptions::new()
.read(true)
.write(true)
.open(filepath)
.await
.context("unable to open versioned file for update")?;
file.seek(SeekFrom::Start(4096))
.await
.context("unable to seek to beginning of file after header")?;
let versioned_file = VersionedFile {
file,
cursor: 4096,
needs_seek: false,
};
Ok(versioned_file)
}
#[cfg(test)]
mod tests {
use super::*;
use testdir::testdir;
async fn stub_upgrade(_: VersionedFile, _: u8, _: u8) -> Result<(), Error> {
Ok(())
}
async fn smoke_upgrade_1_2(
mut vf: VersionedFile,
initial_version: u8,
updated_version: u8,
) -> Result<(), Error> {
if initial_version != 1 || updated_version != 2 {
bail!("this upgrade is intended to take the file from version 1 to version 2");
}
if vf.len().await.unwrap() != 9 {
bail!("file is wrong len");
}
let mut buf = [0u8; 9];
vf.read_exact(&mut buf)
.await
.context("unable to read old file contents")?;
if &buf != b"test_data" {
bail!(format!("file appears corrupt: {:?}", buf));
}
let new_data = b"test";
vf.set_len(0).await.unwrap();
vf.write_all(new_data)
.await
.context("unable to write new data after deleting old data")?;
Ok(())
}
async fn smoke_upgrade_2_3(
mut vf: VersionedFile,
initial_version: u8,
updated_version: u8,
) -> Result<(), Error> {
if initial_version != 2 || updated_version != 3 {
bail!("this upgrade is intended to take the file from version 2 to version 3");
}
if vf.len().await.unwrap() != 4 {
bail!("file is wrong len");
}
let mut buf = [0u8; 4];
vf.read_exact(&mut buf)
.await
.context("unable to read old file contents")?;
if &buf != b"test" {
bail!("file appears corrupt");
}
let new_data = b"testtest";
vf.set_len(0).await.unwrap();
vf.write_all(new_data)
.await
.context("unable to write new data after deleting old data")?;
Ok(())
}
async fn smoke_upgrade_3_4(
mut vf: VersionedFile,
initial_version: u8,
updated_version: u8,
) -> Result<(), Error> {
if initial_version != 3 || updated_version != 4 {
bail!("this upgrade is intended to take the file from version 1 to version 2");
}
if vf.len().await.unwrap() != 8 {
bail!("file is wrong len");
}
let mut buf = [0u8; 8];
vf.read_exact(&mut buf)
.await
.context("unable to read old file contents")?;
if &buf != b"testtest" {
bail!("file appears corrupt");
}
let new_data = b"testtesttest";
vf.set_len(0).await.unwrap();
vf.write_all(new_data)
.await
.context("unable to write new data after deleting old data")?;
Ok(())
}
#[async_std::test]
async fn smoke_test() {
let dir = testdir!();
let test_dat = dir.join("test.dat");
open_file(&test_dat, "versioned_file::test.dat", 0, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap_err();
open_file(&test_dat, "versioned_file::test.dat", 1, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap();
open_file(&test_dat, "versioned_file::test.dat", 1, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap();
open_file(&test_dat, "bad_versioned_file::test.dat", 1, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap_err();
let invalid_name = dir.join("❄️"); open_file(&invalid_name, "versioned_file::test.dat", 1, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap_err();
let invalid_id = dir.join("invalid_identifier.dat");
open_file(&invalid_id, "versioned_file::test.dat::❄️", 1, &Vec::new())
.await
.context("unable to create versioned file")
.unwrap_err();
let mut file = open_file(&test_dat, "versioned_file::test.dat", 1, &Vec::new())
.await
.unwrap();
file.write_all(b"test_data").await.unwrap();
let mut file = open_file(&test_dat, "versioned_file::test.dat", 1, &Vec::new())
.await
.unwrap();
if file.len().await.unwrap() != 9 {
panic!("file has unexpected len");
}
let mut buf = [0u8; 9];
file.read_exact(&mut buf).await.unwrap();
if &buf != b"test_data" {
panic!("data read does not match data written");
}
open_file(&test_dat, "versioned_file::test.dat", 1, &Vec::new())
.await
.unwrap();
let mut upgrade_chain = vec![Upgrade {
initial_version: 1,
updated_version: 2,
process: wrap_upgrade_process(smoke_upgrade_1_2),
}];
let mut file = open_file(&test_dat, "versioned_file::test.dat", 2, &upgrade_chain)
.await
.unwrap();
if file.len().await.unwrap() != 4 {
panic!("file has wrong len");
}
let mut buf = [0u8; 4];
file.read_exact(&mut buf).await.unwrap();
if &buf != b"test" {
panic!("data read does not match data written");
}
open_file(&test_dat, "versioned_file::test.dat", 2, &upgrade_chain)
.await
.unwrap();
upgrade_chain.push(Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(smoke_upgrade_2_3),
});
upgrade_chain.push(Upgrade {
initial_version: 3,
updated_version: 4,
process: wrap_upgrade_process(smoke_upgrade_3_4),
});
let mut file = open_file(&test_dat, "versioned_file::test.dat", 4, &upgrade_chain)
.await
.unwrap();
if file.len().await.unwrap() != 12 {
panic!("file has wrong len");
}
let mut buf = [0u8; 12];
file.read_exact(&mut buf).await.unwrap();
if &buf != b"testtesttest" {
panic!("data read does not match data written");
}
let mut file = open_file(&test_dat, "versioned_file::test.dat", 4, &upgrade_chain)
.await
.unwrap();
file.seek(SeekFrom::End(-5)).await.unwrap();
file.write_all(b"NOVELLA").await.unwrap();
file.seek(SeekFrom::Current(-3)).await.unwrap();
file.seek(SeekFrom::Current(-4)).await.unwrap();
file.seek(SeekFrom::Current(-7)).await.unwrap();
let mut buf = [0u8; 14];
file.read_exact(&mut buf).await.unwrap();
if &buf != b"testtesNOVELLA" {
panic!(
"read data has unexpected result: {} || {}",
std::str::from_utf8(&buf).unwrap(),
buf[0]
);
}
file.seek(SeekFrom::Current(-2)).await.unwrap();
file.seek(SeekFrom::End(-15)).await.unwrap_err();
let mut buf = [0u8; 2];
file.read_exact(&mut buf).await.unwrap();
if &buf != b"LA" {
panic!("seek_end error changed file cursor");
}
file.seek(SeekFrom::Current(-2)).await.unwrap();
file.seek(SeekFrom::Current(-13)).await.unwrap_err();
file.read_exact(&mut buf).await.unwrap();
if &buf != b"LA" {
panic!("seek_end error changed file cursor");
}
}
#[test]
fn test_verify_upgrade_paths() {
verify_upgrade_paths(&Vec::new(), 0).unwrap_err(); verify_upgrade_paths(&Vec::new(), 1).unwrap();
verify_upgrade_paths(&Vec::new(), 2).unwrap();
verify_upgrade_paths(&Vec::new(), 255).unwrap();
verify_upgrade_paths(
&vec![Upgrade {
initial_version: 1,
updated_version: 2,
process: wrap_upgrade_process(stub_upgrade),
}],
2,
)
.unwrap();
verify_upgrade_paths(
&vec![Upgrade {
initial_version: 2,
updated_version: 2,
process: wrap_upgrade_process(stub_upgrade),
}],
2,
)
.unwrap_err();
verify_upgrade_paths(
&vec![Upgrade {
initial_version: 1,
updated_version: 2,
process: wrap_upgrade_process(stub_upgrade),
}],
3,
)
.unwrap_err();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 1,
updated_version: 2,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
],
3,
)
.unwrap();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 1,
updated_version: 2,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 1,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
],
3,
)
.unwrap_err();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 1,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
],
3,
)
.unwrap();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 1,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
],
2,
)
.unwrap_err();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 1,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 3,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 4,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 5,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
],
6,
)
.unwrap();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 5,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 2,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 3,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 1,
updated_version: 3,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 4,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
],
6,
)
.unwrap();
verify_upgrade_paths(
&vec![
Upgrade {
initial_version: 2,
updated_version: 5,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 6,
updated_version: 7,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 3,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 1,
updated_version: 4,
process: wrap_upgrade_process(stub_upgrade),
},
Upgrade {
initial_version: 4,
updated_version: 6,
process: wrap_upgrade_process(stub_upgrade),
},
],
6,
)
.unwrap_err();
}
#[test]
fn test_version_to_str() {
version_to_str(0).unwrap_err();
if version_to_str(1).unwrap() != "001" {
panic!("1 failed");
}
if version_to_str(2).unwrap() != "002" {
panic!("2 failed");
}
if version_to_str(9).unwrap() != "009" {
panic!("9 failed");
}
if version_to_str(39).unwrap() != "039" {
panic!("39 failed");
}
if version_to_str(139).unwrap() != "139" {
panic!("139 failed");
}
}
}