steel-memory 0.1.2

A spatial memory palace for AI agents with semantic vector search, knowledge graphs, and MCP tools
Documentation
use std::path::{Path, PathBuf};
use rusqlite::{Connection, params};
use crate::types::{Drawer, SearchResult};

pub struct VectorStorage {
    db_path: PathBuf,
}

impl VectorStorage {
    pub fn new(db_path: &Path) -> anyhow::Result<Self> {
        let conn = open_conn(db_path)?;
        conn.execute_batch(
            "PRAGMA journal_mode=WAL;
             CREATE TABLE IF NOT EXISTS drawers (
                 id TEXT PRIMARY KEY,
                 content TEXT NOT NULL,
                 wing TEXT NOT NULL,
                 room TEXT NOT NULL,
                 source_file TEXT NOT NULL DEFAULT 'mcp',
                 source_mtime INTEGER NOT NULL DEFAULT 0,
                 chunk_index INTEGER NOT NULL DEFAULT 0,
                 added_by TEXT NOT NULL DEFAULT 'mcp',
                 filed_at TEXT NOT NULL,
                 hall TEXT DEFAULT '',
                 topic TEXT DEFAULT '',
                 drawer_type TEXT DEFAULT '',
                 agent TEXT DEFAULT '',
                 date TEXT DEFAULT '',
                 importance REAL DEFAULT 3.0,
                 vector TEXT NOT NULL
             );",
        )?;
        Ok(Self { db_path: db_path.to_path_buf() })
    }

