Skip to main content

agent_office/storage/
postgres.rs

1use crate::domain::{Edge, GraphQuery, Node, NodeId, Properties};
2use crate::storage::{EdgeDirection, GraphStorage, Result, StorageError, SearchQuery, SearchResults};
3use async_trait::async_trait;
4use sqlx::{Pool, Postgres, Row};
5
6pub struct PostgresStorage {
7    pool: Pool<Postgres>,
8}
9
10impl PostgresStorage {
11    pub fn new(pool: Pool<Postgres>) -> Self {
12        Self { pool }
13    }
14
15    pub async fn setup_tables(&self) -> Result<()> {
16        // Execute each statement separately since SQLx doesn't support multiple statements in one query
17        
18        // Drop existing tables
19        sqlx::query("DROP TABLE IF EXISTS edges CASCADE")
20            .execute(&self.pool)
21            .await
22            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
23            
24        sqlx::query("DROP TABLE IF EXISTS nodes CASCADE")
25            .execute(&self.pool)
26            .await
27            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
28
29        // Create nodes table
30        sqlx::query(
31            r#"
32            CREATE TABLE nodes (
33                id UUID PRIMARY KEY,
34                node_type VARCHAR(255) NOT NULL,
35                properties JSONB NOT NULL DEFAULT '{}',
36                created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
37                updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
38            )
39            "#
40        )
41        .execute(&self.pool)
42        .await
43        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
44
45        // Create edges table
46        sqlx::query(
47            r#"
48            CREATE TABLE edges (
49                id UUID PRIMARY KEY,
50                edge_type VARCHAR(255) NOT NULL,
51                from_node_id UUID NOT NULL,
52                to_node_id UUID NOT NULL,
53                properties JSONB NOT NULL DEFAULT '{}',
54                created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
55                FOREIGN KEY (from_node_id) REFERENCES nodes(id) ON DELETE CASCADE,
56                FOREIGN KEY (to_node_id) REFERENCES nodes(id) ON DELETE CASCADE
57            )
58            "#
59        )
60        .execute(&self.pool)
61        .await
62        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
63
64        // Create indexes for basic lookups
65        sqlx::query("CREATE INDEX idx_nodes_type ON nodes(node_type)")
66            .execute(&self.pool)
67            .await
68            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
69            
70        sqlx::query("CREATE INDEX idx_nodes_properties ON nodes USING GIN(properties)")
71            .execute(&self.pool)
72            .await
73            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
74        
75        // Create time-based indexes for recency searches
76        sqlx::query("CREATE INDEX idx_nodes_created_at ON nodes(created_at DESC)")
77            .execute(&self.pool)
78            .await
79            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
80            
81        sqlx::query("CREATE INDEX idx_nodes_updated_at ON nodes(updated_at DESC)")
82            .execute(&self.pool)
83            .await
84            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
85        
86        // Create composite indexes for time range + type queries
87        sqlx::query("CREATE INDEX idx_nodes_type_created ON nodes(node_type, created_at DESC)")
88            .execute(&self.pool)
89            .await
90            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
91            
92        sqlx::query("CREATE INDEX idx_nodes_type_updated ON nodes(node_type, updated_at DESC)")
93            .execute(&self.pool)
94            .await
95            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
96            
97        sqlx::query("CREATE INDEX idx_edges_type ON edges(edge_type)")
98            .execute(&self.pool)
99            .await
100            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
101            
102        sqlx::query("CREATE INDEX idx_edges_from ON edges(from_node_id)")
103            .execute(&self.pool)
104            .await
105            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
106            
107        sqlx::query("CREATE INDEX idx_edges_to ON edges(to_node_id)")
108            .execute(&self.pool)
109            .await
110            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
111            
112        sqlx::query("CREATE INDEX idx_edges_from_type ON edges(from_node_id, edge_type)")
113            .execute(&self.pool)
114            .await
115            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
116            
117        sqlx::query("CREATE INDEX idx_edges_to_type ON edges(to_node_id, edge_type)")
118            .execute(&self.pool)
119            .await
120            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
121        
122        // Create indexes on edge timestamps
123        sqlx::query("CREATE INDEX idx_edges_created_at ON edges(created_at DESC)")
124            .execute(&self.pool)
125            .await
126            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
127
128        Ok(())
129    }
130
131    /// Migrate schedules table - creates if not exists (preserves existing data)
132    pub async fn migrate_schedules_table(&self) -> Result<()> {
133        // Create schedules table if not exists
134        sqlx::query(
135            r#"
136            CREATE TABLE IF NOT EXISTS schedules (
137                id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
138                agent_id VARCHAR(255) NOT NULL,
139                cron_expression VARCHAR(255) NOT NULL,
140                action TEXT NOT NULL,
141                is_active BOOLEAN NOT NULL DEFAULT true,
142                last_fired_at TIMESTAMP WITH TIME ZONE,
143                created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
144                updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW()
145            )
146            "#,
147        )
148        .execute(&self.pool)
149        .await
150        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
151
152        // Create indexes
153        sqlx::query(
154            "CREATE INDEX IF NOT EXISTS idx_schedules_agent_id ON schedules(agent_id)"
155        )
156        .execute(&self.pool)
157        .await
158        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
159
160        sqlx::query(
161            "CREATE INDEX IF NOT EXISTS idx_schedules_active ON schedules(is_active)"
162        )
163        .execute(&self.pool)
164        .await
165        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
166
167        Ok(())
168    }
169    
170    /// Helper function to convert properties JSONB to searchable text
171    #[allow(dead_code)]
172    fn properties_to_search_text(properties: &Properties) -> String {
173        let mut texts = Vec::new();
174        for (_, value) in properties {
175            if let serde_json::Value::String(s) = serde_json::to_value(value).unwrap_or_default() {
176                texts.push(s);
177            }
178        }
179        texts.join(" ")
180    }
181}
182
183#[async_trait]
184impl GraphStorage for PostgresStorage {
185    async fn create_node(&self, node: &Node) -> Result<Node> {
186        let properties_json = serde_json::to_value(&node.properties)
187            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
188
189        sqlx::query(
190            r#"
191            INSERT INTO nodes (id, node_type, properties, created_at, updated_at)
192            VALUES ($1, $2, $3, $4, $5)
193            "#
194        )
195        .bind(node.id)
196        .bind(&node.node_type)
197        .bind(properties_json)
198        .bind(node.created_at)
199        .bind(node.updated_at)
200        .execute(&self.pool)
201        .await
202        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
203
204        Ok(node.clone())
205    }
206
207    async fn get_node(&self, id: NodeId) -> Result<Node> {
208        let row = sqlx::query(
209            r#"
210            SELECT id, node_type, properties, created_at, updated_at
211            FROM nodes
212            WHERE id = $1
213            "#
214        )
215        .bind(id)
216        .fetch_optional(&self.pool)
217        .await
218        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
219
220        match row {
221            Some(row) => {
222                let properties_json: serde_json::Value = row.try_get("properties")
223                    .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
224                let properties = serde_json::from_value(properties_json)
225                    .map_err(|e| StorageError::SerializationError(e.to_string()))?;
226
227                Ok(Node {
228                    id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
229                    node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
230                    properties,
231                    created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
232                    updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
233                })
234            }
235            None => Err(StorageError::NodeNotFound(id)),
236        }
237    }
238
239    async fn update_node(&self, node: &Node) -> Result<Node> {
240        let properties_json = serde_json::to_value(&node.properties)
241            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
242
243        let result = sqlx::query(
244            r#"
245            UPDATE nodes
246            SET node_type = $2, properties = $3, updated_at = $4
247            WHERE id = $1
248            "#
249        )
250        .bind(node.id)
251        .bind(&node.node_type)
252        .bind(properties_json)
253        .bind(node.updated_at)
254        .execute(&self.pool)
255        .await
256        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
257
258        if result.rows_affected() == 0 {
259            return Err(StorageError::NodeNotFound(node.id));
260        }
261
262        Ok(node.clone())
263    }
264
265    async fn delete_node(&self, id: NodeId) -> Result<()> {
266        let result = sqlx::query("DELETE FROM nodes WHERE id = $1")
267            .bind(id)
268            .execute(&self.pool)
269            .await
270            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
271
272        if result.rows_affected() == 0 {
273            return Err(StorageError::NodeNotFound(id));
274        }
275
276        Ok(())
277    }
278
279    async fn query_nodes(&self, query: &GraphQuery) -> Result<Vec<Node>> {
280        let mut sql = String::from("SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1");
281        
282        // Handle node_types with IN clause instead of ANY for better compatibility
283        if let Some(ref types) = query.node_types {
284            if types.len() == 1 {
285                // Single type - use direct equality
286                sql.push_str(&format!(" AND node_type = '{}'", types[0]));
287            } else if !types.is_empty() {
288                // Multiple types - use IN clause
289                let type_list: Vec<String> = types.iter()
290                    .map(|t| format!("'{}'", t.replace("'", "''")))
291                    .collect();
292                sql.push_str(&format!(" AND node_type IN ({})", type_list.join(", ")));
293            }
294        }
295
296        sql.push_str(" ORDER BY created_at DESC");
297
298        if let Some(limit) = query.limit {
299            sql.push_str(&format!(" LIMIT {}", limit));
300        }
301
302        let rows = sqlx::query(&sql)
303            .fetch_all(&self.pool)
304            .await
305            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
306
307        let mut nodes = Vec::new();
308        for row in rows {
309            let properties_json: serde_json::Value = row.try_get("properties")
310                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
311            let properties = serde_json::from_value(properties_json)
312                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
313
314            nodes.push(Node {
315                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
316                node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
317                properties,
318                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
319                updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
320            });
321        }
322
323        Ok(nodes)
324    }
325
326    async fn create_edge(&self, edge: &Edge) -> Result<Edge> {
327        let properties_json = serde_json::to_value(&edge.properties)
328            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
329
330        sqlx::query(
331            r#"
332            INSERT INTO edges (id, edge_type, from_node_id, to_node_id, properties, created_at)
333            VALUES ($1, $2, $3, $4, $5, $6)
334            "#
335        )
336        .bind(edge.id)
337        .bind(&edge.edge_type)
338        .bind(edge.from_node_id)
339        .bind(edge.to_node_id)
340        .bind(properties_json)
341        .bind(edge.created_at)
342        .execute(&self.pool)
343        .await
344        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
345
346        Ok(edge.clone())
347    }
348
349    async fn get_edges_from(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
350        let rows = if let Some(et) = edge_type {
351            sqlx::query(
352                r#"
353                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
354                FROM edges
355                WHERE from_node_id = $1 AND edge_type = $2
356                ORDER BY created_at DESC
357                "#
358            )
359            .bind(node_id)
360            .bind(et)
361            .fetch_all(&self.pool)
362            .await
363        } else {
364            sqlx::query(
365                r#"
366                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
367                FROM edges
368                WHERE from_node_id = $1
369                ORDER BY created_at DESC
370                "#
371            )
372            .bind(node_id)
373            .fetch_all(&self.pool)
374            .await
375        }
376        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
377
378        let mut edges = Vec::new();
379        for row in rows {
380            let properties_json: serde_json::Value = row.try_get("properties")
381                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
382            let properties = serde_json::from_value(properties_json)
383                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
384
385            edges.push(Edge {
386                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
387                edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
388                from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
389                to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
390                properties,
391                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
392            });
393        }
394
395        Ok(edges)
396    }
397
398    async fn get_edges_to(&self, node_id: NodeId, edge_type: Option<&str>) -> Result<Vec<Edge>> {
399        let rows = if let Some(et) = edge_type {
400            sqlx::query(
401                r#"
402                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
403                FROM edges
404                WHERE to_node_id = $1 AND edge_type = $2
405                ORDER BY created_at DESC
406                "#
407            )
408            .bind(node_id)
409            .bind(et)
410            .fetch_all(&self.pool)
411            .await
412        } else {
413            sqlx::query(
414                r#"
415                SELECT id, edge_type, from_node_id, to_node_id, properties, created_at
416                FROM edges
417                WHERE to_node_id = $1
418                ORDER BY created_at DESC
419                "#
420            )
421            .bind(node_id)
422            .fetch_all(&self.pool)
423            .await
424        }
425        .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
426
427        let mut edges = Vec::new();
428        for row in rows {
429            let properties_json: serde_json::Value = row.try_get("properties")
430                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
431            let properties = serde_json::from_value(properties_json)
432                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
433
434            edges.push(Edge {
435                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
436                edge_type: row.try_get("edge_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
437                from_node_id: row.try_get("from_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
438                to_node_id: row.try_get("to_node_id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
439                properties,
440                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
441            });
442        }
443
444        Ok(edges)
445    }
446
447    async fn get_neighbors(
448        &self,
449        node_id: NodeId,
450        edge_type: Option<&str>,
451        direction: EdgeDirection,
452    ) -> Result<Vec<Node>> {
453        let mut neighbors = Vec::new();
454
455        match direction {
456            EdgeDirection::Outgoing => {
457                let edges = self.get_edges_from(node_id, edge_type).await?;
458                for edge in edges {
459                    if let Ok(node) = self.get_node(edge.to_node_id).await {
460                        neighbors.push(node);
461                    }
462                }
463            }
464            _ => {}
465        }
466
467        match direction {
468            EdgeDirection::Incoming => {
469                let edges = self.get_edges_to(node_id, edge_type).await?;
470                for edge in edges {
471                    if let Ok(node) = self.get_node(edge.from_node_id).await {
472                        neighbors.push(node);
473                    }
474                }
475            }
476            _ => {}
477        }
478
479        Ok(neighbors)
480    }
481
482    async fn search_nodes(&self, query: &SearchQuery) -> Result<SearchResults<Node>> {
483        let offset = query.offset;
484        let limit = query.limit;
485        
486        // Build the SQL query
487        let mut sql = String::from(
488            "SELECT id, node_type, properties, created_at, updated_at FROM nodes WHERE 1=1"
489        );
490        
491        // Add node type filters
492        if !query.node_types.is_empty() {
493            let types: Vec<String> = query.node_types.iter()
494                .map(|t| format!("'{}'", t.replace("'", "''")))
495                .collect();
496            sql.push_str(&format!(" AND node_type IN ({})", types.join(", ")));
497        }
498        
499        // Add text search filter (case-insensitive LIKE on properties)
500        if let Some(ref search_text) = query.search_text {
501            let escaped = search_text.replace("'", "''").replace("%", "\\%").replace("_", "\\_");
502            sql.push_str(&format!(
503                " AND properties::text ILIKE '%{}%'",
504                escaped
505            ));
506        }
507        
508        // Add time range filters
509        if let Some(after) = query.created_after {
510            sql.push_str(&format!(" AND created_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
511        }
512        if let Some(before) = query.created_before {
513            sql.push_str(&format!(" AND created_at <= '{}'", before.format("%Y-%m-%d %H:%M:%S")));
514        }
515        if let Some(after) = query.updated_after {
516            sql.push_str(&format!(" AND updated_at >= '{}'", after.format("%Y-%m-%d %H:%M:%S")));
517        }
518        
519        // Add property filters
520        for (key, value) in &query.property_filters {
521            let escaped_key = key.replace("'", "''");
522            let escaped_value = value.replace("'", "''");
523            sql.push_str(&format!(
524                " AND properties->>'{}' = '{}'",
525                escaped_key, escaped_value
526            ));
527        }
528        
529        // Add ordering by updated_at
530        sql.push_str(" ORDER BY updated_at DESC");
531        
532        // Add limit and offset
533        sql.push_str(&format!(" LIMIT {} OFFSET {}", limit + 1, offset));
534        
535        // Execute query
536        let rows = sqlx::query(&sql)
537            .fetch_all(&self.pool)
538            .await
539            .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
540        
541        // Parse results
542        let mut nodes = Vec::new();
543        
544        for row in rows.into_iter().take(limit) {
545            let properties_json: serde_json::Value = row.try_get("properties")
546                .map_err(|e| StorageError::DatabaseError(e.to_string()))?;
547            let properties: Properties = serde_json::from_value(properties_json)
548                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
549            
550            nodes.push(Node {
551                id: row.try_get("id").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
552                node_type: row.try_get("node_type").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
553                properties,
554                created_at: row.try_get("created_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
555                updated_at: row.try_get("updated_at").map_err(|e| StorageError::DatabaseError(e.to_string()))?,
556            });
557        }
558        
559        Ok(SearchResults {
560            items: nodes,
561        })
562    }
563}