1use anyhow::{Context, Result};
2use tokio_postgres::{Config as PgConfig, NoTls};
3use tracing::{error, info};
4use url::Url;
5
6pub struct DatabaseSetup {
8 database_url: String,
9}
10
11impl DatabaseSetup {
12 pub fn new(database_url: String) -> Self {
13 Self { database_url }
14 }
15
16 pub async fn setup(&self) -> Result<()> {
18 info!("๐๏ธ Starting database setup...");
19
20 let db_info = self.parse_database_url()?;
22 info!(
23 "Database: {} on {}:{}",
24 db_info.database, db_info.host, db_info.port
25 );
26
27 self.check_postgresql_running(&db_info).await?;
29
30 self.ensure_database_exists(&db_info).await?;
32
33 self.check_pgvector_availability(&db_info).await?;
35
36 self.install_pgvector_extension().await?;
38
39 self.run_migrations().await?;
41
42 self.verify_setup().await?;
44
45 info!("โ
Database setup completed successfully!");
46 Ok(())
47 }
48
49 fn parse_database_url(&self) -> Result<DatabaseInfo> {
51 let url = Url::parse(&self.database_url).context("Invalid database URL format")?;
52
53 if url.scheme() != "postgresql" && url.scheme() != "postgres" {
54 return Err(anyhow::anyhow!(
55 "Database URL must use postgresql:// or postgres:// scheme"
56 ));
57 }
58
59 let host = url
60 .host_str()
61 .ok_or_else(|| anyhow::anyhow!("Database URL missing host"))?
62 .to_string();
63
64 let port = url.port().unwrap_or(5432);
65
66 let username = url.username();
67 if username.is_empty() {
68 return Err(anyhow::anyhow!("Database URL missing username"));
69 }
70
71 let password = url.password().unwrap_or("");
72
73 let database = url.path().trim_start_matches('/');
74 if database.is_empty() {
75 return Err(anyhow::anyhow!("Database URL missing database name"));
76 }
77
78 Ok(DatabaseInfo {
79 host,
80 port,
81 username: username.to_string(),
82 password: password.to_string(),
83 database: database.to_string(),
84 })
85 }
86
87 async fn check_postgresql_running(&self, db_info: &DatabaseInfo) -> Result<()> {
89 info!("๐ Checking PostgreSQL connectivity...");
90
91 let system_url = format!(
93 "postgresql://{}:{}@{}:{}/postgres",
94 db_info.username, db_info.password, db_info.host, db_info.port
95 );
96
97 let config: PgConfig = system_url
98 .parse()
99 .context("Failed to parse system database URL")?;
100
101 match config.connect(NoTls).await {
102 Ok((client, connection)) => {
103 tokio::spawn(async move {
104 if let Err(e) = connection.await {
105 error!("System database connection error: {}", e);
106 }
107 });
108
109 client
111 .query("SELECT version()", &[])
112 .await
113 .context("Failed to query PostgreSQL version")?;
114
115 info!("โ
PostgreSQL is running and accessible");
116 Ok(())
117 }
118 Err(e) => {
119 error!("โ Cannot connect to PostgreSQL: {}", e);
120 info!("๐ก Please ensure PostgreSQL is installed and running");
121 info!("๐ก Common solutions:");
122 info!(" - Start PostgreSQL: brew services start postgresql");
123 info!(" - Or: sudo systemctl start postgresql");
124 info!(" - Check connection details in DATABASE_URL");
125 Err(anyhow::anyhow!("PostgreSQL is not accessible: {}", e))
126 }
127 }
128 }
129
130 async fn ensure_database_exists(&self, db_info: &DatabaseInfo) -> Result<()> {
132 info!("๐ Checking if database '{}' exists...", db_info.database);
133
134 let system_url = format!(
136 "postgresql://{}:{}@{}:{}/postgres",
137 db_info.username, db_info.password, db_info.host, db_info.port
138 );
139
140 let config: PgConfig = system_url.parse()?;
141 let (client, connection) = config.connect(NoTls).await?;
142
143 tokio::spawn(async move {
144 if let Err(e) = connection.await {
145 error!("System database connection error: {}", e);
146 }
147 });
148
149 let rows = client
151 .query(
152 "SELECT 1 FROM pg_database WHERE datname = $1",
153 &[&db_info.database],
154 )
155 .await?;
156
157 if rows.is_empty() {
158 info!(
159 "๐ Database '{}' does not exist, creating...",
160 db_info.database
161 );
162
163 let create_query = format!("CREATE DATABASE \"{}\"", db_info.database);
165 client
166 .execute(&create_query, &[])
167 .await
168 .context("Failed to create database")?;
169
170 info!("โ
Database '{}' created successfully", db_info.database);
171 } else {
172 info!("โ
Database '{}' already exists", db_info.database);
173 }
174
175 Ok(())
176 }
177
178 async fn check_pgvector_availability(&self, _db_info: &DatabaseInfo) -> Result<()> {
180 info!("๐ Checking pgvector extension availability...");
181
182 let config: PgConfig = self.database_url.parse()?;
183 let (client, connection) = config.connect(NoTls).await?;
184
185 tokio::spawn(async move {
186 if let Err(e) = connection.await {
187 error!("Database connection error: {}", e);
188 }
189 });
190
191 let rows = client
193 .query(
194 "SELECT name, default_version FROM pg_available_extensions WHERE name = 'vector'",
195 &[],
196 )
197 .await?;
198
199 if rows.is_empty() {
200 error!("โ pgvector extension is not available");
201 info!("๐ก Please install pgvector extension:");
202 info!(" ๐ On macOS (Homebrew): brew install pgvector");
203 info!(" ๐ On Ubuntu/Debian: apt install postgresql-15-pgvector");
204 info!(" ๐ From source: https://github.com/pgvector/pgvector");
205 return Err(anyhow::anyhow!("pgvector extension not available"));
206 } else {
207 let row = &rows[0];
208 let version: String = row.get(1);
209 info!("โ
pgvector extension available (version: {})", version);
210 }
211
212 Ok(())
213 }
214
215 async fn install_pgvector_extension(&self) -> Result<()> {
217 info!("๐ง Installing pgvector extension...");
218
219 let config: PgConfig = self.database_url.parse()?;
220 let (client, connection) = config.connect(NoTls).await?;
221
222 tokio::spawn(async move {
223 if let Err(e) = connection.await {
224 error!("Database connection error: {}", e);
225 }
226 });
227
228 let rows = client
230 .query(
231 "SELECT extname FROM pg_extension WHERE extname = 'vector'",
232 &[],
233 )
234 .await?;
235
236 if !rows.is_empty() {
237 info!("โ
pgvector extension already installed");
238 return Ok(());
239 }
240
241 client
243 .execute("CREATE EXTENSION vector", &[])
244 .await
245 .context("Failed to install pgvector extension")?;
246
247 info!("โ
pgvector extension installed successfully");
248
249 client
251 .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
252 .await
253 .context("pgvector extension installation verification failed")?;
254
255 info!("โ
pgvector extension verification passed");
256 Ok(())
257 }
258
259 async fn run_migrations(&self) -> Result<()> {
261 info!("๐ Running database migrations...");
262
263 let config: PgConfig = self.database_url.parse()?;
264 let (client, connection) = config.connect(NoTls).await?;
265
266 tokio::spawn(async move {
267 if let Err(e) = connection.await {
268 error!("Database connection error: {}", e);
269 }
270 });
271
272 let tracking_exists = client
274 .query(
275 "SELECT 1 FROM information_schema.tables WHERE table_name = 'migration_history'",
276 &[],
277 )
278 .await?;
279
280 if tracking_exists.is_empty() {
281 info!("๐ Creating migration tracking table...");
282 client
283 .execute(
284 r#"
285 CREATE TABLE migration_history (
286 id SERIAL PRIMARY KEY,
287 migration_name VARCHAR(255) NOT NULL UNIQUE,
288 applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
289 )
290 "#,
291 &[],
292 )
293 .await?;
294 }
295
296 let tables_exist = client
298 .query(
299 "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('memories', 'memory_tiers')",
300 &[],
301 )
302 .await?;
303
304 if tables_exist.len() < 2 {
305 info!("๐ Creating main schema tables...");
306 self.create_main_schema(&client).await?;
307 } else {
308 info!("โ
Main schema tables already exist");
309 }
310
311 Ok(())
312 }
313
314 async fn create_main_schema(&self, client: &tokio_postgres::Client) -> Result<()> {
316 info!("๐ Creating main database schema...");
317
318 client
320 .execute(
321 r#"
322 CREATE TABLE IF NOT EXISTS memories (
323 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
324 content TEXT NOT NULL,
325 embedding VECTOR(768),
326 metadata JSONB DEFAULT '{}',
327 tier VARCHAR(20) NOT NULL DEFAULT 'working',
328 importance_score FLOAT DEFAULT 0.0,
329 access_count INTEGER DEFAULT 0,
330 last_accessed TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
331 created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
332 updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
333 )
334 "#,
335 &[],
336 )
337 .await
338 .context("Failed to create memories table")?;
339
340 client
342 .execute(
343 "CREATE INDEX IF NOT EXISTS memories_embedding_idx ON memories USING hnsw (embedding vector_cosine_ops)",
344 &[],
345 )
346 .await
347 .context("Failed to create embedding index")?;
348
349 client
351 .execute(
352 "CREATE INDEX IF NOT EXISTS memories_tier_idx ON memories (tier)",
353 &[],
354 )
355 .await
356 .context("Failed to create tier index")?;
357
358 client
359 .execute(
360 "CREATE INDEX IF NOT EXISTS memories_last_accessed_idx ON memories (last_accessed DESC)",
361 &[],
362 )
363 .await
364 .context("Failed to create last_accessed index")?;
365
366 client
367 .execute(
368 "CREATE INDEX IF NOT EXISTS memories_importance_idx ON memories (importance_score DESC)",
369 &[],
370 )
371 .await
372 .context("Failed to create importance index")?;
373
374 client
376 .execute(
377 r#"
378 CREATE TABLE IF NOT EXISTS memory_tiers (
379 id SERIAL PRIMARY KEY,
380 tier_name VARCHAR(20) NOT NULL UNIQUE,
381 max_capacity INTEGER,
382 current_count INTEGER DEFAULT 0,
383 retention_days INTEGER,
384 created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
385 )
386 "#,
387 &[],
388 )
389 .await
390 .context("Failed to create memory_tiers table")?;
391
392 client
394 .execute(
395 r#"
396 INSERT INTO memory_tiers (tier_name, max_capacity, retention_days)
397 VALUES
398 ('working', 1000, 7),
399 ('warm', 10000, 30),
400 ('cold', NULL, NULL)
401 ON CONFLICT (tier_name) DO NOTHING
402 "#,
403 &[],
404 )
405 .await
406 .context("Failed to insert default tiers")?;
407
408 client
410 .execute(
411 "INSERT INTO migration_history (migration_name) VALUES ('001_initial_schema') ON CONFLICT (migration_name) DO NOTHING",
412 &[],
413 )
414 .await?;
415
416 info!("โ
Main database schema created successfully");
417 Ok(())
418 }
419
420 async fn verify_setup(&self) -> Result<()> {
422 info!("๐ Verifying database setup...");
423
424 let config: PgConfig = self.database_url.parse()?;
425 let (client, connection) = config.connect(NoTls).await?;
426
427 tokio::spawn(async move {
428 if let Err(e) = connection.await {
429 error!("Database connection error: {}", e);
430 }
431 });
432
433 let required_tables = ["memories", "memory_tiers", "migration_history"];
435 for table in &required_tables {
436 let rows = client
437 .query(
438 "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
439 &[table],
440 )
441 .await?;
442
443 if rows.is_empty() {
444 return Err(anyhow::anyhow!("Required table '{}' not found", table));
445 }
446 }
447
448 client
450 .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
451 .await
452 .context("pgvector extension not working")?;
453
454 info!("๐งช Testing vector operations...");
456
457 let test_vector = vec![0.1f32; 768]
459 .iter()
460 .map(|f| f.to_string())
461 .collect::<Vec<_>>()
462 .join(",");
463 client
464 .execute(
465 &format!("INSERT INTO memories (content, embedding) VALUES ($1, '[{test_vector}]'::vector) ON CONFLICT DO NOTHING"),
466 &[&"Setup test memory"],
467 )
468 .await
469 .context("Failed to insert test memory")?;
470
471 let query_vector = vec![0.1f32; 768]
473 .iter()
474 .map(|f| f.to_string())
475 .collect::<Vec<_>>()
476 .join(",");
477 client
478 .query(
479 &format!("SELECT content FROM memories ORDER BY embedding <-> '[{query_vector}]'::vector LIMIT 1"),
480 &[],
481 )
482 .await
483 .context("Failed to perform vector similarity search")?;
484
485 client
487 .execute(
488 "DELETE FROM memories WHERE content = 'Setup test memory'",
489 &[],
490 )
491 .await?;
492
493 info!("โ
Database setup verification passed");
494 Ok(())
495 }
496
497 pub async fn health_check(&self) -> Result<DatabaseHealth> {
499 let config: PgConfig = self.database_url.parse()?;
500 let (client, connection) = config.connect(NoTls).await?;
501
502 tokio::spawn(async move {
503 if let Err(e) = connection.await {
504 error!("Database connection error: {}", e);
505 }
506 });
507
508 let mut health = DatabaseHealth::default();
509
510 match client.query("SELECT 1", &[]).await {
512 Ok(_) => health.connectivity = true,
513 Err(e) => {
514 health.connectivity = false;
515 health.issues.push(format!("Connectivity failed: {e}"));
516 }
517 }
518
519 match client
521 .query("SELECT 1 FROM pg_extension WHERE extname = 'vector'", &[])
522 .await
523 {
524 Ok(rows) => {
525 health.pgvector_installed = !rows.is_empty();
526 if !health.pgvector_installed {
527 health
528 .issues
529 .push("pgvector extension not installed".to_string());
530 }
531 }
532 Err(e) => {
533 health.issues.push(format!("Failed to check pgvector: {e}"));
534 }
535 }
536
537 let required_tables = ["memories", "memory_tiers"];
539 let mut tables_found = 0;
540 for table in &required_tables {
541 match client
542 .query(
543 "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
544 &[table],
545 )
546 .await
547 {
548 Ok(rows) => {
549 if rows.is_empty() {
550 health.issues.push(format!("Table '{table}' missing"));
551 } else {
552 tables_found += 1;
553 }
554 }
555 Err(e) => {
556 health.issues.push(format!("Failed to check table {table}: {e}"));
557 }
558 }
559 }
560 health.schema_ready = tables_found == required_tables.len();
561
562 match client.query("SELECT COUNT(*) FROM memories", &[]).await {
564 Ok(rows) => {
565 let count: i64 = rows[0].get(0);
566 health.memory_count = count as usize;
567 }
568 Err(e) => {
569 health
570 .issues
571 .push(format!("Failed to get memory count: {e}"));
572 }
573 }
574
575 Ok(health)
576 }
577}
578
579#[derive(Debug)]
581struct DatabaseInfo {
582 host: String,
583 port: u16,
584 username: String,
585 password: String,
586 database: String,
587}
588
589#[derive(Debug, Default)]
591pub struct DatabaseHealth {
592 pub connectivity: bool,
593 pub pgvector_installed: bool,
594 pub schema_ready: bool,
595 pub memory_count: usize,
596 pub issues: Vec<String>,
597}
598
599impl DatabaseHealth {
600 pub fn is_healthy(&self) -> bool {
601 self.connectivity && self.pgvector_installed && self.schema_ready && self.issues.is_empty()
602 }
603
604 pub fn status_summary(&self) -> String {
605 if self.is_healthy() {
606 format!("โ
Healthy ({} memories)", self.memory_count)
607 } else {
608 format!("โ Issues: {}", self.issues.join(", "))
609 }
610 }
611}
612
613#[cfg(test)]
614mod tests {
615 use super::*;
616
617 #[test]
618 fn test_parse_database_url() {
619 let setup = DatabaseSetup::new("postgresql://user:pass@localhost:5432/testdb".to_string());
620
621 let info = setup
622 .parse_database_url()
623 .expect("Failed to parse valid database URL");
624 assert_eq!(info.host, "localhost");
625 assert_eq!(info.port, 5432);
626 assert_eq!(info.username, "user");
627 assert_eq!(info.password, "pass");
628 assert_eq!(info.database, "testdb");
629 }
630
631 #[test]
632 fn test_parse_database_url_default_port() {
633 let setup = DatabaseSetup::new("postgresql://user:pass@localhost/testdb".to_string());
634
635 let info = setup
636 .parse_database_url()
637 .expect("Failed to parse database URL with default port");
638 assert_eq!(info.port, 5432); }
640
641 #[test]
642 fn test_parse_invalid_database_url() {
643 let setup = DatabaseSetup::new("invalid-url".to_string());
644 assert!(setup.parse_database_url().is_err());
645
646 let setup = DatabaseSetup::new("http://localhost/db".to_string());
647 assert!(setup.parse_database_url().is_err());
648 }
649}