Skip to main content

hirn_engine/tools/
toolkit.rs

1//! [`MemoryToolkit`] — 6-function agent API wrapping [`HirnDB`].
2
3use std::sync::Arc;
4
5use hirn_core::episodic::EpisodicRecord;
6use hirn_core::error::{HirnError, HirnResult};
7use hirn_core::id::MemoryId;
8use hirn_core::types::{AgentId, EventType};
9
10use crate::db::HirnDB;
11use crate::graph::EdgeId;
12use crate::graph_store::GraphStore;
13use crate::policy::Action;
14
15use super::types::{
16    EdgeInfo, IntrospectionResult, LinkRequest, RecallOptions, RecallRecord, StoreRequest,
17    UpdateRequest,
18};
19
20/// Agent-facing toolkit with 6 self-editing memory operations.
21///
22/// Each method validates input, enforces Cedar policies via the agent's
23/// identity, and delegates to [`HirnDB`]. Designed to be the single
24/// abstraction layer between protocol adapters (MCP, gRPC) and the engine.
25#[derive(Clone)]
26pub struct MemoryToolkit {
27    db: Arc<HirnDB>,
28}
29
30impl MemoryToolkit {
31    /// Create a new toolkit wrapping the given database.
32    pub fn new(db: Arc<HirnDB>) -> Self {
33        Self { db }
34    }
35
36    /// Access the underlying database (for advanced operations).
37    pub fn db(&self) -> &HirnDB {
38        &self.db
39    }
40
41    // ── 1. Store ────────────────────────────────────────────────────────
42
43    /// Store a new memory with RPE-gated admission.
44    ///
45    /// Validates content, enforces `Action::Remember` policy, then delegates
46    /// to `HirnDB::remember()`.
47    pub async fn store(&self, agent_id: AgentId, request: StoreRequest) -> HirnResult<MemoryId> {
48        // Input validation.
49        if request.content.is_empty() {
50            return Err(HirnError::InvalidInput("content must not be empty".into()));
51        }
52        if request.content.len() > 1_000_000 {
53            return Err(HirnError::InvalidInput("content exceeds 1MB limit".into()));
54        }
55        if let Some(imp) = request.importance {
56            if !(0.0..=1.0).contains(&imp) {
57                return Err(HirnError::InvalidInput(
58                    "importance must be between 0.0 and 1.0".into(),
59                ));
60            }
61        }
62
63        let ns = request.namespace.unwrap_or_default();
64
65        // Cedar enforcement.
66        self.db
67            .enforce(agent_id.as_str(), Action::Remember, "default", ns.as_str())
68            .await?;
69
70        // Build record.
71        let mut builder = EpisodicRecord::builder()
72            .content(&request.content)
73            .event_type(request.event_type.unwrap_or(EventType::Observation))
74            .agent_id(agent_id)
75            .namespace(ns);
76
77        if let Some(imp) = request.importance {
78            builder = builder.importance(imp);
79        }
80        if let Some(emb) = request.embedding {
81            builder = builder.embedding(emb);
82        }
83        if let Some(meta) = request.metadata {
84            for (k, v) in &meta {
85                let v_len = match v {
86                    hirn_core::metadata::MetadataValue::String(s) => s.len(),
87                    _ => 0, // non-string variants are bounded by type
88                };
89                if k.len() > 256 || v_len > 10_000 {
90                    return Err(HirnError::InvalidInput(
91                        "metadata key must be ≤256 bytes and value ≤10,000 bytes".into(),
92                    ));
93                }
94            }
95            for (k, v) in meta {
96                builder = builder.metadata_entry(k, v);
97            }
98        }
99
100        let record = builder
101            .build()
102            .map_err(|e| HirnError::InvalidInput(format!("failed to build record: {e}")))?;
103
104        self.db.remember(record).await
105    }
106
107    // ── 2. Recall ───────────────────────────────────────────────────────
108
109    /// Recall memories matching a natural-language query.
110    ///
111    /// Uses `RecallBuilder` directly with proper agent identity for Cedar enforcement.
112    pub async fn recall(
113        &self,
114        agent_id: AgentId,
115        query: &str,
116        options: RecallOptions,
117    ) -> HirnResult<Vec<RecallRecord>> {
118        if query.is_empty() {
119            return Err(HirnError::InvalidInput("query must not be empty".into()));
120        }
121
122        let ns = options.namespace.unwrap_or_default();
123
124        // Embed the query text.
125        let embedding = self.db.embed_text(query).await?;
126
127        // Build recall via RecallBuilder — passes agent_id so Cedar enforcement
128        // inside execute_with_diagnostics() uses the correct identity.
129        let limit = options.limit.unwrap_or(10);
130        let builder = self
131            .db
132            .recall(embedding)
133            .agent_id(agent_id.as_str())
134            .namespace(ns)
135            .limit(limit)
136            .query_text(query)
137            .hybrid(true);
138
139        let results = builder.execute().await?;
140
141        Ok(results
142            .into_iter()
143            .map(|r| {
144                let id = r.record.id();
145                let content = match &r.record {
146                    hirn_core::record::MemoryRecord::Episodic(e) => e.content.clone(),
147                    hirn_core::record::MemoryRecord::Semantic(s) => s.description.clone(),
148                    hirn_core::record::MemoryRecord::Procedural(p) => p.description.clone(),
149                    hirn_core::record::MemoryRecord::Working(w) => w.content.clone(),
150                };
151                RecallRecord {
152                    id,
153                    content,
154                    score: f64::from(r.composite_score),
155                    metadata: Default::default(),
156                }
157            })
158            .collect())
159    }
160
161    // ── 3. Update ───────────────────────────────────────────────────────
162
163    /// Update an existing memory's content and/or metadata.
164    ///
165    /// Enforces `Action::Remember` (writes require store permission).
166    pub async fn update(&self, agent_id: AgentId, request: UpdateRequest) -> HirnResult<()> {
167        if request.content.is_none() && request.metadata.is_none() && request.importance.is_none() {
168            return Err(HirnError::InvalidInput(
169                "at least one of content, metadata, or importance must be provided".into(),
170            ));
171        }
172        if let Some(ref c) = request.content {
173            if c.is_empty() {
174                return Err(HirnError::InvalidInput("content must not be empty".into()));
175            }
176        }
177
178        // Read the record to find its namespace for Cedar enforcement.
179        let existing = self.db.resolve_active_episodic_head(request.id).await?;
180        let ns = existing.namespace;
181
182        self.db
183            .enforce(agent_id.as_str(), Action::Remember, "default", ns.as_str())
184            .await?;
185
186        let content = request.content.clone();
187        let metadata = request.metadata.clone();
188        let importance = request.importance;
189
190        self.db
191            .update_episode(existing.id, move |rec| {
192                if let Some(c) = content {
193                    rec.content = c;
194                }
195                if let Some(meta) = metadata {
196                    rec.metadata.extend(meta);
197                }
198                if let Some(imp) = importance {
199                    rec.importance = imp;
200                }
201            })
202            .await
203    }
204
205    // ── 4. Delete ───────────────────────────────────────────────────────
206
207    /// Soft-delete (archive) a memory.
208    ///
209    /// Sets the archived flag. Does not permanently remove the record.
210    pub async fn delete(&self, agent_id: AgentId, id: MemoryId) -> HirnResult<()> {
211        // Read to find namespace for policy.
212        let existing = self.db.resolve_active_episodic_head(id).await?;
213        let ns = existing.namespace;
214
215        self.db
216            .enforce(agent_id.as_str(), Action::Forget, "default", ns.as_str())
217            .await?;
218
219        self.db.archive_episode(existing.id).await
220    }
221
222    // ── 5. Link ─────────────────────────────────────────────────────────
223
224    /// Create a graph edge between two memories.
225    pub async fn link(&self, agent_id: AgentId, request: LinkRequest) -> HirnResult<EdgeId> {
226        // Default namespace for policy — links cross namespace boundaries.
227        self.db
228            .enforce(agent_id.as_str(), Action::Connect, "default", "default")
229            .await?;
230
231        let weight = request.weight.unwrap_or(0.5);
232        let metadata = request.metadata.unwrap_or_default();
233
234        self.db
235            .connect_with(
236                request.source_id,
237                request.target_id,
238                request.relation,
239                weight,
240                metadata,
241            )
242            .await
243    }
244
245    // ── 6. Introspect ───────────────────────────────────────────────────
246
247    /// Return memory statistics and optionally graph neighborhood for a memory.
248    pub async fn introspect(
249        &self,
250        agent_id: AgentId,
251        id: Option<MemoryId>,
252    ) -> HirnResult<IntrospectionResult> {
253        self.db
254            .enforce(agent_id.as_str(), Action::Recall, "default", "default")
255            .await?;
256
257        let stats = self.db.stats().await?;
258
259        let edges = if let Some(memory_id) = id {
260            let graph = self.db.cached_graph();
261            let node_edges = graph.get_edges(memory_id).await?;
262            node_edges
263                .into_iter()
264                .map(|e| EdgeInfo {
265                    source: e.source,
266                    target: e.target,
267                    relation: e.relation.clone(),
268                    weight: e.weight,
269                })
270                .collect()
271        } else {
272            Vec::new()
273        };
274
275        Ok(IntrospectionResult {
276            total_memories: stats.total_count,
277            episodic_count: stats.episodic_count,
278            semantic_count: stats.semantic_count,
279            procedural_count: stats.procedural_count,
280            working_count: stats.working_count,
281            edge_count: stats.edge_count,
282            edges,
283        })
284    }
285}