Skip to main content

agentic_memory/format/
mmap.rs

1//! Memory-mapped file access for .amem files.
2
3use std::path::Path;
4
5use memmap2::Mmap;
6
7use crate::graph::MemoryGraph;
8use crate::index::cosine_similarity;
9use crate::types::error::{AmemError, AmemResult};
10use crate::types::header::FileHeader;
11use crate::types::{CognitiveEvent, Edge, EdgeType, EventType};
12
13use super::compression::decompress_content;
14
15/// A match result from a similarity search.
16#[derive(Debug, Clone)]
17pub struct SimilarityMatch {
18    /// The node ID that matched.
19    pub node_id: u64,
20    /// The cosine similarity score.
21    pub similarity: f32,
22}
23
24/// Read-only memory-mapped access to an .amem file.
25pub struct MmapReader {
26    mmap: Mmap,
27    header: FileHeader,
28}
29
30impl MmapReader {
31    /// Open an .amem file for memory-mapped read access.
32    pub fn open(path: &Path) -> AmemResult<Self> {
33        let file = std::fs::File::open(path)?;
34        let mmap = unsafe { Mmap::map(&file)? };
35
36        if mmap.len() < 64 {
37            return Err(AmemError::Truncated);
38        }
39
40        let header = FileHeader::read_from(&mut std::io::Cursor::new(&mmap[..64]))?;
41
42        Ok(Self { mmap, header })
43    }
44
45    /// Get the file header.
46    pub fn header(&self) -> &FileHeader {
47        &self.header
48    }
49
50    /// Read a single node record by ID (O(1) access).
51    pub fn read_node(&self, id: u64) -> AmemResult<CognitiveEvent> {
52        if id >= self.header.node_count {
53            return Err(AmemError::NodeNotFound(id));
54        }
55
56        let offset = self.header.node_table_offset as usize + (id as usize) * 72;
57        if offset + 72 > self.mmap.len() {
58            return Err(AmemError::Truncated);
59        }
60
61        let record = &self.mmap[offset..offset + 72];
62        let mut event = parse_node_record_mmap(record)?;
63
64        // Read content
65        event.content = self.read_content_internal(record)?;
66
67        // Read feature vec
68        event.feature_vec = self.read_feature_vec_internal(id)?;
69
70        Ok(event)
71    }
72
73    /// Read a node's content (decompress from content block).
74    pub fn read_content(&self, id: u64) -> AmemResult<String> {
75        if id >= self.header.node_count {
76            return Err(AmemError::NodeNotFound(id));
77        }
78
79        let offset = self.header.node_table_offset as usize + (id as usize) * 72;
80        if offset + 72 > self.mmap.len() {
81            return Err(AmemError::Truncated);
82        }
83
84        let record = &self.mmap[offset..offset + 72];
85        self.read_content_internal(record)
86    }
87
88    fn read_content_internal(&self, node_record: &[u8]) -> AmemResult<String> {
89        let content_offset = u64::from_le_bytes(node_record[44..52].try_into().unwrap());
90        let content_length = u32::from_le_bytes(node_record[52..56].try_into().unwrap());
91
92        if content_length == 0 {
93            return Ok(String::new());
94        }
95
96        let start = self.header.content_block_offset as usize + content_offset as usize;
97        let end = start + content_length as usize;
98        if end > self.mmap.len() {
99            return Err(AmemError::Truncated);
100        }
101
102        decompress_content(&self.mmap[start..end])
103    }
104
105    /// Read a node's feature vector.
106    pub fn read_feature_vec(&self, id: u64) -> AmemResult<Vec<f32>> {
107        self.read_feature_vec_internal(id)
108    }
109
110    fn read_feature_vec_internal(&self, id: u64) -> AmemResult<Vec<f32>> {
111        if id >= self.header.node_count {
112            return Err(AmemError::NodeNotFound(id));
113        }
114
115        let dim = self.header.dimension as usize;
116        let offset = self.header.feature_vec_offset as usize + (id as usize) * dim * 4;
117        if offset + dim * 4 > self.mmap.len() {
118            return Err(AmemError::Truncated);
119        }
120
121        let mut vec = Vec::with_capacity(dim);
122        for i in 0..dim {
123            let byte_offset = offset + i * 4;
124            let bytes: [u8; 4] = self.mmap[byte_offset..byte_offset + 4].try_into().unwrap();
125            vec.push(f32::from_le_bytes(bytes));
126        }
127        Ok(vec)
128    }
129
130    /// Read all edges from a node.
131    pub fn read_edges(&self, id: u64) -> AmemResult<Vec<Edge>> {
132        if id >= self.header.node_count {
133            return Err(AmemError::NodeNotFound(id));
134        }
135
136        let node_offset = self.header.node_table_offset as usize + (id as usize) * 72;
137        if node_offset + 72 > self.mmap.len() {
138            return Err(AmemError::Truncated);
139        }
140
141        let record = &self.mmap[node_offset..node_offset + 72];
142        let edge_offset = u64::from_le_bytes(record[56..64].try_into().unwrap());
143        let edge_count = u16::from_le_bytes(record[64..66].try_into().unwrap());
144
145        let mut edges = Vec::with_capacity(edge_count as usize);
146        let edge_base = self.header.edge_table_offset as usize + edge_offset as usize;
147
148        for i in 0..edge_count as usize {
149            let offset = edge_base + i * 32;
150            if offset + 32 > self.mmap.len() {
151                return Err(AmemError::Truncated);
152            }
153            let data = &self.mmap[offset..offset + 32];
154            let source_id = u64::from_le_bytes(data[0..8].try_into().unwrap());
155            let target_id = u64::from_le_bytes(data[8..16].try_into().unwrap());
156            let edge_type = EdgeType::from_u8(data[16]).ok_or(AmemError::Corrupt(offset as u64))?;
157            let weight = f32::from_le_bytes(data[20..24].try_into().unwrap());
158            let created_at = u64::from_le_bytes(data[24..32].try_into().unwrap());
159            edges.push(Edge {
160                source_id,
161                target_id,
162                edge_type,
163                weight,
164                created_at,
165            });
166        }
167
168        Ok(edges)
169    }
170
171    /// Read the full graph into memory.
172    pub fn read_full_graph(&self) -> AmemResult<MemoryGraph> {
173        let dimension = self.header.dimension as usize;
174        let node_count = self.header.node_count as usize;
175        let edge_count = self.header.edge_count as usize;
176
177        let mut nodes = Vec::with_capacity(node_count);
178        for id in 0..node_count as u64 {
179            nodes.push(self.read_node(id)?);
180        }
181
182        let mut edges = Vec::with_capacity(edge_count);
183        let edge_base = self.header.edge_table_offset as usize;
184        for i in 0..edge_count {
185            let offset = edge_base + i * 32;
186            if offset + 32 > self.mmap.len() {
187                return Err(AmemError::Truncated);
188            }
189            let data = &self.mmap[offset..offset + 32];
190            let source_id = u64::from_le_bytes(data[0..8].try_into().unwrap());
191            let target_id = u64::from_le_bytes(data[8..16].try_into().unwrap());
192            let edge_type = EdgeType::from_u8(data[16]).ok_or(AmemError::Corrupt(offset as u64))?;
193            let weight = f32::from_le_bytes(data[20..24].try_into().unwrap());
194            let created_at = u64::from_le_bytes(data[24..32].try_into().unwrap());
195            edges.push(Edge {
196                source_id,
197                target_id,
198                edge_type,
199                weight,
200                created_at,
201            });
202        }
203
204        MemoryGraph::from_parts(nodes, edges, dimension)
205    }
206
207    /// Compute cosine similarity between a query and a node's feature vector.
208    pub fn similarity_to(&self, id: u64, query: &[f32]) -> AmemResult<f32> {
209        let vec = self.read_feature_vec_internal(id)?;
210        Ok(cosine_similarity(query, &vec))
211    }
212
213    /// Batch similarity: scan all feature vectors and return top-k matches.
214    pub fn batch_similarity(
215        &self,
216        query: &[f32],
217        top_k: usize,
218        min_similarity: f32,
219    ) -> AmemResult<Vec<SimilarityMatch>> {
220        let dim = self.header.dimension as usize;
221        let node_count = self.header.node_count as usize;
222
223        let mut matches: Vec<SimilarityMatch> = Vec::new();
224
225        for id in 0..node_count {
226            let offset = self.header.feature_vec_offset as usize + id * dim * 4;
227            if offset + dim * 4 > self.mmap.len() {
228                break;
229            }
230
231            // Read feature vector directly from mmap
232            let mut vec = Vec::with_capacity(dim);
233            let mut is_zero = true;
234            for j in 0..dim {
235                let byte_offset = offset + j * 4;
236                let bytes: [u8; 4] = self.mmap[byte_offset..byte_offset + 4].try_into().unwrap();
237                let val = f32::from_le_bytes(bytes);
238                if val != 0.0 {
239                    is_zero = false;
240                }
241                vec.push(val);
242            }
243
244            if is_zero {
245                continue;
246            }
247
248            let sim = cosine_similarity(query, &vec);
249            if sim >= min_similarity {
250                matches.push(SimilarityMatch {
251                    node_id: id as u64,
252                    similarity: sim,
253                });
254            }
255        }
256
257        matches.sort_by(|a, b| {
258            b.similarity
259                .partial_cmp(&a.similarity)
260                .unwrap_or(std::cmp::Ordering::Equal)
261        });
262        matches.truncate(top_k);
263        Ok(matches)
264    }
265}
266
267/// Parse a node record from mmap bytes (without content/feature vec).
268fn parse_node_record_mmap(data: &[u8]) -> AmemResult<CognitiveEvent> {
269    let id = u64::from_le_bytes(data[0..8].try_into().unwrap());
270    let event_type = EventType::from_u8(data[8]).ok_or(AmemError::Corrupt(0))?;
271    let created_at = u64::from_le_bytes(data[12..20].try_into().unwrap());
272    let session_id = u32::from_le_bytes(data[20..24].try_into().unwrap());
273    let confidence = f32::from_le_bytes(data[24..28].try_into().unwrap());
274    let access_count = u32::from_le_bytes(data[28..32].try_into().unwrap());
275    let last_accessed = u64::from_le_bytes(data[32..40].try_into().unwrap());
276    let decay_score = f32::from_le_bytes(data[40..44].try_into().unwrap());
277
278    Ok(CognitiveEvent {
279        id,
280        event_type,
281        created_at,
282        session_id,
283        confidence,
284        access_count,
285        last_accessed,
286        decay_score,
287        content: String::new(),
288        feature_vec: Vec::new(),
289    })
290}