Skip to main content

mnemo_core/query/
remember.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::hash::{compute_chain_hash, compute_content_hash};
6use crate::model::event::{AgentEvent, EventType};
7use crate::model::memory::{ConsolidationState, MemoryRecord, MemoryType, Scope, SourceType};
8use crate::model::relation::Relation;
9use crate::query::MnemoEngine;
10#[allow(unused_imports)]
11use base64::Engine as _;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct RememberRequest {
15    pub content: String,
16    pub agent_id: Option<String>,
17    pub memory_type: Option<MemoryType>,
18    pub scope: Option<Scope>,
19    pub importance: Option<f32>,
20    pub tags: Option<Vec<String>>,
21    pub metadata: Option<serde_json::Value>,
22    pub source_type: Option<SourceType>,
23    pub source_id: Option<String>,
24    pub org_id: Option<String>,
25    pub thread_id: Option<String>,
26    pub ttl_seconds: Option<u64>,
27    pub related_to: Option<Vec<String>>,
28    pub decay_rate: Option<f32>,
29    pub created_by: Option<String>,
30}
31
32impl RememberRequest {
33    pub fn new(content: String) -> Self {
34        Self {
35            content,
36            agent_id: None,
37            memory_type: None,
38            scope: None,
39            importance: None,
40            tags: None,
41            metadata: None,
42            source_type: None,
43            source_id: None,
44            org_id: None,
45            thread_id: None,
46            ttl_seconds: None,
47            related_to: None,
48            decay_rate: None,
49            created_by: None,
50        }
51    }
52}
53
54#[non_exhaustive]
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct RememberResponse {
57    pub id: Uuid,
58    pub content_hash: String,
59}
60
61impl RememberResponse {
62    pub fn new(id: Uuid, content_hash: String) -> Self {
63        Self { id, content_hash }
64    }
65}
66
67pub async fn execute(engine: &MnemoEngine, request: RememberRequest) -> Result<RememberResponse> {
68    // Validate
69    if request.content.trim().is_empty() {
70        return Err(Error::Validation("content cannot be empty".to_string()));
71    }
72
73    let resolved_tier = request.memory_type.unwrap_or(MemoryType::Episodic);
74
75    // Tier-specific importance enforcement:
76    // Procedural memories (system prompts, tool definitions) carry an
77    // importance floor so they never decay below the recall threshold.
78    let mut importance = request.importance.unwrap_or(0.5);
79    if resolved_tier == MemoryType::Procedural && importance < engine.procedural_importance_floor {
80        importance = engine.procedural_importance_floor;
81    }
82    if !(0.0..=1.0).contains(&importance) {
83        return Err(Error::Validation(
84            "importance must be between 0.0 and 1.0".to_string(),
85        ));
86    }
87
88    let agent_id = request
89        .agent_id
90        .unwrap_or_else(|| engine.default_agent_id.clone());
91    super::validate_agent_id(&agent_id)?;
92    let org_id = request.org_id.or_else(|| engine.default_org_id.clone());
93    let now = chrono::Utc::now();
94    let now_str = now.to_rfc3339();
95    let id = Uuid::now_v7();
96
97    // Compute embedding
98    let embedding = engine.embedding.embed(&request.content).await?;
99
100    // Compute content hash
101    let content_hash = compute_content_hash(&request.content, &agent_id, &now_str);
102
103    // Chain linking: look up prev_hash
104    // NOTE: Concurrent writes for the same agent_id may race on prev_hash lookup.
105    // DuckDB mode serializes via Arc<Mutex<Connection>>. PostgreSQL deployments
106    // should rely on verify_chain() to detect any broken links.
107    let prev_hash_raw = engine
108        .storage
109        .get_latest_memory_hash(&agent_id, request.thread_id.as_deref())
110        .await?;
111    let prev_hash = Some(compute_chain_hash(&content_hash, prev_hash_raw.as_deref()));
112
113    // Compute expires_at from ttl_seconds. Working-tier memories get an
114    // automatic TTL so they can't outlive their session — caller-supplied
115    // ttl_seconds still wins.
116    let effective_ttl = request.ttl_seconds.or_else(|| {
117        if resolved_tier == MemoryType::Working {
118            Some(engine.ttl_working_seconds)
119        } else {
120            None
121        }
122    });
123    let expires_at =
124        effective_ttl.map(|ttl| (now + chrono::Duration::seconds(ttl as i64)).to_rfc3339());
125
126    let mut record = MemoryRecord {
127        id,
128        agent_id: agent_id.clone(),
129        content: request.content,
130        memory_type: resolved_tier,
131        scope: request.scope.unwrap_or(Scope::Private),
132        importance,
133        tags: request.tags.unwrap_or_default(),
134        metadata: request
135            .metadata
136            .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
137        embedding: Some(embedding.clone()),
138        content_hash: content_hash.clone(),
139        prev_hash,
140        source_type: request.source_type.unwrap_or(SourceType::Agent),
141        source_id: request.source_id,
142        consolidation_state: ConsolidationState::Raw,
143        access_count: 0,
144        org_id,
145        thread_id: request.thread_id,
146        created_at: now_str.clone(),
147        updated_at: now_str,
148        last_accessed_at: None,
149        expires_at,
150        deleted_at: None,
151        decay_rate: request.decay_rate,
152        created_by: request.created_by,
153        version: 1,
154        prev_version_id: None,
155        quarantined: false,
156        quarantine_reason: None,
157        decay_function: None,
158    };
159
160    // Encrypt content if encryption is configured (after embedding, before storage)
161    if let Some(ref enc) = engine.encryption {
162        let encrypted = enc.encrypt(record.content.as_bytes())?;
163        record.content =
164            base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &encrypted);
165    }
166
167    // Store in database
168    engine.storage.insert_memory(&record).await?;
169
170    // Add to vector index
171    engine.index.add(id, &embedding)?;
172
173    // Add to full-text index if available
174    if let Some(ref ft) = engine.full_text {
175        ft.add(id, &record.content)?;
176        ft.commit()?;
177    }
178
179    // Check for anomaly and update agent profile
180    let anomaly_result = super::poisoning::check_for_anomaly(engine, &record).await?;
181    if anomaly_result.is_anomalous {
182        super::poisoning::quarantine_memory(engine, id, &anomaly_result.reasons.join("; ")).await?;
183        tracing::warn!(
184            memory_id = %id,
185            score = anomaly_result.score,
186            reasons = ?anomaly_result.reasons,
187            "Memory quarantined due to anomaly detection"
188        );
189    }
190    super::poisoning::update_agent_profile(engine, &record).await?;
191
192    // Create relations if specified
193    if let Some(ref related_ids) = request.related_to {
194        for target_str in related_ids {
195            if let Ok(target_id) = Uuid::parse_str(target_str) {
196                let relation = Relation {
197                    id: Uuid::now_v7(),
198                    source_id: id,
199                    target_id,
200                    relation_type: "related_to".to_string(),
201                    weight: 1.0,
202                    metadata: serde_json::Value::Object(serde_json::Map::new()),
203                    created_at: record.created_at.clone(),
204                };
205                if let Err(e) = engine.storage.insert_relation(&relation).await {
206                    tracing::error!(relation_id = %relation.id, error = %e, "failed to insert relation");
207                }
208            }
209        }
210    }
211
212    // Emit MemoryWrite event with hash chain linking (fire-and-forget)
213    let prev_event_hash = match engine
214        .storage
215        .get_latest_event_hash(&agent_id, record.thread_id.as_deref())
216        .await
217    {
218        Ok(hash) => hash,
219        Err(e) => {
220            tracing::warn!(error = %e, "failed to get latest event hash, starting new chain segment");
221            None
222        }
223    };
224    let event_prev_hash = Some(compute_chain_hash(
225        &content_hash,
226        prev_event_hash.as_deref(),
227    ));
228    let mut event = AgentEvent {
229        id: Uuid::now_v7(),
230        agent_id: record.agent_id.clone(),
231        thread_id: record.thread_id.clone(),
232        run_id: None,
233        parent_event_id: None,
234        event_type: EventType::MemoryWrite,
235        payload: serde_json::json!({"memory_id": id.to_string()}),
236        trace_id: None,
237        span_id: None,
238        model: None,
239        tokens_input: None,
240        tokens_output: None,
241        latency_ms: None,
242        cost_usd: None,
243        timestamp: record.created_at.clone(),
244        logical_clock: 0,
245        content_hash: content_hash.clone(),
246        prev_hash: event_prev_hash,
247        embedding: None,
248    };
249    // Optionally embed the event payload
250    if engine.embed_events
251        && let Ok(emb) = engine.embedding.embed(&event.payload.to_string()).await
252    {
253        event.embedding = Some(emb);
254    }
255    if let Err(e) = engine.storage.insert_event(&event).await {
256        tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
257    }
258
259    // Put in cache if configured
260    if let Some(ref cache) = engine.cache {
261        cache.put(record);
262    }
263
264    let hash_hex = hex::encode(&content_hash);
265
266    Ok(RememberResponse {
267        id,
268        content_hash: hash_hex,
269    })
270}