1use crate::domain::{Edge, GraphQuery, Node, NodeId, Properties};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults};
3use async_trait::async_trait;
4use sqlx::{Pool, Postgres, Row};
5
6pub struct PostgresStorage {
7 pool: Pool<Postgres>,
8}
9
10impl PostgresStorage {
11 pub fn new(pool: Pool<Postgres>) -> Self {
12 Self { pool }
13 }
14
15 pub async fn setup_tables(&self) -> Result<()> {
16 sqlx::query("DROP TABLE IF EXISTS edges CASCADE")
20 .execute(&self.pool)
21 .await
22 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
23
24 sqlx::query("DROP TABLE IF EXISTS nodes CASCADE")
25 .execute(&self.pool)
26 .await
27 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
28
29 sqlx::query(
31 r#"
32 CREATE TABLE nodes (
33 id UUID PRIMARY KEY,
34 node_type VARCHAR(255) NOT NULL,
35 properties JSONB NOT NULL DEFAULT '{}',
36 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
37 updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
38 )
39 "#
40 )
41 .execute(&self.pool)
42 .await
43 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
44
45 sqlx::query(
47 r#"
48 CREATE TABLE edges (
49 id UUID PRIMARY KEY,
50 edge_type VARCHAR(255) NOT NULL,
51 from_node_id UUID NOT NULL,
52 to_node_id UUID NOT NULL,
53 properties JSONB NOT NULL DEFAULT '{}',
54 created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
55 FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE,
56 FOREIGN KEY (to_node_id) REFERENCES nodes(id) ON DELETE CASCADE
57 )
58 "#
59 )
60 .execute(&self.pool)
61 .await
62 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
63
64 sqlx::query("CREATE INDEX idx_nodes_type ON nodes(node_type)")
66 .execute(&self.pool)
67 .await
68 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
69
70 sqlx::query("CREATE INDEX idx_nodes_properties ON nodes USING GIN(properties)")
71 .execute(&self.pool)
72 .await
73 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
74
75 sqlx::query("CREATE INDEX idx_nodes_created_at ON nodes(created_at DESC)")
77 .execute(&self.pool)
78 .await
79 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
80
81 sqlx::query("CREATE INDEX idx_nodes_updated_at ON nodes(updated_at DESC)")
82 .execute(&self.pool)
83 .await
84 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
85
86 sqlx::query("CREATE INDEX idx_nodes_type_created ON nodes(node_type, created_at DESC)")
88 .execute(&self.pool)
89 .await
90 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
91
92 sqlx::query("CREATE INDEX idx_nodes_type_updated ON nodes(node_type, updated_at DESC)")
93 .execute(&self.pool)
94 .await
95 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
96
97 sqlx::query("CREATE INDEX idx_edges_type ON edges(edge_type)")
98 .execute(&self.pool)
99 .await
100 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
101
102 sqlx::query("CREATE INDEX idx_edges_from ON edges(from_node_id)")
103 .execute(&self.pool)
104 .await
105 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
106
107 sqlx::query("CREATE INDEX idx_edges_to ON edges(to_node_id)")
108 .execute(&self.pool)
109 .await
110 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
111
112 sqlx::query("CREATE INDEX idx_edges_from_type ON edges(from_node_id, edge_type)")
113 .execute(&self.pool)
114 .await
115 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
116
117 sqlx::query("CREATE INDEX idx_edges_to_type ON edges(to_node_id, edge_type)")
118 .execute(&self.pool)
119 .await
120 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
121
122 sqlx::query("CREATE INDEX idx_edges_created_at ON edges(created_at DESC)")
124 .execute(&self.pool)
125 .await
126 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
127
128 Ok(())
129 }
130
131 #[allow(dead_code)]
133 fn properties_to_search_text(properties: &Properties) -> String {
134 let mut texts = Vec::new();
135 for (_, value) in properties {
136 if let serde_json::Value::String(s) = serde_json::to_value(value).unwrap_or_default() {
137 texts.push(s);
138 }
139 }
140 texts.join(" ")
141 }
142}
143
144#[async_trait]
145impl GraphStorage for PostgresStorage {
146 async fn create_node(&self, node: &Node) -> Result<Node> {
147 let properties_json = serde_json::to_value(&node.properties)
148 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
149
150 sqlx::query(
151 r#"
152 INSERT INTO nodes (id, node_type, properties, created_at, updated_at)
153 VALUES ($1, $2, $3, $4, $5)
154 "#
155 )
156 .bind(node.id)
157 .bind(&node.node_type)
158 .bind(properties_json)
159 .bind(node.created_at)
160 .bind(node.updated_at)
161 .execute(&self.pool)
162 .await
163 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
164
165 Ok(node.clone())
166 }
167
168 async fn get_node(&self, id: NodeId) -> Result<Node> {
169 let row = sqlx::query(
170 r#"
171 SELECT id, node_type, properties, created_at, updated_at
172 FROM nodes
173 WHERE id = $1
174 "#
175 )
176 .bind(id)
177 .fetch_optional(&self.pool)
178 .await
179 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
180
181 match row {
182 Some(row) => {
183 let properties_json: serde_json::Value = row.try_get("properties")
184 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
185 let properties = serde_json::from_value(properties_json)
186 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
187
188 Ok(Node {
189 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
190 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
191 properties,
192 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
193 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
194 })
195 }
196 None => Err(StorageError::NodeNotFound(id)),
197 }
198 }
199
200 async fn update_node(&self, node: &Node) -> Result<Node> {
201 let properties_json = serde_json::to_value(&node.properties)
202 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
203
204 let result = sqlx::query(
205 r#"
206 UPDATE nodes
207 SET node_type = $2, properties = $3, updated_at = $4
208 WHERE id = $1
209 "#
210 )
211 .bind(node.id)
212 .bind(&node.node_type)
213 .bind(properties_json)
214 .bind(node.updated_at)
215 .execute(&self.pool)
216 .await
217 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
218
219 if result.rows_affected() == 0 {
220 return Err(StorageError::NodeNotFound(node.id));
221 }
222
223 Ok(node.clone())
224 }
225
226 async fn delete_node(&self, id: NodeId) -> Result<()> {
227 let result = sqlx::query("DELETE FROM nodes WHERE id = $1")
228 .bind(id)
229 .execute(&self.pool)
230 .await
231 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
232
233 if result.rows_affected() == 0 {
234 return Err(StorageError::NodeNotFound(id));
235 }
236
237 Ok(())
238 }
239
240 async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
241 let mut sql = String::from("SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1");
242
243 if let Some(ref types) = query.node_types {
245 if types.len() == 1 {
246 sql.push_str(&format!(" AND node_type = '{}'", types[0]));
248 } else if !types.is_empty() {
249 let type_list: Vec<String> = types.iter()
251 .map(|t| format!("'{}'", t.replace("'", "''")))
252 .collect();
253 sql.push_str(&format!(" AND node_type IN ({})", type_list.join(", ")));
254 }
255 }
256
257 sql.push_str(" ORDER BY created_at DESC");
258
259 if let Some(limit) = query.limit {
260 sql.push_str(&format!(" LIMIT {}", limit));
261 }
262
263 let rows = sqlx::query(&sql)
264 .fetch_all(&self.pool)
265 .await
266 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
267
268 let mut nodes = Vec::new();
269 for row in rows {
270 let properties_json: serde_json::Value = row.try_get("properties")
271 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
272 let properties = serde_json::from_value(properties_json)
273 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
274
275 nodes.push(Node {
276 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
277 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
278 properties,
279 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
280 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
281 });
282 }
283
284 Ok(nodes)
285 }
286
287 async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
288 let properties_json = serde_json::to_value(&edge.properties)
289 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
290
291 sqlx::query(
292 r#"
293 INSERT INTO edges (id, edge_type, from_node_id, to_node_id, properties, created_at)
294 VALUES ($1, $2, $3, $4, $5, $6)
295 "#
296 )
297 .bind(edge.id)
298 .bind(&edge.edge_type)
299 .bind(edge.from_node_id)
300 .bind(edge.to_node_id)
301 .bind(properties_json)
302 .bind(edge.created_at)
303 .execute(&self.pool)
304 .await
305 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
306
307 Ok(edge.clone())
308 }
309
310 async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
311 let rows = if let Some(et) = edge_type {
312 sqlx::query(
313 r#"
314 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
315 FROM edges
316 WHERE from_node_id = $1 AND edge_type = $2
317 ORDER BY created_at DESC
318 "#
319 )
320 .bind(node_id)
321 .bind(et)
322 .fetch_all(&self.pool)
323 .await
324 } else {
325 sqlx::query(
326 r#"
327 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
328 FROM edges
329 WHERE from_node_id = $1
330 ORDER BY created_at DESC
331 "#
332 )
333 .bind(node_id)
334 .fetch_all(&self.pool)
335 .await
336 }
337 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
338
339 let mut edges = Vec::new();
340 for row in rows {
341 let properties_json: serde_json::Value = row.try_get("properties")
342 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
343 let properties = serde_json::from_value(properties_json)
344 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
345
346 edges.push(Edge {
347 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
348 edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
349 from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
350 to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
351 properties,
352 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
353 });
354 }
355
356 Ok(edges)
357 }
358
359 async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
360 let rows = if let Some(et) = edge_type {
361 sqlx::query(
362 r#"
363 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
364 FROM edges
365 WHERE to_node_id = $1 AND edge_type = $2
366 ORDER BY created_at DESC
367 "#
368 )
369 .bind(node_id)
370 .bind(et)
371 .fetch_all(&self.pool)
372 .await
373 } else {
374 sqlx::query(
375 r#"
376 SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
377 FROM edges
378 WHERE to_node_id = $1
379 ORDER BY created_at DESC
380 "#
381 )
382 .bind(node_id)
383 .fetch_all(&self.pool)
384 .await
385 }
386 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
387
388 let mut edges = Vec::new();
389 for row in rows {
390 let properties_json: serde_json::Value = row.try_get("properties")
391 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
392 let properties = serde_json::from_value(properties_json)
393 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
394
395 edges.push(Edge {
396 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
397 edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
398 from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
399 to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
400 properties,
401 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
402 });
403 }
404
405 Ok(edges)
406 }
407
408 async fn get_neighbors(
409 &self,
410 node_id: NodeId,
411 edge_type: Option<&str>,
412 direction: EdgeDirection,
413 ) -> Result<Vec<Node>> {
414 let mut neighbors = Vec::new();
415
416 match direction {
417 EdgeDirection::Outgoing => {
418 let edges = self.get_edges_from(node_id, edge_type).await?;
419 for edge in edges {
420 if let Ok(node) = self.get_node(edge.to_node_id).await {
421 neighbors.push(node);
422 }
423 }
424 }
425 _ => {}
426 }
427
428 match direction {
429 EdgeDirection::Incoming => {
430 let edges = self.get_edges_to(node_id, edge_type).await?;
431 for edge in edges {
432 if let Ok(node) = self.get_node(edge.from_node_id).await {
433 neighbors.push(node);
434 }
435 }
436 }
437 _ => {}
438 }
439
440 Ok(neighbors)
441 }
442
443 async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
444 let offset = query.offset;
445 let limit = query.limit;
446
447 let mut sql = String::from(
449 "SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1"
450 );
451
452 if !query.node_types.is_empty() {
454 let types: Vec<String> = query.node_types.iter()
455 .map(|t| format!("'{}'", t.replace("'", "''")))
456 .collect();
457 sql.push_str(&format!(" AND node_type IN ({})", types.join(", ")));
458 }
459
460 if let Some(ref search_text) = query.search_text {
462 let escaped = search_text.replace("'", "''").replace("%", "\\%").replace("_", "\\_");
463 sql.push_str(&format!(
464 " AND properties::text ILIKE '%{}%'",
465 escaped
466 ));
467 }
468
469 if let Some(after) = query.created_after {
471 sql.push_str(&format!(" AND created_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
472 }
473 if let Some(before) = query.created_before {
474 sql.push_str(&format!(" AND created_at <= '{}'", before.format("%Y-%m-%d %H:%M:%S")));
475 }
476 if let Some(after) = query.updated_after {
477 sql.push_str(&format!(" AND updated_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
478 }
479
480 for (key, value) in &query.property_filters {
482 let escaped_key = key.replace("'", "''");
483 let escaped_value = value.replace("'", "''");
484 sql.push_str(&format!(
485 " AND properties->>'{}' = '{}'",
486 escaped_key, escaped_value
487 ));
488 }
489
490 sql.push_str(" ORDER BY updated_at DESC");
492
493 sql.push_str(&format!(" LIMIT {} OFFSET {}", limit + 1, offset));
495
496 let rows = sqlx::query(&sql)
498 .fetch_all(&self.pool)
499 .await
500 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
501
502 let mut nodes = Vec::new();
504
505 for row in rows.into_iter().take(limit) {
506 let properties_json: serde_json::Value = row.try_get("properties")
507 .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
508 let properties: Properties = serde_json::from_value(properties_json)
509 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
510
511 nodes.push(Node {
512 id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
513 node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
514 properties,
515 created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
516 updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
517 });
518 }
519
520 Ok(SearchResults {
521 items: nodes,
522 })
523 }
524}