codex_memory/
database_setup.rs

1use anyhow::{Context, Result};
2use tokio_postgres::{Config as PgConfig, NoTls};
3use tracing::{error, info};
4use url::Url;
5
6/// Database setup and validation utilities
7pub struct DatabaseSetup {
8    database_url: String,
9}
10
11impl DatabaseSetup {
12    pub fn new(database_url: String) -> Self {
13        Self { database_url }
14    }
15
16    /// Complete database setup process
17    pub async fn setup(&self) -> Result<()> {
18        info!("๐Ÿ—„๏ธ  Starting database setup...");
19
20        // 1. Parse and validate the database URL
21        let db_info = self.parse_database_url()?;
22        info!("Database: {} on {}:{}", db_info.database, db_info.host, db_info.port);
23
24        // 2. Check if PostgreSQL is running
25        self.check_postgresql_running(&db_info).await?;
26
27        // 3. Check if the database exists, create if not
28        self.ensure_database_exists(&db_info).await?;
29
30        // 4. Check for pgvector extension availability
31        self.check_pgvector_availability(&db_info).await?;
32
33        // 5. Install pgvector extension
34        self.install_pgvector_extension().await?;
35
36        // 6. Run migrations
37        self.run_migrations().await?;
38
39        // 7. Verify setup
40        self.verify_setup().await?;
41
42        info!("โœ… Database setup completed successfully!");
43        Ok(())
44    }
45
46    /// Parse database URL and extract connection info
47    fn parse_database_url(&self) -> Result<DatabaseInfo> {
48        let url = Url::parse(&self.database_url)
49            .context("Invalid database URL format")?;
50
51        if url.scheme() != "postgresql" && url.scheme() != "postgres" {
52            return Err(anyhow::anyhow!("Database URL must use postgresql:// or postgres:// scheme"));
53        }
54
55        let host = url.host_str()
56            .ok_or_else(|| anyhow::anyhow!("Database URL missing host"))?
57            .to_string();
58
59        let port = url.port().unwrap_or(5432);
60
61        let username = url.username();
62        if username.is_empty() {
63            return Err(anyhow::anyhow!("Database URL missing username"));
64        }
65
66        let password = url.password().unwrap_or("");
67
68        let database = url.path().trim_start_matches('/');
69        if database.is_empty() {
70            return Err(anyhow::anyhow!("Database URL missing database name"));
71        }
72
73        Ok(DatabaseInfo {
74            host,
75            port,
76            username: username.to_string(),
77            password: password.to_string(),
78            database: database.to_string(),
79        })
80    }
81
82    /// Check if PostgreSQL is running and accessible
83    async fn check_postgresql_running(&self, db_info: &DatabaseInfo) -> Result<()> {
84        info!("๐Ÿ” Checking PostgreSQL connectivity...");
85
86        // Try to connect to the 'postgres' system database first
87        let system_url = format!(
88            "postgresql://{}:{}@{}:{}/postgres",
89            db_info.username, db_info.password, db_info.host, db_info.port
90        );
91
92        let config: PgConfig = system_url.parse()
93            .context("Failed to parse system database URL")?;
94
95        match config.connect(NoTls).await {
96            Ok((client, connection)) => {
97                tokio::spawn(async move {
98                    if let Err(e) = connection.await {
99                        error!("System database connection error: {}", e);
100                    }
101                });
102
103                // Test basic connectivity
104                client.query("SELECT version()", &[]).await
105                    .context("Failed to query PostgreSQL version")?;
106
107                info!("โœ… PostgreSQL is running and accessible");
108                Ok(())
109            }
110            Err(e) => {
111                error!("โŒ Cannot connect to PostgreSQL: {}", e);
112                info!("๐Ÿ’ก Please ensure PostgreSQL is installed and running");
113                info!("๐Ÿ’ก Common solutions:");
114                info!("   - Start PostgreSQL: brew services start postgresql");
115                info!("   - Or: sudo systemctl start postgresql");
116                info!("   - Check connection details in DATABASE_URL");
117                Err(anyhow::anyhow!("PostgreSQL is not accessible: {}", e))
118            }
119        }
120    }
121
122    /// Ensure the target database exists, create if necessary
123    async fn ensure_database_exists(&self, db_info: &DatabaseInfo) -> Result<()> {
124        info!("๐Ÿ” Checking if database '{}' exists...", db_info.database);
125
126        // Connect to system database to check/create target database
127        let system_url = format!(
128            "postgresql://{}:{}@{}:{}/postgres",
129            db_info.username, db_info.password, db_info.host, db_info.port
130        );
131
132        let config: PgConfig = system_url.parse()?;
133        let (client, connection) = config.connect(NoTls).await?;
134
135        tokio::spawn(async move {
136            if let Err(e) = connection.await {
137                error!("System database connection error: {}", e);
138            }
139        });
140
141        // Check if database exists
142        let rows = client
143            .query(
144                "SELECT 1 FROM pg_database WHERE datname = $1",
145                &[&db_info.database],
146            )
147            .await?;
148
149        if rows.is_empty() {
150            info!("๐Ÿ“‹ Database '{}' does not exist, creating...", db_info.database);
151            
152            // Create the database
153            let create_query = format!("CREATE DATABASE \"{}\"", db_info.database);
154            client.execute(&create_query, &[]).await
155                .context("Failed to create database")?;
156
157            info!("โœ… Database '{}' created successfully", db_info.database);
158        } else {
159            info!("โœ… Database '{}' already exists", db_info.database);
160        }
161
162        Ok(())
163    }
164
165    /// Check if pgvector extension is available
166    async fn check_pgvector_availability(&self, _db_info: &DatabaseInfo) -> Result<()> {
167        info!("๐Ÿ” Checking pgvector extension availability...");
168
169        let config: PgConfig = self.database_url.parse()?;
170        let (client, connection) = config.connect(NoTls).await?;
171
172        tokio::spawn(async move {
173            if let Err(e) = connection.await {
174                error!("Database connection error: {}", e);
175            }
176        });
177
178        // Check if pgvector is available in pg_available_extensions
179        let rows = client
180            .query(
181                "SELECT name, default_version FROM pg_available_extensions WHERE name = 'vector'",
182                &[],
183            )
184            .await?;
185
186        if rows.is_empty() {
187            error!("โŒ pgvector extension is not available");
188            info!("๐Ÿ’ก Please install pgvector extension:");
189            info!("   ๐Ÿ“‹ On macOS (Homebrew): brew install pgvector");
190            info!("   ๐Ÿ“‹ On Ubuntu/Debian: apt install postgresql-15-pgvector");
191            info!("   ๐Ÿ“‹ From source: https://github.com/pgvector/pgvector");
192            return Err(anyhow::anyhow!("pgvector extension not available"));
193        } else {
194            let row = &rows[0];
195            let version: String = row.get(1);
196            info!("โœ… pgvector extension available (version: {})", version);
197        }
198
199        Ok(())
200    }
201
202    /// Install pgvector extension in the database
203    async fn install_pgvector_extension(&self) -> Result<()> {
204        info!("๐Ÿ”ง Installing pgvector extension...");
205
206        let config: PgConfig = self.database_url.parse()?;
207        let (client, connection) = config.connect(NoTls).await?;
208
209        tokio::spawn(async move {
210            if let Err(e) = connection.await {
211                error!("Database connection error: {}", e);
212            }
213        });
214
215        // Check if extension is already installed
216        let rows = client
217            .query(
218                "SELECT extname FROM pg_extension WHERE extname = 'vector'",
219                &[],
220            )
221            .await?;
222
223        if !rows.is_empty() {
224            info!("โœ… pgvector extension already installed");
225            return Ok(());
226        }
227
228        // Install the extension
229        client
230            .execute("CREATE EXTENSION vector", &[])
231            .await
232            .context("Failed to install pgvector extension")?;
233
234        info!("โœ… pgvector extension installed successfully");
235
236        // Verify installation by testing basic functionality
237        client
238            .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
239            .await
240            .context("pgvector extension installation verification failed")?;
241
242        info!("โœ… pgvector extension verification passed");
243        Ok(())
244    }
245
246    /// Run database migrations
247    async fn run_migrations(&self) -> Result<()> {
248        info!("๐Ÿ“‹ Running database migrations...");
249
250        let config: PgConfig = self.database_url.parse()?;
251        let (client, connection) = config.connect(NoTls).await?;
252
253        tokio::spawn(async move {
254            if let Err(e) = connection.await {
255                error!("Database connection error: {}", e);
256            }
257        });
258
259        // Check if migration tracking table exists
260        let tracking_exists = client
261            .query(
262                "SELECT 1 FROM information_schema.tables WHERE table_name = 'migration_history'",
263                &[],
264            )
265            .await?;
266
267        if tracking_exists.is_empty() {
268            info!("๐Ÿ“‹ Creating migration tracking table...");
269            client
270                .execute(
271                    r#"
272                    CREATE TABLE migration_history (
273                        id SERIAL PRIMARY KEY,
274                        migration_name VARCHAR(255) NOT NULL UNIQUE,
275                        applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
276                    )
277                    "#,
278                    &[],
279                )
280                .await?;
281        }
282
283        // Check if main tables exist
284        let tables_exist = client
285            .query(
286                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('memories', 'memory_tiers')",
287                &[],
288            )
289            .await?;
290
291        if tables_exist.len() < 2 {
292            info!("๐Ÿ“‹ Creating main schema tables...");
293            self.create_main_schema(&client).await?;
294        } else {
295            info!("โœ… Main schema tables already exist");
296        }
297
298        Ok(())
299    }
300
301    /// Create the main database schema
302    async fn create_main_schema(&self, client: &tokio_postgres::Client) -> Result<()> {
303        info!("๐Ÿ“‹ Creating main database schema...");
304
305        // Create memories table
306        client
307            .execute(
308                r#"
309                CREATE TABLE IF NOT EXISTS memories (
310                    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
311                    content TEXT NOT NULL,
312                    embedding VECTOR(768),
313                    metadata JSONB DEFAULT '{}',
314                    tier VARCHAR(20) NOT NULL DEFAULT 'working',
315                    importance_score FLOAT DEFAULT 0.0,
316                    access_count INTEGER DEFAULT 0,
317                    last_accessed TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
318                    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
319                    updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
320                )
321                "#,
322                &[],
323            )
324            .await
325            .context("Failed to create memories table")?;
326
327        // Create index on embedding for vector similarity search
328        client
329            .execute(
330                "CREATE INDEX IF NOT EXISTS memories_embedding_idx ON memories USING hnsw (embedding vector_cosine_ops)",
331                &[],
332            )
333            .await
334            .context("Failed to create embedding index")?;
335
336        // Create indexes for common queries
337        client
338            .execute(
339                "CREATE INDEX IF NOT EXISTS memories_tier_idx ON memories (tier)",
340                &[],
341            )
342            .await
343            .context("Failed to create tier index")?;
344
345        client
346            .execute(
347                "CREATE INDEX IF NOT EXISTS memories_last_accessed_idx ON memories (last_accessed DESC)",
348                &[],
349            )
350            .await
351            .context("Failed to create last_accessed index")?;
352
353        client
354            .execute(
355                "CREATE INDEX IF NOT EXISTS memories_importance_idx ON memories (importance_score DESC)",
356                &[],
357            )
358            .await
359            .context("Failed to create importance index")?;
360
361        // Create memory_tiers table for tier management
362        client
363            .execute(
364                r#"
365                CREATE TABLE IF NOT EXISTS memory_tiers (
366                    id SERIAL PRIMARY KEY,
367                    tier_name VARCHAR(20) NOT NULL UNIQUE,
368                    max_capacity INTEGER,
369                    current_count INTEGER DEFAULT 0,
370                    retention_days INTEGER,
371                    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
372                )
373                "#,
374                &[],
375            )
376            .await
377            .context("Failed to create memory_tiers table")?;
378
379        // Insert default tiers
380        client
381            .execute(
382                r#"
383                INSERT INTO memory_tiers (tier_name, max_capacity, retention_days)
384                VALUES 
385                    ('working', 1000, 7),
386                    ('warm', 10000, 30),
387                    ('cold', NULL, NULL)
388                ON CONFLICT (tier_name) DO NOTHING
389                "#,
390                &[],
391            )
392            .await
393            .context("Failed to insert default tiers")?;
394
395        // Record migration
396        client
397            .execute(
398                "INSERT INTO migration_history (migration_name) VALUES ('001_initial_schema') ON CONFLICT (migration_name) DO NOTHING",
399                &[],
400            )
401            .await?;
402
403        info!("โœ… Main database schema created successfully");
404        Ok(())
405    }
406
407    /// Verify the complete database setup
408    async fn verify_setup(&self) -> Result<()> {
409        info!("๐Ÿ” Verifying database setup...");
410
411        let config: PgConfig = self.database_url.parse()?;
412        let (client, connection) = config.connect(NoTls).await?;
413
414        tokio::spawn(async move {
415            if let Err(e) = connection.await {
416                error!("Database connection error: {}", e);
417            }
418        });
419
420        // Check that all required tables exist
421        let required_tables = ["memories", "memory_tiers", "migration_history"];
422        for table in &required_tables {
423            let rows = client
424                .query(
425                    "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
426                    &[table],
427                )
428                .await?;
429
430            if rows.is_empty() {
431                return Err(anyhow::anyhow!("Required table '{}' not found", table));
432            }
433        }
434
435        // Check that pgvector extension is working
436        client
437            .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
438            .await
439            .context("pgvector extension not working")?;
440
441        // Test inserting and querying a sample memory
442        info!("๐Ÿงช Testing vector operations...");
443        
444        // Insert a test memory using 768-dimensional vector (matching schema)
445        let test_vector = vec![0.1f32; 768].iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",");
446        client
447            .execute(
448                &format!("INSERT INTO memories (content, embedding) VALUES ($1, '[{}]'::vector) ON CONFLICT DO NOTHING", test_vector),
449                &[&"Setup test memory"],
450            )
451            .await
452            .context("Failed to insert test memory")?;
453
454        // Test vector similarity search using 768-dimensional vector
455        let query_vector = vec![0.1f32; 768].iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",");
456        client
457            .query(
458                &format!("SELECT content FROM memories ORDER BY embedding <-> '[{}]'::vector LIMIT 1", query_vector),
459                &[],
460            )
461            .await
462            .context("Failed to perform vector similarity search")?;
463
464        // Clean up test data
465        client
466            .execute(
467                "DELETE FROM memories WHERE content = 'Setup test memory'",
468                &[],
469            )
470            .await?;
471
472        info!("โœ… Database setup verification passed");
473        Ok(())
474    }
475
476    /// Quick database health check
477    pub async fn health_check(&self) -> Result<DatabaseHealth> {
478        let config: PgConfig = self.database_url.parse()?;
479        let (client, connection) = config.connect(NoTls).await?;
480
481        tokio::spawn(async move {
482            if let Err(e) = connection.await {
483                error!("Database connection error: {}", e);
484            }
485        });
486
487        let mut health = DatabaseHealth::default();
488
489        // Check basic connectivity
490        match client.query("SELECT 1", &[]).await {
491            Ok(_) => health.connectivity = true,
492            Err(e) => {
493                health.connectivity = false;
494                health.issues.push(format!("Connectivity failed: {}", e));
495            }
496        }
497
498        // Check pgvector extension
499        match client.query("SELECT 1 FROM pg_extension WHERE extname = 'vector'", &[]).await {
500            Ok(rows) => {
501                health.pgvector_installed = !rows.is_empty();
502                if !health.pgvector_installed {
503                    health.issues.push("pgvector extension not installed".to_string());
504                }
505            }
506            Err(e) => {
507                health.issues.push(format!("Failed to check pgvector: {}", e));
508            }
509        }
510
511        // Check required tables
512        let required_tables = ["memories", "memory_tiers"];
513        let mut tables_found = 0;
514        for table in &required_tables {
515            match client
516                .query(
517                    "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
518                    &[table],
519                )
520                .await
521            {
522                Ok(rows) => {
523                    if rows.is_empty() {
524                        health.issues.push(format!("Table '{}' missing", table));
525                    } else {
526                        tables_found += 1;
527                    }
528                }
529                Err(e) => {
530                    health.issues.push(format!("Failed to check table {}: {}", table, e));
531                }
532            }
533        }
534        health.schema_ready = tables_found == required_tables.len();
535
536        // Get memory count
537        match client.query("SELECT COUNT(*) FROM memories", &[]).await {
538            Ok(rows) => {
539                let count: i64 = rows[0].get(0);
540                health.memory_count = count as usize;
541            }
542            Err(e) => {
543                health.issues.push(format!("Failed to get memory count: {}", e));
544            }
545        }
546
547        Ok(health)
548    }
549}
550
551/// Database connection information
552#[derive(Debug)]
553struct DatabaseInfo {
554    host: String,
555    port: u16,
556    username: String,
557    password: String,
558    database: String,
559}
560
561/// Database health status
562#[derive(Debug, Default)]
563pub struct DatabaseHealth {
564    pub connectivity: bool,
565    pub pgvector_installed: bool,
566    pub schema_ready: bool,
567    pub memory_count: usize,
568    pub issues: Vec<String>,
569}
570
571impl DatabaseHealth {
572    pub fn is_healthy(&self) -> bool {
573        self.connectivity && self.pgvector_installed && self.schema_ready && self.issues.is_empty()
574    }
575
576    pub fn status_summary(&self) -> String {
577        if self.is_healthy() {
578            format!("โœ… Healthy ({} memories)", self.memory_count)
579        } else {
580            format!("โŒ Issues: {}", self.issues.join(", "))
581        }
582    }
583}
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn test_parse_database_url() {
591        let setup = DatabaseSetup::new(
592            "postgresql://user:pass@localhost:5432/testdb".to_string()
593        );
594
595        let info = setup.parse_database_url().unwrap();
596        assert_eq!(info.host, "localhost");
597        assert_eq!(info.port, 5432);
598        assert_eq!(info.username, "user");
599        assert_eq!(info.password, "pass");
600        assert_eq!(info.database, "testdb");
601    }
602
603    #[test]
604    fn test_parse_database_url_default_port() {
605        let setup = DatabaseSetup::new(
606            "postgresql://user:pass@localhost/testdb".to_string()
607        );
608
609        let info = setup.parse_database_url().unwrap();
610        assert_eq!(info.port, 5432); // Should default to 5432
611    }
612
613    #[test]
614    fn test_parse_invalid_database_url() {
615        let setup = DatabaseSetup::new("invalid-url".to_string());
616        assert!(setup.parse_database_url().is_err());
617
618        let setup = DatabaseSetup::new("http://localhost/db".to_string());
619        assert!(setup.parse_database_url().is_err());
620    }
621}