Skip to main content

agentic_memory/format/
writer.rs

1//! Writes .amem files from in-memory graph.
2
3use std::io::Write;
4use std::path::Path;
5
6use crate::graph::MemoryGraph;
7use crate::types::error::AmemResult;
8use crate::types::header::{FileHeader, HEADER_SIZE};
9use crate::types::{Edge, EventType, AMEM_MAGIC, FORMAT_VERSION};
10
11use super::compression::compress_content;
12
13/// Size of a single node record on disk: 72 bytes.
14const NODE_RECORD_SIZE: u64 = 72;
15
16/// Size of a single edge record on disk: 32 bytes.
17const EDGE_RECORD_SIZE: u64 = 32;
18
19/// Writer for .amem binary files.
20pub struct AmemWriter {
21    dimension: usize,
22}
23
24impl AmemWriter {
25    /// Create a new writer with the given feature vector dimension.
26    pub fn new(dimension: usize) -> Self {
27        Self { dimension }
28    }
29
30    /// Write a complete MemoryGraph to an .amem file.
31    pub fn write_to_file(&self, graph: &MemoryGraph, path: &Path) -> AmemResult<()> {
32        let file = std::fs::File::create(path)?;
33        let mut writer = std::io::BufWriter::new(file);
34        self.write_to(graph, &mut writer)
35    }
36
37    /// Write a complete MemoryGraph to any writer.
38    pub fn write_to(&self, graph: &MemoryGraph, writer: &mut impl Write) -> AmemResult<()> {
39        let nodes = graph.nodes();
40        // Sort edges by source_id for correct edge offset computation
41        let mut edges_sorted: Vec<Edge> = graph.edges().to_vec();
42        edges_sorted.sort_by(|a, b| {
43            a.source_id
44                .cmp(&b.source_id)
45                .then(a.target_id.cmp(&b.target_id))
46        });
47        let edges = &edges_sorted[..];
48        let node_count = nodes.len() as u64;
49        let edge_count = edges.len() as u64;
50
51        // Step 1: Compress all node contents and record offsets
52        let mut compressed_contents: Vec<Vec<u8>> = Vec::with_capacity(nodes.len());
53        let mut content_offsets: Vec<u64> = Vec::with_capacity(nodes.len());
54        let mut content_total_size: u64 = 0;
55
56        for node in nodes {
57            let compressed = compress_content(&node.content)?;
58            content_offsets.push(content_total_size);
59            content_total_size += compressed.len() as u64;
60            compressed_contents.push(compressed);
61        }
62
63        // Step 2: Calculate edge offsets per node
64        // Edges are sorted by source_id. We need to compute the offset for each node's edges.
65        let mut edge_offsets: Vec<(u64, u16)> = vec![(0, 0); nodes.len()];
66        {
67            let mut edge_idx = 0usize;
68            for node in nodes {
69                let start = edge_idx;
70                while edge_idx < edges.len() && edges[edge_idx].source_id == node.id {
71                    edge_idx += 1;
72                }
73                let count = edge_idx - start;
74                // Find the node's position in the sorted nodes list
75                if let Some(pos) = nodes.iter().position(|n| n.id == node.id) {
76                    edge_offsets[pos] = ((start as u64) * EDGE_RECORD_SIZE, count as u16);
77                }
78            }
79        }
80
81        // Step 3: Calculate section offsets
82        let node_table_offset = HEADER_SIZE;
83        let edge_table_offset = node_table_offset + node_count * NODE_RECORD_SIZE;
84        let content_block_offset = edge_table_offset + edge_count * EDGE_RECORD_SIZE;
85        let feature_vec_offset = content_block_offset + content_total_size;
86
87        // Step 4: Write header
88        let header = FileHeader {
89            magic: AMEM_MAGIC,
90            version: FORMAT_VERSION,
91            dimension: self.dimension as u32,
92            node_count,
93            edge_count,
94            node_table_offset,
95            edge_table_offset,
96            content_block_offset,
97            feature_vec_offset,
98        };
99        header.write_to(writer)?;
100
101        // Step 5: Write node table
102        for (i, node) in nodes.iter().enumerate() {
103            write_node_record(
104                writer,
105                node,
106                content_offsets[i],
107                compressed_contents[i].len() as u32,
108                edge_offsets[i].0,
109                edge_offsets[i].1,
110            )?;
111        }
112
113        // Step 6: Write edge table
114        for edge in edges {
115            write_edge_record(writer, edge)?;
116        }
117
118        // Step 7: Write content block
119        for compressed in &compressed_contents {
120            writer.write_all(compressed)?;
121        }
122
123        // Step 8: Write feature vector block
124        for node in nodes {
125            for &val in &node.feature_vec {
126                writer.write_all(&val.to_le_bytes())?;
127            }
128            // Pad if feature vec is shorter than dimension
129            let remaining = self.dimension.saturating_sub(node.feature_vec.len());
130            for _ in 0..remaining {
131                writer.write_all(&0.0f32.to_le_bytes())?;
132            }
133        }
134
135        // Step 9: Write index block
136        self.write_indexes(writer, graph)?;
137
138        writer.flush()?;
139        Ok(())
140    }
141
142    fn write_indexes(&self, writer: &mut impl Write, graph: &MemoryGraph) -> AmemResult<()> {
143        // Type Index (tag 0x01)
144        {
145            let mut buf: Vec<u8> = Vec::new();
146            let type_idx = graph.type_index();
147            for event_type_val in 0u8..=5 {
148                if let Some(et) = EventType::from_u8(event_type_val) {
149                    let ids = type_idx.get(et);
150                    if !ids.is_empty() {
151                        buf.push(event_type_val);
152                        buf.extend_from_slice(&(ids.len() as u64).to_le_bytes());
153                        for &id in ids {
154                            buf.extend_from_slice(&id.to_le_bytes());
155                        }
156                    }
157                }
158            }
159            writer.write_all(&[0x01u8])?;
160            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
161            writer.write_all(&buf)?;
162        }
163
164        // Temporal Index (tag 0x02)
165        {
166            let temporal_idx = graph.temporal_index();
167            let entries = temporal_idx.entries();
168            let mut buf: Vec<u8> = Vec::new();
169            buf.extend_from_slice(&(entries.len() as u64).to_le_bytes());
170            for &(created_at, node_id) in entries {
171                buf.extend_from_slice(&created_at.to_le_bytes());
172                buf.extend_from_slice(&node_id.to_le_bytes());
173            }
174            writer.write_all(&[0x02u8])?;
175            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
176            writer.write_all(&buf)?;
177        }
178
179        // Session Index (tag 0x03)
180        {
181            let session_idx = graph.session_index();
182            let inner = session_idx.inner();
183            let mut buf: Vec<u8> = Vec::new();
184            buf.extend_from_slice(&(inner.len() as u32).to_le_bytes());
185            let mut session_ids: Vec<u32> = inner.keys().copied().collect();
186            session_ids.sort_unstable();
187            for sid in session_ids {
188                let ids = session_idx.get_session(sid);
189                buf.extend_from_slice(&sid.to_le_bytes());
190                buf.extend_from_slice(&(ids.len() as u64).to_le_bytes());
191                for &id in ids {
192                    buf.extend_from_slice(&id.to_le_bytes());
193                }
194            }
195            writer.write_all(&[0x03u8])?;
196            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
197            writer.write_all(&buf)?;
198        }
199
200        // Cluster Map (tag 0x04)
201        {
202            let cluster = graph.cluster_map();
203            let mut buf: Vec<u8> = Vec::new();
204            buf.extend_from_slice(&(cluster.cluster_count() as u32).to_le_bytes());
205            buf.extend_from_slice(&(cluster.dimension() as u32).to_le_bytes());
206            for i in 0..cluster.cluster_count() {
207                if let Some(centroid) = cluster.centroid(i) {
208                    for &val in centroid {
209                        buf.extend_from_slice(&val.to_le_bytes());
210                    }
211                }
212                let members = cluster.get_cluster(i);
213                buf.extend_from_slice(&(members.len() as u64).to_le_bytes());
214                for &id in members {
215                    buf.extend_from_slice(&id.to_le_bytes());
216                }
217            }
218            writer.write_all(&[0x04u8])?;
219            writer.write_all(&(buf.len() as u64).to_le_bytes())?;
220            writer.write_all(&buf)?;
221        }
222
223        Ok(())
224    }
225}
226
227/// Write a single 72-byte node record.
228fn write_node_record(
229    writer: &mut impl Write,
230    node: &crate::types::CognitiveEvent,
231    content_offset: u64,
232    content_length: u32,
233    edge_offset: u64,
234    edge_count: u16,
235) -> AmemResult<()> {
236    writer.write_all(&node.id.to_le_bytes())?; // 8 bytes
237    writer.write_all(&[node.event_type as u8])?; // 1 byte
238    writer.write_all(&[0u8; 3])?; // 3 bytes padding
239    writer.write_all(&node.created_at.to_le_bytes())?; // 8 bytes
240    writer.write_all(&node.session_id.to_le_bytes())?; // 4 bytes
241    writer.write_all(&node.confidence.to_le_bytes())?; // 4 bytes
242    writer.write_all(&node.access_count.to_le_bytes())?; // 4 bytes
243    writer.write_all(&node.last_accessed.to_le_bytes())?; // 8 bytes
244    writer.write_all(&node.decay_score.to_le_bytes())?; // 4 bytes
245    writer.write_all(&content_offset.to_le_bytes())?; // 8 bytes
246    writer.write_all(&content_length.to_le_bytes())?; // 4 bytes
247    writer.write_all(&edge_offset.to_le_bytes())?; // 8 bytes
248    writer.write_all(&edge_count.to_le_bytes())?; // 2 bytes
249    writer.write_all(&[0u8; 6])?; // 6 bytes padding
250                                  // Total: 8+1+3+8+4+4+4+8+4+8+4+8+2+6 = 72
251    Ok(())
252}
253
254/// Write a single 32-byte edge record.
255fn write_edge_record(writer: &mut impl Write, edge: &crate::types::Edge) -> AmemResult<()> {
256    writer.write_all(&edge.source_id.to_le_bytes())?; // 8 bytes
257    writer.write_all(&edge.target_id.to_le_bytes())?; // 8 bytes
258    writer.write_all(&[edge.edge_type as u8])?; // 1 byte
259    writer.write_all(&[0u8; 3])?; // 3 bytes padding
260    writer.write_all(&edge.weight.to_le_bytes())?; // 4 bytes
261    writer.write_all(&edge.created_at.to_le_bytes())?; // 8 bytes
262                                                       // Total: 8+8+1+3+4+8 = 32
263    Ok(())
264}