use crate::{MapStorageErr, Storage};
use codemem_core::{CodememError, Edge, GraphNode, NodeKind, RelationshipType};
use rusqlite::{params, OptionalExtension};
use std::collections::HashMap;
pub(crate) type EdgeRow = (
String,
String,
String,
String,
f64,
String,
i64,
Option<i64>,
Option<i64>,
);
pub(crate) fn edge_from_row(row: EdgeRow) -> Option<Edge> {
let (id, src, dst, rel_str, weight, props_str, created_ts, valid_from_ts, valid_to_ts) = row;
let relationship: RelationshipType = match rel_str.parse() {
Ok(r) => r,
Err(_) => {
tracing::warn!(
edge_id = %id,
relationship = %rel_str,
"Dropping edge with unrecognized relationship type"
);
return None;
}
};
let properties: HashMap<String, serde_json::Value> =
serde_json::from_str(&props_str).unwrap_or_default();
let created_at = chrono::DateTime::from_timestamp(created_ts, 0)?.with_timezone(&chrono::Utc);
let valid_from = valid_from_ts
.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
.map(|dt| dt.with_timezone(&chrono::Utc));
let valid_to = valid_to_ts
.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
.map(|dt| dt.with_timezone(&chrono::Utc));
Some(Edge {
id,
src,
dst,
relationship,
weight,
properties,
created_at,
valid_from,
valid_to,
})
}
pub(crate) fn extract_edge_tuple(row: &rusqlite::Row<'_>) -> rusqlite::Result<EdgeRow> {
let rel_str: String = row.get(3)?;
let props_str: String = row.get(5)?;
let created_ts: i64 = row.get(6)?;
let valid_from_ts: Option<i64> = row.get(7)?;
let valid_to_ts: Option<i64> = row.get(8)?;
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
rel_str,
row.get::<_, f64>(4)?,
props_str,
created_ts,
valid_from_ts,
valid_to_ts,
))
}
impl Storage {
pub fn store_embedding(&self, memory_id: &str, embedding: &[f32]) -> Result<(), CodememError> {
let conn = self.conn()?;
let blob: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
conn.execute(
"INSERT OR REPLACE INTO memory_embeddings (memory_id, embedding) VALUES (?1, ?2)",
params![memory_id, blob],
)
.storage_err()?;
Ok(())
}
pub fn get_embedding(&self, memory_id: &str) -> Result<Option<Vec<f32>>, CodememError> {
let conn = self.conn()?;
let blob: Option<Vec<u8>> = conn
.query_row(
"SELECT embedding FROM memory_embeddings WHERE memory_id = ?1",
params![memory_id],
|row| row.get(0),
)
.optional()
.storage_err()?;
match blob {
Some(bytes) => {
let floats: Vec<f32> = bytes
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Ok(Some(floats))
}
None => Ok(None),
}
}
pub fn insert_graph_node(&self, node: &GraphNode) -> Result<(), CodememError> {
let conn = self.conn()?;
let payload_json = serde_json::to_string(&node.payload)?;
conn.execute(
"INSERT OR REPLACE INTO graph_nodes (id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![
node.id,
node.kind.to_string(),
node.label,
payload_json,
node.centrality,
node.memory_id,
node.namespace,
node.valid_from.map(|dt| dt.timestamp()),
node.valid_to.map(|dt| dt.timestamp()),
],
)
.storage_err()?;
Ok(())
}
pub fn get_graph_node(&self, id: &str) -> Result<Option<GraphNode>, CodememError> {
let conn = self.conn()?;
conn.query_row(
"SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes WHERE id = ?1",
params![id],
|row| {
let kind_str: String = row.get(1)?;
let payload_str: String = row.get(3)?;
Ok((
row.get::<_, String>(0)?,
kind_str,
row.get::<_, String>(2)?,
payload_str,
row.get::<_, f64>(4)?,
row.get::<_, Option<String>>(5)?,
row.get::<_, Option<String>>(6)?,
row.get::<_, Option<i64>>(7)?,
row.get::<_, Option<i64>>(8)?,
))
},
)
.optional()
.storage_err()?
.map(|(id, kind_str, label, payload_str, centrality, memory_id, namespace, valid_from_ts, valid_to_ts)| {
let kind: NodeKind = kind_str.parse().map_err(|e: CodememError| CodememError::Storage(e.to_string()))?;
let payload: HashMap<String, serde_json::Value> =
serde_json::from_str(&payload_str).unwrap_or_default();
Ok(GraphNode {
id,
kind,
label,
payload,
centrality,
memory_id,
namespace,
valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
})
})
.transpose()
}
pub fn delete_graph_node(&self, id: &str) -> Result<bool, CodememError> {
let conn = self.conn()?;
let rows = conn
.execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
.storage_err()?;
Ok(rows > 0)
}
pub fn all_graph_nodes(&self) -> Result<Vec<GraphNode>, CodememError> {
let conn = self.conn()?;
let mut stmt = conn
.prepare("SELECT id, kind, label, payload, centrality, memory_id, namespace, valid_from, valid_to FROM graph_nodes")
.storage_err()?;
let rows = stmt
.query_map([], |row| {
let kind_str: String = row.get(1)?;
let payload_str: String = row.get(3)?;
Ok((
row.get::<_, String>(0)?,
kind_str,
row.get::<_, String>(2)?,
payload_str,
row.get::<_, f64>(4)?,
row.get::<_, Option<String>>(5)?,
row.get::<_, Option<String>>(6)?,
row.get::<_, Option<i64>>(7)?,
row.get::<_, Option<i64>>(8)?,
))
})
.storage_err()?;
let mut nodes = Vec::new();
for row_result in rows {
let (
id,
kind_str,
label,
payload_str,
centrality,
memory_id,
namespace,
valid_from_ts,
valid_to_ts,
) = row_result.storage_err()?;
let kind: NodeKind = match kind_str.parse() {
Ok(k) => k,
Err(_) => {
tracing::warn!(
node_id = %id,
kind = %kind_str,
"Skipping graph node with unrecognized kind"
);
continue;
}
};
let payload: HashMap<String, serde_json::Value> =
serde_json::from_str(&payload_str).unwrap_or_default();
nodes.push(GraphNode {
id,
kind,
label,
payload,
centrality,
memory_id,
namespace,
valid_from: valid_from_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
valid_to: valid_to_ts.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0)),
});
}
Ok(nodes)
}
pub fn insert_graph_edge(&self, edge: &Edge) -> Result<(), CodememError> {
let conn = self.conn()?;
let props_json = serde_json::to_string(&edge.properties)?;
conn.execute(
"INSERT OR REPLACE INTO graph_edges (id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)",
params![
edge.id,
edge.src,
edge.dst,
edge.relationship.to_string(),
edge.weight,
props_json,
edge.created_at.timestamp(),
edge.valid_from.map(|dt| dt.timestamp()),
edge.valid_to.map(|dt| dt.timestamp()),
],
)
.storage_err()?;
Ok(())
}
pub fn get_edges_for_node(&self, node_id: &str) -> Result<Vec<Edge>, CodememError> {
let conn = self.conn()?;
let mut stmt = conn
.prepare(
"SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges WHERE src = ?1 OR dst = ?1",
)
.storage_err()?;
let edges = stmt
.query_map(params![node_id], extract_edge_tuple)
.storage_err()?
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!("Failed to process edge row: {e}");
None
}
})
.filter_map(edge_from_row)
.collect();
Ok(edges)
}
pub fn all_graph_edges(&self) -> Result<Vec<Edge>, CodememError> {
let conn = self.conn()?;
let mut stmt = conn
.prepare("SELECT id, src, dst, relationship, weight, properties, created_at, valid_from, valid_to FROM graph_edges")
.storage_err()?;
let edges = stmt
.query_map([], extract_edge_tuple)
.storage_err()?
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!("Failed to process edge row: {e}");
None
}
})
.filter_map(edge_from_row)
.collect();
Ok(edges)
}
pub fn delete_graph_edge(&self, edge_id: &str) -> Result<bool, CodememError> {
let conn = self.conn()?;
let rows = conn
.execute("DELETE FROM graph_edges WHERE id = ?1", params![edge_id])
.storage_err()?;
Ok(rows > 0)
}
pub fn delete_graph_edges_for_node(&self, node_id: &str) -> Result<usize, CodememError> {
let conn = self.conn()?;
let rows = conn
.execute(
"DELETE FROM graph_edges WHERE src = ?1 OR dst = ?1",
params![node_id],
)
.storage_err()?;
Ok(rows)
}
pub fn graph_edges_for_namespace(&self, namespace: &str) -> Result<Vec<Edge>, CodememError> {
self.graph_edges_for_namespace_with_cross(namespace, false)
}
pub fn graph_edges_for_namespace_with_cross(
&self,
namespace: &str,
include_cross_namespace: bool,
) -> Result<Vec<Edge>, CodememError> {
let conn = self.conn()?;
let condition = if include_cross_namespace {
"gs.namespace = ?1 OR gd.namespace = ?1"
} else {
"gs.namespace = ?1 AND gd.namespace = ?1"
};
let sql = format!(
"SELECT e.id, e.src, e.dst, e.relationship, e.weight, e.properties, e.created_at, e.valid_from, e.valid_to
FROM graph_edges e
INNER JOIN graph_nodes gs ON e.src = gs.id
INNER JOIN graph_nodes gd ON e.dst = gd.id
WHERE {condition}"
);
let mut stmt = conn.prepare(&sql).storage_err()?;
let edges = stmt
.query_map(params![namespace], extract_edge_tuple)
.storage_err()?
.filter_map(|r| match r {
Ok(v) => Some(v),
Err(e) => {
tracing::warn!("Failed to process edge row: {e}");
None
}
})
.filter_map(edge_from_row)
.collect();
Ok(edges)
}
pub fn set_namespace_root(&self, namespace: &str, root_path: &str) -> Result<(), CodememError> {
let conn = self.conn()?;
conn.execute(
"INSERT INTO namespace_roots (namespace, root_path, updated_at)
VALUES (?1, ?2, datetime('now'))
ON CONFLICT(namespace) DO UPDATE SET root_path = ?2, updated_at = datetime('now')",
rusqlite::params![namespace, root_path],
)
.storage_err()?;
Ok(())
}
pub fn get_namespace_root(&self, namespace: &str) -> Result<Option<String>, CodememError> {
let conn = self.conn()?;
conn.query_row(
"SELECT root_path FROM namespace_roots WHERE namespace = ?1",
rusqlite::params![namespace],
|row| row.get(0),
)
.optional()
.storage_err()
}
}
#[cfg(test)]
#[path = "tests/graph_persistence_tests.rs"]
mod tests;