1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5use crate::store::Store;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(rename_all = "snake_case")]
9pub enum RelationType {
10 DerivedFrom,
11 Supersedes,
12 ConflictsWith,
13 RelatedTo,
14}
15
16impl std::fmt::Display for RelationType {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 Self::DerivedFrom => write!(f, "derived_from"),
20 Self::Supersedes => write!(f, "supersedes"),
21 Self::ConflictsWith => write!(f, "conflicts_with"),
22 Self::RelatedTo => write!(f, "related_to"),
23 }
24 }
25}
26
27impl std::str::FromStr for RelationType {
28 type Err = Error;
29 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30 match s {
31 "derived_from" => Ok(Self::DerivedFrom),
32 "supersedes" => Ok(Self::Supersedes),
33 "conflicts_with" => Ok(Self::ConflictsWith),
34 "related_to" => Ok(Self::RelatedTo),
35 other => Err(Error::InvalidInput(format!("unknown relation type: {other}"))),
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize)]
41pub struct Relation {
42 pub id: i64,
43 pub source_id: i64,
44 pub target_id: i64,
45 pub relation_type: RelationType,
46 pub created_at: String,
47}
48
49impl Store {
50 pub fn add_relation(
51 &self,
52 source_id: i64,
53 target_id: i64,
54 rel_type: RelationType,
55 ) -> Result<i64> {
56 if source_id == target_id {
57 return Err(Error::InvalidInput(
58 "cannot create self-referential relation".to_string(),
59 ));
60 }
61 self.get(source_id)?;
63 self.get(target_id)?;
64
65 self.conn().execute(
66 "INSERT OR IGNORE INTO relations (source_id, target_id, relation_type) VALUES (?1, ?2, ?3)",
67 params![source_id, target_id, rel_type.to_string()],
68 )?;
69 Ok(self.conn().last_insert_rowid())
70 }
71
72 pub fn get_relations(&self, memory_id: i64) -> Result<Vec<Relation>> {
73 let mut stmt = self.conn().prepare(
74 "SELECT id, source_id, target_id, relation_type, created_at
75 FROM relations WHERE source_id = ?1 OR target_id = ?1",
76 )?;
77 let results = stmt
78 .query_map(params![memory_id], |row| {
79 let rt: String = row.get(3)?;
80 Ok(Relation {
81 id: row.get(0)?,
82 source_id: row.get(1)?,
83 target_id: row.get(2)?,
84 relation_type: rt.parse().unwrap_or(RelationType::RelatedTo),
85 created_at: row.get(4)?,
86 })
87 })?
88 .collect::<std::result::Result<Vec<_>, _>>()?;
89 Ok(results)
90 }
91
92 pub fn superseded_ids(&self) -> Result<std::collections::HashSet<i64>> {
93 let mut stmt = self
94 .conn()
95 .prepare("SELECT DISTINCT target_id FROM relations WHERE relation_type = 'supersedes'")?;
96 let ids: Vec<i64> = stmt
97 .query_map([], |row| row.get(0))?
98 .collect::<std::result::Result<Vec<_>, _>>()?;
99 Ok(ids.into_iter().collect())
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 fn make_memory(store: &Store, key: &str, value: &str) -> i64 {
108 use crate::types::SaveParams;
109 store
110 .save(SaveParams {
111 key: key.to_string(),
112 value: value.to_string(),
113 scope: Some("/test".to_string()),
114 source_type: Some(crate::types::SourceType::Explicit),
115 tags: None,
116 session_id: None,
117 source_ref: None,
118 source_commit: None,
119 })
120 .unwrap()
121 .id()
122 }
123
124 #[test]
125 fn add_and_get_relation() {
126 let store = Store::open_in_memory().unwrap();
127 let id_a = make_memory(&store, "mem/a", "value a");
128 let id_b = make_memory(&store, "mem/b", "value b");
129
130 let rel_id = store
131 .add_relation(id_a, id_b, RelationType::RelatedTo)
132 .unwrap();
133 assert!(rel_id > 0);
134
135 let relations = store.get_relations(id_a).unwrap();
136 assert_eq!(relations.len(), 1);
137 assert_eq!(relations[0].source_id, id_a);
138 assert_eq!(relations[0].target_id, id_b);
139 assert_eq!(relations[0].relation_type, RelationType::RelatedTo);
140 }
141
142 #[test]
143 fn get_relations_returns_both_directions() {
144 let store = Store::open_in_memory().unwrap();
145 let id_a = make_memory(&store, "mem/a", "value a");
146 let id_b = make_memory(&store, "mem/b", "value b");
147 let id_c = make_memory(&store, "mem/c", "value c");
148
149 store
150 .add_relation(id_a, id_b, RelationType::DerivedFrom)
151 .unwrap();
152 store
153 .add_relation(id_c, id_a, RelationType::Supersedes)
154 .unwrap();
155
156 let relations = store.get_relations(id_a).unwrap();
157 assert_eq!(relations.len(), 2);
158 }
159
160 #[test]
161 fn self_referential_relation_rejected() {
162 let store = Store::open_in_memory().unwrap();
163 let id_a = make_memory(&store, "mem/a", "value a");
164
165 let err = store
166 .add_relation(id_a, id_a, RelationType::RelatedTo)
167 .unwrap_err();
168 assert!(matches!(err, Error::InvalidInput(_)));
169 }
170
171 #[test]
172 fn duplicate_relation_is_ignored() {
173 let store = Store::open_in_memory().unwrap();
174 let id_a = make_memory(&store, "mem/a", "value a");
175 let id_b = make_memory(&store, "mem/b", "value b");
176
177 store
178 .add_relation(id_a, id_b, RelationType::RelatedTo)
179 .unwrap();
180 store
182 .add_relation(id_a, id_b, RelationType::RelatedTo)
183 .unwrap();
184
185 let relations = store.get_relations(id_a).unwrap();
186 assert_eq!(relations.len(), 1);
187 }
188
189 #[test]
190 fn superseded_ids_returns_correct_set() {
191 let store = Store::open_in_memory().unwrap();
192 let id_a = make_memory(&store, "mem/a", "value a");
193 let id_b = make_memory(&store, "mem/b", "value b");
194 let id_c = make_memory(&store, "mem/c", "value c");
195
196 store
197 .add_relation(id_a, id_b, RelationType::Supersedes)
198 .unwrap();
199 store
200 .add_relation(id_a, id_c, RelationType::Supersedes)
201 .unwrap();
202
203 let superseded = store.superseded_ids().unwrap();
204 assert!(superseded.contains(&id_b));
205 assert!(superseded.contains(&id_c));
206 assert!(!superseded.contains(&id_a));
207 }
208
209 #[test]
210 fn relation_type_roundtrip() {
211 for (s, expected) in [
212 ("derived_from", RelationType::DerivedFrom),
213 ("supersedes", RelationType::Supersedes),
214 ("conflicts_with", RelationType::ConflictsWith),
215 ("related_to", RelationType::RelatedTo),
216 ] {
217 let parsed: RelationType = s.parse().unwrap();
218 assert_eq!(parsed, expected);
219 assert_eq!(parsed.to_string(), s);
220 }
221 assert!("unknown".parse::<RelationType>().is_err());
222 }
223}