Skip to main content

agent_office/storage/
postgres.rs

1use crate::domain::{Edge, GraphQuery, Node, NodeId, Properties};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults};
3use async_trait::async_trait;
4use sqlx::{Pool, Postgres, Row};
5
6pub struct PostgresStorage {
7    pool: Pool<Postgres>,
8}
9
10impl PostgresStorage {
11    pub fn new(pool: Pool<Postgres>) -> Self {
12        Self { pool }
13    }
14
15    pub async fn setup_tables(&self) -> Result<()> {
16        // Execute each statement separately since SQLx doesn't support multiple statements in one query
17        
18        // Drop existing tables
19        sqlx::query("DROP TABLE IF EXISTS edges CASCADE")
20            .execute(&self.pool)
21            .await
22            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
23            
24        sqlx::query("DROP TABLE IF EXISTS nodes CASCADE")
25            .execute(&self.pool)
26            .await
27            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
28
29        // Create nodes table
30        sqlx::query(
31            r#"
32            CREATE TABLE nodes (
33                id UUID PRIMARY KEY,
34                node_type VARCHAR(255) NOT NULL,
35                properties JSONB NOT NULL DEFAULT '{}',
36                created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
37                updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
38            )
39            "#
40        )
41        .execute(&self.pool)
42        .await
43        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
44
45        // Create edges table
46        sqlx::query(
47            r#"
48            CREATE TABLE edges (
49                id UUID PRIMARY KEY,
50                edge_type VARCHAR(255) NOT NULL,
51                from_node_id UUID NOT NULL,
52                to_node_id UUID NOT NULL,
53                properties JSONB NOT NULL DEFAULT '{}',
54                created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
55                FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE,
56                FOREIGN KEY (to_node_id) REFERENCES nodes(id) ON DELETE CASCADE
57            )
58            "#
59        )
60        .execute(&self.pool)
61        .await
62        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
63
64        // Create indexes for basic lookups
65        sqlx::query("CREATE INDEX idx_nodes_type ON nodes(node_type)")
66            .execute(&self.pool)
67            .await
68            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
69            
70        sqlx::query("CREATE INDEX idx_nodes_properties ON nodes USING GIN(properties)")
71            .execute(&self.pool)
72            .await
73            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
74        
75        // Create time-based indexes for recency searches
76        sqlx::query("CREATE INDEX idx_nodes_created_at ON nodes(created_at DESC)")
77            .execute(&self.pool)
78            .await
79            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
80            
81        sqlx::query("CREATE INDEX idx_nodes_updated_at ON nodes(updated_at DESC)")
82            .execute(&self.pool)
83            .await
84            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
85        
86        // Create composite indexes for time range + type queries
87        sqlx::query("CREATE INDEX idx_nodes_type_created ON nodes(node_type, created_at DESC)")
88            .execute(&self.pool)
89            .await
90            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
91            
92        sqlx::query("CREATE INDEX idx_nodes_type_updated ON nodes(node_type, updated_at DESC)")
93            .execute(&self.pool)
94            .await
95            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
96            
97        sqlx::query("CREATE INDEX idx_edges_type ON edges(edge_type)")
98            .execute(&self.pool)
99            .await
100            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
101            
102        sqlx::query("CREATE INDEX idx_edges_from ON edges(from_node_id)")
103            .execute(&self.pool)
104            .await
105            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
106            
107        sqlx::query("CREATE INDEX idx_edges_to ON edges(to_node_id)")
108            .execute(&self.pool)
109            .await
110            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
111            
112        sqlx::query("CREATE INDEX idx_edges_from_type ON edges(from_node_id, edge_type)")
113            .execute(&self.pool)
114            .await
115            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
116            
117        sqlx::query("CREATE INDEX idx_edges_to_type ON edges(to_node_id, edge_type)")
118            .execute(&self.pool)
119            .await
120            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
121        
122        // Create indexes on edge timestamps
123        sqlx::query("CREATE INDEX idx_edges_created_at ON edges(created_at DESC)")
124            .execute(&self.pool)
125            .await
126            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
127
128        Ok(())
129    }
130    
131    /// Helper function to convert properties JSONB to searchable text
132    #[allow(dead_code)]
133    fn properties_to_search_text(properties: &Properties) -> String {
134        let mut texts = Vec::new();
135        for (_, value) in properties {
136            if let serde_json::Value::String(s) = serde_json::to_value(value).unwrap_or_default() {
137                texts.push(s);
138            }
139        }
140        texts.join(" ")
141    }
142}
143
144#[async_trait]
145impl GraphStorage for PostgresStorage {
146    async fn create_node(&self, node: &Node) -> Result<Node> {
147        let properties_json = serde_json::to_value(&node.properties)
148            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
149
150        sqlx::query(
151            r#"
152            INSERT INTO nodes (id, node_type, properties, created_at, updated_at)
153            VALUES ($1, $2, $3, $4, $5)
154            "#
155        )
156        .bind(node.id)
157        .bind(&node.node_type)
158        .bind(properties_json)
159        .bind(node.created_at)
160        .bind(node.updated_at)
161        .execute(&self.pool)
162        .await
163        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
164
165        Ok(node.clone())
166    }
167
168    async fn get_node(&self, id: NodeId) -> Result<Node> {
169        let row = sqlx::query(
170            r#"
171            SELECT id, node_type, properties, created_at, updated_at
172            FROM nodes
173            WHERE id = $1
174            "#
175        )
176        .bind(id)
177        .fetch_optional(&self.pool)
178        .await
179        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
180
181        match row {
182            Some(row) => {
183                let properties_json: serde_json::Value = row.try_get("properties")
184                    .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
185                let properties = serde_json::from_value(properties_json)
186                    .map_err(|e| StorageError::SerializationError(e.to_string()))?;
187
188                Ok(Node {
189                    id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
190                    node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
191                    properties,
192                    created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
193                    updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
194                })
195            }
196            None => Err(StorageError::NodeNotFound(id)),
197        }
198    }
199
200    async fn update_node(&self, node: &Node) -> Result<Node> {
201        let properties_json = serde_json::to_value(&node.properties)
202            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
203
204        let result = sqlx::query(
205            r#"
206            UPDATE nodes
207            SET node_type = $2, properties = $3, updated_at = $4
208            WHERE id = $1
209            "#
210        )
211        .bind(node.id)
212        .bind(&node.node_type)
213        .bind(properties_json)
214        .bind(node.updated_at)
215        .execute(&self.pool)
216        .await
217        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
218
219        if result.rows_affected() == 0 {
220            return Err(StorageError::NodeNotFound(node.id));
221        }
222
223        Ok(node.clone())
224    }
225
226    async fn delete_node(&self, id: NodeId) -> Result<()> {
227        let result = sqlx::query("DELETE FROM nodes WHERE id = $1")
228            .bind(id)
229            .execute(&self.pool)
230            .await
231            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
232
233        if result.rows_affected() == 0 {
234            return Err(StorageError::NodeNotFound(id));
235        }
236
237        Ok(())
238    }
239
240    async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
241        let mut sql = String::from("SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1");
242        
243        // Handle node_types with IN clause instead of ANY for better compatibility
244        if let Some(ref types) = query.node_types {
245            if types.len() == 1 {
246                // Single type - use direct equality
247                sql.push_str(&format!(" AND node_type = '{}'", types[0]));
248            } else if !types.is_empty() {
249                // Multiple types - use IN clause
250                let type_list: Vec<String> = types.iter()
251                    .map(|t| format!("'{}'", t.replace("'", "''")))
252                    .collect();
253                sql.push_str(&format!(" AND node_type IN ({})", type_list.join(", ")));
254            }
255        }
256
257        sql.push_str(" ORDER BY created_at DESC");
258
259        if let Some(limit) = query.limit {
260            sql.push_str(&format!(" LIMIT {}", limit));
261        }
262
263        let rows = sqlx::query(&sql)
264            .fetch_all(&self.pool)
265            .await
266            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
267
268        let mut nodes = Vec::new();
269        for row in rows {
270            let properties_json: serde_json::Value = row.try_get("properties")
271                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
272            let properties = serde_json::from_value(properties_json)
273                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
274
275            nodes.push(Node {
276                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
277                node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
278                properties,
279                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
280                updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
281            });
282        }
283
284        Ok(nodes)
285    }
286
287    async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
288        let properties_json = serde_json::to_value(&edge.properties)
289            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
290
291        sqlx::query(
292            r#"
293            INSERT INTO edges (id, edge_type, from_node_id, to_node_id, properties, created_at)
294            VALUES ($1, $2, $3, $4, $5, $6)
295            "#
296        )
297        .bind(edge.id)
298        .bind(&edge.edge_type)
299        .bind(edge.from_node_id)
300        .bind(edge.to_node_id)
301        .bind(properties_json)
302        .bind(edge.created_at)
303        .execute(&self.pool)
304        .await
305        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
306
307        Ok(edge.clone())
308    }
309
310    async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
311        let rows = if let Some(et) = edge_type {
312            sqlx::query(
313                r#"
314                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
315                FROM edges
316                WHERE from_node_id = $1 AND edge_type = $2
317                ORDER BY created_at DESC
318                "#
319            )
320            .bind(node_id)
321            .bind(et)
322            .fetch_all(&self.pool)
323            .await
324        } else {
325            sqlx::query(
326                r#"
327                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
328                FROM edges
329                WHERE from_node_id = $1
330                ORDER BY created_at DESC
331                "#
332            )
333            .bind(node_id)
334            .fetch_all(&self.pool)
335            .await
336        }
337        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
338
339        let mut edges = Vec::new();
340        for row in rows {
341            let properties_json: serde_json::Value = row.try_get("properties")
342                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
343            let properties = serde_json::from_value(properties_json)
344                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
345
346            edges.push(Edge {
347                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
348                edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
349                from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
350                to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
351                properties,
352                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
353            });
354        }
355
356        Ok(edges)
357    }
358
359    async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
360        let rows = if let Some(et) = edge_type {
361            sqlx::query(
362                r#"
363                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
364                FROM edges
365                WHERE to_node_id = $1 AND edge_type = $2
366                ORDER BY created_at DESC
367                "#
368            )
369            .bind(node_id)
370            .bind(et)
371            .fetch_all(&self.pool)
372            .await
373        } else {
374            sqlx::query(
375                r#"
376                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
377                FROM edges
378                WHERE to_node_id = $1
379                ORDER BY created_at DESC
380                "#
381            )
382            .bind(node_id)
383            .fetch_all(&self.pool)
384            .await
385        }
386        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
387
388        let mut edges = Vec::new();
389        for row in rows {
390            let properties_json: serde_json::Value = row.try_get("properties")
391                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
392            let properties = serde_json::from_value(properties_json)
393                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
394
395            edges.push(Edge {
396                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
397                edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
398                from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
399                to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
400                properties,
401                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
402            });
403        }
404
405        Ok(edges)
406    }
407
408    async fn get_neighbors(
409        &self,
410        node_id: NodeId,
411        edge_type: Option<&str>,
412        direction: EdgeDirection,
413    ) -> Result<Vec<Node>> {
414        let mut neighbors = Vec::new();
415
416        match direction {
417            EdgeDirection::Outgoing => {
418                let edges = self.get_edges_from(node_id, edge_type).await?;
419                for edge in edges {
420                    if let Ok(node) = self.get_node(edge.to_node_id).await {
421                        neighbors.push(node);
422                    }
423                }
424            }
425            _ => {}
426        }
427
428        match direction {
429            EdgeDirection::Incoming => {
430                let edges = self.get_edges_to(node_id, edge_type).await?;
431                for edge in edges {
432                    if let Ok(node) = self.get_node(edge.from_node_id).await {
433                        neighbors.push(node);
434                    }
435                }
436            }
437            _ => {}
438        }
439
440        Ok(neighbors)
441    }
442
443    async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
444        let offset = query.offset;
445        let limit = query.limit;
446        
447        // Build the SQL query
448        let mut sql = String::from(
449            "SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1"
450        );
451        
452        // Add node type filters
453        if !query.node_types.is_empty() {
454            let types: Vec<String> = query.node_types.iter()
455                .map(|t| format!("'{}'", t.replace("'", "''")))
456                .collect();
457            sql.push_str(&format!(" AND node_type IN ({})", types.join(", ")));
458        }
459        
460        // Add text search filter (case-insensitive LIKE on properties)
461        if let Some(ref search_text) = query.search_text {
462            let escaped = search_text.replace("'", "''").replace("%", "\\%").replace("_", "\\_");
463            sql.push_str(&format!(
464                " AND properties::text ILIKE '%{}%'",
465                escaped
466            ));
467        }
468        
469        // Add time range filters
470        if let Some(after) = query.created_after {
471            sql.push_str(&format!(" AND created_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
472        }
473        if let Some(before) = query.created_before {
474            sql.push_str(&format!(" AND created_at <= '{}'", before.format("%Y-%m-%d %H:%M:%S")));
475        }
476        if let Some(after) = query.updated_after {
477            sql.push_str(&format!(" AND updated_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
478        }
479        
480        // Add property filters
481        for (key, value) in &query.property_filters {
482            let escaped_key = key.replace("'", "''");
483            let escaped_value = value.replace("'", "''");
484            sql.push_str(&format!(
485                " AND properties->>'{}' = '{}'",
486                escaped_key, escaped_value
487            ));
488        }
489        
490        // Add ordering by updated_at
491        sql.push_str(" ORDER BY updated_at DESC");
492        
493        // Add limit and offset
494        sql.push_str(&format!(" LIMIT {} OFFSET {}", limit + 1, offset));
495        
496        // Execute query
497        let rows = sqlx::query(&sql)
498            .fetch_all(&self.pool)
499            .await
500            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
501        
502        // Parse results
503        let mut nodes = Vec::new();
504        
505        for row in rows.into_iter().take(limit) {
506            let properties_json: serde_json::Value = row.try_get("properties")
507                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
508            let properties: Properties = serde_json::from_value(properties_json)
509                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
510            
511            nodes.push(Node {
512                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
513                node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
514                properties,
515                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
516                updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
517            });
518        }
519        
520        Ok(SearchResults {
521            items: nodes,
522        })
523    }
524}