codex_memory/
database_setup.rs

1use anyhow::{Context, Result};
2use tokio_postgres::{Config as PgConfig, NoTls};
3use tracing::{error, info};
4use url::Url;
5
6/// Database setup and validation utilities
7pub struct DatabaseSetup {
8    database_url: String,
9}
10
11impl DatabaseSetup {
12    pub fn new(database_url: String) -> Self {
13        Self { database_url }
14    }
15
16    /// Complete database setup process
17    pub async fn setup(&self) -> Result<()> {
18        info!("๐Ÿ—„๏ธ  Starting database setup...");
19
20        // 1. Parse and validate the database URL
21        let db_info = self.parse_database_url()?;
22        info!(
23            "Database: {} on {}:{}",
24            db_info.database, db_info.host, db_info.port
25        );
26
27        // 2. Check if PostgreSQL is running
28        self.check_postgresql_running(&db_info).await?;
29
30        // 3. Check if the database exists, create if not
31        self.ensure_database_exists(&db_info).await?;
32
33        // 4. Check for pgvector extension availability
34        self.check_pgvector_availability(&db_info).await?;
35
36        // 5. Install pgvector extension
37        self.install_pgvector_extension().await?;
38
39        // 6. Run migrations
40        self.run_migrations().await?;
41
42        // 7. Verify setup
43        self.verify_setup().await?;
44
45        info!("โœ… Database setup completed successfully!");
46        Ok(())
47    }
48
49    /// Parse database URL and extract connection info
50    fn parse_database_url(&self) -> Result<DatabaseInfo> {
51        let url = Url::parse(&self.database_url).context("Invalid database URL format")?;
52
53        if url.scheme() != "postgresql" && url.scheme() != "postgres" {
54            return Err(anyhow::anyhow!(
55                "Database URL must use postgresql:// or postgres:// scheme"
56            ));
57        }
58
59        let host = url
60            .host_str()
61            .ok_or_else(|| anyhow::anyhow!("Database URL missing host"))?
62            .to_string();
63
64        let port = url.port().unwrap_or(5432);
65
66        let username = url.username();
67        if username.is_empty() {
68            return Err(anyhow::anyhow!("Database URL missing username"));
69        }
70
71        let password = url.password().unwrap_or("");
72
73        let database = url.path().trim_start_matches('/');
74        if database.is_empty() {
75            return Err(anyhow::anyhow!("Database URL missing database name"));
76        }
77
78        Ok(DatabaseInfo {
79            host,
80            port,
81            username: username.to_string(),
82            password: password.to_string(),
83            database: database.to_string(),
84        })
85    }
86
87    /// Check if PostgreSQL is running and accessible
88    async fn check_postgresql_running(&self, db_info: &DatabaseInfo) -> Result<()> {
89        info!("๐Ÿ” Checking PostgreSQL connectivity...");
90
91        // Try to connect to the 'postgres' system database first
92        let system_url = format!(
93            "postgresql://{}:{}@{}:{}/postgres",
94            db_info.username, db_info.password, db_info.host, db_info.port
95        );
96
97        let config: PgConfig = system_url
98            .parse()
99            .context("Failed to parse system database URL")?;
100
101        match config.connect(NoTls).await {
102            Ok((client, connection)) => {
103                tokio::spawn(async move {
104                    if let Err(e) = connection.await {
105                        error!("System database connection error: {}", e);
106                    }
107                });
108
109                // Test basic connectivity
110                client
111                    .query("SELECT version()", &[])
112                    .await
113                    .context("Failed to query PostgreSQL version")?;
114
115                info!("โœ… PostgreSQL is running and accessible");
116                Ok(())
117            }
118            Err(e) => {
119                error!("โŒ Cannot connect to PostgreSQL: {}", e);
120                info!("๐Ÿ’ก Please ensure PostgreSQL is installed and running");
121                info!("๐Ÿ’ก Common solutions:");
122                info!("   - Start PostgreSQL: brew services start postgresql");
123                info!("   - Or: sudo systemctl start postgresql");
124                info!("   - Check connection details in DATABASE_URL");
125                Err(anyhow::anyhow!("PostgreSQL is not accessible: {}", e))
126            }
127        }
128    }
129
130    /// Ensure the target database exists, create if necessary
131    async fn ensure_database_exists(&self, db_info: &DatabaseInfo) -> Result<()> {
132        info!("๐Ÿ” Checking if database '{}' exists...", db_info.database);
133
134        // Connect to system database to check/create target database
135        let system_url = format!(
136            "postgresql://{}:{}@{}:{}/postgres",
137            db_info.username, db_info.password, db_info.host, db_info.port
138        );
139
140        let config: PgConfig = system_url.parse()?;
141        let (client, connection) = config.connect(NoTls).await?;
142
143        tokio::spawn(async move {
144            if let Err(e) = connection.await {
145                error!("System database connection error: {}", e);
146            }
147        });
148
149        // Check if database exists
150        let rows = client
151            .query(
152                "SELECT 1 FROM pg_database WHERE datname = $1",
153                &[&db_info.database],
154            )
155            .await?;
156
157        if rows.is_empty() {
158            info!(
159                "๐Ÿ“‹ Database '{}' does not exist, creating...",
160                db_info.database
161            );
162
163            // Create the database
164            let create_query = format!("CREATE DATABASE \"{}\"", db_info.database);
165            client
166                .execute(&create_query, &[])
167                .await
168                .context("Failed to create database")?;
169
170            info!("โœ… Database '{}' created successfully", db_info.database);
171        } else {
172            info!("โœ… Database '{}' already exists", db_info.database);
173        }
174
175        Ok(())
176    }
177
178    /// Check if pgvector extension is available
179    async fn check_pgvector_availability(&self, _db_info: &DatabaseInfo) -> Result<()> {
180        info!("๐Ÿ” Checking pgvector extension availability...");
181
182        let config: PgConfig = self.database_url.parse()?;
183        let (client, connection) = config.connect(NoTls).await?;
184
185        tokio::spawn(async move {
186            if let Err(e) = connection.await {
187                error!("Database connection error: {}", e);
188            }
189        });
190
191        // Check if pgvector is available in pg_available_extensions
192        let rows = client
193            .query(
194                "SELECT name, default_version FROM pg_available_extensions WHERE name = 'vector'",
195                &[],
196            )
197            .await?;
198
199        if rows.is_empty() {
200            error!("โŒ pgvector extension is not available");
201            info!("๐Ÿ’ก Please install pgvector extension:");
202            info!("   ๐Ÿ“‹ On macOS (Homebrew): brew install pgvector");
203            info!("   ๐Ÿ“‹ On Ubuntu/Debian: apt install postgresql-15-pgvector");
204            info!("   ๐Ÿ“‹ From source: https://github.com/pgvector/pgvector");
205            return Err(anyhow::anyhow!("pgvector extension not available"));
206        } else {
207            let row = &rows[0];
208            let version: String = row.get(1);
209            info!("โœ… pgvector extension available (version: {})", version);
210        }
211
212        Ok(())
213    }
214
215    /// Install pgvector extension in the database
216    async fn install_pgvector_extension(&self) -> Result<()> {
217        info!("๐Ÿ”ง Installing pgvector extension...");
218
219        let config: PgConfig = self.database_url.parse()?;
220        let (client, connection) = config.connect(NoTls).await?;
221
222        tokio::spawn(async move {
223            if let Err(e) = connection.await {
224                error!("Database connection error: {}", e);
225            }
226        });
227
228        // Check if extension is already installed
229        let rows = client
230            .query(
231                "SELECT extname FROM pg_extension WHERE extname = 'vector'",
232                &[],
233            )
234            .await?;
235
236        if !rows.is_empty() {
237            info!("โœ… pgvector extension already installed");
238            return Ok(());
239        }
240
241        // Install the extension
242        client
243            .execute("CREATE EXTENSION vector", &[])
244            .await
245            .context("Failed to install pgvector extension")?;
246
247        info!("โœ… pgvector extension installed successfully");
248
249        // Verify installation by testing basic functionality
250        client
251            .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
252            .await
253            .context("pgvector extension installation verification failed")?;
254
255        info!("โœ… pgvector extension verification passed");
256        Ok(())
257    }
258
259    /// Run database migrations
260    async fn run_migrations(&self) -> Result<()> {
261        info!("๐Ÿ“‹ Running database migrations...");
262
263        let config: PgConfig = self.database_url.parse()?;
264        let (client, connection) = config.connect(NoTls).await?;
265
266        tokio::spawn(async move {
267            if let Err(e) = connection.await {
268                error!("Database connection error: {}", e);
269            }
270        });
271
272        // Check if migration tracking table exists
273        let tracking_exists = client
274            .query(
275                "SELECT 1 FROM information_schema.tables WHERE table_name = 'migration_history'",
276                &[],
277            )
278            .await?;
279
280        if tracking_exists.is_empty() {
281            info!("๐Ÿ“‹ Creating migration tracking table...");
282            client
283                .execute(
284                    r#"
285                    CREATE TABLE migration_history (
286                        id SERIAL PRIMARY KEY,
287                        migration_name VARCHAR(255) NOT NULL UNIQUE,
288                        applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
289                    )
290                    "#,
291                    &[],
292                )
293                .await?;
294        }
295
296        // Check if main tables exist
297        let tables_exist = client
298            .query(
299                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('memories', 'memory_tiers')",
300                &[],
301            )
302            .await?;
303
304        if tables_exist.len() < 2 {
305            info!("๐Ÿ“‹ Creating main schema tables...");
306            self.create_main_schema(&client).await?;
307        } else {
308            info!("โœ… Main schema tables already exist");
309        }
310
311        Ok(())
312    }
313
314    /// Create the main database schema
315    async fn create_main_schema(&self, client: &tokio_postgres::Client) -> Result<()> {
316        info!("๐Ÿ“‹ Creating main database schema...");
317
318        // Create memories table
319        client
320            .execute(
321                r#"
322                CREATE TABLE IF NOT EXISTS memories (
323                    id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
324                    content TEXT NOT NULL,
325                    embedding VECTOR(768),
326                    metadata JSONB DEFAULT '{}',
327                    tier VARCHAR(20) NOT NULL DEFAULT 'working',
328                    importance_score FLOAT DEFAULT 0.0,
329                    access_count INTEGER DEFAULT 0,
330                    last_accessed TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
331                    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
332                    updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
333                )
334                "#,
335                &[],
336            )
337            .await
338            .context("Failed to create memories table")?;
339
340        // Create index on embedding for vector similarity search
341        client
342            .execute(
343                "CREATE INDEX IF NOT EXISTS memories_embedding_idx ON memories USING hnsw (embedding vector_cosine_ops)",
344                &[],
345            )
346            .await
347            .context("Failed to create embedding index")?;
348
349        // Create indexes for common queries
350        client
351            .execute(
352                "CREATE INDEX IF NOT EXISTS memories_tier_idx ON memories (tier)",
353                &[],
354            )
355            .await
356            .context("Failed to create tier index")?;
357
358        client
359            .execute(
360                "CREATE INDEX IF NOT EXISTS memories_last_accessed_idx ON memories (last_accessed DESC)",
361                &[],
362            )
363            .await
364            .context("Failed to create last_accessed index")?;
365
366        client
367            .execute(
368                "CREATE INDEX IF NOT EXISTS memories_importance_idx ON memories (importance_score DESC)",
369                &[],
370            )
371            .await
372            .context("Failed to create importance index")?;
373
374        // Create memory_tiers table for tier management
375        client
376            .execute(
377                r#"
378                CREATE TABLE IF NOT EXISTS memory_tiers (
379                    id SERIAL PRIMARY KEY,
380                    tier_name VARCHAR(20) NOT NULL UNIQUE,
381                    max_capacity INTEGER,
382                    current_count INTEGER DEFAULT 0,
383                    retention_days INTEGER,
384                    created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
385                )
386                "#,
387                &[],
388            )
389            .await
390            .context("Failed to create memory_tiers table")?;
391
392        // Insert default tiers
393        client
394            .execute(
395                r#"
396                INSERT INTO memory_tiers (tier_name, max_capacity, retention_days)
397                VALUES 
398                    ('working', 1000, 7),
399                    ('warm', 10000, 30),
400                    ('cold', NULL, NULL)
401                ON CONFLICT (tier_name) DO NOTHING
402                "#,
403                &[],
404            )
405            .await
406            .context("Failed to insert default tiers")?;
407
408        // Record migration
409        client
410            .execute(
411                "INSERT INTO migration_history (migration_name) VALUES ('001_initial_schema') ON CONFLICT (migration_name) DO NOTHING",
412                &[],
413            )
414            .await?;
415
416        info!("โœ… Main database schema created successfully");
417        Ok(())
418    }
419
420    /// Verify the complete database setup
421    async fn verify_setup(&self) -> Result<()> {
422        info!("๐Ÿ” Verifying database setup...");
423
424        let config: PgConfig = self.database_url.parse()?;
425        let (client, connection) = config.connect(NoTls).await?;
426
427        tokio::spawn(async move {
428            if let Err(e) = connection.await {
429                error!("Database connection error: {}", e);
430            }
431        });
432
433        // Check that all required tables exist
434        let required_tables = ["memories", "memory_tiers", "migration_history"];
435        for table in &required_tables {
436            let rows = client
437                .query(
438                    "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
439                    &[table],
440                )
441                .await?;
442
443            if rows.is_empty() {
444                return Err(anyhow::anyhow!("Required table '{}' not found", table));
445            }
446        }
447
448        // Check that pgvector extension is working
449        client
450            .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
451            .await
452            .context("pgvector extension not working")?;
453
454        // Test inserting and querying a sample memory
455        info!("๐Ÿงช Testing vector operations...");
456
457        // Insert a test memory using 768-dimensional vector (matching schema)
458        let test_vector = vec![0.1f32; 768]
459            .iter()
460            .map(|f| f.to_string())
461            .collect::<Vec<_>>()
462            .join(",");
463        client
464            .execute(
465                &format!("INSERT INTO memories (content, embedding) VALUES ($1, '[{test_vector}]'::vector) ON CONFLICT DO NOTHING"),
466                &[&"Setup test memory"],
467            )
468            .await
469            .context("Failed to insert test memory")?;
470
471        // Test vector similarity search using 768-dimensional vector
472        let query_vector = vec![0.1f32; 768]
473            .iter()
474            .map(|f| f.to_string())
475            .collect::<Vec<_>>()
476            .join(",");
477        client
478            .query(
479                &format!("SELECT content FROM memories ORDER BY embedding <-> '[{query_vector}]'::vector LIMIT 1"),
480                &[],
481            )
482            .await
483            .context("Failed to perform vector similarity search")?;
484
485        // Clean up test data
486        client
487            .execute(
488                "DELETE FROM memories WHERE content = 'Setup test memory'",
489                &[],
490            )
491            .await?;
492
493        info!("โœ… Database setup verification passed");
494        Ok(())
495    }
496
497    /// Quick database health check
498    pub async fn health_check(&self) -> Result<DatabaseHealth> {
499        let config: PgConfig = self.database_url.parse()?;
500        let (client, connection) = config.connect(NoTls).await?;
501
502        tokio::spawn(async move {
503            if let Err(e) = connection.await {
504                error!("Database connection error: {}", e);
505            }
506        });
507
508        let mut health = DatabaseHealth::default();
509
510        // Check basic connectivity
511        match client.query("SELECT 1", &[]).await {
512            Ok(_) => health.connectivity = true,
513            Err(e) => {
514                health.connectivity = false;
515                health.issues.push(format!("Connectivity failed: {e}"));
516            }
517        }
518
519        // Check pgvector extension
520        match client
521            .query("SELECT 1 FROM pg_extension WHERE extname = 'vector'", &[])
522            .await
523        {
524            Ok(rows) => {
525                health.pgvector_installed = !rows.is_empty();
526                if !health.pgvector_installed {
527                    health
528                        .issues
529                        .push("pgvector extension not installed".to_string());
530                }
531            }
532            Err(e) => {
533                health.issues.push(format!("Failed to check pgvector: {e}"));
534            }
535        }
536
537        // Check required tables
538        let required_tables = ["memories", "memory_tiers"];
539        let mut tables_found = 0;
540        for table in &required_tables {
541            match client
542                .query(
543                    "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
544                    &[table],
545                )
546                .await
547            {
548                Ok(rows) => {
549                    if rows.is_empty() {
550                        health.issues.push(format!("Table '{table}' missing"));
551                    } else {
552                        tables_found += 1;
553                    }
554                }
555                Err(e) => {
556                    health.issues.push(format!("Failed to check table {table}: {e}"));
557                }
558            }
559        }
560        health.schema_ready = tables_found == required_tables.len();
561
562        // Get memory count
563        match client.query("SELECT COUNT(*) FROM memories", &[]).await {
564            Ok(rows) => {
565                let count: i64 = rows[0].get(0);
566                health.memory_count = count as usize;
567            }
568            Err(e) => {
569                health
570                    .issues
571                    .push(format!("Failed to get memory count: {e}"));
572            }
573        }
574
575        Ok(health)
576    }
577}
578
579/// Database connection information
580#[derive(Debug)]
581struct DatabaseInfo {
582    host: String,
583    port: u16,
584    username: String,
585    password: String,
586    database: String,
587}
588
589/// Database health status
590#[derive(Debug, Default)]
591pub struct DatabaseHealth {
592    pub connectivity: bool,
593    pub pgvector_installed: bool,
594    pub schema_ready: bool,
595    pub memory_count: usize,
596    pub issues: Vec<String>,
597}
598
599impl DatabaseHealth {
600    pub fn is_healthy(&self) -> bool {
601        self.connectivity && self.pgvector_installed && self.schema_ready && self.issues.is_empty()
602    }
603
604    pub fn status_summary(&self) -> String {
605        if self.is_healthy() {
606            format!("โœ… Healthy ({} memories)", self.memory_count)
607        } else {
608            format!("โŒ Issues: {}", self.issues.join(", "))
609        }
610    }
611}
612
613#[cfg(test)]
614mod tests {
615    use super::*;
616
617    #[test]
618    fn test_parse_database_url() {
619        let setup = DatabaseSetup::new("postgresql://user:pass@localhost:5432/testdb".to_string());
620
621        let info = setup
622            .parse_database_url()
623            .expect("Failed to parse valid database URL");
624        assert_eq!(info.host, "localhost");
625        assert_eq!(info.port, 5432);
626        assert_eq!(info.username, "user");
627        assert_eq!(info.password, "pass");
628        assert_eq!(info.database, "testdb");
629    }
630
631    #[test]
632    fn test_parse_database_url_default_port() {
633        let setup = DatabaseSetup::new("postgresql://user:pass@localhost/testdb".to_string());
634
635        let info = setup
636            .parse_database_url()
637            .expect("Failed to parse database URL with default port");
638        assert_eq!(info.port, 5432); // Should default to 5432
639    }
640
641    #[test]
642    fn test_parse_invalid_database_url() {
643        let setup = DatabaseSetup::new("invalid-url".to_string());
644        assert!(setup.parse_database_url().is_err());
645
646        let setup = DatabaseSetup::new("http://localhost/db".to_string());
647        assert!(setup.parse_database_url().is_err());
648    }
649}