1use rusqlite::params;
2use serde::{Deserialize, Serialize};
3
4use crate::error::{Error, Result};
5use crate::store::Store;
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
8#[serde(rename_all = "snake_case")]
9pub enum RelationType {
10 DerivedFrom,
11 Supersedes,
12 ConflictsWith,
13 RelatedTo,
14}
15
16impl std::fmt::Display for RelationType {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 Self::DerivedFrom => write!(f, "derived_from"),
20 Self::Supersedes => write!(f, "supersedes"),
21 Self::ConflictsWith => write!(f, "conflicts_with"),
22 Self::RelatedTo => write!(f, "related_to"),
23 }
24 }
25}
26
27impl std::str::FromStr for RelationType {
28 type Err = Error;
29 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
30 match s {
31 "derived_from" => Ok(Self::DerivedFrom),
32 "supersedes" => Ok(Self::Supersedes),
33 "conflicts_with" => Ok(Self::ConflictsWith),
34 "related_to" => Ok(Self::RelatedTo),
35 other => Err(Error::InvalidInput(format!(
36 "unknown relation type: {other}"
37 ))),
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize)]
43pub struct Relation {
44 pub id: i64,
45 pub source_id: i64,
46 pub target_id: i64,
47 pub relation_type: RelationType,
48 pub created_at: String,
49}
50
51impl Store {
52 pub fn add_relation(
53 &self,
54 source_id: i64,
55 target_id: i64,
56 rel_type: RelationType,
57 ) -> Result<i64> {
58 if source_id == target_id {
59 return Err(Error::InvalidInput(
60 "cannot create self-referential relation".to_string(),
61 ));
62 }
63 self.get(source_id)?;
65 self.get(target_id)?;
66
67 self.conn().execute(
68 "INSERT OR IGNORE INTO relations (source_id, target_id, relation_type) VALUES (?1, ?2, ?3)",
69 params![source_id, target_id, rel_type.to_string()],
70 )?;
71 Ok(self.conn().last_insert_rowid())
72 }
73
74 pub fn get_relations(&self, memory_id: i64) -> Result<Vec<Relation>> {
75 let mut stmt = self.conn().prepare(
76 "SELECT id, source_id, target_id, relation_type, created_at
77 FROM relations WHERE source_id = ?1 OR target_id = ?1",
78 )?;
79 let results = stmt
80 .query_map(params![memory_id], |row| {
81 let rt: String = row.get(3)?;
82 Ok(Relation {
83 id: row.get(0)?,
84 source_id: row.get(1)?,
85 target_id: row.get(2)?,
86 relation_type: rt.parse().unwrap_or(RelationType::RelatedTo),
87 created_at: row.get(4)?,
88 })
89 })?
90 .collect::<std::result::Result<Vec<_>, _>>()?;
91 Ok(results)
92 }
93
94 pub fn superseded_ids(&self) -> Result<std::collections::HashSet<i64>> {
95 let mut stmt = self.conn().prepare(
96 "SELECT DISTINCT target_id FROM relations WHERE relation_type = 'supersedes'",
97 )?;
98 let ids: Vec<i64> = stmt
99 .query_map([], |row| row.get(0))?
100 .collect::<std::result::Result<Vec<_>, _>>()?;
101 Ok(ids.into_iter().collect())
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108
109 fn make_memory(store: &Store, key: &str, value: &str) -> i64 {
110 use crate::types::SaveParams;
111 store
112 .save(SaveParams {
113 key: key.to_string(),
114 value: value.to_string(),
115 scope: Some("/test".to_string()),
116 source_type: Some(crate::types::SourceType::Explicit),
117 tags: None,
118 source_ref: None,
119 source_commit: None,
120 })
121 .unwrap()
122 .id()
123 }
124
125 #[test]
126 fn add_and_get_relation() {
127 let store = Store::open_in_memory().unwrap();
128 let id_a = make_memory(&store, "mem/a", "value a");
129 let id_b = make_memory(&store, "mem/b", "value b");
130
131 let rel_id = store
132 .add_relation(id_a, id_b, RelationType::RelatedTo)
133 .unwrap();
134 assert!(rel_id > 0);
135
136 let relations = store.get_relations(id_a).unwrap();
137 assert_eq!(relations.len(), 1);
138 assert_eq!(relations[0].source_id, id_a);
139 assert_eq!(relations[0].target_id, id_b);
140 assert_eq!(relations[0].relation_type, RelationType::RelatedTo);
141 }
142
143 #[test]
144 fn get_relations_returns_both_directions() {
145 let store = Store::open_in_memory().unwrap();
146 let id_a = make_memory(&store, "mem/a", "value a");
147 let id_b = make_memory(&store, "mem/b", "value b");
148 let id_c = make_memory(&store, "mem/c", "value c");
149
150 store
151 .add_relation(id_a, id_b, RelationType::DerivedFrom)
152 .unwrap();
153 store
154 .add_relation(id_c, id_a, RelationType::Supersedes)
155 .unwrap();
156
157 let relations = store.get_relations(id_a).unwrap();
158 assert_eq!(relations.len(), 2);
159 }
160
161 #[test]
162 fn self_referential_relation_rejected() {
163 let store = Store::open_in_memory().unwrap();
164 let id_a = make_memory(&store, "mem/a", "value a");
165
166 let err = store
167 .add_relation(id_a, id_a, RelationType::RelatedTo)
168 .unwrap_err();
169 assert!(matches!(err, Error::InvalidInput(_)));
170 }
171
172 #[test]
173 fn duplicate_relation_is_ignored() {
174 let store = Store::open_in_memory().unwrap();
175 let id_a = make_memory(&store, "mem/a", "value a");
176 let id_b = make_memory(&store, "mem/b", "value b");
177
178 store
179 .add_relation(id_a, id_b, RelationType::RelatedTo)
180 .unwrap();
181 store
183 .add_relation(id_a, id_b, RelationType::RelatedTo)
184 .unwrap();
185
186 let relations = store.get_relations(id_a).unwrap();
187 assert_eq!(relations.len(), 1);
188 }
189
190 #[test]
191 fn superseded_ids_returns_correct_set() {
192 let store = Store::open_in_memory().unwrap();
193 let id_a = make_memory(&store, "mem/a", "value a");
194 let id_b = make_memory(&store, "mem/b", "value b");
195 let id_c = make_memory(&store, "mem/c", "value c");
196
197 store
198 .add_relation(id_a, id_b, RelationType::Supersedes)
199 .unwrap();
200 store
201 .add_relation(id_a, id_c, RelationType::Supersedes)
202 .unwrap();
203
204 let superseded = store.superseded_ids().unwrap();
205 assert!(superseded.contains(&id_b));
206 assert!(superseded.contains(&id_c));
207 assert!(!superseded.contains(&id_a));
208 }
209
210 #[test]
211 fn relation_type_roundtrip() {
212 for (s, expected) in [
213 ("derived_from", RelationType::DerivedFrom),
214 ("supersedes", RelationType::Supersedes),
215 ("conflicts_with", RelationType::ConflictsWith),
216 ("related_to", RelationType::RelatedTo),
217 ] {
218 let parsed: RelationType = s.parse().unwrap();
219 assert_eq!(parsed, expected);
220 assert_eq!(parsed.to_string(), s);
221 }
222 assert!("unknown".parse::<RelationType>().is_err());
223 }
224}