Skip to main content

memory_core/store/
relations.rs

1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5use crate::store::Store;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(rename_all = "snake_case")]
9pub enum RelationType {
10    DerivedFrom,
11    Supersedes,
12    ConflictsWith,
13    RelatedTo,
14}
15
16impl std::fmt::Display for RelationType {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            Self::DerivedFrom => write!(f, "derived_from"),
20            Self::Supersedes => write!(f, "supersedes"),
21            Self::ConflictsWith => write!(f, "conflicts_with"),
22            Self::RelatedTo => write!(f, "related_to"),
23        }
24    }
25}
26
27impl std::str::FromStr for RelationType {
28    type Err = Error;
29    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30        match s {
31            "derived_from" => Ok(Self::DerivedFrom),
32            "supersedes" => Ok(Self::Supersedes),
33            "conflicts_with" => Ok(Self::ConflictsWith),
34            "related_to" => Ok(Self::RelatedTo),
35            other => Err(Error::InvalidInput(format!(
36                "unknown relation type: {other}"
37            ))),
38        }
39    }
40}
41
42#[derive(Debug, Clone, Serialize)]
43pub struct Relation {
44    pub id: i64,
45    pub source_id: i64,
46    pub target_id: i64,
47    pub relation_type: RelationType,
48    pub created_at: String,
49}
50
51impl Store {
52    pub fn add_relation(
53        &self,
54        source_id: i64,
55        target_id: i64,
56        rel_type: RelationType,
57    ) -> Result<i64> {
58        if source_id == target_id {
59            return Err(Error::InvalidInput(
60                "cannot create self-referential relation".to_string(),
61            ));
62        }
63        // Verify both memories exist
64        self.get(source_id)?;
65        self.get(target_id)?;
66
67        self.conn().execute(
68            "INSERT OR IGNORE INTO relations (source_id, target_id, relation_type) VALUES (?1, ?2, ?3)",
69            params![source_id, target_id, rel_type.to_string()],
70        )?;
71        Ok(self.conn().last_insert_rowid())
72    }
73
74    pub fn get_relations(&self, memory_id: i64) -> Result<Vec<Relation>> {
75        let mut stmt = self.conn().prepare(
76            "SELECT id, source_id, target_id, relation_type, created_at
77             FROM relations WHERE source_id = ?1 OR target_id = ?1",
78        )?;
79        let results = stmt
80            .query_map(params![memory_id], |row| {
81                let rt: String = row.get(3)?;
82                Ok(Relation {
83                    id: row.get(0)?,
84                    source_id: row.get(1)?,
85                    target_id: row.get(2)?,
86                    relation_type: rt.parse().unwrap_or(RelationType::RelatedTo),
87                    created_at: row.get(4)?,
88                })
89            })?
90            .collect::<std::result::Result<Vec<_>, _>>()?;
91        Ok(results)
92    }
93
94    pub fn superseded_ids(&self) -> Result<std::collections::HashSet<i64>> {
95        let mut stmt = self.conn().prepare(
96            "SELECT DISTINCT target_id FROM relations WHERE relation_type = 'supersedes'",
97        )?;
98        let ids: Vec<i64> = stmt
99            .query_map([], |row| row.get(0))?
100            .collect::<std::result::Result<Vec<_>, _>>()?;
101        Ok(ids.into_iter().collect())
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108
109    fn make_memory(store: &Store, key: &str, value: &str) -> i64 {
110        use crate::types::SaveParams;
111        store
112            .save(SaveParams {
113                key: key.to_string(),
114                value: value.to_string(),
115                scope: Some("/test".to_string()),
116                source_type: Some(crate::types::SourceType::Explicit),
117                tags: None,
118                source_ref: None,
119                source_commit: None,
120            })
121            .unwrap()
122            .id()
123    }
124
125    #[test]
126    fn add_and_get_relation() {
127        let store = Store::open_in_memory().unwrap();
128        let id_a = make_memory(&store, "mem/a", "value a");
129        let id_b = make_memory(&store, "mem/b", "value b");
130
131        let rel_id = store
132            .add_relation(id_a, id_b, RelationType::RelatedTo)
133            .unwrap();
134        assert!(rel_id > 0);
135
136        let relations = store.get_relations(id_a).unwrap();
137        assert_eq!(relations.len(), 1);
138        assert_eq!(relations[0].source_id, id_a);
139        assert_eq!(relations[0].target_id, id_b);
140        assert_eq!(relations[0].relation_type, RelationType::RelatedTo);
141    }
142
143    #[test]
144    fn get_relations_returns_both_directions() {
145        let store = Store::open_in_memory().unwrap();
146        let id_a = make_memory(&store, "mem/a", "value a");
147        let id_b = make_memory(&store, "mem/b", "value b");
148        let id_c = make_memory(&store, "mem/c", "value c");
149
150        store
151            .add_relation(id_a, id_b, RelationType::DerivedFrom)
152            .unwrap();
153        store
154            .add_relation(id_c, id_a, RelationType::Supersedes)
155            .unwrap();
156
157        let relations = store.get_relations(id_a).unwrap();
158        assert_eq!(relations.len(), 2);
159    }
160
161    #[test]
162    fn self_referential_relation_rejected() {
163        let store = Store::open_in_memory().unwrap();
164        let id_a = make_memory(&store, "mem/a", "value a");
165
166        let err = store
167            .add_relation(id_a, id_a, RelationType::RelatedTo)
168            .unwrap_err();
169        assert!(matches!(err, Error::InvalidInput(_)));
170    }
171
172    #[test]
173    fn duplicate_relation_is_ignored() {
174        let store = Store::open_in_memory().unwrap();
175        let id_a = make_memory(&store, "mem/a", "value a");
176        let id_b = make_memory(&store, "mem/b", "value b");
177
178        store
179            .add_relation(id_a, id_b, RelationType::RelatedTo)
180            .unwrap();
181        // Second insert with INSERT OR IGNORE returns 0 for last_insert_rowid
182        store
183            .add_relation(id_a, id_b, RelationType::RelatedTo)
184            .unwrap();
185
186        let relations = store.get_relations(id_a).unwrap();
187        assert_eq!(relations.len(), 1);
188    }
189
190    #[test]
191    fn superseded_ids_returns_correct_set() {
192        let store = Store::open_in_memory().unwrap();
193        let id_a = make_memory(&store, "mem/a", "value a");
194        let id_b = make_memory(&store, "mem/b", "value b");
195        let id_c = make_memory(&store, "mem/c", "value c");
196
197        store
198            .add_relation(id_a, id_b, RelationType::Supersedes)
199            .unwrap();
200        store
201            .add_relation(id_a, id_c, RelationType::Supersedes)
202            .unwrap();
203
204        let superseded = store.superseded_ids().unwrap();
205        assert!(superseded.contains(&id_b));
206        assert!(superseded.contains(&id_c));
207        assert!(!superseded.contains(&id_a));
208    }
209
210    #[test]
211    fn relation_type_roundtrip() {
212        for (s, expected) in [
213            ("derived_from", RelationType::DerivedFrom),
214            ("supersedes", RelationType::Supersedes),
215            ("conflicts_with", RelationType::ConflictsWith),
216            ("related_to", RelationType::RelatedTo),
217        ] {
218            let parsed: RelationType = s.parse().unwrap();
219            assert_eq!(parsed, expected);
220            assert_eq!(parsed.to_string(), s);
221        }
222        assert!("unknown".parse::<RelationType>().is_err());
223    }
224}