Skip to main content

ents_heed/
lib.rs

1//! LMDB-based entity storage implementation using the heed crate.
2//!
3//! This module provides an LMDB (via heed) implementation of the entity storage traits,
4//! mirroring the functionality of ents-sqlite but using LMDB as the underlying store.
5//!
6//! # Storage Layout
7//!
8//! The implementation uses three LMDB databases:
9//! - `entities`: Maps entity IDs to serialized entity JSON
10//! - `edges`: Maps composite keys (source, sort_key, dest) to empty values
11//! - `meta`: Stores metadata like the next entity ID
12
13use std::borrow::BorrowMut;
14use std::cell::RefCell;
15use std::fs;
16use std::path::Path;
17use std::sync::{Arc, Mutex};
18
19use byteorder::{BigEndian, ByteOrder};
20use ents::{
21    DatabaseError, Edge, EdgeDraft, EdgeQuery, EdgeQueryResult, EdgeValue, Ent,
22    Id, IncomingEdgeProvider, QueryEdge, ReadEnt, SortOrder, Transactional,
23};
24use heed::types::{Bytes, Str};
25use heed::{Database, Env, EnvOpenOptions, RwTxn};
26use snowflaked::Generator;
27
28/// Maximum number of edges returned by find_edges
29const MAX_EDGES: usize = 100;
30
31/// LMDB environment wrapper that manages the databases.
32#[derive(Clone)]
33pub struct HeedEnv {
34    env: Env,
35    entities: Database<heed::types::U64<BigEndian>, Str>,
36    edges: Database<Bytes, Bytes>,
37    id_generator: Arc<Mutex<Generator>>,
38}
39
40impl HeedEnv {
41    /// Opens or creates an LMDB environment at the given path.
42    ///
43    /// # Arguments
44    /// * `path` - Directory path for the LMDB environment
45    /// * `map_size` - Maximum size of the database in bytes (default: 1GB)
46    pub fn open<P: AsRef<Path>>(
47        path: P,
48        map_size: Option<usize>,
49    ) -> Result<Self, DatabaseError> {
50        let path = path.as_ref();
51        fs::create_dir_all(path).map_err(|e| DatabaseError::Other {
52            source: Box::new(e),
53        })?;
54
55        let env = unsafe {
56            EnvOpenOptions::new()
57                .map_size(map_size.unwrap_or(1024 * 1024 * 1024)) // 1GB default
58                .max_dbs(2)
59                .open(path)
60        }
61        .map_err(|e| DatabaseError::Other {
62            source: Box::new(e),
63        })?;
64
65        // Create or open the databases
66        let mut wtxn = env.write_txn().map_err(|e| DatabaseError::Other {
67            source: Box::new(e),
68        })?;
69
70        let entities: Database<heed::types::U64<BigEndian>, Str> = env
71            .create_database(&mut wtxn, Some("entities"))
72            .map_err(|e| DatabaseError::Other {
73                source: Box::new(e),
74            })?;
75
76        let edges: Database<Bytes, Bytes> = env
77            .create_database(&mut wtxn, Some("edges"))
78            .map_err(|e| DatabaseError::Other {
79                source: Box::new(e),
80            })?;
81
82        wtxn.commit().map_err(|e| DatabaseError::Other {
83            source: Box::new(e),
84        })?;
85
86        // Initialize snowflake ID generator
87        // Using node_id 0, can be configured if needed for distributed systems
88        let id_generator = Generator::new(0);
89
90        Ok(Self {
91            env,
92            entities,
93            edges,
94            id_generator: Arc::new(Mutex::new(id_generator)),
95        })
96    }
97
98    /// Begins a read-write transaction.
99    pub fn write_txn(&self) -> Result<Txn<'_>, DatabaseError> {
100        let txn = self.env.write_txn().map_err(|e| DatabaseError::Other {
101            source: Box::new(e),
102        })?;
103        Ok(Txn {
104            txn: RefCell::new(txn),
105            env: self,
106        })
107    }
108
109    /// Allocates the next entity ID using snowflake algorithm.
110    fn next_id(&self) -> Result<Id, DatabaseError> {
111        let mut generator =
112            self.id_generator.lock().map_err(|e| DatabaseError::Other {
113                source: Box::new(std::io::Error::other(format!(
114                    "Failed to lock ID generator: {}",
115                    e
116                ))),
117            })?;
118        Ok(generator.generate())
119    }
120}
121
122/// A read-write transaction wrapper.
123///
124/// Uses interior mutability via RefCell to satisfy the Transactional trait's
125/// requirement for &self methods while still allowing mutation.
126pub struct Txn<'env> {
127    txn: RefCell<RwTxn<'env>>,
128    env: &'env HeedEnv,
129}
130
131impl<'env> Txn<'env> {
132    /// Inserts an entity and returns its assigned ID.
133    fn insert<E: Ent>(&self, ent: &E) -> Result<Id, DatabaseError> {
134        let id = self.env.next_id()?;
135        let mut wtxn = self.txn.borrow_mut();
136
137        let data_json =
138            serde_json::to_string(&(ent as &dyn Ent)).map_err(|e| {
139                DatabaseError::Other {
140                    source: Box::new(e),
141                }
142            })?;
143
144        self.env
145            .entities
146            .put(&mut wtxn, &id, &data_json)
147            .map_err(|e| DatabaseError::Other {
148                source: Box::new(e),
149            })?;
150
151        Ok(id)
152    }
153
154    /// Internal update that writes entity with optional CAS check.
155    fn update_internal(
156        &self,
157        id: Id,
158        ent: Box<dyn Ent>,
159        expected_last_updated: Option<u64>,
160    ) -> Result<bool, DatabaseError> {
161        // If CAS check is needed, verify current last_updated
162        if let Some(expected) = expected_last_updated {
163            if let Some(current) = self.get(id)? {
164                if current.last_updated() != expected {
165                    return Ok(false);
166                }
167            } else {
168                return Ok(false);
169            }
170        }
171
172        let data_json =
173            serde_json::to_string(&ent).map_err(|e| DatabaseError::Other {
174                source: Box::new(e),
175            })?;
176
177        self.env
178            .entities
179            .put(&mut self.txn.borrow_mut(), &id, &data_json)
180            .map_err(|e| DatabaseError::Other {
181                source: Box::new(e),
182            })?;
183
184        Ok(true)
185    }
186
187    fn delete_edge(
188        &self,
189        source: Id,
190        sort_key: &[u8],
191        dest: Id,
192    ) -> Result<(), DatabaseError> {
193        let key = make_edge_key(source, sort_key, dest);
194        self.env
195            .edges
196            .delete(&mut self.txn.borrow_mut(), &key)
197            .map_err(|e| DatabaseError::Other {
198                source: Box::new(e),
199            })?;
200        Ok(())
201    }
202}
203
204impl<'env> ReadEnt for Txn<'env> {
205    fn get(&self, id: Id) -> Result<Option<Box<dyn Ent>>, DatabaseError> {
206        let txn = self.txn.borrow();
207        match self.env.entities.get(&txn, &id).map_err(|e| {
208            DatabaseError::Other {
209                source: Box::new(e),
210            }
211        })? {
212            Some(data_json) => {
213                let mut ent = serde_json::from_str::<Box<dyn Ent>>(data_json)
214                    .map_err(|e| DatabaseError::Other {
215                    source: Box::new(e),
216                })?;
217                ent.set_id(id);
218                Ok(Some(ent))
219            }
220            None => Ok(None),
221        }
222    }
223}
224
225impl<'env> Transactional for Txn<'env> {
226    fn create<E: Ent>(&self, mut ent: E) -> Result<Id, DatabaseError> {
227        let id = self.insert(&ent)?;
228        ent.set_id(id);
229        ent.setup_edges(self).map_err(|e| DatabaseError::Other {
230            source: Box::new(e),
231        })?;
232        Ok(id)
233    }
234
235    fn delete(&self, id: Id) -> Result<(), DatabaseError> {
236        // Delete edges where this entity is the destination
237        // We need to scan all edges and delete matching ones
238        let to_delete: Vec<Vec<u8>> = {
239            let txn = self.txn.borrow();
240            let iter = self.env.edges.iter(&txn).map_err(|e| {
241                DatabaseError::Other {
242                    source: Box::new(e),
243                }
244            })?;
245
246            let mut keys = Vec::new();
247            for result in iter {
248                let (key, _) = result.map_err(|e| DatabaseError::Other {
249                    source: Box::new(e),
250                })?;
251                let (_, _, dest) = parse_edge_key(key);
252                if dest == id {
253                    keys.push(key.to_vec());
254                }
255            }
256            keys
257        };
258
259        for key in to_delete {
260            self.env
261                .edges
262                .delete(&mut self.txn.borrow_mut(), &key)
263                .map_err(|e| DatabaseError::Other {
264                    source: Box::new(e),
265                })?;
266        }
267
268        // Delete the entity
269        self.env
270            .entities
271            .delete(&mut self.txn.borrow_mut(), &id)
272            .map_err(|e| DatabaseError::Other {
273                source: Box::new(e),
274            })?;
275
276        Ok(())
277    }
278
279    fn create_edge(&self, edge: EdgeValue) -> Result<(), DatabaseError> {
280        let key = make_edge_key(edge.source, &edge.sort_key, edge.dest);
281        self.env
282            .edges
283            .put(&mut self.txn.borrow_mut(), &key, &[])
284            .map_err(|e| DatabaseError::Other {
285                source: Box::new(e),
286            })?;
287        Ok(())
288    }
289
290    fn update<T: Ent, F: FnOnce(&mut T), B: BorrowMut<T>>(
291        &self,
292        mut ent0: B,
293        mutator: F,
294    ) -> Result<bool, DatabaseError> {
295        let ent = ent0.borrow_mut();
296        let draft0 = T::EdgeProvider::draft(ent);
297        let ent_id = ent.id();
298        let expected_last_updated = ent.last_updated();
299
300        mutator(ent);
301        ent.mark_updated().map_err(|e| DatabaseError::Other {
302            source: Box::new(e),
303        })?;
304
305        let draft1 = T::EdgeProvider::draft(ent);
306
307        // Optimization: if drafts are equal, no edge changes needed
308        if draft0 == draft1 {
309            return self.update_internal(
310                ent.id(),
311                dyn_clone::clone_box(ent),
312                Some(expected_last_updated),
313            );
314        }
315
316        let edge0 = draft0
317            .check(self)
318            .map(|edges| {
319                edges
320                    .into_iter()
321                    .map(|edge| edge.with_dest(ent_id))
322                    .collect::<Vec<_>>()
323            })
324            .map_err(|e| DatabaseError::Other {
325                source: Box::new(e),
326            })?;
327        let edge1 = draft1
328            .check(self)
329            .map(|edges| {
330                edges
331                    .into_iter()
332                    .map(|edge| edge.with_dest(ent_id))
333                    .collect::<Vec<_>>()
334            })
335            .map_err(|e| DatabaseError::Other {
336                source: Box::new(e),
337            })?;
338
339        let updated = self.update_internal(
340            ent.id(),
341            dyn_clone::clone_box(ent),
342            Some(expected_last_updated),
343        )?;
344
345        if updated {
346            // Remove old edges if they existed
347            for edge in edge0 {
348                self.delete_edge(edge.source, &edge.sort_key, edge.dest)?;
349            }
350
351            // Create new edges if they exist
352            for edge in edge1 {
353                self.create_edge(edge)?;
354            }
355        }
356
357        Ok(updated)
358    }
359
360    fn commit(self) -> Result<(), DatabaseError> {
361        self.txn
362            .into_inner()
363            .commit()
364            .map_err(|e| DatabaseError::Other {
365                source: Box::new(e),
366            })
367    }
368}
369
370impl<'env> QueryEdge for Txn<'env> {
371    fn find_edges(
372        &self,
373        source: Id,
374        query: EdgeQuery,
375    ) -> Result<EdgeQueryResult, DatabaseError> {
376        let txn = self.txn.borrow();
377        {
378            let txn: &heed::RoTxn<'_> = &txn;
379            let edges_db: &Database<Bytes, Bytes> = &self.env.edges;
380            let mut results = Vec::new();
381
382            // Create the prefix for this source
383            let mut prefix = [0u8; 8];
384            BigEndian::write_u64(&mut prefix, source);
385
386            // Get iterator
387            let iter = edges_db.prefix_iter(txn, &prefix).map_err(|e| {
388                DatabaseError::Other {
389                    source: Box::new(e),
390                }
391            })?;
392
393            // Collect all matching edges
394            let mut all_edges: Vec<Edge> = Vec::new();
395
396            for result in iter {
397                let (key, _) = result.map_err(|e| DatabaseError::Other {
398                    source: Box::new(e),
399                })?;
400
401                let (src, sort_key, dest) = parse_edge_key(key);
402                if src != source {
403                    break; // Past our prefix
404                }
405
406                // Apply edge name filter if specified
407                if !query.edge_names.is_empty()
408                    && !query.edge_names.contains(&sort_key)
409                {
410                    continue;
411                }
412
413                all_edges.push(Edge::new(src, sort_key.to_vec(), dest));
414            }
415
416            // Sort based on order
417            match query.order {
418                SortOrder::Asc => {
419                    all_edges.sort_by(|a, b| {
420                        (&a.sort_key, a.dest).cmp(&(&b.sort_key, b.dest))
421                    });
422                }
423                SortOrder::Desc => {
424                    all_edges.sort_by(|a, b| {
425                        (&b.sort_key, b.dest).cmp(&(&a.sort_key, a.dest))
426                    });
427                }
428            }
429
430            // Apply cursor filter, collecting one extra to detect has_more
431            for edge in all_edges {
432                if let Some(ref cursor) = query.cursor {
433                    let edge_key = (edge.sort_key.as_slice(), edge.dest);
434                    let cursor_key = (cursor.sort_key, cursor.destination);
435
436                    match query.order {
437                        SortOrder::Asc => {
438                            if edge_key <= cursor_key {
439                                continue;
440                            }
441                        }
442                        SortOrder::Desc => {
443                            if edge_key >= cursor_key {
444                                continue;
445                            }
446                        }
447                    }
448                }
449
450                results.push(edge);
451
452                if results.len() > MAX_EDGES {
453                    break;
454                }
455            }
456
457            let has_more = results.len() > MAX_EDGES;
458            if has_more {
459                results.truncate(MAX_EDGES);
460            }
461
462            Ok(EdgeQueryResult {
463                edges: results,
464                has_more,
465            })
466        }
467    }
468}
469
470/// Creates a composite key for an edge: source (8 bytes) + sort_key + dest (8 bytes)
471fn make_edge_key(source: Id, sort_key: &[u8], dest: Id) -> Vec<u8> {
472    let mut key = Vec::with_capacity(8 + sort_key.len() + 8);
473    let mut buf = [0u8; 8];
474
475    BigEndian::write_u64(&mut buf, source);
476    key.extend_from_slice(&buf);
477
478    key.extend_from_slice(sort_key);
479
480    BigEndian::write_u64(&mut buf, dest);
481    key.extend_from_slice(&buf);
482
483    key
484}
485
486/// Parses a composite edge key into (source, sort_key, dest)
487fn parse_edge_key(key: &[u8]) -> (Id, &[u8], Id) {
488    let source = BigEndian::read_u64(&key[0..8]);
489    let dest = BigEndian::read_u64(&key[key.len() - 8..]);
490    let sort_key = &key[8..key.len() - 8];
491    (source, sort_key, dest)
492}
493
494impl ents::TransactionProvider for HeedEnv {
495    type Tx<'a> = Txn<'a>;
496
497    fn execute<R, F>(&self, func: F) -> Result<R, DatabaseError>
498    where
499        F: for<'b> FnOnce(Self::Tx<'b>) -> R,
500    {
501        Ok(func(self.write_txn()?))
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_edge_key_roundtrip() {
511        let source = 12345u64;
512        let sort_key = b"test_edge";
513        let dest = 67890u64;
514
515        let key = make_edge_key(source, sort_key, dest);
516        let (parsed_source, parsed_sort_key, parsed_dest) =
517            parse_edge_key(&key);
518
519        assert_eq!(parsed_source, source);
520        assert_eq!(parsed_sort_key, sort_key);
521        assert_eq!(parsed_dest, dest);
522    }
523
524    #[test]
525    fn test_edge_key_ordering() {
526        // Verify that keys sort correctly
527        let key1 = make_edge_key(1, b"a", 10);
528        let key2 = make_edge_key(1, b"a", 20);
529        let key3 = make_edge_key(1, b"b", 10);
530        let key4 = make_edge_key(2, b"a", 10);
531
532        assert!(key1 < key2); // Same source and type, different dest
533        assert!(key2 < key3); // Same source, different type
534        assert!(key3 < key4); // Different source
535    }
536}