use std::{
collections::BTreeSet,
io::{
Seek,
SeekFrom,
Write,
},
path::Path,
};
use buffered_reader::{BufferedReader};
const VALUE_BYTES: usize = 32;
pub type Value = [u8; VALUE_BYTES];
type File = buffered_reader::File<'static, ()>;
pub struct Set {
header: Header,
store: File,
scratch: BTreeSet<Value>,
}
impl Set {
#[allow(dead_code)]
fn len(&self) -> usize {
usize::try_from(self.header.entries).expect("representable")
+ self.scratch.len()
}
pub fn contains(&mut self, value: &Value) -> Result<bool> {
Ok(self.stored_values()?.binary_search(value).is_ok()
|| self.scratch.contains(value))
}
pub fn insert(&mut self, value: Value) {
self.scratch.insert(value);
}
fn stored_values(&mut self) -> Result<&[Value]> {
let entries = self.header.entries as usize;
let bytes = self.store.data_hard(entries * VALUE_BYTES)?;
unsafe {
Ok(std::slice::from_raw_parts(bytes.as_ptr() as *const Value,
entries))
}
}
pub fn read<P: AsRef<Path>>(path: P, context: &str) -> Result<Self> {
assert_eq!(VALUE_BYTES, std::mem::size_of::<Value>());
assert_eq!(std::mem::size_of::<[Value; 2]>(),
2 * VALUE_BYTES,
"values are unpadded");
let context: [u8; CONTEXT_BYTES] = context.as_bytes()
.try_into()
.map_err(|_| Error::BadContext)?;
let (header, reader) = match File::open(path) {
Ok(mut f) => {
let header = Header::read(&mut f, context)?;
(header, f)
},
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
let t = tempfile::NamedTempFile::new()?;
let f = File::open(t.path())?;
(Header::new(context), f)
},
Err(e) => return Err(e.into()),
};
Ok(Set {
header,
store: reader,
scratch: Default::default(),
})
}
pub fn write<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
if self.scratch.is_empty() {
return Ok(());
}
let mut sink = tempfile::NamedTempFile::new_in(
path.as_ref().parent().ok_or(Error::BadPath)?)?;
let mut h = self.header.clone();
h.entries = 0; h.write(&mut sink)?;
let mut entries = 0;
let scratch = std::mem::replace(&mut self.scratch, Default::default());
let mut stored = self.stored_values()?;
for new in scratch.iter() {
let p = stored.partition_point(|v| v < new);
let before = &stored[..p];
let before_bytes = unsafe {
std::slice::from_raw_parts(before.as_ptr() as *const u8,
before.len() * VALUE_BYTES)
};
sink.write_all(before_bytes)?;
entries += p;
if before.is_empty() || &before[p - 1] != new {
sink.write_all(new)?;
entries += 1;
}
stored = &stored[p..];
}
{
let stored_bytes = unsafe {
std::slice::from_raw_parts(stored.as_ptr() as *const u8,
stored.len() * VALUE_BYTES)
};
sink.write_all(stored_bytes)?;
entries += stored.len();
}
self.scratch = scratch;
sink.as_file_mut().seek(SeekFrom::Start(0))?;
h.entries = entries.try_into().map_err(|_| Error::TooManyEntries)?;
h.write(&mut sink)?;
sink.flush()?;
sink.persist(path).map_err(|pe| pe.error)?;
Ok(())
}
}
const CONTEXT_BYTES: usize = 12;
#[derive(Debug, Clone)]
struct Header {
version: u8,
context: [u8; CONTEXT_BYTES],
entries: u32,
}
impl Header {
const MAGIC: &'static [u8; 15] = b"StoredSortedSet";
fn new(context: [u8; CONTEXT_BYTES]) -> Self {
Header {
version: 1,
context,
entries: 0,
}
}
fn read(reader: &mut File, context: [u8; CONTEXT_BYTES]) -> Result<Self> {
let m = reader.data_consume_hard(Self::MAGIC.len())?;
if &m[..Self::MAGIC.len()] != &Self::MAGIC[..] {
return Err(Error::BadMagic);
}
let v = reader.data_consume_hard(1)?;
let version = v[0];
if version != 1 {
return Err(Error::UnsupportedVersion(version));
}
let c = &reader.data_consume_hard(context.len())?[..context.len()];
if &c[..] != &context[..] {
return Err(Error::BadContext);
}
let e = &reader.data_consume_hard(4)?[..4];
let entries =
u32::from_be_bytes(e.try_into().expect("we read 4 bytes"));
Ok(Header {
version,
context,
entries,
})
}
fn write(&self, sink: &mut dyn Write) -> Result<()> {
sink.write_all(Self::MAGIC)?;
sink.write_all(&[self.version])?;
sink.write_all(&self.context)?;
sink.write_all(&self.entries.to_be_bytes())?;
Ok(())
}
}
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("Bad magic read from file")]
BadMagic,
#[error("Unsupported version: {0}")]
UnsupportedVersion(u8),
#[error("Bad context read from file")]
BadContext,
#[error("Too many entries")]
TooManyEntries,
#[error("Bad path")]
BadPath,
#[error("Io error")]
Io(#[from] std::io::Error),
}
pub type Result<T> = ::std::result::Result<T, Error>;