1use crate::domain::{Edge, EdgeId, GraphQuery, Node, NodeId, Properties};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults, OrderBy, OrderDirection};
3use async_trait::async_trait;
4use sqlx::{Pool, Postgres, Row};
5
6pub struct PostgresStorage {
7 pool: Pool<Postgres>,
8}
9
10impl PostgresStorage {
11 pub fn new(pool: Pool<Postgres>) -> Self {
12 Self { pool }
13 }
14
15 pub async fn setup_tables(&self) -> Result<()> {
16 sqlx::query("DROP TABLE IF EXISTS edges CASCADE")
20 .execute(&self.pool)
21 .await
22 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
23
24 sqlx::query("DROP TABLE IF EXISTS nodes CASCADE")
25 .execute(&self.pool)
26 .await
27 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
28
29 sqlx::query(
31 r#"
32 CREATE TABLE nodes (
33 id UUID PRIMARY KEY,
34 node_type VARCHAR(255) NOT NULL,
35 properties JSONB NOT NULL DEFAULT '{}',
36 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
37 updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
38 )
39 "#
40 )
41 .execute(&self.pool)
42 .await
43 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
44
45 sqlx::query(
47 r#"
48 CREATE TABLE edges (
49 id UUID PRIMARY KEY,
50 edge_type VARCHAR(255) NOT NULL,
51 from_node_id UUID NOT NULL,
52 to_node_id UUID NOT NULL,
53 properties JSONB NOT NULL DEFAULT '{}',
54 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
55 FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE,
56 FOREIGN KEY (to_node_id) REFERENCES nodes(id) ON DELETE CASCADE
57 )
58 "#
59 )
60 .execute(&self.pool)
61 .await
62 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
63
64 sqlx::query("CREATE INDEX idx_nodes_type ON nodes(node_type)")
66 .execute(&self.pool)
67 .await
68 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
69
70 sqlx::query("CREATE INDEX idx_nodes_properties ON nodes USING GIN(properties)")
71 .execute(&self.pool)
72 .await
73 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
74
75 sqlx::query("CREATE INDEX idx_nodes_created_at ON nodes(created_at DESC)")
77 .execute(&self.pool)
78 .await
79 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
80
81 sqlx::query("CREATE INDEX idx_nodes_updated_at ON nodes(updated_at DESC)")
82 .execute(&self.pool)
83 .await
84 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
85
86 sqlx::query("CREATE INDEX idx_nodes_type_created ON nodes(node_type, created_at DESC)")
88 .execute(&self.pool)
89 .await
90 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
91
92 sqlx::query("CREATE INDEX idx_nodes_type_updated ON nodes(node_type, updated_at DESC)")
93 .execute(&self.pool)
94 .await
95 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
96
97 sqlx::query("CREATE INDEX idx_edges_type ON edges(edge_type)")
98 .execute(&self.pool)
99 .await
100 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
101
102 sqlx::query("CREATE INDEX idx_edges_from ON edges(from_node_id)")
103 .execute(&self.pool)
104 .await
105 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
106
107 sqlx::query("CREATE INDEX idx_edges_to ON edges(to_node_id)")
108 .execute(&self.pool)
109 .await
110 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
111
112 sqlx::query("CREATE INDEX idx_edges_from_type ON edges(from_node_id, edge_type)")
113 .execute(&self.pool)
114 .await
115 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
116
117 sqlx::query("CREATE INDEX idx_edges_to_type ON edges(to_node_id, edge_type)")
118 .execute(&self.pool)
119 .await
120 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
121
122 sqlx::query("CREATE INDEX idx_edges_created_at ON edges(created_at DESC)")
124 .execute(&self.pool)
125 .await
126 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
127
128 Ok(())
129 }
130
131 fn properties_to_search_text(properties: &Properties) -> String {
133 let mut texts = Vec::new();
134 for (_, value) in properties {
135 if let serde_json::Value::String(s) = serde_json::to_value(value).unwrap_or_default() {
136 texts.push(s);
137 }
138 }
139 texts.join(" ")
140 }
141}
142
143#[async_trait]
144impl GraphStorage for PostgresStorage {
145 async fn create_node(&self, node: &Node) -> Result<Node> {
146 let properties_json = serde_json::to_value(&node.properties)
147 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
148
149 sqlx::query(
150 r#"
151 INSERT INTO nodes (id, node_type, properties, created_at, updated_at)
152 VALUES ($1, $2, $3, $4, $5)
153 "#
154 )
155 .bind(node.id)
156 .bind(&node.node_type)
157 .bind(properties_json)
158 .bind(node.created_at)
159 .bind(node.updated_at)
160 .execute(&self.pool)
161 .await
162 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
163
164 Ok(node.clone())
165 }
166
167 async fn get_node(&self, id: NodeId) -> Result<Node> {
168 let row = sqlx::query(
169 r#"
170 SELECT id, node_type, properties, created_at, updated_at
171 FROM nodes
172 WHERE id = $1
173 "#
174 )
175 .bind(id)
176 .fetch_optional(&self.pool)
177 .await
178 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
179
180 match row {
181 Some(row) => {
182 let properties_json: serde_json::Value = row.try_get("properties")
183 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
184 let properties = serde_json::from_value(properties_json)
185 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
186
187 Ok(Node {
188 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
189 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
190 properties,
191 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
192 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
193 })
194 }
195 None => Err(StorageError::NodeNotFound(id)),
196 }
197 }
198
199 async fn update_node(&self, node: &Node) -> Result<Node> {
200 let properties_json = serde_json::to_value(&node.properties)
201 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
202
203 let result = sqlx::query(
204 r#"
205 UPDATE nodes
206 SET node_type = $2, properties = $3, updated_at = $4
207 WHERE id = $1
208 "#
209 )
210 .bind(node.id)
211 .bind(&node.node_type)
212 .bind(properties_json)
213 .bind(node.updated_at)
214 .execute(&self.pool)
215 .await
216 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
217
218 if result.rows_affected() == 0 {
219 return Err(StorageError::NodeNotFound(node.id));
220 }
221
222 Ok(node.clone())
223 }
224
225 async fn delete_node(&self, id: NodeId) -> Result<()> {
226 let result = sqlx::query("DELETE FROM nodes WHERE id = $1")
227 .bind(id)
228 .execute(&self.pool)
229 .await
230 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
231
232 if result.rows_affected() == 0 {
233 return Err(StorageError::NodeNotFound(id));
234 }
235
236 Ok(())
237 }
238
239 async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
240 let mut sql = String::from("SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1");
241
242 if let Some(ref types) = query.node_types {
244 if types.len() == 1 {
245 sql.push_str(&format!(" AND node_type = '{}'", types[0]));
247 } else if !types.is_empty() {
248 let type_list: Vec<String> = types.iter()
250 .map(|t| format!("'{}'", t.replace("'", "''")))
251 .collect();
252 sql.push_str(&format!(" AND node_type IN ({})", type_list.join(", ")));
253 }
254 }
255
256 sql.push_str(" ORDER BY created_at DESC");
257
258 if let Some(limit) = query.limit {
259 sql.push_str(&format!(" LIMIT {}", limit));
260 }
261
262 let rows = sqlx::query(&sql)
263 .fetch_all(&self.pool)
264 .await
265 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
266
267 let mut nodes = Vec::new();
268 for row in rows {
269 let properties_json: serde_json::Value = row.try_get("properties")
270 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
271 let properties = serde_json::from_value(properties_json)
272 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
273
274 nodes.push(Node {
275 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
276 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
277 properties,
278 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
279 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
280 });
281 }
282
283 Ok(nodes)
284 }
285
286 async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
287 let properties_json = serde_json::to_value(&edge.properties)
288 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
289
290 sqlx::query(
291 r#"
292 INSERT INTO edges (id, edge_type, from_node_id, to_node_id, properties, created_at)
293 VALUES ($1, $2, $3, $4, $5, $6)
294 "#
295 )
296 .bind(edge.id)
297 .bind(&edge.edge_type)
298 .bind(edge.from_node_id)
299 .bind(edge.to_node_id)
300 .bind(properties_json)
301 .bind(edge.created_at)
302 .execute(&self.pool)
303 .await
304 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
305
306 Ok(edge.clone())
307 }
308
309 async fn get_edge(&self, id: EdgeId) -> Result<Edge> {
310 let row = sqlx::query(
311 r#"
312 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
313 FROM edges
314 WHERE id = $1
315 "#
316 )
317 .bind(id)
318 .fetch_optional(&self.pool)
319 .await
320 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
321
322 match row {
323 Some(row) => {
324 let properties_json: serde_json::Value = row.try_get("properties")
325 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
326 let properties = serde_json::from_value(properties_json)
327 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
328
329 Ok(Edge {
330 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
331 edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
332 from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
333 to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
334 properties,
335 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
336 })
337 }
338 None => Err(StorageError::EdgeNotFound(id)),
339 }
340 }
341
342 async fn delete_edge(&self, id: EdgeId) -> Result<()> {
343 let result = sqlx::query("DELETE FROM edges WHERE id = $1")
344 .bind(id)
345 .execute(&self.pool)
346 .await
347 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
348
349 if result.rows_affected() == 0 {
350 return Err(StorageError::EdgeNotFound(id));
351 }
352
353 Ok(())
354 }
355
356 async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
357 let rows = if let Some(et) = edge_type {
358 sqlx::query(
359 r#"
360 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
361 FROM edges
362 WHERE from_node_id = $1 AND edge_type = $2
363 ORDER BY created_at DESC
364 "#
365 )
366 .bind(node_id)
367 .bind(et)
368 .fetch_all(&self.pool)
369 .await
370 } else {
371 sqlx::query(
372 r#"
373 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
374 FROM edges
375 WHERE from_node_id = $1
376 ORDER BY created_at DESC
377 "#
378 )
379 .bind(node_id)
380 .fetch_all(&self.pool)
381 .await
382 }
383 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
384
385 let mut edges = Vec::new();
386 for row in rows {
387 let properties_json: serde_json::Value = row.try_get("properties")
388 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
389 let properties = serde_json::from_value(properties_json)
390 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
391
392 edges.push(Edge {
393 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
394 edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
395 from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
396 to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
397 properties,
398 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
399 });
400 }
401
402 Ok(edges)
403 }
404
405 async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
406 let rows = if let Some(et) = edge_type {
407 sqlx::query(
408 r#"
409 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
410 FROM edges
411 WHERE to_node_id = $1 AND edge_type = $2
412 ORDER BY created_at DESC
413 "#
414 )
415 .bind(node_id)
416 .bind(et)
417 .fetch_all(&self.pool)
418 .await
419 } else {
420 sqlx::query(
421 r#"
422 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
423 FROM edges
424 WHERE to_node_id = $1
425 ORDER BY created_at DESC
426 "#
427 )
428 .bind(node_id)
429 .fetch_all(&self.pool)
430 .await
431 }
432 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
433
434 let mut edges = Vec::new();
435 for row in rows {
436 let properties_json: serde_json::Value = row.try_get("properties")
437 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
438 let properties = serde_json::from_value(properties_json)
439 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
440
441 edges.push(Edge {
442 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
443 edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
444 from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
445 to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
446 properties,
447 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
448 });
449 }
450
451 Ok(edges)
452 }
453
454 async fn get_neighbors(
455 &self,
456 node_id: NodeId,
457 edge_type: Option<&str>,
458 direction: EdgeDirection,
459 ) -> Result<Vec<Node>> {
460 let mut neighbors = Vec::new();
461
462 match direction {
463 EdgeDirection::Outgoing | EdgeDirection::Both => {
464 let edges = self.get_edges_from(node_id, edge_type).await?;
465 for edge in edges {
466 if let Ok(node) = self.get_node(edge.to_node_id).await {
467 neighbors.push(node);
468 }
469 }
470 }
471 _ => {}
472 }
473
474 match direction {
475 EdgeDirection::Incoming | EdgeDirection::Both => {
476 let edges = self.get_edges_to(node_id, edge_type).await?;
477 for edge in edges {
478 if let Ok(node) = self.get_node(edge.from_node_id).await {
479 neighbors.push(node);
480 }
481 }
482 }
483 _ => {}
484 }
485
486 Ok(neighbors)
487 }
488
489 async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
490 let offset = query.offset;
491 let limit = query.limit;
492
493 let mut sql = String::from(
495 "SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1"
496 );
497
498 if !query.node_types.is_empty() {
500 let types: Vec<String> = query.node_types.iter()
501 .map(|t| format!("'{}'", t.replace("'", "''")))
502 .collect();
503 sql.push_str(&format!(" AND node_type IN ({})", types.join(", ")));
504 }
505
506 if let Some(ref search_text) = query.search_text {
508 let escaped = search_text.replace("'", "''").replace("%", "\\%").replace("_", "\\_");
509 sql.push_str(&format!(
510 " AND properties::text ILIKE '%{}%'",
511 escaped
512 ));
513 }
514
515 if let Some(after) = query.created_after {
517 sql.push_str(&format!(" AND created_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
518 }
519 if let Some(before) = query.created_before {
520 sql.push_str(&format!(" AND created_at <= '{}'", before.format("%Y-%m-%d %H:%M:%S")));
521 }
522 if let Some(after) = query.updated_after {
523 sql.push_str(&format!(" AND updated_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
524 }
525
526 for (key, value) in &query.property_filters {
528 let escaped_key = key.replace("'", "''");
529 let escaped_value = value.replace("'", "''");
530 sql.push_str(&format!(
531 " AND properties->>'{}' = '{}'",
532 escaped_key, escaped_value
533 ));
534 }
535
536 let order_col = match query.order_by {
538 OrderBy::CreatedAt => "created_at",
539 OrderBy::UpdatedAt => "updated_at",
540 OrderBy::Relevance => "updated_at", };
542 let order_dir = match query.order_direction {
543 OrderDirection::Asc => "ASC",
544 OrderDirection::Desc => "DESC",
545 };
546 sql.push_str(&format!(" ORDER BY {} {}", order_col, order_dir));
547
548 sql.push_str(&format!(" LIMIT {} OFFSET {}", limit + 1, offset));
550
551 let rows = sqlx::query(&sql)
553 .fetch_all(&self.pool)
554 .await
555 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
556
557 let mut nodes = Vec::new();
559 let has_more = rows.len() > limit;
560 let row_count = std::cmp::min(rows.len(), limit);
561
562 for row in rows.into_iter().take(limit) {
563 let properties_json: serde_json::Value = row.try_get("properties")
564 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
565 let properties: Properties = serde_json::from_value(properties_json)
566 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
567
568 nodes.push(Node {
569 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
570 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
571 properties,
572 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
573 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
574 });
575 }
576
577 let count_query = sql.replace(&format!(" LIMIT {} OFFSET {}", limit + 1, offset), "");
579 let count_sql = format!("SELECT COUNT(*) FROM ({}) as count_query", count_query);
580 let count_row = sqlx::query(&count_sql)
581 .fetch_one(&self.pool)
582 .await
583 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
584 let total_count: i64 = count_row.try_get(0)
585 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
586
587 Ok(SearchResults {
588 items: nodes,
589 total_count: total_count as usize,
590 returned_count: row_count,
591 has_more,
592 limit,
593 offset,
594 })
595 }
596
597 async fn count_nodes(&self, query: &SearchQuery) -> Result<usize> {
598 let mut sql = String::from("SELECT COUNT(*) FROM nodes WHERE 1=1");
600
601 if !query.node_types.is_empty() {
603 let types: Vec<String> = query.node_types.iter()
604 .map(|t| format!("'{}'", t.replace("'", "''")))
605 .collect();
606 sql.push_str(&format!(" AND node_type IN ({})", types.join(", ")));
607 }
608
609 if let Some(ref search_text) = query.search_text {
611 let escaped = search_text.replace("'", "''").replace("%", "\\%").replace("_", "\\_");
612 sql.push_str(&format!(
613 " AND properties::text ILIKE '%{}%'",
614 escaped
615 ));
616 }
617
618 if let Some(after) = query.created_after {
620 sql.push_str(&format!(" AND created_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
621 }
622 if let Some(before) = query.created_before {
623 sql.push_str(&format!(" AND created_at <= '{}'", before.format("%Y-%m-%d %H:%M:%S")));
624 }
625 if let Some(after) = query.updated_after {
626 sql.push_str(&format!(" AND updated_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
627 }
628
629 for (key, value) in &query.property_filters {
631 let escaped_key = key.replace("'", "''");
632 let escaped_value = value.replace("'", "''");
633 sql.push_str(&format!(
634 " AND properties->>'{}' = '{}'",
635 escaped_key, escaped_value
636 ));
637 }
638
639 let row = sqlx::query(&sql)
641 .fetch_one(&self.pool)
642 .await
643 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
644
645 let count: i64 = row.try_get(0)
646 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
647
648 Ok(count as usize)
649 }
650}