1use crate::{Error, Result};
4use serde::{Deserialize, Serialize};
5use sqlx::SqlitePool;
6use std::collections::{HashMap, HashSet};
7use std::str::FromStr;
8use tracing::debug;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "lowercase")]
13pub enum RelationType {
14 Uses,
16 Extends,
18 Conflicts,
20 Requires,
22}
23
24impl FromStr for RelationType {
25 type Err = Error;
26
27 fn from_str(s: &str) -> Result<Self> {
28 match s.to_lowercase().as_str() {
29 "uses" => Ok(RelationType::Uses),
30 "extends" => Ok(RelationType::Extends),
31 "conflicts" => Ok(RelationType::Conflicts),
32 "requires" => Ok(RelationType::Requires),
33 _ => Err(Error::InvalidRelationType(s.to_string())),
34 }
35 }
36}
37
38impl RelationType {
39 pub fn as_str(&self) -> &'static str {
41 match self {
42 RelationType::Uses => "uses",
43 RelationType::Extends => "extends",
44 RelationType::Conflicts => "conflicts",
45 RelationType::Requires => "requires",
46 }
47 }
48
49 pub fn all() -> &'static [RelationType] {
51 &[
52 RelationType::Uses,
53 RelationType::Extends,
54 RelationType::Conflicts,
55 RelationType::Requires,
56 ]
57 }
58}
59
60impl std::fmt::Display for RelationType {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 write!(f, "{}", self.as_str())
63 }
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct Relation {
69 pub from_id: String,
70 pub to_id: String,
71 pub relation_type: RelationType,
72 pub metadata: Option<String>,
73 pub created_at: i64,
74}
75
76#[derive(Clone)]
78pub struct GraphOperations {
79 pool: SqlitePool,
80}
81
82impl GraphOperations {
83 pub(crate) fn new(pool: SqlitePool) -> Self {
85 Self { pool }
86 }
87
88 pub async fn create_relation(
117 &self,
118 from_id: &str,
119 to_id: &str,
120 relation_type: RelationType,
121 metadata: Option<String>,
122 ) -> Result<()> {
123 debug!(
124 "Creating relation: {} -[{}]-> {}",
125 from_id, relation_type, to_id
126 );
127
128 if self.would_create_cycle(from_id, to_id).await? {
130 return Err(Error::CircularDependency {
131 from: from_id.to_string(),
132 to: to_id.to_string(),
133 });
134 }
135
136 let created_at = chrono::Utc::now().timestamp();
137
138 sqlx::query(
139 r#"
140 INSERT OR REPLACE INTO relations (from_id, to_id, relation_type, metadata, created_at)
141 VALUES (?, ?, ?, ?, ?)
142 "#,
143 )
144 .bind(from_id)
145 .bind(to_id)
146 .bind(relation_type.as_str())
147 .bind(&metadata)
148 .bind(created_at)
149 .execute(&self.pool)
150 .await?;
151
152 debug!("Created relation successfully");
153 Ok(())
154 }
155
156 pub async fn delete_relation(
158 &self,
159 from_id: &str,
160 to_id: &str,
161 relation_type: RelationType,
162 ) -> Result<()> {
163 debug!(
164 "Deleting relation: {} -[{}]-> {}",
165 from_id, relation_type, to_id
166 );
167
168 sqlx::query(
169 r#"
170 DELETE FROM relations
171 WHERE from_id = ? AND to_id = ? AND relation_type = ?
172 "#,
173 )
174 .bind(from_id)
175 .bind(to_id)
176 .bind(relation_type.as_str())
177 .execute(&self.pool)
178 .await?;
179
180 Ok(())
181 }
182
183 pub async fn get_outgoing(&self, from_id: &str) -> Result<Vec<Relation>> {
185 debug!("Getting outgoing relations for: {}", from_id);
186
187 let rows: Vec<(String, String, String, Option<String>, i64)> = sqlx::query_as(
188 r#"
189 SELECT from_id, to_id, relation_type, metadata, created_at
190 FROM relations
191 WHERE from_id = ?
192 ORDER BY created_at DESC
193 "#,
194 )
195 .bind(from_id)
196 .fetch_all(&self.pool)
197 .await?;
198
199 let mut relations = Vec::with_capacity(rows.len());
200 for (from_id, to_id, relation_type, metadata, created_at) in rows {
201 relations.push(Relation {
202 from_id,
203 to_id,
204 relation_type: RelationType::from_str(&relation_type)?,
205 metadata,
206 created_at,
207 });
208 }
209
210 Ok(relations)
211 }
212
213 pub async fn get_incoming(&self, to_id: &str) -> Result<Vec<Relation>> {
215 debug!("Getting incoming relations for: {}", to_id);
216
217 let rows: Vec<(String, String, String, Option<String>, i64)> = sqlx::query_as(
218 r#"
219 SELECT from_id, to_id, relation_type, metadata, created_at
220 FROM relations
221 WHERE to_id = ?
222 ORDER BY created_at DESC
223 "#,
224 )
225 .bind(to_id)
226 .fetch_all(&self.pool)
227 .await?;
228
229 let mut relations = Vec::with_capacity(rows.len());
230 for (from_id, to_id, relation_type, metadata, created_at) in rows {
231 relations.push(Relation {
232 from_id,
233 to_id,
234 relation_type: RelationType::from_str(&relation_type)?,
235 metadata,
236 created_at,
237 });
238 }
239
240 Ok(relations)
241 }
242
243 pub async fn get_all_relations(&self, id: &str) -> Result<Vec<Relation>> {
245 debug!("Getting all relations for: {}", id);
246
247 let rows: Vec<(String, String, String, Option<String>, i64)> = sqlx::query_as(
248 r#"
249 SELECT from_id, to_id, relation_type, metadata, created_at
250 FROM relations
251 WHERE from_id = ? OR to_id = ?
252 ORDER BY created_at DESC
253 "#,
254 )
255 .bind(id)
256 .bind(id)
257 .fetch_all(&self.pool)
258 .await?;
259
260 let mut relations = Vec::with_capacity(rows.len());
261 for (from_id, to_id, relation_type, metadata, created_at) in rows {
262 relations.push(Relation {
263 from_id,
264 to_id,
265 relation_type: RelationType::from_str(&relation_type)?,
266 metadata,
267 created_at,
268 });
269 }
270
271 Ok(relations)
272 }
273
274 pub async fn get_dependencies(&self, id: &str) -> Result<Vec<String>> {
276 debug!("Getting dependencies for: {}", id);
277
278 let rows: Vec<(String,)> = sqlx::query_as(
279 r#"
280 SELECT DISTINCT to_id
281 FROM relations
282 WHERE from_id = ? AND relation_type IN ('uses', 'requires', 'extends')
283 "#,
284 )
285 .bind(id)
286 .fetch_all(&self.pool)
287 .await?;
288
289 Ok(rows.into_iter().map(|(id,)| id).collect())
290 }
291
292 pub async fn get_dependents(&self, id: &str) -> Result<Vec<String>> {
294 debug!("Getting dependents for: {}", id);
295
296 let rows: Vec<(String,)> = sqlx::query_as(
297 r#"
298 SELECT DISTINCT from_id
299 FROM relations
300 WHERE to_id = ? AND relation_type IN ('uses', 'requires', 'extends')
301 "#,
302 )
303 .bind(id)
304 .fetch_all(&self.pool)
305 .await?;
306
307 Ok(rows.into_iter().map(|(id,)| id).collect())
308 }
309
310 async fn would_create_cycle(&self, from_id: &str, to_id: &str) -> Result<bool> {
312 let reachable = self.get_reachable_nodes(to_id).await?;
316 Ok(reachable.contains(from_id))
317 }
318
319 async fn get_reachable_nodes(&self, start_id: &str) -> Result<HashSet<String>> {
321 let mut reachable = HashSet::new();
322 let mut to_visit = vec![start_id.to_string()];
323
324 while let Some(current) = to_visit.pop() {
325 if reachable.contains(¤t) {
326 continue;
327 }
328
329 reachable.insert(current.clone());
330
331 let deps = self.get_dependencies(¤t).await?;
332 for dep in deps {
333 if !reachable.contains(&dep) {
334 to_visit.push(dep);
335 }
336 }
337 }
338
339 Ok(reachable)
340 }
341
342 pub async fn build_graph(&self) -> Result<HashMap<String, Vec<String>>> {
344 debug!("Building full dependency graph");
345
346 let rows: Vec<(String, String)> = sqlx::query_as(
347 r#"
348 SELECT DISTINCT from_id, to_id
349 FROM relations
350 WHERE relation_type IN ('uses', 'requires', 'extends')
351 "#,
352 )
353 .fetch_all(&self.pool)
354 .await?;
355
356 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
357
358 for (from_id, to_id) in rows {
359 graph.entry(from_id).or_default().push(to_id);
360 }
361
362 Ok(graph)
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::{Database, Expertise, Scope, StorageOperations};
370 use tempfile::TempDir;
371
372 async fn setup_db() -> (Database, TempDir) {
373 let temp_dir = TempDir::new().unwrap();
374 let db_path = temp_dir.path().join("test.db");
375 let db = Database::open(&db_path).await.unwrap();
376 (db, temp_dir)
377 }
378
379 async fn create_test_expertise(db: &Database, id: &str) {
380 let mut exp = Expertise::new(id, "1.0.0");
381 exp.metadata.scope = Scope::Personal;
382 db.storage().create(exp).await.unwrap();
383 }
384
385 #[tokio::test]
386 async fn test_create_relation() {
387 let (db, _temp) = setup_db().await;
388
389 create_test_expertise(&db, "exp-1").await;
390 create_test_expertise(&db, "exp-2").await;
391
392 db.graph()
393 .create_relation("exp-1", "exp-2", RelationType::Uses, None)
394 .await
395 .unwrap();
396
397 let outgoing = db.graph().get_outgoing("exp-1").await.unwrap();
398 assert_eq!(outgoing.len(), 1);
399 assert_eq!(outgoing[0].to_id, "exp-2");
400 assert_eq!(outgoing[0].relation_type, RelationType::Uses);
401 }
402
403 #[tokio::test]
404 async fn test_circular_dependency_detection() {
405 let (db, _temp) = setup_db().await;
406
407 create_test_expertise(&db, "exp-1").await;
408 create_test_expertise(&db, "exp-2").await;
409 create_test_expertise(&db, "exp-3").await;
410
411 db.graph()
413 .create_relation("exp-1", "exp-2", RelationType::Uses, None)
414 .await
415 .unwrap();
416 db.graph()
417 .create_relation("exp-2", "exp-3", RelationType::Uses, None)
418 .await
419 .unwrap();
420
421 let result = db
423 .graph()
424 .create_relation("exp-3", "exp-1", RelationType::Uses, None)
425 .await;
426
427 assert!(matches!(result, Err(Error::CircularDependency { .. })));
428 }
429
430 #[tokio::test]
431 async fn test_get_dependencies() {
432 let (db, _temp) = setup_db().await;
433
434 create_test_expertise(&db, "exp-1").await;
435 create_test_expertise(&db, "exp-2").await;
436 create_test_expertise(&db, "exp-3").await;
437
438 db.graph()
439 .create_relation("exp-1", "exp-2", RelationType::Uses, None)
440 .await
441 .unwrap();
442 db.graph()
443 .create_relation("exp-1", "exp-3", RelationType::Requires, None)
444 .await
445 .unwrap();
446
447 let deps = db.graph().get_dependencies("exp-1").await.unwrap();
448 assert_eq!(deps.len(), 2);
449 assert!(deps.contains(&"exp-2".to_string()));
450 assert!(deps.contains(&"exp-3".to_string()));
451 }
452
453 #[tokio::test]
454 async fn test_get_dependents() {
455 let (db, _temp) = setup_db().await;
456
457 create_test_expertise(&db, "exp-1").await;
458 create_test_expertise(&db, "exp-2").await;
459 create_test_expertise(&db, "exp-3").await;
460
461 db.graph()
462 .create_relation("exp-2", "exp-1", RelationType::Uses, None)
463 .await
464 .unwrap();
465 db.graph()
466 .create_relation("exp-3", "exp-1", RelationType::Requires, None)
467 .await
468 .unwrap();
469
470 let dependents = db.graph().get_dependents("exp-1").await.unwrap();
471 assert_eq!(dependents.len(), 2);
472 assert!(dependents.contains(&"exp-2".to_string()));
473 assert!(dependents.contains(&"exp-3".to_string()));
474 }
475
476 #[tokio::test]
477 async fn test_delete_relation() {
478 let (db, _temp) = setup_db().await;
479
480 create_test_expertise(&db, "exp-1").await;
481 create_test_expertise(&db, "exp-2").await;
482
483 db.graph()
484 .create_relation("exp-1", "exp-2", RelationType::Uses, None)
485 .await
486 .unwrap();
487
488 db.graph()
489 .delete_relation("exp-1", "exp-2", RelationType::Uses)
490 .await
491 .unwrap();
492
493 let outgoing = db.graph().get_outgoing("exp-1").await.unwrap();
494 assert_eq!(outgoing.len(), 0);
495 }
496}