Skip to main content

agentic_memory/format/
writer.rs

1//! Writes .amem files from in-memory graph.
2
3use std::io::Write;
4use std::path::Path;
5
6use crate::engine::tokenizer::Tokenizer;
7use crate::graph::MemoryGraph;
8use crate::index::{DocLengths, TermIndex};
9use crate::types::error::AmemResult;
10use crate::types::header::{feature_flags, FileHeader, HEADER_SIZE};
11use crate::types::{Edge, EventType, AMEM_MAGIC, FORMAT_VERSION};
12
13use super::compression::compress_content;
14
15/// Size of a single node record on disk: 72 bytes.
16const NODE_RECORD_SIZE: u64 = 72;
17
18/// Size of a single edge record on disk: 32 bytes.
19const EDGE_RECORD_SIZE: u64 = 32;
20
21/// Writer for .amem binary files.
22pub struct AmemWriter {
23    dimension: usize,
24}
25
26impl AmemWriter {
27    /// Create a new writer with the given feature vector dimension.
28    pub fn new(dimension: usize) -> Self {
29        Self { dimension }
30    }
31
32    /// Write a complete MemoryGraph to an .amem file.
33    pub fn write_to_file(&self, graph: &MemoryGraph, path: &Path) -> AmemResult<()> {
34        let file = std::fs::File::create(path)?;
35        let mut writer = std::io::BufWriter::new(file);
36        self.write_to(graph, &mut writer)
37    }
38
39    /// Write a complete MemoryGraph to any writer.
40    pub fn write_to(&self, graph: &MemoryGraph, writer: &mut impl Write) -> AmemResult<()> {
41        let nodes = graph.nodes();
42        // Sort edges by source_id for correct edge offset computation
43        let mut edges_sorted: Vec<Edge> = graph.edges().to_vec();
44        edges_sorted.sort_by(|a, b| {
45            a.source_id
46                .cmp(&b.source_id)
47                .then(a.target_id.cmp(&b.target_id))
48        });
49        let edges = &edges_sorted[..];
50        let node_count = nodes.len() as u64;
51        let edge_count = edges.len() as u64;
52
53        // Step 1: Compress all node contents and record offsets
54        let mut compressed_contents: Vec<Vec<u8>> = Vec::with_capacity(nodes.len());
55        let mut content_offsets: Vec<u64> = Vec::with_capacity(nodes.len());
56        let mut content_total_size: u64 = 0;
57
58        for node in nodes {
59            let compressed = compress_content(&node.content)?;
60            content_offsets.push(content_total_size);
61            content_total_size += compressed.len() as u64;
62            compressed_contents.push(compressed);
63        }
64
65        // Step 2: Calculate edge offsets per node
66        // Edges are sorted by source_id. We need to compute the offset for each node's edges.
67        let mut edge_offsets: Vec<(u64, u16)> = vec![(0, 0); nodes.len()];
68        {
69            let mut edge_idx = 0usize;
70            for node in nodes {
71                let start = edge_idx;
72                while edge_idx < edges.len() && edges[edge_idx].source_id == node.id {
73                    edge_idx += 1;
74                }
75                let count = edge_idx - start;
76                // Find the node's position in the sorted nodes list
77                if let Some(pos) = nodes.iter().position(|n| n.id == node.id) {
78                    edge_offsets[pos] = ((start as u64) * EDGE_RECORD_SIZE, count as u16);
79                }
80            }
81        }
82
83        // Step 3: Calculate section offsets
84        let node_table_offset = HEADER_SIZE;
85        let edge_table_offset = node_table_offset + node_count * NODE_RECORD_SIZE;
86        let content_block_offset = edge_table_offset + edge_count * EDGE_RECORD_SIZE;
87        let feature_vec_offset = content_block_offset + content_total_size;
88
89        // Compute feature flags: if we have nodes, we build BM25 indexes
90        let mut flags: u32 = 0;
91        if graph.node_count() > 0 {
92            flags |= feature_flags::HAS_TERM_INDEX | feature_flags::HAS_DOC_LENGTHS;
93        }
94
95        // Step 4: Write header
96        let header = FileHeader {
97            magic: AMEM_MAGIC,
98            version: FORMAT_VERSION,
99            dimension: self.dimension as u32,
100            flags,
101            node_count,
102            edge_count,
103            node_table_offset,
104            edge_table_offset,
105            content_block_offset,
106            feature_vec_offset,
107        };
108        header.write_to(writer)?;
109
110        // Step 5: Write node table
111        for (i, node) in nodes.iter().enumerate() {
112            write_node_record(
113                writer,
114                node,
115                content_offsets[i],
116                compressed_contents[i].len() as u32,
117                edge_offsets[i].0,
118                edge_offsets[i].1,
119            )?;
120        }
121
122        // Step 6: Write edge table
123        for edge in edges {
124            write_edge_record(writer, edge)?;
125        }
126
127        // Step 7: Write content block
128        for compressed in &compressed_contents {
129            writer.write_all(compressed)?;
130        }
131
132        // Step 8: Write feature vector block
133        for node in nodes {
134            for &val in &node.feature_vec {
135                writer.write_all(&val.to_le_bytes())?;
136            }
137            // Pad if feature vec is shorter than dimension
138            let remaining = self.dimension.saturating_sub(node.feature_vec.len());
139            for _ in 0..remaining {
140                writer.write_all(&0.0f32.to_le_bytes())?;
141            }
142        }
143
144        // Step 9: Write index block
145        self.write_indexes(writer, graph)?;
146
147        writer.flush()?;
148        Ok(())
149    }
150
151    fn write_indexes(&self, writer: &mut impl Write, graph: &MemoryGraph) -> AmemResult<()> {
152        // Type Index (tag 0x01)
153        {
154            let mut buf: Vec<u8> = Vec::new();
155            let type_idx = graph.type_index();
156            for event_type_val in 0u8..=5 {
157                if let Some(et) = EventType::from_u8(event_type_val) {
158                    let ids = type_idx.get(et);
159                    if !ids.is_empty() {
160                        buf.push(event_type_val);
161                        buf.extend_from_slice(&(ids.len() as u64).to_le_bytes());
162                        for &id in ids {
163                            buf.extend_from_slice(&id.to_le_bytes());
164                        }
165                    }
166                }
167            }
168            writer.write_all(&[0x01u8])?;
169            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
170            writer.write_all(&buf)?;
171        }
172
173        // Temporal Index (tag 0x02)
174        {
175            let temporal_idx = graph.temporal_index();
176            let entries = temporal_idx.entries();
177            let mut buf: Vec<u8> = Vec::new();
178            buf.extend_from_slice(&(entries.len() as u64).to_le_bytes());
179            for &(created_at, node_id) in entries {
180                buf.extend_from_slice(&created_at.to_le_bytes());
181                buf.extend_from_slice(&node_id.to_le_bytes());
182            }
183            writer.write_all(&[0x02u8])?;
184            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
185            writer.write_all(&buf)?;
186        }
187
188        // Session Index (tag 0x03)
189        {
190            let session_idx = graph.session_index();
191            let inner = session_idx.inner();
192            let mut buf: Vec<u8> = Vec::new();
193            buf.extend_from_slice(&(inner.len() as u32).to_le_bytes());
194            let mut session_ids: Vec<u32> = inner.keys().copied().collect();
195            session_ids.sort_unstable();
196            for sid in session_ids {
197                let ids = session_idx.get_session(sid);
198                buf.extend_from_slice(&sid.to_le_bytes());
199                buf.extend_from_slice(&(ids.len() as u64).to_le_bytes());
200                for &id in ids {
201                    buf.extend_from_slice(&id.to_le_bytes());
202                }
203            }
204            writer.write_all(&[0x03u8])?;
205            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
206            writer.write_all(&buf)?;
207        }
208
209        // Cluster Map (tag 0x04)
210        {
211            let cluster = graph.cluster_map();
212            let mut buf: Vec<u8> = Vec::new();
213            buf.extend_from_slice(&(cluster.cluster_count() as u32).to_le_bytes());
214            buf.extend_from_slice(&(cluster.dimension() as u32).to_le_bytes());
215            for i in 0..cluster.cluster_count() {
216                if let Some(centroid) = cluster.centroid(i) {
217                    for &val in centroid {
218                        buf.extend_from_slice(&val.to_le_bytes());
219                    }
220                }
221                let members = cluster.get_cluster(i);
222                buf.extend_from_slice(&(members.len() as u64).to_le_bytes());
223                for &id in members {
224                    buf.extend_from_slice(&id.to_le_bytes());
225                }
226            }
227            writer.write_all(&[0x04u8])?;
228            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
229            writer.write_all(&buf)?;
230        }
231
232        // Term Index (tag 0x05) — NEW
233        if graph.node_count() > 0 {
234            let tokenizer = Tokenizer::new();
235            let term_index = TermIndex::build(graph, &tokenizer);
236            let buf = term_index.to_bytes();
237            writer.write_all(&[0x05u8])?;
238            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
239            writer.write_all(&buf)?;
240
241            // Doc Lengths (tag 0x06) — NEW
242            let doc_lengths = DocLengths::build(graph, &tokenizer);
243            let buf = doc_lengths.to_bytes();
244            writer.write_all(&[0x06u8])?;
245            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
246            writer.write_all(&buf)?;
247        }
248
249        Ok(())
250    }
251}
252
253/// Write a single 72-byte node record.
254fn write_node_record(
255    writer: &mut impl Write,
256    node: &crate::types::CognitiveEvent,
257    content_offset: u64,
258    content_length: u32,
259    edge_offset: u64,
260    edge_count: u16,
261) -> AmemResult<()> {
262    writer.write_all(&node.id.to_le_bytes())?; // 8 bytes
263    writer.write_all(&[node.event_type as u8])?; // 1 byte
264    writer.write_all(&[0u8; 3])?; // 3 bytes padding
265    writer.write_all(&node.created_at.to_le_bytes())?; // 8 bytes
266    writer.write_all(&node.session_id.to_le_bytes())?; // 4 bytes
267    writer.write_all(&node.confidence.to_le_bytes())?; // 4 bytes
268    writer.write_all(&node.access_count.to_le_bytes())?; // 4 bytes
269    writer.write_all(&node.last_accessed.to_le_bytes())?; // 8 bytes
270    writer.write_all(&node.decay_score.to_le_bytes())?; // 4 bytes
271    writer.write_all(&content_offset.to_le_bytes())?; // 8 bytes
272    writer.write_all(&content_length.to_le_bytes())?; // 4 bytes
273    writer.write_all(&edge_offset.to_le_bytes())?; // 8 bytes
274    writer.write_all(&edge_count.to_le_bytes())?; // 2 bytes
275    writer.write_all(&[0u8; 6])?; // 6 bytes padding
276                                  // Total: 8+1+3+8+4+4+4+8+4+8+4+8+2+6 = 72
277    Ok(())
278}
279
280/// Write a single 32-byte edge record.
281fn write_edge_record(writer: &mut impl Write, edge: &crate::types::Edge) -> AmemResult<()> {
282    writer.write_all(&edge.source_id.to_le_bytes())?; // 8 bytes
283    writer.write_all(&edge.target_id.to_le_bytes())?; // 8 bytes
284    writer.write_all(&[edge.edge_type as u8])?; // 1 byte
285    writer.write_all(&[0u8; 3])?; // 3 bytes padding
286    writer.write_all(&edge.weight.to_le_bytes())?; // 4 bytes
287    writer.write_all(&edge.created_at.to_le_bytes())?; // 8 bytes
288                                                       // Total: 8+8+1+3+4+8 = 32
289    Ok(())
290}