sequoia_git/
persistent_set.rs1use std::{
23 collections::BTreeSet,
24 io::{
25 Seek,
26 SeekFrom,
27 Write,
28 },
29 path::Path,
30};
31use buffered_reader::{BufferedReader};
32
33const VALUE_BYTES: usize = 32;
34pub type Value = [u8; VALUE_BYTES];
35
36#[cfg(unix)]
38type File = buffered_reader::File<'static, ()>;
39#[cfg(not(unix))]
40type File = buffered_reader::File<()>;
41
42pub struct Set {
43 header: Header,
44 store: File,
45 scratch: BTreeSet<Value>,
46}
47
48impl Set {
54 #[allow(dead_code)]
56 fn len(&self) -> usize {
57 usize::try_from(self.header.entries).expect("representable")
58 + self.scratch.len()
59 }
61
62 pub fn contains(&mut self, value: &Value) -> Result<bool> {
64 Ok(self.stored_values()?.binary_search(value).is_ok()
65 || self.scratch.contains(value))
66 }
67
68 pub fn insert(&mut self, value: Value) {
70 self.scratch.insert(value);
74 }
75
76 fn stored_values(&mut self) -> Result<&[Value]> {
77 let entries = self.header.entries as usize;
78 let bytes = self.store.data_hard(entries * VALUE_BYTES)?;
79 unsafe {
80 Ok(std::slice::from_raw_parts(bytes.as_ptr() as *const Value,
81 entries))
82 }
83 }
84
85 pub fn read<P: AsRef<Path>>(path: P, context: &str) -> Result<Self> {
86 assert_eq!(VALUE_BYTES, std::mem::size_of::<Value>());
90 assert_eq!(std::mem::size_of::<[Value; 2]>(),
91 2 * VALUE_BYTES,
92 "values are unpadded");
93
94 let context: [u8; CONTEXT_BYTES] = context.as_bytes()
95 .try_into()
96 .map_err(|_| Error::BadContext)?;
97
98 let (header, reader) = match File::open(path) {
99 Ok(mut f) => {
100 let header = Header::read(&mut f, context)?;
101 (header, f)
102 },
103 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
104 let t = tempfile::NamedTempFile::new()?;
105 let f = File::open(t.path())?;
108 (Header::new(context), f)
109 },
110 Err(e) => return Err(e.into()),
111 };
112
113 Ok(Set {
117 header,
118 store: reader,
119 scratch: Default::default(),
120 })
121 }
122
123 pub fn write<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
124 if self.scratch.is_empty() {
126 return Ok(());
127 }
128
129 let mut sink = tempfile::NamedTempFile::new_in(
130 path.as_ref().parent().ok_or(Error::BadPath)?)?;
131
132 let mut h = self.header.clone();
134 h.entries = 0; h.write(&mut sink)?;
136
137 let mut entries = 0;
139 let scratch = std::mem::replace(&mut self.scratch, Default::default());
140 let mut stored = self.stored_values()?;
141
142 for new in scratch.iter() {
143 let p = stored.partition_point(|v| v < new);
144
145 let before = &stored[..p];
146 let before_bytes = unsafe {
147 std::slice::from_raw_parts(before.as_ptr() as *const u8,
148 before.len() * VALUE_BYTES)
149 };
150 sink.write_all(before_bytes)?;
151 entries += p;
152
153 if before.is_empty() || &before[p - 1] != new {
155 sink.write_all(new)?;
156 entries += 1;
157 }
158
159 stored = &stored[p..];
161 }
162
163 {
165 let stored_bytes = unsafe {
166 std::slice::from_raw_parts(stored.as_ptr() as *const u8,
167 stored.len() * VALUE_BYTES)
168 };
169 sink.write_all(stored_bytes)?;
170 entries += stored.len();
171 }
172
173 self.scratch = scratch;
175
176 sink.as_file_mut().seek(SeekFrom::Start(0))?;
179 h.entries = entries.try_into().map_err(|_| Error::TooManyEntries)?;
180 h.write(&mut sink)?;
181 sink.flush()?;
182
183 sink.persist(path).map_err(|pe| pe.error)?;
184 Ok(())
185 }
186}
187
188const CONTEXT_BYTES: usize = 12;
189
190#[derive(Debug, Clone)]
191struct Header {
192 version: u8,
193 context: [u8; CONTEXT_BYTES],
194 entries: u32,
195}
196
197impl Header {
198 const MAGIC: &'static [u8; 15] = b"StoredSortedSet";
199
200 fn new(context: [u8; CONTEXT_BYTES]) -> Self {
201 Header {
202 version: 1,
203 context,
204 entries: 0,
205 }
206 }
207
208 fn read(reader: &mut File, context: [u8; CONTEXT_BYTES]) -> Result<Self> {
209 let m = reader.data_consume_hard(Self::MAGIC.len())?;
210 if &m[..Self::MAGIC.len()] != &Self::MAGIC[..] {
211 return Err(Error::BadMagic);
212 }
213 let v = reader.data_consume_hard(1)?;
214 let version = v[0];
215 if version != 1 {
216 return Err(Error::UnsupportedVersion(version));
217 }
218
219 let c = &reader.data_consume_hard(context.len())?[..context.len()];
220 if &c[..] != &context[..] {
221 return Err(Error::BadContext);
222 }
223
224 let e = &reader.data_consume_hard(4)?[..4];
225 let entries =
226 u32::from_be_bytes(e.try_into().expect("we read 4 bytes"));
227
228 Ok(Header {
229 version,
230 context,
231 entries,
232 })
233 }
234
235 fn write(&self, sink: &mut dyn Write) -> Result<()> {
236 sink.write_all(Self::MAGIC)?;
237 sink.write_all(&[self.version])?;
238 sink.write_all(&self.context)?;
239 sink.write_all(&self.entries.to_be_bytes())?;
240 Ok(())
241 }
242}
243
244#[derive(thiserror::Error, Debug)]
246pub enum Error {
247 #[error("Bad magic read from file")]
248 BadMagic,
249 #[error("Unsupported version: {0}")]
250 UnsupportedVersion(u8),
251 #[error("Bad context read from file")]
252 BadContext,
253 #[error("Too many entries")]
254 TooManyEntries,
255 #[error("Bad path")]
256 BadPath,
257 #[error("Io error")]
258 Io(#[from] std::io::Error),
259}
260
261pub type Result<T> = ::std::result::Result<T, Error>;