#[cfg(target_os = "linux")]
mod linux;
#[cfg(target_os = "macos")]
mod macos;
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
mod other;
#[cfg(target_os = "windows")]
mod windows;
#[cfg(target_os = "linux")]
use linux as platform;
#[cfg(target_os = "macos")]
use macos as platform;
#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
use other as platform;
#[cfg(target_os = "windows")]
use windows as platform;
use std::collections::VecDeque;
use std::fs::{File, OpenOptions};
use std::io::{self, Seek, SeekFrom, Write};
use std::path::Path;
use std::sync::mpsc::{SyncSender, sync_channel};
use std::sync::{Arc, Condvar, Mutex};
use std::thread::{self, JoinHandle};
use super::writeback::WritebackPipeline;
const WRITEBACK_CHUNK_BYTES: u64 = 32 * 1024 * 1024;
const RING_CAPACITY_BYTES: usize = 128 * 1024 * 1024;
const ACTIVE_CLUSTER_WINDOW_BYTES: u64 = WRITEBACK_CHUNK_BYTES;
const WRITER_THREAD_NAME: &str = "freemkv-writeback-writer";
const MAX_WRITE_CHUNK_BYTES: usize = RING_CAPACITY_BYTES;
enum Cmd {
Write(Vec<u8>),
Seek(SeekFrom),
Flush,
SyncAll { done: SyncSender<io::Result<()>> },
Finish { done: SyncSender<()> },
}
struct RingState {
queue: VecDeque<Cmd>,
bytes_inflight: usize,
sticky_error: Option<io::ErrorKind>,
writer_gone: bool,
}
struct Shared {
state: Mutex<RingState>,
space_available: Condvar,
work_available: Condvar,
}
impl Shared {
fn new() -> Self {
Self {
state: Mutex::new(RingState {
queue: VecDeque::new(),
bytes_inflight: 0,
sticky_error: None,
writer_gone: false,
}),
space_available: Condvar::new(),
work_available: Condvar::new(),
}
}
}
struct ActiveClusterBuffer {
cap: u64,
lo: u64,
data: VecDeque<u8>,
}
impl ActiveClusterBuffer {
fn new(cap: u64) -> Self {
Self {
cap,
lo: 0,
data: VecDeque::with_capacity(cap as usize),
}
}
fn reset(&mut self, new_lo: u64) {
self.data.clear();
self.lo = new_lo;
}
fn hi(&self) -> u64 {
self.lo + self.data.len() as u64
}
fn contains(&self, pos: u64) -> bool {
pos >= self.lo && pos <= self.hi()
}
fn push(&mut self, start: u64, bytes: &[u8]) {
let h = self.hi();
if start == h {
self.data.extend(bytes.iter().copied());
} else if start >= self.lo && start <= h {
let offset = (start - self.lo) as usize;
let mut bi = 0usize;
while bi < bytes.len() && offset + bi < self.data.len() {
self.data[offset + bi] = bytes[bi];
bi += 1;
}
if bi < bytes.len() {
self.data.extend(bytes[bi..].iter().copied());
}
} else {
self.data.clear();
self.lo = start;
self.data.extend(bytes.iter().copied());
}
while self.data.len() as u64 > self.cap {
self.data.pop_front();
self.lo += 1;
}
}
}
pub(crate) struct WritebackFile {
shared: Arc<Shared>,
writer: Option<JoinHandle<()>>,
muxer_pos: u64,
}
impl WritebackFile {
pub(crate) fn new(mut file: File) -> io::Result<Self> {
let pos = file.stream_position()?;
Ok(Self::spawn(file, pos))
}
#[allow(dead_code)]
pub(crate) fn create(path: &Path) -> io::Result<Self> {
let file = File::create(path)?;
Self::new(file)
}
pub(crate) fn create_with_size_hint(path: &Path, size_bytes: u64) -> io::Result<Self> {
let file = File::create(path)?;
platform::preallocate(&file, size_bytes);
Self::new(file)
}
pub(crate) fn open(path: &Path) -> io::Result<Self> {
let file = OpenOptions::new().write(true).open(path)?;
Self::new(file)
}
fn spawn(file: File, start_pos: u64) -> Self {
let shared = Arc::new(Shared::new());
let shared_w = Arc::clone(&shared);
let writer = thread::Builder::new()
.name(WRITER_THREAD_NAME.into())
.spawn(move || {
writer_thread_main(file, start_pos, shared_w);
})
.expect("writer thread spawn");
Self {
shared,
writer: Some(writer),
muxer_pos: start_pos,
}
}
pub(crate) fn sync_all(&mut self) -> io::Result<()> {
let (tx, rx) = sync_channel::<io::Result<()>>(0);
self.push_command(Cmd::SyncAll { done: tx }, 0)?;
match rx.recv() {
Ok(r) => r,
Err(_) => {
Err(io::Error::from(io::ErrorKind::BrokenPipe))
}
}
}
fn push_command(&mut self, cmd: Cmd, bytes_charge: usize) -> io::Result<()> {
let mut guard = self.shared.state.lock().unwrap();
if let Some(kind) = guard.sticky_error {
return Err(io::Error::from(kind));
}
if guard.writer_gone {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
while bytes_charge > 0
&& guard.bytes_inflight + bytes_charge > RING_CAPACITY_BYTES
&& guard.bytes_inflight > 0
{
guard = self.shared.space_available.wait(guard).unwrap();
if let Some(kind) = guard.sticky_error {
return Err(io::Error::from(kind));
}
if guard.writer_gone {
return Err(io::Error::from(io::ErrorKind::BrokenPipe));
}
}
guard.queue.push_back(cmd);
guard.bytes_inflight += bytes_charge;
drop(guard);
self.shared.work_available.notify_one();
Ok(())
}
}
impl Write for WritebackFile {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
let n = buf.len();
if n == 0 {
return Ok(0);
}
if n > MAX_WRITE_CHUNK_BYTES {
let mut off = 0;
while off < n {
let take = (n - off).min(MAX_WRITE_CHUNK_BYTES);
self.push_command(Cmd::Write(buf[off..off + take].to_vec()), take)?;
off += take;
}
} else {
self.push_command(Cmd::Write(buf.to_vec()), n)?;
}
self.muxer_pos += n as u64;
Ok(n)
}
fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
let n = buf.len();
if n == 0 {
return Ok(());
}
if n > MAX_WRITE_CHUNK_BYTES {
let mut off = 0;
while off < n {
let take = (n - off).min(MAX_WRITE_CHUNK_BYTES);
self.push_command(Cmd::Write(buf[off..off + take].to_vec()), take)?;
off += take;
}
} else {
self.push_command(Cmd::Write(buf.to_vec()), n)?;
}
self.muxer_pos += n as u64;
Ok(())
}
fn flush(&mut self) -> io::Result<()> {
self.push_command(Cmd::Flush, 0)
}
}
impl Seek for WritebackFile {
fn seek(&mut self, from: SeekFrom) -> io::Result<u64> {
let new_pos = match from {
SeekFrom::Start(n) => n,
SeekFrom::Current(d) => {
let base = self.muxer_pos as i64;
let p = base
.checked_add(d)
.ok_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))?;
if p < 0 {
return Err(io::Error::from(io::ErrorKind::InvalidInput));
}
p as u64
}
SeekFrom::End(_) => {
return Err(io::Error::from(io::ErrorKind::Unsupported));
}
};
self.push_command(Cmd::Seek(SeekFrom::Start(new_pos)), 0)?;
self.muxer_pos = new_pos;
Ok(new_pos)
}
}
impl Drop for WritebackFile {
fn drop(&mut self) {
let (tx, rx) = sync_channel::<()>(0);
let push_ok = {
let mut guard = match self.shared.state.lock() {
Ok(g) => g,
Err(poison) => {
poison.into_inner()
}
};
if guard.writer_gone {
false
} else {
guard.queue.push_back(Cmd::Finish { done: tx });
self.shared.work_available.notify_one();
true
}
};
if push_ok {
let _ = rx.recv();
}
if let Some(jh) = self.writer.take() {
if let Err(panic) = jh.join() {
tracing::error!(
target: "mux",
"WritebackFile writer thread panicked during Drop; data may be lost"
);
if !std::thread::panicking() {
std::panic::resume_unwind(panic);
}
}
}
}
}
fn writer_thread_main(file: File, start_pos: u64, shared: Arc<Shared>) {
let pipeline = WritebackPipeline::new(&file, start_pos, WRITEBACK_CHUNK_BYTES);
let mut state = WriterState {
file,
pipeline,
pos: start_pos,
active: ActiveClusterBuffer::new(ACTIVE_CLUSTER_WINDOW_BYTES),
shared: Arc::clone(&shared),
};
state.run();
}
struct WriterState {
file: File,
pipeline: WritebackPipeline,
pos: u64,
active: ActiveClusterBuffer,
shared: Arc<Shared>,
}
impl WriterState {
fn run(&mut self) {
loop {
let cmd = match self.dequeue() {
Some(c) => c,
None => {
self.mark_writer_gone();
return;
}
};
match cmd {
Cmd::Write(buf) => {
if let Err(e) = self.do_write(&buf) {
self.publish_error(e.kind());
}
}
Cmd::Seek(from) => {
if let Err(e) = self.do_seek(from) {
self.publish_error(e.kind());
}
}
Cmd::Flush => {
}
Cmd::SyncAll { done } => {
let r = self.do_sync_all();
let _ = done.send(r);
}
Cmd::Finish { done } => {
self.pipeline.finalize();
self.mark_writer_gone();
let _ = done.send(());
return;
}
}
}
}
fn dequeue(&mut self) -> Option<Cmd> {
let mut guard = self.shared.state.lock().unwrap();
loop {
if let Some(cmd) = guard.queue.pop_front() {
if let Cmd::Write(ref buf) = cmd {
guard.bytes_inflight = guard.bytes_inflight.saturating_sub(buf.len());
}
drop(guard);
self.shared.space_available.notify_all();
return Some(cmd);
}
guard = self.shared.work_available.wait(guard).unwrap();
}
}
fn do_write(&mut self, buf: &[u8]) -> io::Result<()> {
let start = self.pos;
self.file.write_all(buf)?;
self.pos += buf.len() as u64;
self.pipeline.note_progress(self.pos);
self.active.push(start, buf);
Ok(())
}
fn do_seek(&mut self, from: SeekFrom) -> io::Result<()> {
let target = match from {
SeekFrom::Start(n) => n,
SeekFrom::Current(_) | SeekFrom::End(_) => {
let p = self.file.seek(from)?;
self.pos = p;
self.active.reset(p);
self.pipeline.handle_seek(p);
return Ok(());
}
};
if target == self.pos {
return Ok(());
}
if self.active.contains(target) {
tracing::trace!(
target: "mux",
"WritebackFile in-window seek pos={} -> {} window=[{},{}]",
self.pos,
target,
self.active.lo,
self.active.hi(),
);
self.file.seek(SeekFrom::Start(target))?;
self.pos = target;
} else {
tracing::debug!(
target: "mux",
"WritebackFile out-of-window seek pos={} -> {} window=[{},{}]",
self.pos,
target,
self.active.lo,
self.active.hi(),
);
self.pipeline.handle_seek(target);
self.file.seek(SeekFrom::Start(target))?;
self.pos = target;
self.active.reset(target);
}
Ok(())
}
fn do_sync_all(&mut self) -> io::Result<()> {
self.pipeline.finalize();
platform::durable_sync(&self.file)
}
fn publish_error(&self, kind: io::ErrorKind) {
let mut guard = self.shared.state.lock().unwrap();
if guard.sticky_error.is_none() {
guard.sticky_error = Some(kind);
}
drop(guard);
self.shared.space_available.notify_all();
}
fn mark_writer_gone(&self) {
let mut guard = self.shared.state.lock().unwrap();
guard.writer_gone = true;
drop(guard);
self.shared.space_available.notify_all();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Read;
fn read_back(path: &Path) -> Vec<u8> {
let mut f = File::open(path).unwrap();
let mut v = Vec::new();
f.read_to_end(&mut v).unwrap();
v
}
#[test]
fn write_then_drop_persists_bytes() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("a.bin");
{
let mut w = WritebackFile::create(&p).unwrap();
w.write_all(b"hello world").unwrap();
}
assert_eq!(read_back(&p), b"hello world");
}
#[test]
fn sync_all_blocks_until_ring_drains() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("b.bin");
let mut w = WritebackFile::create(&p).unwrap();
for _ in 0..32 {
w.write_all(&[0x5au8; 1024]).unwrap();
}
w.sync_all().unwrap();
let bytes = read_back(&p);
assert_eq!(bytes.len(), 32 * 1024);
assert!(bytes.iter().all(|&b| b == 0x5a));
drop(w);
}
#[test]
fn in_window_seek_then_patch_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("c.bin");
let mut w = WritebackFile::create(&p).unwrap();
let big = vec![b'A'; 4096];
w.write_all(&big).unwrap();
w.seek(SeekFrom::Start(1000)).unwrap();
w.write_all(b"PATCHED!").unwrap();
w.sync_all().unwrap();
drop(w);
let bytes = read_back(&p);
assert_eq!(bytes.len(), 4096);
assert_eq!(&bytes[1000..1008], b"PATCHED!");
assert_eq!(bytes[999], b'A');
assert_eq!(bytes[1008], b'A');
}
#[test]
fn out_of_window_seek_then_patch_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("d.bin");
let mut w = WritebackFile::create(&p).unwrap();
let chunk = vec![b'A'; 1024 * 1024];
for _ in 0..33 {
w.write_all(&chunk).unwrap();
}
w.seek(SeekFrom::Start(100)).unwrap();
w.write_all(b"OUTSIDE!").unwrap();
w.sync_all().unwrap();
drop(w);
let bytes = read_back(&p);
assert_eq!(bytes.len(), 33 * 1024 * 1024);
assert_eq!(&bytes[100..108], b"OUTSIDE!");
assert_eq!(bytes[99], b'A');
assert_eq!(bytes[108], b'A');
}
#[test]
fn backpressure_blocks_when_ring_full() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("e.bin");
let mut w = WritebackFile::create(&p).unwrap();
let total = RING_CAPACITY_BYTES.saturating_mul(2) + (RING_CAPACITY_BYTES / 2);
let chunk = vec![0u8; 1024 * 1024];
let mut written = 0;
while written < total {
let take = (total - written).min(chunk.len());
w.write_all(&chunk[..take]).unwrap();
written += take;
}
w.sync_all().unwrap();
drop(w);
let meta = std::fs::metadata(&p).unwrap();
assert_eq!(meta.len() as usize, total);
}
#[test]
fn flush_is_observed_in_order() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("f.bin");
let mut w = WritebackFile::create(&p).unwrap();
w.write_all(b"one").unwrap();
w.flush().unwrap();
w.write_all(b"two").unwrap();
w.flush().unwrap();
w.write_all(b"three").unwrap();
w.sync_all().unwrap();
drop(w);
assert_eq!(read_back(&p), b"onetwothree");
}
#[test]
fn active_cluster_buffer_contiguous_append_and_trim() {
let mut b = ActiveClusterBuffer::new(8);
b.push(0, b"abcd");
assert_eq!(b.lo, 0);
assert_eq!(b.hi(), 4);
b.push(4, b"efgh");
assert_eq!(b.lo, 0);
assert_eq!(b.hi(), 8);
b.push(8, b"ij");
assert_eq!(b.lo, 2);
assert_eq!(b.hi(), 10);
assert!(b.contains(2));
assert!(b.contains(10));
assert!(!b.contains(1));
assert!(!b.contains(11));
}
#[test]
fn active_cluster_buffer_in_window_patch() {
let mut b = ActiveClusterBuffer::new(16);
b.push(100, b"AAAAAAAA");
assert!(b.contains(104));
b.push(102, b"BB");
let collected: Vec<u8> = b.data.iter().copied().collect();
assert_eq!(collected, b"AABBAAAA");
assert_eq!(b.lo, 100);
assert_eq!(b.hi(), 108);
}
#[test]
fn active_cluster_buffer_non_contiguous_reseats() {
let mut b = ActiveClusterBuffer::new(16);
b.push(0, b"abcd");
b.push(1000, b"XYZ");
assert_eq!(b.lo, 1000);
assert_eq!(b.hi(), 1003);
}
#[test]
fn writer_thread_panic_surfaces_on_drop() {
let dir = tempfile::tempdir().unwrap();
let p = dir.path().join("ro.bin");
std::fs::write(&p, b"seed").unwrap();
let f = OpenOptions::new().read(true).open(&p).unwrap();
let mut w = WritebackFile::new(f).unwrap();
let mut saw_error = false;
for _ in 0..32 {
match w.write_all(b"x") {
Ok(()) => {
std::thread::sleep(std::time::Duration::from_millis(10));
}
Err(_) => {
saw_error = true;
break;
}
}
}
assert!(
saw_error,
"expected the writer to publish an error on a read-only fd"
);
drop(w);
}
}