Skip to main content

memory_core/store/
relations.rs

1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5use crate::store::Store;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(rename_all = "snake_case")]
9pub enum RelationType {
10    DerivedFrom,
11    Supersedes,
12    ConflictsWith,
13    RelatedTo,
14}
15
16impl std::fmt::Display for RelationType {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            Self::DerivedFrom => write!(f, "derived_from"),
20            Self::Supersedes => write!(f, "supersedes"),
21            Self::ConflictsWith => write!(f, "conflicts_with"),
22            Self::RelatedTo => write!(f, "related_to"),
23        }
24    }
25}
26
27impl std::str::FromStr for RelationType {
28    type Err = Error;
29    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30        match s {
31            "derived_from" => Ok(Self::DerivedFrom),
32            "supersedes" => Ok(Self::Supersedes),
33            "conflicts_with" => Ok(Self::ConflictsWith),
34            "related_to" => Ok(Self::RelatedTo),
35            other => Err(Error::InvalidInput(format!("unknown relation type: {other}"))),
36        }
37    }
38}
39
40#[derive(Debug, Clone, Serialize)]
41pub struct Relation {
42    pub id: i64,
43    pub source_id: i64,
44    pub target_id: i64,
45    pub relation_type: RelationType,
46    pub created_at: String,
47}
48
49impl Store {
50    pub fn add_relation(
51        &self,
52        source_id: i64,
53        target_id: i64,
54        rel_type: RelationType,
55    ) -> Result<i64> {
56        if source_id == target_id {
57            return Err(Error::InvalidInput(
58                "cannot create self-referential relation".to_string(),
59            ));
60        }
61        // Verify both memories exist
62        self.get(source_id)?;
63        self.get(target_id)?;
64
65        self.conn().execute(
66            "INSERT OR IGNORE INTO relations (source_id, target_id, relation_type) VALUES (?1, ?2, ?3)",
67            params![source_id, target_id, rel_type.to_string()],
68        )?;
69        Ok(self.conn().last_insert_rowid())
70    }
71
72    pub fn get_relations(&self, memory_id: i64) -> Result<Vec<Relation>> {
73        let mut stmt = self.conn().prepare(
74            "SELECT id, source_id, target_id, relation_type, created_at
75             FROM relations WHERE source_id = ?1 OR target_id = ?1",
76        )?;
77        let results = stmt
78            .query_map(params![memory_id], |row| {
79                let rt: String = row.get(3)?;
80                Ok(Relation {
81                    id: row.get(0)?,
82                    source_id: row.get(1)?,
83                    target_id: row.get(2)?,
84                    relation_type: rt.parse().unwrap_or(RelationType::RelatedTo),
85                    created_at: row.get(4)?,
86                })
87            })?
88            .collect::<std::result::Result<Vec<_>, _>>()?;
89        Ok(results)
90    }
91
92    pub fn superseded_ids(&self) -> Result<std::collections::HashSet<i64>> {
93        let mut stmt = self
94            .conn()
95            .prepare("SELECT DISTINCT target_id FROM relations WHERE relation_type = 'supersedes'")?;
96        let ids: Vec<i64> = stmt
97            .query_map([], |row| row.get(0))?
98            .collect::<std::result::Result<Vec<_>, _>>()?;
99        Ok(ids.into_iter().collect())
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    fn make_memory(store: &Store, key: &str, value: &str) -> i64 {
108        use crate::types::SaveParams;
109        store
110            .save(SaveParams {
111                key: key.to_string(),
112                value: value.to_string(),
113                scope: Some("/test".to_string()),
114                source_type: Some(crate::types::SourceType::Explicit),
115                tags: None,
116                session_id: None,
117                source_ref: None,
118                source_commit: None,
119            })
120            .unwrap()
121            .id()
122    }
123
124    #[test]
125    fn add_and_get_relation() {
126        let store = Store::open_in_memory().unwrap();
127        let id_a = make_memory(&store, "mem/a", "value a");
128        let id_b = make_memory(&store, "mem/b", "value b");
129
130        let rel_id = store
131            .add_relation(id_a, id_b, RelationType::RelatedTo)
132            .unwrap();
133        assert!(rel_id > 0);
134
135        let relations = store.get_relations(id_a).unwrap();
136        assert_eq!(relations.len(), 1);
137        assert_eq!(relations[0].source_id, id_a);
138        assert_eq!(relations[0].target_id, id_b);
139        assert_eq!(relations[0].relation_type, RelationType::RelatedTo);
140    }
141
142    #[test]
143    fn get_relations_returns_both_directions() {
144        let store = Store::open_in_memory().unwrap();
145        let id_a = make_memory(&store, "mem/a", "value a");
146        let id_b = make_memory(&store, "mem/b", "value b");
147        let id_c = make_memory(&store, "mem/c", "value c");
148
149        store
150            .add_relation(id_a, id_b, RelationType::DerivedFrom)
151            .unwrap();
152        store
153            .add_relation(id_c, id_a, RelationType::Supersedes)
154            .unwrap();
155
156        let relations = store.get_relations(id_a).unwrap();
157        assert_eq!(relations.len(), 2);
158    }
159
160    #[test]
161    fn self_referential_relation_rejected() {
162        let store = Store::open_in_memory().unwrap();
163        let id_a = make_memory(&store, "mem/a", "value a");
164
165        let err = store
166            .add_relation(id_a, id_a, RelationType::RelatedTo)
167            .unwrap_err();
168        assert!(matches!(err, Error::InvalidInput(_)));
169    }
170
171    #[test]
172    fn duplicate_relation_is_ignored() {
173        let store = Store::open_in_memory().unwrap();
174        let id_a = make_memory(&store, "mem/a", "value a");
175        let id_b = make_memory(&store, "mem/b", "value b");
176
177        store
178            .add_relation(id_a, id_b, RelationType::RelatedTo)
179            .unwrap();
180        // Second insert with INSERT OR IGNORE returns 0 for last_insert_rowid
181        store
182            .add_relation(id_a, id_b, RelationType::RelatedTo)
183            .unwrap();
184
185        let relations = store.get_relations(id_a).unwrap();
186        assert_eq!(relations.len(), 1);
187    }
188
189    #[test]
190    fn superseded_ids_returns_correct_set() {
191        let store = Store::open_in_memory().unwrap();
192        let id_a = make_memory(&store, "mem/a", "value a");
193        let id_b = make_memory(&store, "mem/b", "value b");
194        let id_c = make_memory(&store, "mem/c", "value c");
195
196        store
197            .add_relation(id_a, id_b, RelationType::Supersedes)
198            .unwrap();
199        store
200            .add_relation(id_a, id_c, RelationType::Supersedes)
201            .unwrap();
202
203        let superseded = store.superseded_ids().unwrap();
204        assert!(superseded.contains(&id_b));
205        assert!(superseded.contains(&id_c));
206        assert!(!superseded.contains(&id_a));
207    }
208
209    #[test]
210    fn relation_type_roundtrip() {
211        for (s, expected) in [
212            ("derived_from", RelationType::DerivedFrom),
213            ("supersedes", RelationType::Supersedes),
214            ("conflicts_with", RelationType::ConflictsWith),
215            ("related_to", RelationType::RelatedTo),
216        ] {
217            let parsed: RelationType = s.parse().unwrap();
218            assert_eq!(parsed, expected);
219            assert_eq!(parsed.to_string(), s);
220        }
221        assert!("unknown".parse::<RelationType>().is_err());
222    }
223}