Skip to main content

mii_memory/
store.rs

1use std::cmp::Ordering;
2use std::env;
3use std::fs;
4use std::io::{Error as IoError, ErrorKind};
5use std::path::{Path, PathBuf};
6
7use anyhow::{Context, Result, bail};
8use chrono::{DateTime, Utc};
9use rusqlite::{Connection, params};
10use serde::Serialize;
11
12use crate::embedding::{blend, cosine_similarity, decode_embedding, embed_text, encode_embedding};
13use crate::expiration::{fingerprint_for_condition, is_expired, validate_expiration};
14use crate::model::{ExpirationCondition, MemoryMode, normalize_tags};
15
16const SCHEMA_VERSION: i64 = 2;
17const SIMILAR_MEMORY_THRESHOLD: f32 = 0.72;
18const SESSION_ENV: &str = "MII_MEMORY_SESSION";
19const SESSION_PARENT_ENV: &str = "MII_MEMORY_SESSION_PARENT";
20const MCP_SESSION_ENV: &str = "MCP_SESSION_ID";
21
22pub struct MemoryStore {
23    connection: Connection,
24}
25
26#[derive(Debug, Clone)]
27pub struct SetMemory {
28    pub content: String,
29    pub mode: MemoryMode,
30    pub mode_ref: Option<String>,
31    pub tags: Vec<String>,
32    pub expiration_condition: Option<ExpirationCondition>,
33    pub expiration_value: Option<String>,
34    pub metadata: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct SearchOptions {
39    pub query: String,
40    pub positive_tags: Vec<String>,
41    pub negative_tags: Vec<String>,
42    pub limit: usize,
43    pub offset: usize,
44    pub mode: Option<MemoryMode>,
45    pub mode_ref: Option<String>,
46}
47
48#[derive(Debug, Clone, Serialize)]
49pub struct Alert {
50    pub session_ref: String,
51    pub content: String,
52}
53
54impl Default for SearchOptions {
55    fn default() -> Self {
56        Self {
57            query: String::new(),
58            positive_tags: Vec::new(),
59            negative_tags: Vec::new(),
60            limit: 10,
61            offset: 0,
62            mode: None,
63            mode_ref: None,
64        }
65    }
66}
67
68#[derive(Debug, Clone, Serialize)]
69pub struct MemorySearchResult {
70    pub id: i64,
71    pub content: String,
72    pub mode: MemoryMode,
73    pub mode_ref: Option<String>,
74    pub tags: Vec<String>,
75    pub score: f32,
76    pub positive_score: f32,
77    pub negative_score: f32,
78    pub usage_count: i64,
79    pub metadata: Option<String>,
80    pub created_at: DateTime<Utc>,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct TagSummary {
85    pub tag: String,
86    pub count: i64,
87}
88
89#[derive(Debug, Clone, Serialize)]
90pub struct MemoryEntry {
91    pub id: i64,
92    pub content: String,
93    pub mode: MemoryMode,
94    pub mode_ref: Option<String>,
95    pub tags: Vec<String>,
96    pub positive_score: f32,
97    pub negative_score: f32,
98    pub usage_count: i64,
99    pub metadata: Option<String>,
100    pub expiration_condition: Option<ExpirationCondition>,
101    pub expiration_value: Option<String>,
102    pub created_at: DateTime<Utc>,
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub relevance: Option<f32>,
105}
106
107#[derive(Debug, Default, Clone)]
108pub struct BrowseOptions {
109    pub text: Option<String>,
110    pub tags: Vec<String>,
111    pub mode: Option<MemoryMode>,
112    pub limit: usize,
113    pub offset: usize,
114}
115
116#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
117pub struct StoreSignature {
118    pub memory_count: i64,
119    pub max_memory_id: i64,
120    pub last_updated_at: Option<String>,
121    pub alert_count: i64,
122    pub max_alert_id: i64,
123}
124
125impl MemoryStore {
126    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
127        let path = path.as_ref();
128        if let Some(parent) = path
129            .parent()
130            .filter(|parent| !parent.as_os_str().is_empty())
131        {
132            fs::create_dir_all(parent).with_context(|| {
133                format!("failed to create database directory {}", parent.display())
134            })?;
135        }
136
137        let connection = Connection::open(path)
138            .with_context(|| format!("failed to open database {}", path.display()))?;
139        Self::from_connection(connection)
140    }
141
142    pub fn in_memory() -> Result<Self> {
143        Self::from_connection(Connection::open_in_memory()?)
144    }
145
146    fn from_connection(connection: Connection) -> Result<Self> {
147        let mut store = Self { connection };
148        store.migrate()?;
149        Ok(store)
150    }
151
152    pub fn set(&mut self, input: SetMemory) -> Result<i64> {
153        let input = normalize_set_memory(input)?;
154        let now = Utc::now();
155        let created_at = now.to_rfc3339();
156        let content_embedding =
157            embed_text(&input.content).context("failed to embed memory content")?;
158        let tag_text = input.tags.join(" ");
159        let tag_embedding = embed_text(&tag_text).context("failed to embed memory tags")?;
160        let combined_embedding = blend(&content_embedding, &tag_embedding);
161        let file_fingerprint = fingerprint_for_condition(
162            input.expiration_condition,
163            input.expiration_value.as_deref(),
164        )?;
165        let related_updates = self.similar_memory_updates(&combined_embedding, now)?;
166
167        let transaction = self.connection.transaction()?;
168        transaction.execute(
169            "INSERT INTO memories (
170                content, mode, mode_ref, tags_json, expiration_condition, expiration_value,
171                metadata, content_embedding, tag_embedding, combined_embedding,
172                positive_score, negative_score, usage_count, created_at, updated_at,
173                file_fingerprint
174            ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, 0.0, 0.0, 0, ?11, ?11, ?12)",
175            params![
176                input.content,
177                input.mode.as_str(),
178                input.mode_ref,
179                serde_json::to_string(&input.tags)?,
180                input.expiration_condition.map(ExpirationCondition::as_str),
181                input.expiration_value,
182                input.metadata,
183                encode_embedding(&content_embedding),
184                encode_embedding(&tag_embedding),
185                encode_embedding(&combined_embedding),
186                created_at,
187                file_fingerprint,
188            ],
189        )?;
190        let id = transaction.last_insert_rowid();
191
192        for tag in input.tags {
193            transaction.execute(
194                "INSERT OR IGNORE INTO memory_tags (memory_id, tag) VALUES (?1, ?2)",
195                params![id, tag],
196            )?;
197        }
198
199        for (memory_id, penalty) in related_updates {
200            transaction.execute(
201                "UPDATE memories
202                 SET negative_score = negative_score + ?1, updated_at = ?2
203                 WHERE id = ?3",
204                params![penalty, created_at, memory_id],
205            )?;
206        }
207
208        transaction.commit()?;
209        Ok(id)
210    }
211
212    pub fn get(&mut self, options: SearchOptions) -> Result<Vec<MemorySearchResult>> {
213        let options = normalize_search_options(options)?;
214        let now = Utc::now();
215        let query_embedding = embed_text(&options.query).context("failed to embed memory query")?;
216        let query_lower = options.query.to_ascii_lowercase();
217        let mut scored = Vec::new();
218
219        for memory in self.load_memories()? {
220            if !memory.matches_scope(options.mode, options.mode_ref.as_deref()) {
221                continue;
222            }
223
224            if memory.is_expired(now) {
225                continue;
226            }
227
228            if !options
229                .positive_tags
230                .iter()
231                .all(|tag| memory.tags.iter().any(|memory_tag| memory_tag == tag))
232            {
233                continue;
234            }
235
236            let score = score_memory(&memory, &query_embedding, &query_lower, &options);
237            scored.push((memory, score));
238        }
239
240        scored.sort_by(|(left_memory, left_score), (right_memory, right_score)| {
241            right_score
242                .partial_cmp(left_score)
243                .unwrap_or(Ordering::Equal)
244                .then_with(|| right_memory.id.cmp(&left_memory.id))
245        });
246
247        let returned = scored
248            .into_iter()
249            .skip(options.offset)
250            .take(options.limit)
251            .enumerate()
252            .map(|(rank, (memory, score))| (rank, memory, score))
253            .collect::<Vec<_>>();
254
255        self.record_retrievals(&returned, now)?;
256
257        Ok(returned
258            .into_iter()
259            .map(|(_, memory, score)| MemorySearchResult {
260                id: memory.id,
261                content: memory.content,
262                mode: memory.mode,
263                mode_ref: memory.mode_ref,
264                tags: memory.tags,
265                score,
266                positive_score: memory.positive_score,
267                negative_score: memory.negative_score,
268                usage_count: memory.usage_count,
269                metadata: memory.metadata,
270                created_at: memory.created_at,
271            })
272            .collect())
273    }
274
275    pub fn list_tags(&self, filter: Option<&str>) -> Result<Vec<TagSummary>> {
276        let now = Utc::now();
277        let mut summaries = std::collections::BTreeMap::<String, i64>::new();
278        let filter = filter.map(str::trim).filter(|filter| !filter.is_empty());
279        let filter_lower = filter.map(str::to_ascii_lowercase);
280        let filter_embedding = filter
281            .map(embed_text)
282            .transpose()
283            .context("failed to embed tag filter")?;
284
285        for memory in self.load_memories()? {
286            if memory.is_expired(now) {
287                continue;
288            }
289
290            for tag in memory.tags {
291                if let Some(filter_lower) = &filter_lower {
292                    let tag_matches_text = tag.contains(filter_lower);
293                    let tag_matches_embedding = if let Some(filter_embedding) = &filter_embedding {
294                        let tag_embedding =
295                            embed_text(&tag).context("failed to embed memory tag")?;
296                        cosine_similarity(filter_embedding, &tag_embedding) >= 0.2
297                    } else {
298                        false
299                    };
300
301                    if !tag_matches_text && !tag_matches_embedding {
302                        continue;
303                    }
304                }
305
306                *summaries.entry(tag).or_default() += 1;
307            }
308        }
309
310        Ok(summaries
311            .into_iter()
312            .map(|(tag, count)| TagSummary { tag, count })
313            .collect())
314    }
315
316    pub fn set_alert(
317        &mut self,
318        session_ref: impl Into<String>,
319        content: impl Into<String>,
320    ) -> Result<i64> {
321        let session_ref = session_ref_with_configured_parent(session_ref.into())?;
322        let content = normalize_required_text(content.into(), "alert content")?;
323
324        self.connection.execute(
325            "INSERT INTO alerts (session_ref, content) VALUES (?1, ?2)",
326            params![session_ref, content],
327        )?;
328
329        Ok(self.connection.last_insert_rowid())
330    }
331
332    pub fn get_alerts(&mut self, session_ref: impl Into<String>) -> Result<Vec<Alert>> {
333        let session_ref = session_ref_with_configured_parent(session_ref.into())?;
334        let transaction = self.connection.transaction()?;
335        let alerts = {
336            let mut statement =
337                transaction.prepare("SELECT id, session_ref, content FROM alerts ORDER BY id")?;
338            let rows = statement.query_map([], |row| {
339                Ok((
340                    row.get::<_, i64>(0)?,
341                    Alert {
342                        session_ref: row.get(1)?,
343                        content: row.get(2)?,
344                    },
345                ))
346            })?;
347
348            rows.filter_map(|row| match row {
349                Ok((id, alert)) if session_refs_share_lineage(&session_ref, &alert.session_ref) => {
350                    Some(Ok((id, alert)))
351                }
352                Ok(_) => None,
353                Err(error) => Some(Err(error)),
354            })
355            .collect::<Result<Vec<_>, _>>()?
356        };
357
358        for (id, _) in &alerts {
359            transaction.execute("DELETE FROM alerts WHERE id = ?1", params![id])?;
360        }
361
362        transaction.commit()?;
363        Ok(alerts.into_iter().map(|(_, alert)| alert).collect())
364    }
365
366    pub fn browse(&self, options: BrowseOptions) -> Result<Vec<MemoryEntry>> {
367        let now = Utc::now();
368        let text_filter = options
369            .text
370            .as_deref()
371            .map(str::trim)
372            .filter(|text| !text.is_empty())
373            .map(str::to_string);
374        let lowered_text = text_filter.as_deref().map(str::to_ascii_lowercase);
375        let query_embedding = text_filter
376            .as_deref()
377            .map(embed_text)
378            .transpose()
379            .context("failed to embed explorer query")?;
380        let tag_filter = normalize_tags(&options.tags);
381        let limit = if options.limit == 0 {
382            50
383        } else {
384            options.limit
385        };
386
387        let mut entries: Vec<(MemoryEntry, f32, DateTime<Utc>)> = Vec::new();
388        for memory in self.load_memories()? {
389            if memory.is_expired(now) {
390                continue;
391            }
392
393            if let Some(mode) = options.mode
394                && memory.mode != mode
395            {
396                continue;
397            }
398
399            if !tag_filter
400                .iter()
401                .all(|tag| memory.tags.iter().any(|memory_tag| memory_tag == tag))
402            {
403                continue;
404            }
405
406            let mut relevance: Option<f32> = None;
407            if let Some(text) = &lowered_text {
408                let content_lower = memory.content.to_ascii_lowercase();
409                let content_match = content_lower.contains(text);
410                let tag_match = memory
411                    .tags
412                    .iter()
413                    .any(|tag| tag.to_ascii_lowercase().contains(text));
414                let metadata_match = memory
415                    .metadata
416                    .as_deref()
417                    .is_some_and(|metadata| metadata.to_ascii_lowercase().contains(text));
418
419                let semantic = query_embedding
420                    .as_deref()
421                    .map(|embedding| cosine_similarity(embedding, &memory.combined_embedding))
422                    .unwrap_or(0.0);
423                let text_bonus = if content_match {
424                    0.2
425                } else if tag_match || metadata_match {
426                    0.1
427                } else {
428                    0.0
429                };
430                let score = semantic + text_bonus;
431
432                if !content_match && !tag_match && !metadata_match && semantic < 0.25 {
433                    continue;
434                }
435
436                relevance = Some(score.clamp(0.0, 1.2));
437            }
438
439            let created_at = memory.created_at;
440            let entry = MemoryEntry {
441                id: memory.id,
442                content: memory.content,
443                mode: memory.mode,
444                mode_ref: memory.mode_ref,
445                tags: memory.tags,
446                positive_score: memory.positive_score,
447                negative_score: memory.negative_score,
448                usage_count: memory.usage_count,
449                metadata: memory.metadata,
450                expiration_condition: memory.expiration_condition,
451                expiration_value: memory.expiration_value,
452                created_at,
453                relevance,
454            };
455            let sort_score = relevance.unwrap_or(0.0);
456            entries.push((entry, sort_score, created_at));
457        }
458
459        if text_filter.is_some() {
460            entries.sort_by(|left, right| {
461                right
462                    .1
463                    .partial_cmp(&left.1)
464                    .unwrap_or(Ordering::Equal)
465                    .then_with(|| right.2.cmp(&left.2))
466            });
467        } else {
468            entries.sort_by(|left, right| {
469                right
470                    .2
471                    .cmp(&left.2)
472                    .then_with(|| right.0.id.cmp(&left.0.id))
473            });
474        }
475
476        Ok(entries
477            .into_iter()
478            .skip(options.offset)
479            .take(limit)
480            .map(|(entry, _, _)| entry)
481            .collect())
482    }
483
484    pub fn signature(&self) -> Result<StoreSignature> {
485        let (memory_count, max_memory_id, last_updated_at) = self.connection.query_row(
486            "SELECT COUNT(*), COALESCE(MAX(id), 0), MAX(updated_at) FROM memories",
487            [],
488            |row| {
489                Ok((
490                    row.get::<_, i64>(0)?,
491                    row.get::<_, i64>(1)?,
492                    row.get::<_, Option<String>>(2)?,
493                ))
494            },
495        )?;
496        let (alert_count, max_alert_id) = self.connection.query_row(
497            "SELECT COUNT(*), COALESCE(MAX(id), 0) FROM alerts",
498            [],
499            |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)),
500        )?;
501
502        Ok(StoreSignature {
503            memory_count,
504            max_memory_id,
505            last_updated_at,
506            alert_count,
507            max_alert_id,
508        })
509    }
510
511    fn migrate(&mut self) -> Result<()> {
512        self.connection.pragma_update(None, "foreign_keys", "ON")?;
513        let version: i64 = self
514            .connection
515            .pragma_query_value(None, "user_version", |row| row.get(0))?;
516
517        if version > SCHEMA_VERSION {
518            bail!(
519                "database schema version {version} is newer than this binary supports ({SCHEMA_VERSION})"
520            );
521        }
522
523        if version == 0 {
524            let transaction = self.connection.transaction()?;
525            transaction.execute_batch(
526                "CREATE TABLE IF NOT EXISTS memories (
527                    id INTEGER PRIMARY KEY AUTOINCREMENT,
528                    content TEXT NOT NULL,
529                    mode TEXT NOT NULL,
530                    mode_ref TEXT,
531                    tags_json TEXT NOT NULL,
532                    expiration_condition TEXT,
533                    expiration_value TEXT,
534                    metadata TEXT,
535                    content_embedding TEXT NOT NULL,
536                    tag_embedding TEXT NOT NULL,
537                    combined_embedding TEXT NOT NULL,
538                    positive_score REAL NOT NULL DEFAULT 0.0,
539                    negative_score REAL NOT NULL DEFAULT 0.0,
540                    usage_count INTEGER NOT NULL DEFAULT 0,
541                    created_at TEXT NOT NULL,
542                    updated_at TEXT NOT NULL,
543                    file_fingerprint TEXT
544                );
545
546                CREATE TABLE IF NOT EXISTS memory_tags (
547                    memory_id INTEGER NOT NULL,
548                    tag TEXT NOT NULL,
549                    PRIMARY KEY (memory_id, tag),
550                    FOREIGN KEY (memory_id) REFERENCES memories(id) ON DELETE CASCADE
551                );
552
553                CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(mode, mode_ref);
554                CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag);
555                PRAGMA user_version = 1;",
556            )?;
557            transaction.commit()?;
558        }
559
560        let version: i64 = self
561            .connection
562            .pragma_query_value(None, "user_version", |row| row.get(0))?;
563
564        if version == 1 {
565            let transaction = self.connection.transaction()?;
566            transaction.execute_batch(
567                "CREATE TABLE IF NOT EXISTS alerts (
568                    id INTEGER PRIMARY KEY AUTOINCREMENT,
569                    session_ref TEXT NOT NULL,
570                    content TEXT NOT NULL
571                );
572
573                CREATE INDEX IF NOT EXISTS idx_alerts_session_ref ON alerts(session_ref, id);
574                PRAGMA user_version = 2;",
575            )?;
576            transaction.commit()?;
577        }
578
579        Ok(())
580    }
581
582    fn load_memories(&self) -> Result<Vec<MemoryRecord>> {
583        let mut statement = self.connection.prepare(
584            "SELECT
585                id, content, mode, mode_ref, tags_json, expiration_condition, expiration_value,
586                metadata, combined_embedding, positive_score, negative_score, usage_count,
587                created_at, file_fingerprint
588             FROM memories",
589        )?;
590
591        let rows = statement.query_map([], |row| {
592            let mode: String = row.get(2)?;
593            let expiration_condition: Option<String> = row.get(5)?;
594            let created_at: String = row.get(12)?;
595
596            Ok(MemoryRecord {
597                id: row.get(0)?,
598                content: row.get(1)?,
599                mode: mode
600                    .parse()
601                    .map_err(|error| conversion_error(error, "mode"))?,
602                mode_ref: row.get(3)?,
603                tags: serde_json::from_str::<Vec<String>>(&row.get::<_, String>(4)?)
604                    .unwrap_or_default(),
605                expiration_condition: expiration_condition
606                    .as_deref()
607                    .map(str::parse::<ExpirationCondition>)
608                    .transpose()
609                    .map_err(|error| conversion_error(error, "expiration_condition"))?,
610                expiration_value: row.get(6)?,
611                metadata: row.get(7)?,
612                combined_embedding: decode_embedding(&row.get::<_, String>(8)?),
613                positive_score: row.get(9)?,
614                negative_score: row.get(10)?,
615                usage_count: row.get(11)?,
616                created_at: DateTime::parse_from_rfc3339(&created_at)
617                    .map(|datetime| datetime.with_timezone(&Utc))
618                    .map_err(|error| rusqlite::Error::ToSqlConversionFailure(Box::new(error)))?,
619                file_fingerprint: row.get(13)?,
620            })
621        })?;
622
623        rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
624    }
625
626    fn similar_memory_updates(
627        &self,
628        combined_embedding: &[f32],
629        now: DateTime<Utc>,
630    ) -> Result<Vec<(i64, f32)>> {
631        let mut updates = Vec::new();
632
633        for memory in self.load_memories()? {
634            if memory.is_expired(now) {
635                continue;
636            }
637
638            let similarity = cosine_similarity(combined_embedding, &memory.combined_embedding);
639            if similarity >= SIMILAR_MEMORY_THRESHOLD {
640                updates.push((memory.id, similarity));
641            }
642        }
643
644        Ok(updates)
645    }
646
647    fn record_retrievals(
648        &mut self,
649        returned: &[(usize, MemoryRecord, f32)],
650        now: DateTime<Utc>,
651    ) -> Result<()> {
652        if returned.is_empty() {
653            return Ok(());
654        }
655
656        let updated_at = now.to_rfc3339();
657        let transaction = self.connection.transaction()?;
658        for (rank, memory, _) in returned {
659            let gain = 1.0_f32 / (*rank as f32 + 1.0);
660            transaction.execute(
661                "UPDATE memories
662                 SET positive_score = positive_score + ?1,
663                     usage_count = usage_count + 1,
664                     updated_at = ?2
665                 WHERE id = ?3",
666                params![gain, updated_at, memory.id],
667            )?;
668        }
669        transaction.commit()?;
670        Ok(())
671    }
672}
673
674fn conversion_error(error: anyhow::Error, field: &'static str) -> rusqlite::Error {
675    rusqlite::Error::ToSqlConversionFailure(Box::new(IoError::new(
676        ErrorKind::InvalidData,
677        format!("invalid {field}: {error}"),
678    )))
679}
680
681pub fn default_database_path() -> PathBuf {
682    PathBuf::from(".mii-memory.db")
683}
684
685pub fn infer_mode_ref(mode: MemoryMode, explicit: Option<String>) -> Result<Option<String>> {
686    match mode {
687        MemoryMode::Global => Ok(None),
688        MemoryMode::Workspace => {
689            if let Some(explicit) = normalize_optional_text(explicit) {
690                return Ok(Some(explicit));
691            }
692
693            Ok(Some(
694                env::current_dir()
695                    .context("failed to infer workspace mode_ref from current directory")?
696                    .to_string_lossy()
697                    .into_owned(),
698            ))
699        }
700        MemoryMode::Session => Ok(Some(infer_session_ref(explicit)?)),
701    }
702}
703
704pub fn infer_session_ref(explicit: Option<String>) -> Result<String> {
705    let session_ref = normalize_optional_text(explicit)
706        .or_else(|| env_text(SESSION_ENV))
707        .or_else(|| env_text(MCP_SESSION_ENV))
708        .unwrap_or_else(|| "default".to_string());
709
710    session_ref_with_configured_parent(session_ref)
711}
712
713pub fn infer_mcp_session_ref(generated: String) -> Result<String> {
714    mcp_session_ref(
715        generated,
716        env_text(SESSION_ENV),
717        env_text(SESSION_PARENT_ENV),
718    )
719}
720
721fn normalize_set_memory(mut input: SetMemory) -> Result<SetMemory> {
722    input.content = input.content.trim().to_string();
723    if input.content.is_empty() {
724        bail!("memory content cannot be empty");
725    }
726
727    input.tags = normalize_tags(&input.tags);
728    if input.tags.is_empty() {
729        bail!("at least one tag is required");
730    }
731
732    input.mode_ref = infer_mode_ref(input.mode, input.mode_ref)?;
733
734    match (
735        input.expiration_condition,
736        input.expiration_value.as_deref(),
737    ) {
738        (Some(condition), Some(value)) => validate_expiration(condition, value)?,
739        (Some(condition), None) => bail!("expiration condition {condition} requires a value"),
740        (None, Some(_)) => bail!("expiration value was provided without an expiration condition"),
741        (None, None) => {}
742    }
743
744    Ok(input)
745}
746
747fn normalize_required_text(mut value: String, field: &'static str) -> Result<String> {
748    value = value.trim().to_string();
749    if value.is_empty() {
750        bail!("{field} cannot be empty");
751    }
752
753    Ok(value)
754}
755
756fn normalize_optional_text(value: Option<String>) -> Option<String> {
757    value
758        .map(|value| value.trim().to_string())
759        .filter(|value| !value.is_empty())
760}
761
762fn env_text(name: &str) -> Option<String> {
763    env::var(name)
764        .ok()
765        .and_then(|value| normalize_optional_text(Some(value)))
766}
767
768fn session_ref_with_configured_parent(session_ref: String) -> Result<String> {
769    session_ref_with_parent(session_ref, env_text(SESSION_PARENT_ENV))
770}
771
772fn mcp_session_ref(
773    generated: String,
774    configured_session: Option<String>,
775    parent_ref: Option<String>,
776) -> Result<String> {
777    let session_ref = normalize_optional_text(configured_session).unwrap_or(generated);
778
779    session_ref_with_parent(session_ref, parent_ref)
780}
781
782fn session_ref_with_parent(session_ref: String, parent_ref: Option<String>) -> Result<String> {
783    let session_ref = normalize_required_text(session_ref, "session_ref")?;
784    let Some(parent_ref) = normalize_optional_text(parent_ref) else {
785        return Ok(session_ref);
786    };
787
788    if session_ref == parent_ref || session_ref_is_ancestor(&parent_ref, &session_ref) {
789        return Ok(session_ref);
790    }
791
792    Ok(format!("{parent_ref}/{session_ref}"))
793}
794
795fn normalize_search_options(mut options: SearchOptions) -> Result<SearchOptions> {
796    options.query = options.query.trim().to_string();
797    options.positive_tags = normalize_tags(&options.positive_tags);
798    options.negative_tags = normalize_tags(&options.negative_tags);
799    options.mode_ref = options
800        .mode_ref
801        .map(|mode_ref| mode_ref.trim().to_string())
802        .filter(|mode_ref| !mode_ref.is_empty());
803    if options.mode == Some(MemoryMode::Session) {
804        options.mode_ref = Some(infer_session_ref(options.mode_ref)?);
805    }
806    options.limit = options.limit.max(1);
807    Ok(options)
808}
809
810fn session_refs_share_lineage(requested_ref: &str, stored_ref: &str) -> bool {
811    requested_ref == stored_ref
812        || session_ref_is_ancestor(requested_ref, stored_ref)
813        || session_ref_is_ancestor(stored_ref, requested_ref)
814}
815
816fn session_ref_is_ancestor(ancestor: &str, descendant: &str) -> bool {
817    descendant
818        .strip_prefix(ancestor)
819        .is_some_and(|suffix| suffix.starts_with('/'))
820}
821
822fn score_memory(
823    memory: &MemoryRecord,
824    query_embedding: &[f32],
825    query_lower: &str,
826    options: &SearchOptions,
827) -> f32 {
828    let semantic = cosine_similarity(query_embedding, &memory.combined_embedding) * 10.0;
829    let content_lower = memory.content.to_ascii_lowercase();
830    let text_bonus = if !query_lower.is_empty() && content_lower.contains(query_lower) {
831        2.0
832    } else {
833        0.0
834    };
835    let tag_text_bonus =
836        if !query_lower.is_empty() && memory.tags.iter().any(|tag| tag.contains(query_lower)) {
837            1.0
838        } else {
839            0.0
840        };
841    let positive_tag_bonus = options.positive_tags.len() as f32 * 0.35;
842    let negative_tag_penalty = options
843        .negative_tags
844        .iter()
845        .filter(|negative_tag| memory.tags.iter().any(|tag| tag == *negative_tag))
846        .count() as f32
847        * 4.0;
848
849    semantic + text_bonus + tag_text_bonus + positive_tag_bonus + memory.positive_score
850        - memory.negative_score
851        - negative_tag_penalty
852}
853
854#[derive(Debug, Clone)]
855struct MemoryRecord {
856    id: i64,
857    content: String,
858    mode: MemoryMode,
859    mode_ref: Option<String>,
860    tags: Vec<String>,
861    expiration_condition: Option<ExpirationCondition>,
862    expiration_value: Option<String>,
863    metadata: Option<String>,
864    combined_embedding: Vec<f32>,
865    positive_score: f32,
866    negative_score: f32,
867    usage_count: i64,
868    created_at: DateTime<Utc>,
869    file_fingerprint: Option<String>,
870}
871
872impl MemoryRecord {
873    fn matches_scope(&self, mode: Option<MemoryMode>, mode_ref: Option<&str>) -> bool {
874        if mode.is_some_and(|mode| self.mode != mode) {
875            return false;
876        }
877
878        if let Some(mode_ref) = mode_ref {
879            return self.mode_ref.as_deref().is_some_and(|stored_ref| {
880                if self.mode == MemoryMode::Session {
881                    session_refs_share_lineage(mode_ref, stored_ref)
882                } else {
883                    stored_ref == mode_ref
884                }
885            });
886        }
887
888        true
889    }
890
891    fn is_expired(&self, now: DateTime<Utc>) -> bool {
892        is_expired(
893            self.expiration_condition,
894            self.expiration_value.as_deref(),
895            self.created_at,
896            self.usage_count,
897            self.file_fingerprint.as_deref(),
898            now,
899        )
900    }
901}
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906    #[cfg(has_embedded_embeddings)]
907    use std::io::Write;
908
909    fn memory(content: &str, tags: &[&str]) -> SetMemory {
910        SetMemory {
911            content: content.to_string(),
912            mode: MemoryMode::Global,
913            mode_ref: None,
914            tags: tags.iter().map(|tag| tag.to_string()).collect(),
915            expiration_condition: None,
916            expiration_value: None,
917            metadata: None,
918        }
919    }
920
921    #[cfg(has_embedded_embeddings)]
922    fn session_memory(content: &str, session_ref: &str) -> SetMemory {
923        let mut input = memory(content, &["lineage"]);
924        input.mode = MemoryMode::Session;
925        input.mode_ref = Some(session_ref.to_string());
926        input
927    }
928
929    #[cfg(has_embedded_embeddings)]
930    #[test]
931    fn set_get_and_list_tags_round_trip() -> Result<()> {
932        let mut store = MemoryStore::in_memory()?;
933        store.set(memory("Rust sqlite memory backend", &["rust", "sqlite"]))?;
934
935        let results = store.get(SearchOptions {
936            query: "sqlite backend".to_string(),
937            positive_tags: vec!["rust".to_string()],
938            ..SearchOptions::default()
939        })?;
940
941        assert_eq!(results.len(), 1);
942        assert_eq!(results[0].content, "Rust sqlite memory backend");
943
944        let tags = store.list_tags(Some("sql"))?;
945        assert_eq!(tags[0].tag, "sqlite");
946        Ok(())
947    }
948
949    #[cfg(not(has_embedded_embeddings))]
950    #[test]
951    fn set_requires_embeddings_when_not_configured() -> Result<()> {
952        let mut store = MemoryStore::in_memory()?;
953        let error = store
954            .set(memory("Rust sqlite memory backend", &["rust", "sqlite"]))
955            .unwrap_err();
956
957        assert!(
958            error
959                .chain()
960                .any(|cause| cause.to_string().contains("--embeddings <PATH>"))
961        );
962        Ok(())
963    }
964
965    #[cfg(has_embedded_embeddings)]
966    #[test]
967    fn usage_expiration_hides_memory_after_limit() -> Result<()> {
968        let mut store = MemoryStore::in_memory()?;
969        let mut input = memory("single use memory", &["temporary"]);
970        input.expiration_condition = Some(ExpirationCondition::Usage);
971        input.expiration_value = Some("1".to_string());
972        store.set(input)?;
973
974        let first = store.get(SearchOptions {
975            query: "single".to_string(),
976            ..SearchOptions::default()
977        })?;
978        let second = store.get(SearchOptions {
979            query: "single".to_string(),
980            ..SearchOptions::default()
981        })?;
982
983        assert_eq!(first.len(), 1);
984        assert!(second.is_empty());
985        Ok(())
986    }
987
988    #[cfg(has_embedded_embeddings)]
989    #[test]
990    fn file_pristine_expiration_tracks_changes() -> Result<()> {
991        let directory = tempfile::tempdir()?;
992        let file_path = directory.path().join("tracked.txt");
993        fs::write(&file_path, "first")?;
994
995        let mut store = MemoryStore::in_memory()?;
996        let mut input = memory("tracked file state", &["file"]);
997        input.expiration_condition = Some(ExpirationCondition::FilePristine);
998        input.expiration_value = Some(file_path.to_string_lossy().into_owned());
999        store.set(input)?;
1000
1001        assert_eq!(
1002            store
1003                .get(SearchOptions {
1004                    query: "tracked".to_string(),
1005                    ..SearchOptions::default()
1006                })?
1007                .len(),
1008            1
1009        );
1010
1011        let mut file = fs::OpenOptions::new().append(true).open(&file_path)?;
1012        writeln!(file, "changed")?;
1013
1014        assert!(
1015            store
1016                .get(SearchOptions {
1017                    query: "tracked".to_string(),
1018                    ..SearchOptions::default()
1019                })?
1020                .is_empty()
1021        );
1022        Ok(())
1023    }
1024
1025    #[test]
1026    fn alerts_are_session_scoped_and_one_shot() -> Result<()> {
1027        let mut store = MemoryStore::in_memory()?;
1028        store.set_alert("session-a", "remember the summary")?;
1029        store.set_alert("session-b", "other session")?;
1030
1031        let first = store.get_alerts("session-a")?;
1032        let second = store.get_alerts("session-a")?;
1033        let other = store.get_alerts("session-b")?;
1034
1035        assert_eq!(first.len(), 1);
1036        assert_eq!(first[0].content, "remember the summary");
1037        assert!(second.is_empty());
1038        assert_eq!(other.len(), 1);
1039        assert_eq!(other[0].content, "other session");
1040        Ok(())
1041    }
1042
1043    #[cfg(has_embedded_embeddings)]
1044    #[test]
1045    fn session_memories_follow_sub_session_lineage() -> Result<()> {
1046        let mut store = MemoryStore::in_memory()?;
1047        store.set(session_memory("lineage parent note", "parent"))?;
1048        store.set(session_memory("lineage child note", "parent/child"))?;
1049        store.set(session_memory(
1050            "lineage grandchild note",
1051            "parent/child/grandchild",
1052        ))?;
1053        store.set(session_memory("lineage sibling note", "parent/sibling"))?;
1054
1055        let child_results = store.get(SearchOptions {
1056            query: "lineage".to_string(),
1057            mode: Some(MemoryMode::Session),
1058            mode_ref: Some("parent/child".to_string()),
1059            limit: 10,
1060            ..SearchOptions::default()
1061        })?;
1062        let child_contents = child_results
1063            .iter()
1064            .map(|result| result.content.as_str())
1065            .collect::<Vec<_>>();
1066        assert!(child_contents.contains(&"lineage parent note"));
1067        assert!(child_contents.contains(&"lineage child note"));
1068        assert!(child_contents.contains(&"lineage grandchild note"));
1069        assert!(!child_contents.contains(&"lineage sibling note"));
1070
1071        let parent_results = store.get(SearchOptions {
1072            query: "lineage child".to_string(),
1073            mode: Some(MemoryMode::Session),
1074            mode_ref: Some("parent".to_string()),
1075            limit: 10,
1076            ..SearchOptions::default()
1077        })?;
1078        assert!(
1079            parent_results
1080                .iter()
1081                .any(|result| result.content == "lineage child note")
1082        );
1083        Ok(())
1084    }
1085
1086    #[test]
1087    fn session_parent_prefix_is_applied_once() -> Result<()> {
1088        assert_eq!(
1089            session_ref_with_parent("child".to_string(), Some("parent".to_string()))?,
1090            "parent/child"
1091        );
1092        assert_eq!(
1093            session_ref_with_parent("parent/child".to_string(), Some("parent".to_string()))?,
1094            "parent/child"
1095        );
1096        assert_eq!(
1097            session_ref_with_parent("parent".to_string(), Some("parent".to_string()))?,
1098            "parent"
1099        );
1100        assert_eq!(
1101            session_ref_with_parent("other".to_string(), Some("parent/child".to_string()))?,
1102            "parent/child/other"
1103        );
1104        Ok(())
1105    }
1106
1107    #[test]
1108    fn mcp_session_ref_uses_configured_session_or_generated_fallback() -> Result<()> {
1109        assert_eq!(
1110            mcp_session_ref(
1111                "generated".to_string(),
1112                Some("configured".to_string()),
1113                None,
1114            )?,
1115            "configured"
1116        );
1117        assert_eq!(
1118            mcp_session_ref("generated".to_string(), None, None)?,
1119            "generated"
1120        );
1121        assert_eq!(
1122            mcp_session_ref(
1123                "generated".to_string(),
1124                Some("   ".to_string()),
1125                Some("parent".to_string()),
1126            )?,
1127            "parent/generated"
1128        );
1129        assert_eq!(
1130            mcp_session_ref(
1131                "generated".to_string(),
1132                Some("configured".to_string()),
1133                Some("parent".to_string()),
1134            )?,
1135            "parent/configured"
1136        );
1137        Ok(())
1138    }
1139
1140    #[test]
1141    fn alerts_follow_sub_session_lineage_and_remain_one_shot() -> Result<()> {
1142        let mut store = MemoryStore::in_memory()?;
1143        store.set_alert("parent", "parent alert")?;
1144        store.set_alert("parent/child", "child alert")?;
1145        store.set_alert("parent/sibling", "sibling alert")?;
1146
1147        let child_alerts = store.get_alerts("parent/child")?;
1148        let child_contents = child_alerts
1149            .iter()
1150            .map(|alert| alert.content.as_str())
1151            .collect::<Vec<_>>();
1152        assert_eq!(child_contents, vec!["parent alert", "child alert"]);
1153        assert!(store.get_alerts("parent/child")?.is_empty());
1154
1155        let parent_alerts = store.get_alerts("parent")?;
1156        let parent_contents = parent_alerts
1157            .iter()
1158            .map(|alert| alert.content.as_str())
1159            .collect::<Vec<_>>();
1160        assert_eq!(parent_contents, vec!["sibling alert"]);
1161        assert!(store.get_alerts("parent")?.is_empty());
1162        Ok(())
1163    }
1164
1165    #[test]
1166    fn default_database_path_matches_spec() {
1167        assert_eq!(default_database_path(), PathBuf::from(".mii-memory.db"));
1168    }
1169}