1use std::borrow::BorrowMut;
14use std::cell::RefCell;
15use std::fs;
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19use byteorder::{BigEndian, ByteOrder};
20use ents::{
21 DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
22 Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
23};
24use heed::types::{Bytes, Str};
25use heed::{Database, Env, EnvOpenOptions, RwTxn};
26use snowflaked::Generator;
27
28const MAX_EDGES: usize = 100;
30
31#[derive(Clone)]
33pub struct HeedEnv {
34 env: Env,
35 entities: Database<heed::types::U64<BigEndian>, Str>,
36 edges: Database<Bytes, Bytes>,
37 id_generator: Arc<Mutex<Generator>>,
38}
39
40impl HeedEnv {
41 pub fn open<P: AsRef<Path>>(
47 path: P,
48 map_size: Option<usize>,
49 ) -> Result<Self, DatabaseError> {
50 let path = path.as_ref();
51 fs::create_dir_all(path).map_err(|e| DatabaseError::Other {
52 source: Box::new(e),
53 })?;
54
55 let env = unsafe {
56 EnvOpenOptions::new()
57 .map_size(map_size.unwrap_or(1024 * 1024 * 1024)) .max_dbs(2)
59 .open(path)
60 }
61 .map_err(|e| DatabaseError::Other {
62 source: Box::new(e),
63 })?;
64
65 let mut wtxn = env.write_txn().map_err(|e| DatabaseError::Other {
67 source: Box::new(e),
68 })?;
69
70 let entities: Database<heed::types::U64<BigEndian>, Str> = env
71 .create_database(&mut wtxn, Some("entities"))
72 .map_err(|e| DatabaseError::Other {
73 source: Box::new(e),
74 })?;
75
76 let edges: Database<Bytes, Bytes> = env
77 .create_database(&mut wtxn, Some("edges"))
78 .map_err(|e| DatabaseError::Other {
79 source: Box::new(e),
80 })?;
81
82 wtxn.commit().map_err(|e| DatabaseError::Other {
83 source: Box::new(e),
84 })?;
85
86 let id_generator = Generator::new(0);
89
90 Ok(Self {
91 env,
92 entities,
93 edges,
94 id_generator: Arc::new(Mutex::new(id_generator)),
95 })
96 }
97
98 pub fn write_txn(&self) -> Result<Txn<'_>, DatabaseError> {
100 let txn = self.env.write_txn().map_err(|e| DatabaseError::Other {
101 source: Box::new(e),
102 })?;
103 Ok(Txn {
104 txn: RefCell::new(txn),
105 env: self,
106 })
107 }
108
109 fn next_id(&self) -> Result<Id, DatabaseError> {
111 let mut generator =
112 self.id_generator.lock().map_err(|e| DatabaseError::Other {
113 source: Box::new(std::io::Error::other(format!(
114 "Failed to lock ID generator: {}",
115 e
116 ))),
117 })?;
118 Ok(generator.generate())
119 }
120}
121
122pub struct Txn<'env> {
127 txn: RefCell<RwTxn<'env>>,
128 env: &'env HeedEnv,
129}
130
131impl<'env> Txn<'env> {
132 fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
134 let id = self.env.next_id()?;
135 let mut wtxn = self.txn.borrow_mut();
136
137 let data_json =
138 serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
139 DatabaseError::Other {
140 source: Box::new(e),
141 }
142 })?;
143
144 self.env
145 .entities
146 .put(&mut wtxn, &id, &data_json)
147 .map_err(|e| DatabaseError::Other {
148 source: Box::new(e),
149 })?;
150
151 Ok(id)
152 }
153
154 fn update_internal(
156 &self,
157 id: Id,
158 ent: Box<dyn Ent>,
159 expected_last_updated: Option<u64>,
160 ) -> Result<bool, DatabaseError> {
161 if let Some(expected) = expected_last_updated {
163 if let Some(current) = self.get(id)? {
164 if current.last_updated() != expected {
165 return Ok(false);
166 }
167 } else {
168 return Ok(false);
169 }
170 }
171
172 let data_json =
173 serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
174 source: Box::new(e),
175 })?;
176
177 self.env
178 .entities
179 .put(&mut self.txn.borrow_mut(), &id, &data_json)
180 .map_err(|e| DatabaseError::Other {
181 source: Box::new(e),
182 })?;
183
184 Ok(true)
185 }
186
187 fn delete_edge(
188 &self,
189 source: Id,
190 sort_key: &[u8],
191 dest: Id,
192 ) -> Result<(), DatabaseError> {
193 let key = make_edge_key(source, sort_key, dest);
194 self.env
195 .edges
196 .delete(&mut self.txn.borrow_mut(), &key)
197 .map_err(|e| DatabaseError::Other {
198 source: Box::new(e),
199 })?;
200 Ok(())
201 }
202}
203
204impl<'env> ReadEnt for Txn<'env> {
205 fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
206 let txn = self.txn.borrow();
207 match self.env.entities.get(&txn, &id).map_err(|e| {
208 DatabaseError::Other {
209 source: Box::new(e),
210 }
211 })? {
212 Some(data_json) => {
213 let mut ent = serde_json::from_str::<Box<dyn Ent>>(data_json)
214 .map_err(|e| DatabaseError::Other {
215 source: Box::new(e),
216 })?;
217 ent.set_id(id);
218 Ok(Some(ent))
219 }
220 None => Ok(None),
221 }
222 }
223}
224
225impl<'env> Transactional for Txn<'env> {
226 fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
227 let id = self.insert(&ent)?;
228 ent.set_id(id);
229 ent.setup_edges(self).map_err(|e| DatabaseError::Other {
230 source: Box::new(e),
231 })?;
232 Ok(id)
233 }
234
235 fn delete(&self, id: Id) -> Result<(), DatabaseError> {
236 let to_delete: Vec<Vec<u8>> = {
239 let txn = self.txn.borrow();
240 let iter = self.env.edges.iter(&txn).map_err(|e| {
241 DatabaseError::Other {
242 source: Box::new(e),
243 }
244 })?;
245
246 let mut keys = Vec::new();
247 for result in iter {
248 let (key, _) = result.map_err(|e| DatabaseError::Other {
249 source: Box::new(e),
250 })?;
251 let (_, _, dest) = parse_edge_key(key);
252 if dest == id {
253 keys.push(key.to_vec());
254 }
255 }
256 keys
257 };
258
259 for key in to_delete {
260 self.env
261 .edges
262 .delete(&mut self.txn.borrow_mut(), &key)
263 .map_err(|e| DatabaseError::Other {
264 source: Box::new(e),
265 })?;
266 }
267
268 self.env
270 .entities
271 .delete(&mut self.txn.borrow_mut(), &id)
272 .map_err(|e| DatabaseError::Other {
273 source: Box::new(e),
274 })?;
275
276 Ok(())
277 }
278
279 fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
280 let key = make_edge_key(edge.source, &edge.sort_key, edge.dest);
281 self.env
282 .edges
283 .put(&mut self.txn.borrow_mut(), &key, &[])
284 .map_err(|e| DatabaseError::Other {
285 source: Box::new(e),
286 })?;
287 Ok(())
288 }
289
290 fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
291 &self,
292 mut ent0: B,
293 mutator: F,
294 ) -> Result<bool, DatabaseError> {
295 let ent = ent0.borrow_mut();
296 let draft0 = T::EdgeProvider::draft(ent);
297 let ent_id = ent.id();
298 let expected_last_updated = ent.last_updated();
299
300 mutator(ent);
301 ent.mark_updated().map_err(|e| DatabaseError::Other {
302 source: Box::new(e),
303 })?;
304
305 let draft1 = T::EdgeProvider::draft(ent);
306
307 if draft0 == draft1 {
309 return self.update_internal(
310 ent.id(),
311 dyn_clone::clone_box(ent),
312 Some(expected_last_updated),
313 );
314 }
315
316 let edge0 = draft0
317 .check(self)
318 .map(|edges| {
319 edges
320 .into_iter()
321 .map(|edge| edge.with_dest(ent_id))
322 .collect::<Vec<_>>()
323 })
324 .map_err(|e| DatabaseError::Other {
325 source: Box::new(e),
326 })?;
327 let edge1 = draft1
328 .check(self)
329 .map(|edges| {
330 edges
331 .into_iter()
332 .map(|edge| edge.with_dest(ent_id))
333 .collect::<Vec<_>>()
334 })
335 .map_err(|e| DatabaseError::Other {
336 source: Box::new(e),
337 })?;
338
339 let updated = self.update_internal(
340 ent.id(),
341 dyn_clone::clone_box(ent),
342 Some(expected_last_updated),
343 )?;
344
345 if updated {
346 for edge in edge0 {
348 self.delete_edge(edge.source, &edge.sort_key, edge.dest)?;
349 }
350
351 for edge in edge1 {
353 self.create_edge(edge)?;
354 }
355 }
356
357 Ok(updated)
358 }
359
360 fn commit(self) -> Result<(), DatabaseError> {
361 self.txn
362 .into_inner()
363 .commit()
364 .map_err(|e| DatabaseError::Other {
365 source: Box::new(e),
366 })
367 }
368}
369
370impl<'env> QueryEdge for Txn<'env> {
371 fn find_edges(
372 &self,
373 source: Id,
374 query: EdgeQuery,
375 ) -> Result<EdgeQueryResult, DatabaseError> {
376 let txn = self.txn.borrow();
377 {
378 let txn: &heed::RoTxn<'_> = &txn;
379 let edges_db: &Database<Bytes, Bytes> = &self.env.edges;
380 let mut results = Vec::new();
381
382 let mut prefix = [0u8; 8];
384 BigEndian::write_u64(&mut prefix, source);
385
386 let iter = edges_db.prefix_iter(txn, &prefix).map_err(|e| {
388 DatabaseError::Other {
389 source: Box::new(e),
390 }
391 })?;
392
393 let mut all_edges: Vec<Edge> = Vec::new();
395
396 for result in iter {
397 let (key, _) = result.map_err(|e| DatabaseError::Other {
398 source: Box::new(e),
399 })?;
400
401 let (src, sort_key, dest) = parse_edge_key(key);
402 if src != source {
403 break; }
405
406 if !query.edge_names.is_empty()
408 && !query.edge_names.contains(&sort_key)
409 {
410 continue;
411 }
412
413 all_edges.push(Edge::new(src, sort_key.to_vec(), dest));
414 }
415
416 match query.order {
418 SortOrder::Asc => {
419 all_edges.sort_by(|a, b| {
420 (&a.sort_key, a.dest).cmp(&(&b.sort_key, b.dest))
421 });
422 }
423 SortOrder::Desc => {
424 all_edges.sort_by(|a, b| {
425 (&b.sort_key, b.dest).cmp(&(&a.sort_key, a.dest))
426 });
427 }
428 }
429
430 for edge in all_edges {
432 if let Some(ref cursor) = query.cursor {
433 let edge_key = (edge.sort_key.as_slice(), edge.dest);
434 let cursor_key = (cursor.sort_key, cursor.destination);
435
436 match query.order {
437 SortOrder::Asc => {
438 if edge_key <= cursor_key {
439 continue;
440 }
441 }
442 SortOrder::Desc => {
443 if edge_key >= cursor_key {
444 continue;
445 }
446 }
447 }
448 }
449
450 results.push(edge);
451
452 if results.len() > MAX_EDGES {
453 break;
454 }
455 }
456
457 let has_more = results.len() > MAX_EDGES;
458 if has_more {
459 results.truncate(MAX_EDGES);
460 }
461
462 Ok(EdgeQueryResult {
463 edges: results,
464 has_more,
465 })
466 }
467 }
468}
469
470fn make_edge_key(source: Id, sort_key: &[u8], dest: Id) -> Vec<u8> {
472 let mut key = Vec::with_capacity(8 + sort_key.len() + 8);
473 let mut buf = [0u8; 8];
474
475 BigEndian::write_u64(&mut buf, source);
476 key.extend_from_slice(&buf);
477
478 key.extend_from_slice(sort_key);
479
480 BigEndian::write_u64(&mut buf, dest);
481 key.extend_from_slice(&buf);
482
483 key
484}
485
486fn parse_edge_key(key: &[u8]) -> (Id, &[u8], Id) {
488 let source = BigEndian::read_u64(&key[0..8]);
489 let dest = BigEndian::read_u64(&key[key.len() - 8..]);
490 let sort_key = &key[8..key.len() - 8];
491 (source, sort_key, dest)
492}
493
494impl ents::TransactionProvider for HeedEnv {
495 type Tx<'a> = Txn<'a>;
496
497 fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
498 where
499 F: for<'b> FnOnce(Self::Tx<'b>) -> R,
500 {
501 Ok(func(self.write_txn()?))
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_edge_key_roundtrip() {
511 let source = 12345u64;
512 let sort_key = b"test_edge";
513 let dest = 67890u64;
514
515 let key = make_edge_key(source, sort_key, dest);
516 let (parsed_source, parsed_sort_key, parsed_dest) =
517 parse_edge_key(&key);
518
519 assert_eq!(parsed_source, source);
520 assert_eq!(parsed_sort_key, sort_key);
521 assert_eq!(parsed_dest, dest);
522 }
523
524 #[test]
525 fn test_edge_key_ordering() {
526 let key1 = make_edge_key(1, b"a", 10);
528 let key2 = make_edge_key(1, b"a", 20);
529 let key3 = make_edge_key(1, b"b", 10);
530 let key4 = make_edge_key(2, b"a", 10);
531
532 assert!(key1 < key2); assert!(key2 < key3); assert!(key3 < key4); }
536}