use std::cmp::min;
use std::fs::File;
use std::io::{self, Seek, SeekFrom, Write};
use fallocate;
use FallocateMode;
pub trait PunchHole {
fn punch_hole(&mut self, offset: u64, length: u64) -> io::Result<()>;
}
impl PunchHole for File {
fn punch_hole(&mut self, offset: u64, length: u64) -> io::Result<()> {
fallocate(self, FallocateMode::PunchHole, true, offset, length as u64)
.map_err(|e| io::Error::from_raw_os_error(e.errno()))
}
}
pub trait WriteZeroes {
fn write_zeroes(&mut self, length: usize) -> io::Result<usize>;
}
impl<T: PunchHole + Seek + Write> WriteZeroes for T {
fn write_zeroes(&mut self, length: usize) -> io::Result<usize> {
let offset = self.seek(SeekFrom::Current(0))?;
match self.punch_hole(offset, length as u64) {
Ok(()) => {
self.seek(SeekFrom::Current(length as i64))?;
return Ok(length);
}
Err(_) => {} }
let buf_size = min(length, 0x10000);
let buf = vec![0u8; buf_size];
let mut nwritten: usize = 0;
while nwritten < length {
let remaining = length - nwritten;
let write_size = min(remaining, buf_size);
nwritten += self.write(&buf[0..write_size])?;
}
Ok(length)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs::OpenOptions;
use std::io::{Read, Seek, SeekFrom};
use std::path::PathBuf;
use TempDir;
#[test]
fn simple_test() {
let tempdir = TempDir::new("/tmp/write_zeroes_test").unwrap();
let mut path = PathBuf::from(tempdir.as_path().unwrap());
path.push("file");
let mut f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&path)
.unwrap();
f.set_len(16384).unwrap();
let orig_data = [0x55u8; 5678];
f.seek(SeekFrom::Start(1234)).unwrap();
f.write(&orig_data).unwrap();
let mut readback = [0u8; 16384];
f.seek(SeekFrom::Start(0)).unwrap();
f.read(&mut readback).unwrap();
for read in readback[0..1234].iter() {
assert_eq!(*read, 0);
}
for read in readback[1234..(1234 + 5678)].iter() {
assert_eq!(*read, 0x55);
}
for read in readback[(1234 + 5678)..].iter() {
assert_eq!(*read, 0);
}
f.seek(SeekFrom::Start(2345)).unwrap();
f.write_zeroes(4321).expect("write_zeroes failed");
assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 2345 + 4321);
f.seek(SeekFrom::Start(0)).unwrap();
f.read(&mut readback).unwrap();
for read in readback[0..1234].iter() {
assert_eq!(*read, 0);
}
for read in readback[1234..2345].iter() {
assert_eq!(*read, 0x55);
}
for read in readback[2345..(2345 + 4321)].iter() {
assert_eq!(*read, 0);
}
for read in readback[(2345 + 4321)..(1234 + 5678)].iter() {
assert_eq!(*read, 0x55);
}
for read in readback[(1234 + 5678)..].iter() {
assert_eq!(*read, 0);
}
}
#[test]
fn large_write_zeroes() {
let tempdir = TempDir::new("/tmp/write_zeroes_test").unwrap();
let mut path = PathBuf::from(tempdir.as_path().unwrap());
path.push("file");
let mut f = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(&path)
.unwrap();
f.set_len(16384).unwrap();
let orig_data = [0x55u8; 0x20000];
f.seek(SeekFrom::Start(0)).unwrap();
f.write(&orig_data).unwrap();
f.seek(SeekFrom::Start(0)).unwrap();
f.write_zeroes(0x10001).expect("write_zeroes failed");
assert_eq!(f.seek(SeekFrom::Current(0)).unwrap(), 0x10001);
let mut readback = [0u8; 0x20000];
f.seek(SeekFrom::Start(0)).unwrap();
f.read(&mut readback).unwrap();
for read in readback[0..0x10001].iter() {
assert_eq!(*read, 0);
}
for read in readback[0x10001..0x20000].iter() {
assert_eq!(*read, 0x55);
}
}
}