Skip to main content

mnemo_core/query/
poisoning.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::anomaly::outlier::score_embedding_outlier;
5use crate::error::Result;
6use crate::model::agent_profile::AgentProfile;
7use crate::model::memory::{MemoryRecord, SourceType};
8use crate::query::MnemoEngine;
9use crate::storage::MemoryFilter;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct AnomalyCheckResult {
13    pub is_anomalous: bool,
14    pub score: f32,
15    pub reasons: Vec<String>,
16}
17
18/// Configuration for [`check_for_anomaly`].
19///
20/// All knobs are opt-in: the default policy reproduces v0.3.2 behaviour
21/// exactly (lexical markers + profile-based heuristics only). Callers
22/// attach a policy to the engine via
23/// [`crate::query::MnemoEngine::with_poisoning_policy`].
24///
25/// # v0.3.3 additions
26///
27/// [`PoisoningPolicy::with_outlier_threshold`] enables the embedding-space
28/// z-score outlier detector in [`crate::anomaly::outlier`]. When set, and
29/// when a trained [`crate::model::embedding_baseline::EmbeddingBaseline`]
30/// exists for the record's agent, `check_for_anomaly` adds
31/// `OUTLIER_SCORE_CONTRIBUTION` to the anomaly score whenever the record's
32/// embedding is at least `threshold` standard deviations from the baseline.
33#[derive(Debug, Clone, Default)]
34pub struct PoisoningPolicy {
35    /// z-score threshold above which an embedding is considered an
36    /// outlier. `None` disables the check. `Some(3.0)` is a reasonable
37    /// starting point.
38    pub outlier_threshold: Option<f32>,
39}
40
41/// How much an outlier-flagged record contributes to the anomaly score.
42/// Chosen so that an outlier alone crosses the `>= 0.5` anomalous
43/// threshold without stacking profile-based signals.
44pub const OUTLIER_SCORE_CONTRIBUTION: f32 = 0.5;
45
46impl PoisoningPolicy {
47    /// Enable the z-score outlier gate with the supplied threshold.
48    ///
49    /// Example: `PoisoningPolicy::default().with_outlier_threshold(3.0)`
50    pub fn with_outlier_threshold(mut self, threshold: f32) -> Self {
51        self.outlier_threshold = Some(threshold);
52        self
53    }
54}
55
56/// One row returned by [`replay_quarantine`].
57#[non_exhaustive]
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct QuarantineReplayEntry {
60    pub id: Uuid,
61    pub agent_id: String,
62    pub content: String,
63    pub reason: String,
64    pub created_at: String,
65    pub source_type: SourceType,
66    pub tags: Vec<String>,
67}
68
69/// Indirect-injection markers that commonly appear in MINJA-style attacks
70/// when an LLM is tricked into writing a memory it shouldn't. When present
71/// in records NOT written by a tool call (i.e. injected via a retrieved
72/// web/document fragment) they are a strong signal of injection.
73const SELF_REFERENTIAL_INSTRUCTION_MARKERS: &[&str] = &[
74    "remember this",
75    "remember the following",
76    "in the future, always",
77    "from now on, you",
78    "from now on, always",
79    "as the system, i",
80    "as your system prompt",
81    "permanently remember",
82    "never forget that",
83    "always respond with",
84    "always answer by",
85    "whenever asked about",
86    "when you are asked",
87];
88
89/// Sources we trust to carry instruction-like content by design.
90fn is_trusted_source(st: SourceType) -> bool {
91    matches!(
92        st,
93        SourceType::ToolOutput
94            | SourceType::System
95            | SourceType::UserInput
96            | SourceType::Human
97            | SourceType::ModelResponse
98    )
99}
100
101/// `source:<label>` tag marking a record as coming from an indirect
102/// injection vector (a web page, a document, a third-party email).
103fn looks_like_indirect_ingest(record: &MemoryRecord) -> bool {
104    record.tags.iter().any(|t| {
105        let lower = t.to_lowercase();
106        lower == "source:web"
107            || lower == "source:document"
108            || lower == "source:email"
109            || lower == "source:third_party"
110            || lower == "source:retrieved"
111    }) || matches!(
112        record.source_type,
113        SourceType::Retrieval | SourceType::Import
114    )
115}
116
117/// Self-referential instruction marker check — MINJA-class indirect
118/// injection signal. Returns a (matched_marker, is_suspicious_source) pair.
119fn check_self_referential_injection(record: &MemoryRecord) -> Option<&'static str> {
120    let lower = record.content.to_lowercase();
121    let matched = SELF_REFERENTIAL_INSTRUCTION_MARKERS
122        .iter()
123        .find(|p| lower.contains(**p))
124        .copied()?;
125    // Self-referential phrasing from a trusted source (tool output, user
126    // input, human) is legitimate; only flag it when the record arrived
127    // via an indirect ingest path.
128    if is_trusted_source(record.source_type) && !looks_like_indirect_ingest(record) {
129        return None;
130    }
131    Some(matched)
132}
133
134/// Detect common prompt injection patterns in memory content.
135///
136/// These patterns attempt to override AI agent instructions when the
137/// memory is recalled and included in an LLM context.
138fn contains_prompt_injection_patterns(content: &str) -> bool {
139    let lower = content.to_lowercase();
140    let patterns = [
141        "ignore all previous instructions",
142        "ignore previous instructions",
143        "disregard all prior",
144        "disregard previous",
145        "override system prompt",
146        "you are now in",
147        "new instructions:",
148        "system: you are",
149        "<<sys>>",
150        "[system]",
151        "```system",
152    ];
153    patterns.iter().any(|p| lower.contains(p))
154}
155
156/// Check a newly inserted memory record for anomaly indicators.
157///
158/// Scoring:
159/// - Importance deviation >0.4 from agent mean → +0.3
160/// - Content length >5x or <0.1x agent average → +0.3
161/// - High-frequency burst (>3x normal rate in last minute) → +0.4
162/// - Prompt injection patterns in content → +0.5
163/// - Total score >= 0.5 → anomalous
164pub async fn check_for_anomaly(
165    engine: &MnemoEngine,
166    record: &MemoryRecord,
167) -> Result<AnomalyCheckResult> {
168    let profile = engine.storage.get_agent_profile(&record.agent_id).await?;
169
170    let mut score: f32 = 0.0;
171    let mut reasons = Vec::new();
172
173    if let Some(ref profile) = profile {
174        // Check importance outlier
175        let importance_deviation = (record.importance as f64 - profile.avg_importance).abs();
176        if importance_deviation > 0.4 {
177            score += 0.3;
178            reasons.push(format!(
179                "importance {:.2} deviates {:.2} from agent average {:.2}",
180                record.importance, importance_deviation, profile.avg_importance
181            ));
182        }
183
184        // Check content length outlier
185        let content_len = record.content.len() as f64;
186        if profile.avg_content_length > 0.0 {
187            let ratio = content_len / profile.avg_content_length;
188            if !(0.1..=5.0).contains(&ratio) {
189                score += 0.3;
190                reasons.push(format!(
191                    "content length {} is {:.1}x agent average {:.0}",
192                    record.content.len(),
193                    ratio,
194                    profile.avg_content_length
195                ));
196            }
197        }
198
199        // Check high-frequency burst: compare recent write count to expected rate
200        // If agent has N memories over their lifetime, average rate = N / hours_active
201        // A burst is >3x that rate in the last minute
202        // Simplified: if total_memories > 10 and a new memory comes in very quickly
203        // We approximate by checking if total_memories suggests rapid growth
204        if profile.total_memories > 10 {
205            // Parse last_updated to get time window
206            if let Ok(last_updated) = chrono::DateTime::parse_from_rfc3339(&profile.last_updated)
207                && let Ok(created) = chrono::DateTime::parse_from_rfc3339(&record.created_at)
208            {
209                let seconds_since_update = (created - last_updated).num_seconds().max(1);
210                // If profile was updated less than 1 second ago, it's a burst
211                if seconds_since_update < 1 {
212                    score += 0.4;
213                    reasons.push("high-frequency burst detected".to_string());
214                }
215            }
216        }
217    }
218    // If no profile exists yet, we can't detect anomalies — treat as normal
219
220    // Check for prompt injection patterns in content
221    if contains_prompt_injection_patterns(&record.content) {
222        score += 0.5;
223        reasons.push("content contains prompt injection patterns".to_string());
224    }
225
226    // MINJA-class: self-referential instruction phrasing in a record that
227    // arrived through an indirect-ingest path (retrieved doc, web page,
228    // tagged `source:*`). Strong signal — see arXiv:2503.03704.
229    if let Some(marker) = check_self_referential_injection(record) {
230        score += 0.6;
231        reasons.push(format!(
232            "self-referential injection marker '{marker}' in indirectly-ingested record"
233        ));
234    }
235
236    // v0.3.3: embedding-space z-score outlier gate. Runs only when a
237    // baseline has been trained for this agent AND the operator set
238    // `PoisoningPolicy::outlier_threshold`. Catches semantic drift that
239    // lexical markers miss — e.g. an adversarial rewrite that preserves
240    // meaning but pushes the vector off-distribution.
241    if let Some(threshold) = engine.poisoning_policy.outlier_threshold
242        && record.embedding.is_some()
243        && let Some(baseline) = engine
244            .storage
245            .get_embedding_baseline(&record.agent_id)
246            .await?
247    {
248        let out = score_embedding_outlier(record, &baseline, threshold);
249        if out.is_outlier {
250            score += OUTLIER_SCORE_CONTRIBUTION;
251            reasons.push(format!(
252                "embedding z-score {:.2} >= threshold {:.2} (baseline n={}, {} dims >3σ)",
253                out.z_score, out.threshold, out.baseline_n, out.dims_flagged
254            ));
255        }
256    }
257
258    Ok(AnomalyCheckResult {
259        is_anomalous: score >= 0.5,
260        score,
261        reasons,
262    })
263}
264
265/// List every quarantined memory for `agent_id` with `created_at >= since`.
266/// Returns them in chronological order so operators can walk a review
267/// queue deterministically.
268pub async fn replay_quarantine(
269    engine: &MnemoEngine,
270    agent_id: &str,
271    since: Option<&str>,
272) -> Result<Vec<QuarantineReplayEntry>> {
273    let filter = MemoryFilter {
274        agent_id: Some(agent_id.to_string()),
275        // Quarantined records may be soft-deleted if an operator later
276        // hard-purged them via `forget_subject`; we still want visibility.
277        include_deleted: true,
278        ..Default::default()
279    };
280    let records = engine
281        .storage
282        .list_memories(&filter, super::MAX_BATCH_QUERY_LIMIT, 0)
283        .await?;
284    let mut out: Vec<QuarantineReplayEntry> = records
285        .into_iter()
286        .filter(|r| r.quarantined)
287        .filter(|r| match since {
288            None => true,
289            Some(cutoff) => r.created_at.as_str() >= cutoff,
290        })
291        .map(|r| QuarantineReplayEntry {
292            id: r.id,
293            agent_id: r.agent_id,
294            content: r.content,
295            reason: r
296                .quarantine_reason
297                .unwrap_or_else(|| "unspecified".to_string()),
298            created_at: r.created_at,
299            source_type: r.source_type,
300            tags: r.tags,
301        })
302        .collect();
303    out.sort_by(|a, b| a.created_at.cmp(&b.created_at));
304    Ok(out)
305}
306
307/// Mark a memory as quarantined with a reason.
308pub async fn quarantine_memory(engine: &MnemoEngine, id: Uuid, reason: &str) -> Result<()> {
309    if let Some(mut record) = engine.storage.get_memory(id).await? {
310        record.quarantined = true;
311        record.quarantine_reason = Some(reason.to_string());
312        record.updated_at = chrono::Utc::now().to_rfc3339();
313        engine.storage.update_memory(&record).await?;
314    }
315    Ok(())
316}
317
318/// Update the agent profile with statistics from the new memory.
319pub async fn update_agent_profile(engine: &MnemoEngine, record: &MemoryRecord) -> Result<()> {
320    let now = chrono::Utc::now().to_rfc3339();
321    let existing = engine.storage.get_agent_profile(&record.agent_id).await?;
322
323    let profile = match existing {
324        Some(mut p) => {
325            // Incremental mean update
326            let n = p.total_memories as f64;
327            p.avg_importance = (p.avg_importance * n + record.importance as f64) / (n + 1.0);
328            p.avg_content_length =
329                (p.avg_content_length * n + record.content.len() as f64) / (n + 1.0);
330            p.total_memories += 1;
331            p.last_updated = now;
332            p
333        }
334        None => AgentProfile {
335            agent_id: record.agent_id.clone(),
336            avg_importance: record.importance as f64,
337            avg_content_length: record.content.len() as f64,
338            total_memories: 1,
339            last_updated: now,
340        },
341    };
342
343    engine
344        .storage
345        .insert_or_update_agent_profile(&profile)
346        .await?;
347    Ok(())
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_anomaly_result_default() {
356        let result = AnomalyCheckResult {
357            is_anomalous: false,
358            score: 0.0,
359            reasons: vec![],
360        };
361        assert!(!result.is_anomalous);
362        assert_eq!(result.score, 0.0);
363    }
364}