#[cfg(target_os = "linux")]
mod linux;
#[cfg(target_os = "macos")]
mod macos;
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
mod other;
#[cfg(target_os = "windows")]
mod windows;
#[cfg(target_os = "linux")]
use linux as platform;
#[cfg(target_os = "macos")]
use macos as platform;
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
use other as platform;
#[cfg(target_os = "windows")]
use windows as platform;
use std::fs::{File, OpenOptions};
use std::io::{self, Seek, SeekFrom, Write};
use std::path::Path;
use super::writeback::WritebackPipeline;
const WRITEBACK_CHUNK_BYTES_DEFAULT: u64 = 32 * 1024 * 1024;
fn writeback_chunk_bytes() -> u64 {
std::env::var("FREEMKV_WRITEBACK_CHUNK_MIB")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.filter(|&n| n > 0)
.map(|n| n * 1024 * 1024)
.unwrap_or(WRITEBACK_CHUNK_BYTES_DEFAULT)
}
pub(crate) struct WritebackFile {
file: File,
pipeline: WritebackPipeline,
pos: u64,
}
impl WritebackFile {
pub(crate) fn new(mut file: File) -> io::Result<Self> {
let pos = file.stream_position()?;
let pipeline = WritebackPipeline::new(&file, pos, writeback_chunk_bytes());
Ok(Self {
file,
pipeline,
pos,
})
}
#[allow(dead_code)]
pub(crate) fn create(path: &Path) -> io::Result<Self> {
let file = File::create(path)?;
Self::new(file)
}
pub(crate) fn create_with_size_hint(path: &Path, size_bytes: u64) -> io::Result<Self> {
let file = File::create(path)?;
platform::preallocate(&file, size_bytes);
Self::new(file)
}
pub(crate) fn open(path: &Path) -> io::Result<Self> {
let file = OpenOptions::new().write(true).open(path)?;
Self::new(file)
}
pub(crate) fn sync_all(&mut self) -> io::Result<()> {
self.pipeline.finalize();
platform::durable_sync(&self.file)
}
}
impl Write for WritebackFile {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = self.file.write(buf)?;
self.pos += n as u64;
self.pipeline.note_progress(self.pos);
Ok(n)
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
self.file.write_all(buf)?;
self.pos += buf.len() as u64;
self.pipeline.note_progress(self.pos);
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
self.file.flush()
}
}
impl Seek for WritebackFile {
fn seek(&mut self, from: SeekFrom) -> io::Result<u64> {
let p = self.file.seek(from)?;
if p != self.pos {
let from_pos = self.pos;
let to_pos = p;
let delta: i64 = (to_pos as i64).wrapping_sub(from_pos as i64);
tracing::debug!(
target: "mux",
"WritebackFile seek from={from_pos} to={to_pos} delta={delta}"
);
self.pipeline.handle_seek(p);
self.pos = p;
}
Ok(p)
}
}
impl Drop for WritebackFile {
fn drop(&mut self) {
self.pipeline.finalize();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
fn read_back(path: &Path) -> Vec<u8> {
let mut f = File::open(path).unwrap();
let mut v = Vec::new();
f.read_to_end(&mut v).unwrap();
v
}
#[test]
fn write_then_drop_persists_bytes() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("a.bin");
{
let mut w = WritebackFile::create(&p).unwrap();
w.write_all(b"hello world").unwrap();
}
assert_eq!(read_back(&p), b"hello world");
}
#[test]
fn sync_all_drains_and_flushes() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("b.bin");
let mut w = WritebackFile::create(&p).unwrap();
for _ in 0..32 {
w.write_all(&[0x5au8; 1024]).unwrap();
}
w.sync_all().unwrap();
let bytes = read_back(&p);
assert_eq!(bytes.len(), 32 * 1024);
assert!(bytes.iter().all(|&b| b == 0x5a));
drop(w);
}
#[test]
fn seek_then_patch_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("c.bin");
let mut w = WritebackFile::create(&p).unwrap();
let big = vec![b'A'; 4096];
w.write_all(&big).unwrap();
w.seek(SeekFrom::Start(1000)).unwrap();
w.write_all(b"PATCHED!").unwrap();
w.sync_all().unwrap();
drop(w);
let bytes = read_back(&p);
assert_eq!(bytes.len(), 4096);
assert_eq!(&bytes[1000..1008], b"PATCHED!");
assert_eq!(bytes[999], b'A');
assert_eq!(bytes[1008], b'A');
}
#[test]
fn flush_is_observed_in_order() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("f.bin");
let mut w = WritebackFile::create(&p).unwrap();
w.write_all(b"one").unwrap();
w.flush().unwrap();
w.write_all(b"two").unwrap();
w.flush().unwrap();
w.write_all(b"three").unwrap();
w.sync_all().unwrap();
drop(w);
assert_eq!(read_back(&p), b"onetwothree");
}
}