Skip to main content

agentic_memory/engine/
write.rs

1//! Memory formation pipeline — the write engine.
2
3use crate::graph::MemoryGraph;
4use crate::types::{
5    now_micros, AmemError, AmemResult, CognitiveEvent, CognitiveEventBuilder, Edge, EdgeType,
6    EventType,
7};
8
9use super::decay::calculate_decay;
10
11/// Result of an ingest operation.
12#[derive(Debug)]
13pub struct IngestResult {
14    /// IDs of newly created nodes.
15    pub new_node_ids: Vec<u64>,
16    /// Number of new edges created.
17    pub new_edge_count: usize,
18    /// IDs of nodes that were updated (touch count, last_accessed).
19    pub touched_node_ids: Vec<u64>,
20}
21
22/// Report from running decay calculations.
23#[derive(Debug)]
24pub struct DecayReport {
25    /// Number of nodes whose decay_score was updated.
26    pub nodes_decayed: usize,
27    /// Nodes whose decay_score dropped below 0.1 (candidates for archival).
28    pub low_importance_nodes: Vec<u64>,
29}
30
31/// The write engine orchestrates memory formation.
32pub struct WriteEngine {
33    dimension: usize,
34}
35
36impl WriteEngine {
37    /// Create a new write engine.
38    pub fn new(dimension: usize) -> Self {
39        Self { dimension }
40    }
41
42    /// Process a batch of new cognitive events and integrate them into the graph.
43    pub fn ingest(
44        &self,
45        graph: &mut MemoryGraph,
46        events: Vec<CognitiveEvent>,
47        edges: Vec<Edge>,
48    ) -> AmemResult<IngestResult> {
49        let mut new_node_ids = Vec::with_capacity(events.len());
50        let mut touched_node_ids = Vec::new();
51
52        // Step 1-3: Validate and add all events
53        for event in events {
54            let id = graph.add_node(event)?;
55            new_node_ids.push(id);
56        }
57
58        // Step 4-5: Validate and add all edges
59        let mut new_edge_count = 0;
60        for edge in edges {
61            graph.add_edge(edge)?;
62            new_edge_count += 1;
63        }
64
65        // Ensure adjacency is rebuilt after bulk edge insertion
66        graph.ensure_adjacency();
67
68        // Step 8: Touch referenced nodes (nodes that existing edges point to)
69        let new_id_set: std::collections::HashSet<u64> = new_node_ids.iter().copied().collect();
70        for edge in graph.edges() {
71            // If a new node has an edge to an existing node, touch that existing node
72            if new_id_set.contains(&edge.source_id)
73                && !new_id_set.contains(&edge.target_id)
74                && !touched_node_ids.contains(&edge.target_id)
75            {
76                touched_node_ids.push(edge.target_id);
77            }
78        }
79
80        for &id in &touched_node_ids {
81            if let Some(node) = graph.get_node_mut(id) {
82                node.access_count += 1;
83                node.last_accessed = now_micros();
84            }
85        }
86
87        Ok(IngestResult {
88            new_node_ids,
89            new_edge_count,
90            touched_node_ids,
91        })
92    }
93
94    /// Record a correction: marks old node as superseded, adds new node.
95    pub fn correct(
96        &self,
97        graph: &mut MemoryGraph,
98        old_node_id: u64,
99        new_content: &str,
100        session_id: u32,
101    ) -> AmemResult<u64> {
102        // Verify old node exists
103        if graph.get_node(old_node_id).is_none() {
104            return Err(AmemError::NodeNotFound(old_node_id));
105        }
106
107        // Create new correction node
108        let event = CognitiveEventBuilder::new(EventType::Correction, new_content)
109            .session_id(session_id)
110            .confidence(1.0)
111            .feature_vec(vec![0.0; self.dimension])
112            .build();
113
114        let new_id = graph.add_node(event)?;
115
116        // Create SUPERSEDES edge from new to old
117        let edge = Edge::new(new_id, old_node_id, EdgeType::Supersedes, 1.0);
118        graph.add_edge(edge)?;
119
120        // Ensure adjacency is rebuilt
121        graph.ensure_adjacency();
122
123        // Reduce old node's confidence to 0.0
124        if let Some(old_node) = graph.get_node_mut(old_node_id) {
125            old_node.confidence = 0.0;
126        }
127
128        Ok(new_id)
129    }
130
131    /// Compress a session into an episode node.
132    pub fn compress_session(
133        &self,
134        graph: &mut MemoryGraph,
135        session_id: u32,
136        summary: &str,
137    ) -> AmemResult<u64> {
138        // Find all nodes in this session
139        let session_node_ids: Vec<u64> = graph.session_index().get_session(session_id).to_vec();
140
141        // Create episode node
142        let event = CognitiveEventBuilder::new(EventType::Episode, summary)
143            .session_id(session_id)
144            .confidence(1.0)
145            .feature_vec(vec![0.0; self.dimension])
146            .build();
147
148        let episode_id = graph.add_node(event)?;
149
150        // Create PART_OF edges from each session node to the episode
151        for &node_id in &session_node_ids {
152            let edge = Edge::new(node_id, episode_id, EdgeType::PartOf, 1.0);
153            graph.add_edge(edge)?;
154        }
155
156        // Ensure adjacency is rebuilt
157        graph.ensure_adjacency();
158
159        Ok(episode_id)
160    }
161
162    /// Touch a node (update access_count and last_accessed).
163    pub fn touch(&self, graph: &mut MemoryGraph, node_id: u64) -> AmemResult<()> {
164        let node = graph
165            .get_node_mut(node_id)
166            .ok_or(AmemError::NodeNotFound(node_id))?;
167        node.access_count += 1;
168        node.last_accessed = now_micros();
169        Ok(())
170    }
171
172    /// Run decay calculations across all nodes.
173    pub fn run_decay(&self, graph: &mut MemoryGraph, current_time: u64) -> AmemResult<DecayReport> {
174        let mut nodes_decayed = 0;
175        let mut low_importance_nodes = Vec::new();
176
177        // Collect node IDs first to avoid borrow issues
178        let node_ids: Vec<u64> = graph.nodes().iter().map(|n| n.id).collect();
179
180        for id in node_ids {
181            if let Some(node) = graph.get_node_mut(id) {
182                let new_score = calculate_decay(node, current_time);
183                if (new_score - node.decay_score).abs() > f32::EPSILON {
184                    node.decay_score = new_score;
185                    nodes_decayed += 1;
186                }
187                if new_score < 0.1 {
188                    low_importance_nodes.push(id);
189                }
190            }
191        }
192
193        Ok(DecayReport {
194            nodes_decayed,
195            low_importance_nodes,
196        })
197    }
198}