    pub fn add_drawer(&self, drawer: &Drawer, vector: &[f32]) -> anyhow::Result<()> {
        let conn = open_conn(&self.db_path)?;
        let vec_json = serde_json::to_string(vector)?;
        conn.execute(
            "INSERT OR REPLACE INTO drawers
             (id, content, wing, room, source_file, source_mtime, chunk_index, added_by,
              filed_at, hall, topic, drawer_type, agent, date, importance, vector)
             VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16)",
            params![
                drawer.id, drawer.content, drawer.wing, drawer.room,
                drawer.source_file, drawer.source_mtime, drawer.chunk_index, drawer.added_by,
                drawer.filed_at, drawer.hall, drawer.topic, drawer.drawer_type,
                drawer.agent, drawer.date, drawer.importance, vec_json
            ],
        )?;
        Ok(())
    }

    pub fn delete_drawer(&self, id: &str) -> anyhow::Result<usize> {
        let conn = open_conn(&self.db_path)?;
        let n = conn.execute("DELETE FROM drawers WHERE id = ?1", params![id])?;
        Ok(n)
    }

    pub fn search(
        &self,
        query_vec: &[f32],
        limit: usize,
        wing: Option<&str>,
        room: Option<&str>,
    ) -> anyhow::Result<Vec<SearchResult>> {
        let conn = open_conn(&self.db_path)?;
        let mut stmt = conn.prepare(
            "SELECT id, content, wing, room, source_file, source_mtime, chunk_index,
                    added_by, filed_at, hall, topic, drawer_type, agent, date, importance, vector
             FROM drawers",
        )?;
        let rows = stmt.query_map([], |row| {
            Ok((
                row_to_drawer(row)?,
                row.get::<_, String>(15)?,
            ))
        })?;

        let mut results: Vec<SearchResult> = Vec::new();
        for row in rows {
            let (drawer, vec_json) = row?;
            if wing.is_some_and(|w| drawer.wing != w) {
                continue;
            }
            if room.is_some_and(|r| drawer.room != r) {
                continue;
            }
            let stored_vec: Vec<f32> = serde_json::from_str(&vec_json).unwrap_or_default();
            let sim = cosine_similarity(query_vec, &stored_vec);
            results.push(SearchResult { drawer, similarity: sim });
        }
        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
        results.truncate(limit);
        Ok(results)
    }

    pub fn get_all(
        &self,
        wing: Option<&str>,
        room: Option<&str>,
        limit: usize,
    ) -> anyhow::Result<Vec<Drawer>> {
        let conn = open_conn(&self.db_path)?;
        let base = "SELECT id,content,wing,room,source_file,source_mtime,chunk_index,added_by,\
                    filed_at,hall,topic,drawer_type,agent,date,importance,vector FROM drawers";
        let mut conditions: Vec<String> = Vec::new();
        let mut param_values: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();

        if let Some(w) = wing {
            conditions.push(format!("wing=?{}", param_values.len() + 1));
            param_values.push(Box::new(w.to_string()));
        }
        if let Some(r) = room {
            conditions.push(format!("room=?{}", param_values.len() + 1));
            param_values.push(Box::new(r.to_string()));
        }
        let limit_idx = param_values.len() + 1;
        param_values.push(Box::new(limit as i64));

        let sql = if conditions.is_empty() {
            format!("{} ORDER BY importance DESC LIMIT ?{}", base, limit_idx)
        } else {
            format!("{} WHERE {} ORDER BY importance DESC LIMIT ?{}", base, conditions.join(" AND "), limit_idx)
        };

        let mut stmt = conn.prepare(&sql)?;
        let refs: Vec<&dyn rusqlite::ToSql> = param_values.iter().map(|p| p.as_ref()).collect();
        let drawers = stmt.query_map(refs.as_slice(), row_to_drawer)?
            .filter_map(|r| r.ok())
            .collect();
        Ok(drawers)
    }

    pub fn count(&self) -> anyhow::Result<usize> {
        let conn = open_conn(&self.db_path)?;
        let n: i64 = conn.query_row("SELECT COUNT(*) FROM drawers", [], |row| row.get(0))?;
        Ok(n as usize)
    }

    pub fn list_wings(&self) -> anyhow::Result<Vec<(String, usize)>> {
        let conn = open_conn(&self.db_path)?;
        let mut stmt = conn.prepare("SELECT wing, COUNT(*) as cnt FROM drawers GROUP BY wing ORDER BY cnt DESC")?;
        let rows = stmt.query_map([], |row| {
            Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)? as usize))
        })?.filter_map(|r| r.ok()).collect();
        Ok(rows)
    }

    pub fn list_rooms(&self, wing: Option<&str>) -> anyhow::Result<Vec<(String, String, usize)>> {
        let conn = open_conn(&self.db_path)?;
        let rows = if let Some(w) = wing {
            let mut stmt = conn.prepare(
                "SELECT wing, room, COUNT(*) as cnt FROM drawers WHERE wing=?1 GROUP BY wing, room ORDER BY cnt DESC"
            )?;
            stmt.query_map(params![w], |row| {
                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
            })?.filter_map(|r| r.ok()).collect()
        } else {
            let mut stmt = conn.prepare(
                "SELECT wing, room, COUNT(*) as cnt FROM drawers GROUP BY wing, room ORDER BY cnt DESC"
            )?;
            stmt.query_map([], |row| {
                Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?, row.get::<_, i64>(2)? as usize))
            })?.filter_map(|r| r.ok()).collect()
        };
        Ok(rows)
    }

    pub fn get_taxonomy(&self) -> anyhow::Result<serde_json::Value> {
        let rooms = self.list_rooms(None)?;
        let mut map = serde_json::Map::new();
        for (wing, room, count) in rooms {
            let wing_entry = map.entry(wing).or_insert_with(|| serde_json::json!({}));
            if let serde_json::Value::Object(wing_map) = wing_entry {
                wing_map.insert(room, serde_json::json!(count));
            }
        }
        Ok(serde_json::Value::Object(map))
    }
}

fn open_conn(path: &Path) -> anyhow::Result<Connection> {
    let conn = Connection::open(path)?;
    conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
    Ok(conn)
}

fn row_to_drawer(row: &rusqlite::Row<'_>) -> rusqlite::Result<Drawer> {
    Ok(Drawer {
        id: row.get(0)?,
        content: row.get(1)?,
        wing: row.get(2)?,
        room: row.get(3)?,
        source_file: row.get(4)?,
        source_mtime: row.get(5)?,
        chunk_index: row.get(6)?,
        added_by: row.get(7)?,
        filed_at: row.get(8)?,
        hall: row.get(9)?,
        topic: row.get(10)?,
        drawer_type: row.get(11)?,
        agent: row.get(12)?,
        date: row.get(13)?,
        importance: row.get(14)?,
    })
}

pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
    if norm_a == 0.0 || norm_b == 0.0 {
        return 0.0;
    }
    (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
}