1use std::cmp::Ordering;
2use std::env;
3use std::fs;
4use std::io::{Error as IoError, ErrorKind};
5use std::path::{Path, PathBuf};
6
7use anyhow::{Context, Result, bail};
8use chrono::{DateTime, Utc};
9use rusqlite::{Connection, params};
10use serde::Serialize;
11
12use crate::embedding::{blend, cosine_similarity, decode_embedding, embed_text, encode_embedding};
13use crate::expiration::{fingerprint_for_condition, is_expired, validate_expiration};
14use crate::model::{ExpirationCondition, MemoryMode, normalize_tags};
15
16const SCHEMA_VERSION: i64 = 2;
17const SIMILAR_MEMORY_THRESHOLD: f32 = 0.72;
18const SESSION_ENV: &str = "MII_MEMORY_SESSION";
19const SESSION_PARENT_ENV: &str = "MII_MEMORY_SESSION_PARENT";
20const MCP_SESSION_ENV: &str = "MCP_SESSION_ID";
21
22pub struct MemoryStore {
23 connection: Connection,
24}
25
26#[derive(Debug, Clone)]
27pub struct SetMemory {
28 pub content: String,
29 pub mode: MemoryMode,
30 pub mode_ref: Option<String>,
31 pub tags: Vec<String>,
32 pub expiration_condition: Option<ExpirationCondition>,
33 pub expiration_value: Option<String>,
34 pub metadata: Option<String>,
35}
36
37#[derive(Debug, Clone)]
38pub struct SearchOptions {
39 pub query: String,
40 pub positive_tags: Vec<String>,
41 pub negative_tags: Vec<String>,
42 pub limit: usize,
43 pub offset: usize,
44 pub mode: Option<MemoryMode>,
45 pub mode_ref: Option<String>,
46}
47
48#[derive(Debug, Clone, Serialize)]
49pub struct Alert {
50 pub session_ref: String,
51 pub content: String,
52}
53
54impl Default for SearchOptions {
55 fn default() -> Self {
56 Self {
57 query: String::new(),
58 positive_tags: Vec::new(),
59 negative_tags: Vec::new(),
60 limit: 10,
61 offset: 0,
62 mode: None,
63 mode_ref: None,
64 }
65 }
66}
67
68#[derive(Debug, Clone, Serialize)]
69pub struct MemorySearchResult {
70 pub id: i64,
71 pub content: String,
72 pub mode: MemoryMode,
73 pub mode_ref: Option<String>,
74 pub tags: Vec<String>,
75 pub score: f32,
76 pub positive_score: f32,
77 pub negative_score: f32,
78 pub usage_count: i64,
79 pub metadata: Option<String>,
80 pub created_at: DateTime<Utc>,
81}
82
83#[derive(Debug, Clone, Serialize)]
84pub struct TagSummary {
85 pub tag: String,
86 pub count: i64,
87}
88
89#[derive(Debug, Clone, Serialize)]
90pub struct MemoryEntry {
91 pub id: i64,
92 pub content: String,
93 pub mode: MemoryMode,
94 pub mode_ref: Option<String>,
95 pub tags: Vec<String>,
96 pub positive_score: f32,
97 pub negative_score: f32,
98 pub usage_count: i64,
99 pub metadata: Option<String>,
100 pub expiration_condition: Option<ExpirationCondition>,
101 pub expiration_value: Option<String>,
102 pub created_at: DateTime<Utc>,
103 #[serde(skip_serializing_if = "Option::is_none")]
104 pub relevance: Option<f32>,
105}
106
107#[derive(Debug, Default, Clone)]
108pub struct BrowseOptions {
109 pub text: Option<String>,
110 pub tags: Vec<String>,
111 pub mode: Option<MemoryMode>,
112 pub limit: usize,
113 pub offset: usize,
114}
115
116#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
117pub struct StoreSignature {
118 pub memory_count: i64,
119 pub max_memory_id: i64,
120 pub last_updated_at: Option<String>,
121 pub alert_count: i64,
122 pub max_alert_id: i64,
123}
124
125impl MemoryStore {
126 pub fn open(path: impl AsRef<Path>) -> Result<Self> {
127 let path = path.as_ref();
128 if let Some(parent) = path
129 .parent()
130 .filter(|parent| !parent.as_os_str().is_empty())
131 {
132 fs::create_dir_all(parent).with_context(|| {
133 format!("failed to create database directory {}", parent.display())
134 })?;
135 }
136
137 let connection = Connection::open(path)
138 .with_context(|| format!("failed to open database {}", path.display()))?;
139 Self::from_connection(connection)
140 }
141
142 pub fn in_memory() -> Result<Self> {
143 Self::from_connection(Connection::open_in_memory()?)
144 }
145
146 fn from_connection(connection: Connection) -> Result<Self> {
147 let mut store = Self { connection };
148 store.migrate()?;
149 Ok(store)
150 }
151
152 pub fn set(&mut self, input: SetMemory) -> Result<i64> {
153 let input = normalize_set_memory(input)?;
154 let now = Utc::now();
155 let created_at = now.to_rfc3339();
156 let content_embedding =
157 embed_text(&input.content).context("failed to embed memory content")?;
158 let tag_text = input.tags.join(" ");
159 let tag_embedding = embed_text(&tag_text).context("failed to embed memory tags")?;
160 let combined_embedding = blend(&content_embedding, &tag_embedding);
161 let file_fingerprint = fingerprint_for_condition(
162 input.expiration_condition,
163 input.expiration_value.as_deref(),
164 )?;
165 let related_updates = self.similar_memory_updates(&combined_embedding, now)?;
166
167 let transaction = self.connection.transaction()?;
168 transaction.execute(
169 "INSERT INTO memories (
170 content, mode, mode_ref, tags_json, expiration_condition, expiration_value,
171 metadata, content_embedding, tag_embedding, combined_embedding,
172 positive_score, negative_score, usage_count, created_at, updated_at,
173 file_fingerprint
174 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, 0.0, 0.0, 0, ?11, ?11, ?12)",
175 params![
176 input.content,
177 input.mode.as_str(),
178 input.mode_ref,
179 serde_json::to_string(&input.tags)?,
180 input.expiration_condition.map(ExpirationCondition::as_str),
181 input.expiration_value,
182 input.metadata,
183 encode_embedding(&content_embedding),
184 encode_embedding(&tag_embedding),
185 encode_embedding(&combined_embedding),
186 created_at,
187 file_fingerprint,
188 ],
189 )?;
190 let id = transaction.last_insert_rowid();
191
192 for tag in input.tags {
193 transaction.execute(
194 "INSERT OR IGNORE INTO memory_tags (memory_id, tag) VALUES (?1, ?2)",
195 params![id, tag],
196 )?;
197 }
198
199 for (memory_id, penalty) in related_updates {
200 transaction.execute(
201 "UPDATE memories
202 SET negative_score = negative_score + ?1, updated_at = ?2
203 WHERE id = ?3",
204 params![penalty, created_at, memory_id],
205 )?;
206 }
207
208 transaction.commit()?;
209 Ok(id)
210 }
211
212 pub fn get(&mut self, options: SearchOptions) -> Result<Vec<MemorySearchResult>> {
213 let options = normalize_search_options(options)?;
214 let now = Utc::now();
215 let query_embedding = embed_text(&options.query).context("failed to embed memory query")?;
216 let query_lower = options.query.to_ascii_lowercase();
217 let mut scored = Vec::new();
218
219 for memory in self.load_memories()? {
220 if !memory.matches_scope(options.mode, options.mode_ref.as_deref()) {
221 continue;
222 }
223
224 if memory.is_expired(now) {
225 continue;
226 }
227
228 if !options
229 .positive_tags
230 .iter()
231 .all(|tag| memory.tags.iter().any(|memory_tag| memory_tag == tag))
232 {
233 continue;
234 }
235
236 let score = score_memory(&memory, &query_embedding, &query_lower, &options);
237 scored.push((memory, score));
238 }
239
240 scored.sort_by(|(left_memory, left_score), (right_memory, right_score)| {
241 right_score
242 .partial_cmp(left_score)
243 .unwrap_or(Ordering::Equal)
244 .then_with(|| right_memory.id.cmp(&left_memory.id))
245 });
246
247 let returned = scored
248 .into_iter()
249 .skip(options.offset)
250 .take(options.limit)
251 .enumerate()
252 .map(|(rank, (memory, score))| (rank, memory, score))
253 .collect::<Vec<_>>();
254
255 self.record_retrievals(&returned, now)?;
256
257 Ok(returned
258 .into_iter()
259 .map(|(_, memory, score)| MemorySearchResult {
260 id: memory.id,
261 content: memory.content,
262 mode: memory.mode,
263 mode_ref: memory.mode_ref,
264 tags: memory.tags,
265 score,
266 positive_score: memory.positive_score,
267 negative_score: memory.negative_score,
268 usage_count: memory.usage_count,
269 metadata: memory.metadata,
270 created_at: memory.created_at,
271 })
272 .collect())
273 }
274
275 pub fn list_tags(&self, filter: Option<&str>) -> Result<Vec<TagSummary>> {
276 let now = Utc::now();
277 let mut summaries = std::collections::BTreeMap::<String, i64>::new();
278 let filter = filter.map(str::trim).filter(|filter| !filter.is_empty());
279 let filter_lower = filter.map(str::to_ascii_lowercase);
280 let filter_embedding = filter
281 .map(embed_text)
282 .transpose()
283 .context("failed to embed tag filter")?;
284
285 for memory in self.load_memories()? {
286 if memory.is_expired(now) {
287 continue;
288 }
289
290 for tag in memory.tags {
291 if let Some(filter_lower) = &filter_lower {
292 let tag_matches_text = tag.contains(filter_lower);
293 let tag_matches_embedding = if let Some(filter_embedding) = &filter_embedding {
294 let tag_embedding =
295 embed_text(&tag).context("failed to embed memory tag")?;
296 cosine_similarity(filter_embedding, &tag_embedding) >= 0.2
297 } else {
298 false
299 };
300
301 if !tag_matches_text && !tag_matches_embedding {
302 continue;
303 }
304 }
305
306 *summaries.entry(tag).or_default() += 1;
307 }
308 }
309
310 Ok(summaries
311 .into_iter()
312 .map(|(tag, count)| TagSummary { tag, count })
313 .collect())
314 }
315
316 pub fn set_alert(
317 &mut self,
318 session_ref: impl Into<String>,
319 content: impl Into<String>,
320 ) -> Result<i64> {
321 let session_ref = session_ref_with_configured_parent(session_ref.into())?;
322 let content = normalize_required_text(content.into(), "alert content")?;
323
324 self.connection.execute(
325 "INSERT INTO alerts (session_ref, content) VALUES (?1, ?2)",
326 params![session_ref, content],
327 )?;
328
329 Ok(self.connection.last_insert_rowid())
330 }
331
332 pub fn get_alerts(&mut self, session_ref: impl Into<String>) -> Result<Vec<Alert>> {
333 let session_ref = session_ref_with_configured_parent(session_ref.into())?;
334 let transaction = self.connection.transaction()?;
335 let alerts = {
336 let mut statement =
337 transaction.prepare("SELECT id, session_ref, content FROM alerts ORDER BY id")?;
338 let rows = statement.query_map([], |row| {
339 Ok((
340 row.get::<_, i64>(0)?,
341 Alert {
342 session_ref: row.get(1)?,
343 content: row.get(2)?,
344 },
345 ))
346 })?;
347
348 rows.filter_map(|row| match row {
349 Ok((id, alert)) if session_refs_share_lineage(&session_ref, &alert.session_ref) => {
350 Some(Ok((id, alert)))
351 }
352 Ok(_) => None,
353 Err(error) => Some(Err(error)),
354 })
355 .collect::<Result<Vec<_>, _>>()?
356 };
357
358 for (id, _) in &alerts {
359 transaction.execute("DELETE FROM alerts WHERE id = ?1", params![id])?;
360 }
361
362 transaction.commit()?;
363 Ok(alerts.into_iter().map(|(_, alert)| alert).collect())
364 }
365
366 pub fn browse(&self, options: BrowseOptions) -> Result<Vec<MemoryEntry>> {
367 let now = Utc::now();
368 let text_filter = options
369 .text
370 .as_deref()
371 .map(str::trim)
372 .filter(|text| !text.is_empty())
373 .map(str::to_string);
374 let lowered_text = text_filter.as_deref().map(str::to_ascii_lowercase);
375 let query_embedding = text_filter
376 .as_deref()
377 .map(embed_text)
378 .transpose()
379 .context("failed to embed explorer query")?;
380 let tag_filter = normalize_tags(&options.tags);
381 let limit = if options.limit == 0 {
382 50
383 } else {
384 options.limit
385 };
386
387 let mut entries: Vec<(MemoryEntry, f32, DateTime<Utc>)> = Vec::new();
388 for memory in self.load_memories()? {
389 if memory.is_expired(now) {
390 continue;
391 }
392
393 if let Some(mode) = options.mode
394 && memory.mode != mode
395 {
396 continue;
397 }
398
399 if !tag_filter
400 .iter()
401 .all(|tag| memory.tags.iter().any(|memory_tag| memory_tag == tag))
402 {
403 continue;
404 }
405
406 let mut relevance: Option<f32> = None;
407 if let Some(text) = &lowered_text {
408 let content_lower = memory.content.to_ascii_lowercase();
409 let content_match = content_lower.contains(text);
410 let tag_match = memory
411 .tags
412 .iter()
413 .any(|tag| tag.to_ascii_lowercase().contains(text));
414 let metadata_match = memory
415 .metadata
416 .as_deref()
417 .is_some_and(|metadata| metadata.to_ascii_lowercase().contains(text));
418
419 let semantic = query_embedding
420 .as_deref()
421 .map(|embedding| cosine_similarity(embedding, &memory.combined_embedding))
422 .unwrap_or(0.0);
423 let text_bonus = if content_match {
424 0.2
425 } else if tag_match || metadata_match {
426 0.1
427 } else {
428 0.0
429 };
430 let score = semantic + text_bonus;
431
432 if !content_match && !tag_match && !metadata_match && semantic < 0.25 {
433 continue;
434 }
435
436 relevance = Some(score.clamp(0.0, 1.2));
437 }
438
439 let created_at = memory.created_at;
440 let entry = MemoryEntry {
441 id: memory.id,
442 content: memory.content,
443 mode: memory.mode,
444 mode_ref: memory.mode_ref,
445 tags: memory.tags,
446 positive_score: memory.positive_score,
447 negative_score: memory.negative_score,
448 usage_count: memory.usage_count,
449 metadata: memory.metadata,
450 expiration_condition: memory.expiration_condition,
451 expiration_value: memory.expiration_value,
452 created_at,
453 relevance,
454 };
455 let sort_score = relevance.unwrap_or(0.0);
456 entries.push((entry, sort_score, created_at));
457 }
458
459 if text_filter.is_some() {
460 entries.sort_by(|left, right| {
461 right
462 .1
463 .partial_cmp(&left.1)
464 .unwrap_or(Ordering::Equal)
465 .then_with(|| right.2.cmp(&left.2))
466 });
467 } else {
468 entries.sort_by(|left, right| {
469 right
470 .2
471 .cmp(&left.2)
472 .then_with(|| right.0.id.cmp(&left.0.id))
473 });
474 }
475
476 Ok(entries
477 .into_iter()
478 .skip(options.offset)
479 .take(limit)
480 .map(|(entry, _, _)| entry)
481 .collect())
482 }
483
484 pub fn signature(&self) -> Result<StoreSignature> {
485 let (memory_count, max_memory_id, last_updated_at) = self.connection.query_row(
486 "SELECT COUNT(*), COALESCE(MAX(id), 0), MAX(updated_at) FROM memories",
487 [],
488 |row| {
489 Ok((
490 row.get::<_, i64>(0)?,
491 row.get::<_, i64>(1)?,
492 row.get::<_, Option<String>>(2)?,
493 ))
494 },
495 )?;
496 let (alert_count, max_alert_id) = self.connection.query_row(
497 "SELECT COUNT(*), COALESCE(MAX(id), 0) FROM alerts",
498 [],
499 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)),
500 )?;
501
502 Ok(StoreSignature {
503 memory_count,
504 max_memory_id,
505 last_updated_at,
506 alert_count,
507 max_alert_id,
508 })
509 }
510
511 fn migrate(&mut self) -> Result<()> {
512 self.connection.pragma_update(None, "foreign_keys", "ON")?;
513 let version: i64 = self
514 .connection
515 .pragma_query_value(None, "user_version", |row| row.get(0))?;
516
517 if version > SCHEMA_VERSION {
518 bail!(
519 "database schema version {version} is newer than this binary supports ({SCHEMA_VERSION})"
520 );
521 }
522
523 if version == 0 {
524 let transaction = self.connection.transaction()?;
525 transaction.execute_batch(
526 "CREATE TABLE IF NOT EXISTS memories (
527 id INTEGER PRIMARY KEY AUTOINCREMENT,
528 content TEXT NOT NULL,
529 mode TEXT NOT NULL,
530 mode_ref TEXT,
531 tags_json TEXT NOT NULL,
532 expiration_condition TEXT,
533 expiration_value TEXT,
534 metadata TEXT,
535 content_embedding TEXT NOT NULL,
536 tag_embedding TEXT NOT NULL,
537 combined_embedding TEXT NOT NULL,
538 positive_score REAL NOT NULL DEFAULT 0.0,
539 negative_score REAL NOT NULL DEFAULT 0.0,
540 usage_count INTEGER NOT NULL DEFAULT 0,
541 created_at TEXT NOT NULL,
542 updated_at TEXT NOT NULL,
543 file_fingerprint TEXT
544 );
545
546 CREATE TABLE IF NOT EXISTS memory_tags (
547 memory_id INTEGER NOT NULL,
548 tag TEXT NOT NULL,
549 PRIMARY KEY (memory_id, tag),
550 FOREIGN KEY (memory_id) REFERENCES memories(id) ON DELETE CASCADE
551 );
552
553 CREATE INDEX IF NOT EXISTS idx_memories_scope ON memories(mode, mode_ref);
554 CREATE INDEX IF NOT EXISTS idx_memory_tags_tag ON memory_tags(tag);
555 PRAGMA user_version = 1;",
556 )?;
557 transaction.commit()?;
558 }
559
560 let version: i64 = self
561 .connection
562 .pragma_query_value(None, "user_version", |row| row.get(0))?;
563
564 if version == 1 {
565 let transaction = self.connection.transaction()?;
566 transaction.execute_batch(
567 "CREATE TABLE IF NOT EXISTS alerts (
568 id INTEGER PRIMARY KEY AUTOINCREMENT,
569 session_ref TEXT NOT NULL,
570 content TEXT NOT NULL
571 );
572
573 CREATE INDEX IF NOT EXISTS idx_alerts_session_ref ON alerts(session_ref, id);
574 PRAGMA user_version = 2;",
575 )?;
576 transaction.commit()?;
577 }
578
579 Ok(())
580 }
581
582 fn load_memories(&self) -> Result<Vec<MemoryRecord>> {
583 let mut statement = self.connection.prepare(
584 "SELECT
585 id, content, mode, mode_ref, tags_json, expiration_condition, expiration_value,
586 metadata, combined_embedding, positive_score, negative_score, usage_count,
587 created_at, file_fingerprint
588 FROM memories",
589 )?;
590
591 let rows = statement.query_map([], |row| {
592 let mode: String = row.get(2)?;
593 let expiration_condition: Option<String> = row.get(5)?;
594 let created_at: String = row.get(12)?;
595
596 Ok(MemoryRecord {
597 id: row.get(0)?,
598 content: row.get(1)?,
599 mode: mode
600 .parse()
601 .map_err(|error| conversion_error(error, "mode"))?,
602 mode_ref: row.get(3)?,
603 tags: serde_json::from_str::<Vec<String>>(&row.get::<_, String>(4)?)
604 .unwrap_or_default(),
605 expiration_condition: expiration_condition
606 .as_deref()
607 .map(str::parse::<ExpirationCondition>)
608 .transpose()
609 .map_err(|error| conversion_error(error, "expiration_condition"))?,
610 expiration_value: row.get(6)?,
611 metadata: row.get(7)?,
612 combined_embedding: decode_embedding(&row.get::<_, String>(8)?),
613 positive_score: row.get(9)?,
614 negative_score: row.get(10)?,
615 usage_count: row.get(11)?,
616 created_at: DateTime::parse_from_rfc3339(&created_at)
617 .map(|datetime| datetime.with_timezone(&Utc))
618 .map_err(|error| rusqlite::Error::ToSqlConversionFailure(Box::new(error)))?,
619 file_fingerprint: row.get(13)?,
620 })
621 })?;
622
623 rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
624 }
625
626 fn similar_memory_updates(
627 &self,
628 combined_embedding: &[f32],
629 now: DateTime<Utc>,
630 ) -> Result<Vec<(i64, f32)>> {
631 let mut updates = Vec::new();
632
633 for memory in self.load_memories()? {
634 if memory.is_expired(now) {
635 continue;
636 }
637
638 let similarity = cosine_similarity(combined_embedding, &memory.combined_embedding);
639 if similarity >= SIMILAR_MEMORY_THRESHOLD {
640 updates.push((memory.id, similarity));
641 }
642 }
643
644 Ok(updates)
645 }
646
647 fn record_retrievals(
648 &mut self,
649 returned: &[(usize, MemoryRecord, f32)],
650 now: DateTime<Utc>,
651 ) -> Result<()> {
652 if returned.is_empty() {
653 return Ok(());
654 }
655
656 let updated_at = now.to_rfc3339();
657 let transaction = self.connection.transaction()?;
658 for (rank, memory, _) in returned {
659 let gain = 1.0_f32 / (*rank as f32 + 1.0);
660 transaction.execute(
661 "UPDATE memories
662 SET positive_score = positive_score + ?1,
663 usage_count = usage_count + 1,
664 updated_at = ?2
665 WHERE id = ?3",
666 params![gain, updated_at, memory.id],
667 )?;
668 }
669 transaction.commit()?;
670 Ok(())
671 }
672}
673
674fn conversion_error(error: anyhow::Error, field: &'static str) -> rusqlite::Error {
675 rusqlite::Error::ToSqlConversionFailure(Box::new(IoError::new(
676 ErrorKind::InvalidData,
677 format!("invalid {field}: {error}"),
678 )))
679}
680
681pub fn default_database_path() -> PathBuf {
682 PathBuf::from(".mii-memory.db")
683}
684
685pub fn infer_mode_ref(mode: MemoryMode, explicit: Option<String>) -> Result<Option<String>> {
686 match mode {
687 MemoryMode::Global => Ok(None),
688 MemoryMode::Workspace => {
689 if let Some(explicit) = normalize_optional_text(explicit) {
690 return Ok(Some(explicit));
691 }
692
693 Ok(Some(
694 env::current_dir()
695 .context("failed to infer workspace mode_ref from current directory")?
696 .to_string_lossy()
697 .into_owned(),
698 ))
699 }
700 MemoryMode::Session => Ok(Some(infer_session_ref(explicit)?)),
701 }
702}
703
704pub fn infer_session_ref(explicit: Option<String>) -> Result<String> {
705 let session_ref = normalize_optional_text(explicit)
706 .or_else(|| env_text(SESSION_ENV))
707 .or_else(|| env_text(MCP_SESSION_ENV))
708 .unwrap_or_else(|| "default".to_string());
709
710 session_ref_with_configured_parent(session_ref)
711}
712
713pub fn infer_mcp_session_ref(generated: String) -> Result<String> {
714 mcp_session_ref(
715 generated,
716 env_text(SESSION_ENV),
717 env_text(SESSION_PARENT_ENV),
718 )
719}
720
721fn normalize_set_memory(mut input: SetMemory) -> Result<SetMemory> {
722 input.content = input.content.trim().to_string();
723 if input.content.is_empty() {
724 bail!("memory content cannot be empty");
725 }
726
727 input.tags = normalize_tags(&input.tags);
728 if input.tags.is_empty() {
729 bail!("at least one tag is required");
730 }
731
732 input.mode_ref = infer_mode_ref(input.mode, input.mode_ref)?;
733
734 match (
735 input.expiration_condition,
736 input.expiration_value.as_deref(),
737 ) {
738 (Some(condition), Some(value)) => validate_expiration(condition, value)?,
739 (Some(condition), None) => bail!("expiration condition {condition} requires a value"),
740 (None, Some(_)) => bail!("expiration value was provided without an expiration condition"),
741 (None, None) => {}
742 }
743
744 Ok(input)
745}
746
747fn normalize_required_text(mut value: String, field: &'static str) -> Result<String> {
748 value = value.trim().to_string();
749 if value.is_empty() {
750 bail!("{field} cannot be empty");
751 }
752
753 Ok(value)
754}
755
756fn normalize_optional_text(value: Option<String>) -> Option<String> {
757 value
758 .map(|value| value.trim().to_string())
759 .filter(|value| !value.is_empty())
760}
761
762fn env_text(name: &str) -> Option<String> {
763 env::var(name)
764 .ok()
765 .and_then(|value| normalize_optional_text(Some(value)))
766}
767
768fn session_ref_with_configured_parent(session_ref: String) -> Result<String> {
769 session_ref_with_parent(session_ref, env_text(SESSION_PARENT_ENV))
770}
771
772fn mcp_session_ref(
773 generated: String,
774 configured_session: Option<String>,
775 parent_ref: Option<String>,
776) -> Result<String> {
777 let session_ref = normalize_optional_text(configured_session).unwrap_or(generated);
778
779 session_ref_with_parent(session_ref, parent_ref)
780}
781
782fn session_ref_with_parent(session_ref: String, parent_ref: Option<String>) -> Result<String> {
783 let session_ref = normalize_required_text(session_ref, "session_ref")?;
784 let Some(parent_ref) = normalize_optional_text(parent_ref) else {
785 return Ok(session_ref);
786 };
787
788 if session_ref == parent_ref || session_ref_is_ancestor(&parent_ref, &session_ref) {
789 return Ok(session_ref);
790 }
791
792 Ok(format!("{parent_ref}/{session_ref}"))
793}
794
795fn normalize_search_options(mut options: SearchOptions) -> Result<SearchOptions> {
796 options.query = options.query.trim().to_string();
797 options.positive_tags = normalize_tags(&options.positive_tags);
798 options.negative_tags = normalize_tags(&options.negative_tags);
799 options.mode_ref = options
800 .mode_ref
801 .map(|mode_ref| mode_ref.trim().to_string())
802 .filter(|mode_ref| !mode_ref.is_empty());
803 if options.mode == Some(MemoryMode::Session) {
804 options.mode_ref = Some(infer_session_ref(options.mode_ref)?);
805 }
806 options.limit = options.limit.max(1);
807 Ok(options)
808}
809
810fn session_refs_share_lineage(requested_ref: &str, stored_ref: &str) -> bool {
811 requested_ref == stored_ref
812 || session_ref_is_ancestor(requested_ref, stored_ref)
813 || session_ref_is_ancestor(stored_ref, requested_ref)
814}
815
816fn session_ref_is_ancestor(ancestor: &str, descendant: &str) -> bool {
817 descendant
818 .strip_prefix(ancestor)
819 .is_some_and(|suffix| suffix.starts_with('/'))
820}
821
822fn score_memory(
823 memory: &MemoryRecord,
824 query_embedding: &[f32],
825 query_lower: &str,
826 options: &SearchOptions,
827) -> f32 {
828 let semantic = cosine_similarity(query_embedding, &memory.combined_embedding) * 10.0;
829 let content_lower = memory.content.to_ascii_lowercase();
830 let text_bonus = if !query_lower.is_empty() && content_lower.contains(query_lower) {
831 2.0
832 } else {
833 0.0
834 };
835 let tag_text_bonus =
836 if !query_lower.is_empty() && memory.tags.iter().any(|tag| tag.contains(query_lower)) {
837 1.0
838 } else {
839 0.0
840 };
841 let positive_tag_bonus = options.positive_tags.len() as f32 * 0.35;
842 let negative_tag_penalty = options
843 .negative_tags
844 .iter()
845 .filter(|negative_tag| memory.tags.iter().any(|tag| tag == *negative_tag))
846 .count() as f32
847 * 4.0;
848
849 semantic + text_bonus + tag_text_bonus + positive_tag_bonus + memory.positive_score
850 - memory.negative_score
851 - negative_tag_penalty
852}
853
854#[derive(Debug, Clone)]
855struct MemoryRecord {
856 id: i64,
857 content: String,
858 mode: MemoryMode,
859 mode_ref: Option<String>,
860 tags: Vec<String>,
861 expiration_condition: Option<ExpirationCondition>,
862 expiration_value: Option<String>,
863 metadata: Option<String>,
864 combined_embedding: Vec<f32>,
865 positive_score: f32,
866 negative_score: f32,
867 usage_count: i64,
868 created_at: DateTime<Utc>,
869 file_fingerprint: Option<String>,
870}
871
872impl MemoryRecord {
873 fn matches_scope(&self, mode: Option<MemoryMode>, mode_ref: Option<&str>) -> bool {
874 if mode.is_some_and(|mode| self.mode != mode) {
875 return false;
876 }
877
878 if let Some(mode_ref) = mode_ref {
879 return self.mode_ref.as_deref().is_some_and(|stored_ref| {
880 if self.mode == MemoryMode::Session {
881 session_refs_share_lineage(mode_ref, stored_ref)
882 } else {
883 stored_ref == mode_ref
884 }
885 });
886 }
887
888 true
889 }
890
891 fn is_expired(&self, now: DateTime<Utc>) -> bool {
892 is_expired(
893 self.expiration_condition,
894 self.expiration_value.as_deref(),
895 self.created_at,
896 self.usage_count,
897 self.file_fingerprint.as_deref(),
898 now,
899 )
900 }
901}
902
903#[cfg(test)]
904mod tests {
905 use super::*;
906 #[cfg(has_embedded_embeddings)]
907 use std::io::Write;
908
909 fn memory(content: &str, tags: &[&str]) -> SetMemory {
910 SetMemory {
911 content: content.to_string(),
912 mode: MemoryMode::Global,
913 mode_ref: None,
914 tags: tags.iter().map(|tag| tag.to_string()).collect(),
915 expiration_condition: None,
916 expiration_value: None,
917 metadata: None,
918 }
919 }
920
921 #[cfg(has_embedded_embeddings)]
922 fn session_memory(content: &str, session_ref: &str) -> SetMemory {
923 let mut input = memory(content, &["lineage"]);
924 input.mode = MemoryMode::Session;
925 input.mode_ref = Some(session_ref.to_string());
926 input
927 }
928
929 #[cfg(has_embedded_embeddings)]
930 #[test]
931 fn set_get_and_list_tags_round_trip() -> Result<()> {
932 let mut store = MemoryStore::in_memory()?;
933 store.set(memory("Rust sqlite memory backend", &["rust", "sqlite"]))?;
934
935 let results = store.get(SearchOptions {
936 query: "sqlite backend".to_string(),
937 positive_tags: vec!["rust".to_string()],
938 ..SearchOptions::default()
939 })?;
940
941 assert_eq!(results.len(), 1);
942 assert_eq!(results[0].content, "Rust sqlite memory backend");
943
944 let tags = store.list_tags(Some("sql"))?;
945 assert_eq!(tags[0].tag, "sqlite");
946 Ok(())
947 }
948
949 #[cfg(not(has_embedded_embeddings))]
950 #[test]
951 fn set_requires_embeddings_when_not_configured() -> Result<()> {
952 let mut store = MemoryStore::in_memory()?;
953 let error = store
954 .set(memory("Rust sqlite memory backend", &["rust", "sqlite"]))
955 .unwrap_err();
956
957 assert!(
958 error
959 .chain()
960 .any(|cause| cause.to_string().contains("--embeddings <PATH>"))
961 );
962 Ok(())
963 }
964
965 #[cfg(has_embedded_embeddings)]
966 #[test]
967 fn usage_expiration_hides_memory_after_limit() -> Result<()> {
968 let mut store = MemoryStore::in_memory()?;
969 let mut input = memory("single use memory", &["temporary"]);
970 input.expiration_condition = Some(ExpirationCondition::Usage);
971 input.expiration_value = Some("1".to_string());
972 store.set(input)?;
973
974 let first = store.get(SearchOptions {
975 query: "single".to_string(),
976 ..SearchOptions::default()
977 })?;
978 let second = store.get(SearchOptions {
979 query: "single".to_string(),
980 ..SearchOptions::default()
981 })?;
982
983 assert_eq!(first.len(), 1);
984 assert!(second.is_empty());
985 Ok(())
986 }
987
988 #[cfg(has_embedded_embeddings)]
989 #[test]
990 fn file_pristine_expiration_tracks_changes() -> Result<()> {
991 let directory = tempfile::tempdir()?;
992 let file_path = directory.path().join("tracked.txt");
993 fs::write(&file_path, "first")?;
994
995 let mut store = MemoryStore::in_memory()?;
996 let mut input = memory("tracked file state", &["file"]);
997 input.expiration_condition = Some(ExpirationCondition::FilePristine);
998 input.expiration_value = Some(file_path.to_string_lossy().into_owned());
999 store.set(input)?;
1000
1001 assert_eq!(
1002 store
1003 .get(SearchOptions {
1004 query: "tracked".to_string(),
1005 ..SearchOptions::default()
1006 })?
1007 .len(),
1008 1
1009 );
1010
1011 let mut file = fs::OpenOptions::new().append(true).open(&file_path)?;
1012 writeln!(file, "changed")?;
1013
1014 assert!(
1015 store
1016 .get(SearchOptions {
1017 query: "tracked".to_string(),
1018 ..SearchOptions::default()
1019 })?
1020 .is_empty()
1021 );
1022 Ok(())
1023 }
1024
1025 #[test]
1026 fn alerts_are_session_scoped_and_one_shot() -> Result<()> {
1027 let mut store = MemoryStore::in_memory()?;
1028 store.set_alert("session-a", "remember the summary")?;
1029 store.set_alert("session-b", "other session")?;
1030
1031 let first = store.get_alerts("session-a")?;
1032 let second = store.get_alerts("session-a")?;
1033 let other = store.get_alerts("session-b")?;
1034
1035 assert_eq!(first.len(), 1);
1036 assert_eq!(first[0].content, "remember the summary");
1037 assert!(second.is_empty());
1038 assert_eq!(other.len(), 1);
1039 assert_eq!(other[0].content, "other session");
1040 Ok(())
1041 }
1042
1043 #[cfg(has_embedded_embeddings)]
1044 #[test]
1045 fn session_memories_follow_sub_session_lineage() -> Result<()> {
1046 let mut store = MemoryStore::in_memory()?;
1047 store.set(session_memory("lineage parent note", "parent"))?;
1048 store.set(session_memory("lineage child note", "parent/child"))?;
1049 store.set(session_memory(
1050 "lineage grandchild note",
1051 "parent/child/grandchild",
1052 ))?;
1053 store.set(session_memory("lineage sibling note", "parent/sibling"))?;
1054
1055 let child_results = store.get(SearchOptions {
1056 query: "lineage".to_string(),
1057 mode: Some(MemoryMode::Session),
1058 mode_ref: Some("parent/child".to_string()),
1059 limit: 10,
1060 ..SearchOptions::default()
1061 })?;
1062 let child_contents = child_results
1063 .iter()
1064 .map(|result| result.content.as_str())
1065 .collect::<Vec<_>>();
1066 assert!(child_contents.contains(&"lineage parent note"));
1067 assert!(child_contents.contains(&"lineage child note"));
1068 assert!(child_contents.contains(&"lineage grandchild note"));
1069 assert!(!child_contents.contains(&"lineage sibling note"));
1070
1071 let parent_results = store.get(SearchOptions {
1072 query: "lineage child".to_string(),
1073 mode: Some(MemoryMode::Session),
1074 mode_ref: Some("parent".to_string()),
1075 limit: 10,
1076 ..SearchOptions::default()
1077 })?;
1078 assert!(
1079 parent_results
1080 .iter()
1081 .any(|result| result.content == "lineage child note")
1082 );
1083 Ok(())
1084 }
1085
1086 #[test]
1087 fn session_parent_prefix_is_applied_once() -> Result<()> {
1088 assert_eq!(
1089 session_ref_with_parent("child".to_string(), Some("parent".to_string()))?,
1090 "parent/child"
1091 );
1092 assert_eq!(
1093 session_ref_with_parent("parent/child".to_string(), Some("parent".to_string()))?,
1094 "parent/child"
1095 );
1096 assert_eq!(
1097 session_ref_with_parent("parent".to_string(), Some("parent".to_string()))?,
1098 "parent"
1099 );
1100 assert_eq!(
1101 session_ref_with_parent("other".to_string(), Some("parent/child".to_string()))?,
1102 "parent/child/other"
1103 );
1104 Ok(())
1105 }
1106
1107 #[test]
1108 fn mcp_session_ref_uses_configured_session_or_generated_fallback() -> Result<()> {
1109 assert_eq!(
1110 mcp_session_ref(
1111 "generated".to_string(),
1112 Some("configured".to_string()),
1113 None,
1114 )?,
1115 "configured"
1116 );
1117 assert_eq!(
1118 mcp_session_ref("generated".to_string(), None, None)?,
1119 "generated"
1120 );
1121 assert_eq!(
1122 mcp_session_ref(
1123 "generated".to_string(),
1124 Some(" ".to_string()),
1125 Some("parent".to_string()),
1126 )?,
1127 "parent/generated"
1128 );
1129 assert_eq!(
1130 mcp_session_ref(
1131 "generated".to_string(),
1132 Some("configured".to_string()),
1133 Some("parent".to_string()),
1134 )?,
1135 "parent/configured"
1136 );
1137 Ok(())
1138 }
1139
1140 #[test]
1141 fn alerts_follow_sub_session_lineage_and_remain_one_shot() -> Result<()> {
1142 let mut store = MemoryStore::in_memory()?;
1143 store.set_alert("parent", "parent alert")?;
1144 store.set_alert("parent/child", "child alert")?;
1145 store.set_alert("parent/sibling", "sibling alert")?;
1146
1147 let child_alerts = store.get_alerts("parent/child")?;
1148 let child_contents = child_alerts
1149 .iter()
1150 .map(|alert| alert.content.as_str())
1151 .collect::<Vec<_>>();
1152 assert_eq!(child_contents, vec!["parent alert", "child alert"]);
1153 assert!(store.get_alerts("parent/child")?.is_empty());
1154
1155 let parent_alerts = store.get_alerts("parent")?;
1156 let parent_contents = parent_alerts
1157 .iter()
1158 .map(|alert| alert.content.as_str())
1159 .collect::<Vec<_>>();
1160 assert_eq!(parent_contents, vec!["sibling alert"]);
1161 assert!(store.get_alerts("parent")?.is_empty());
1162 Ok(())
1163 }
1164
1165 #[test]
1166 fn default_database_path_matches_spec() {
1167 assert_eq!(default_database_path(), PathBuf::from(".mii-memory.db"));
1168 }
1169}