1use anyhow::{Context, Result};
2use tokio_postgres::{Config as PgConfig, NoTls};
3use tracing::{error, info};
4use url::Url;
5
6pub struct DatabaseSetup {
8 database_url: String,
9}
10
11impl DatabaseSetup {
12 pub fn new(database_url: String) -> Self {
13 Self { database_url }
14 }
15
16 pub async fn setup(&self) -> Result<()> {
18 info!("๐๏ธ Starting database setup...");
19
20 let db_info = self.parse_database_url()?;
22 info!("Database: {} on {}:{}", db_info.database, db_info.host, db_info.port);
23
24 self.check_postgresql_running(&db_info).await?;
26
27 self.ensure_database_exists(&db_info).await?;
29
30 self.check_pgvector_availability(&db_info).await?;
32
33 self.install_pgvector_extension().await?;
35
36 self.run_migrations().await?;
38
39 self.verify_setup().await?;
41
42 info!("โ
Database setup completed successfully!");
43 Ok(())
44 }
45
46 fn parse_database_url(&self) -> Result<DatabaseInfo> {
48 let url = Url::parse(&self.database_url)
49 .context("Invalid database URL format")?;
50
51 if url.scheme() != "postgresql" && url.scheme() != "postgres" {
52 return Err(anyhow::anyhow!("Database URL must use postgresql:// or postgres:// scheme"));
53 }
54
55 let host = url.host_str()
56 .ok_or_else(|| anyhow::anyhow!("Database URL missing host"))?
57 .to_string();
58
59 let port = url.port().unwrap_or(5432);
60
61 let username = url.username();
62 if username.is_empty() {
63 return Err(anyhow::anyhow!("Database URL missing username"));
64 }
65
66 let password = url.password().unwrap_or("");
67
68 let database = url.path().trim_start_matches('/');
69 if database.is_empty() {
70 return Err(anyhow::anyhow!("Database URL missing database name"));
71 }
72
73 Ok(DatabaseInfo {
74 host,
75 port,
76 username: username.to_string(),
77 password: password.to_string(),
78 database: database.to_string(),
79 })
80 }
81
82 async fn check_postgresql_running(&self, db_info: &DatabaseInfo) -> Result<()> {
84 info!("๐ Checking PostgreSQL connectivity...");
85
86 let system_url = format!(
88 "postgresql://{}:{}@{}:{}/postgres",
89 db_info.username, db_info.password, db_info.host, db_info.port
90 );
91
92 let config: PgConfig = system_url.parse()
93 .context("Failed to parse system database URL")?;
94
95 match config.connect(NoTls).await {
96 Ok((client, connection)) => {
97 tokio::spawn(async move {
98 if let Err(e) = connection.await {
99 error!("System database connection error: {}", e);
100 }
101 });
102
103 client.query("SELECT version()", &[]).await
105 .context("Failed to query PostgreSQL version")?;
106
107 info!("โ
PostgreSQL is running and accessible");
108 Ok(())
109 }
110 Err(e) => {
111 error!("โ Cannot connect to PostgreSQL: {}", e);
112 info!("๐ก Please ensure PostgreSQL is installed and running");
113 info!("๐ก Common solutions:");
114 info!(" - Start PostgreSQL: brew services start postgresql");
115 info!(" - Or: sudo systemctl start postgresql");
116 info!(" - Check connection details in DATABASE_URL");
117 Err(anyhow::anyhow!("PostgreSQL is not accessible: {}", e))
118 }
119 }
120 }
121
122 async fn ensure_database_exists(&self, db_info: &DatabaseInfo) -> Result<()> {
124 info!("๐ Checking if database '{}' exists...", db_info.database);
125
126 let system_url = format!(
128 "postgresql://{}:{}@{}:{}/postgres",
129 db_info.username, db_info.password, db_info.host, db_info.port
130 );
131
132 let config: PgConfig = system_url.parse()?;
133 let (client, connection) = config.connect(NoTls).await?;
134
135 tokio::spawn(async move {
136 if let Err(e) = connection.await {
137 error!("System database connection error: {}", e);
138 }
139 });
140
141 let rows = client
143 .query(
144 "SELECT 1 FROM pg_database WHERE datname = $1",
145 &[&db_info.database],
146 )
147 .await?;
148
149 if rows.is_empty() {
150 info!("๐ Database '{}' does not exist, creating...", db_info.database);
151
152 let create_query = format!("CREATE DATABASE \"{}\"", db_info.database);
154 client.execute(&create_query, &[]).await
155 .context("Failed to create database")?;
156
157 info!("โ
Database '{}' created successfully", db_info.database);
158 } else {
159 info!("โ
Database '{}' already exists", db_info.database);
160 }
161
162 Ok(())
163 }
164
165 async fn check_pgvector_availability(&self, _db_info: &DatabaseInfo) -> Result<()> {
167 info!("๐ Checking pgvector extension availability...");
168
169 let config: PgConfig = self.database_url.parse()?;
170 let (client, connection) = config.connect(NoTls).await?;
171
172 tokio::spawn(async move {
173 if let Err(e) = connection.await {
174 error!("Database connection error: {}", e);
175 }
176 });
177
178 let rows = client
180 .query(
181 "SELECT name, default_version FROM pg_available_extensions WHERE name = 'vector'",
182 &[],
183 )
184 .await?;
185
186 if rows.is_empty() {
187 error!("โ pgvector extension is not available");
188 info!("๐ก Please install pgvector extension:");
189 info!(" ๐ On macOS (Homebrew): brew install pgvector");
190 info!(" ๐ On Ubuntu/Debian: apt install postgresql-15-pgvector");
191 info!(" ๐ From source: https://github.com/pgvector/pgvector");
192 return Err(anyhow::anyhow!("pgvector extension not available"));
193 } else {
194 let row = &rows[0];
195 let version: String = row.get(1);
196 info!("โ
pgvector extension available (version: {})", version);
197 }
198
199 Ok(())
200 }
201
202 async fn install_pgvector_extension(&self) -> Result<()> {
204 info!("๐ง Installing pgvector extension...");
205
206 let config: PgConfig = self.database_url.parse()?;
207 let (client, connection) = config.connect(NoTls).await?;
208
209 tokio::spawn(async move {
210 if let Err(e) = connection.await {
211 error!("Database connection error: {}", e);
212 }
213 });
214
215 let rows = client
217 .query(
218 "SELECT extname FROM pg_extension WHERE extname = 'vector'",
219 &[],
220 )
221 .await?;
222
223 if !rows.is_empty() {
224 info!("โ
pgvector extension already installed");
225 return Ok(());
226 }
227
228 client
230 .execute("CREATE EXTENSION vector", &[])
231 .await
232 .context("Failed to install pgvector extension")?;
233
234 info!("โ
pgvector extension installed successfully");
235
236 client
238 .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
239 .await
240 .context("pgvector extension installation verification failed")?;
241
242 info!("โ
pgvector extension verification passed");
243 Ok(())
244 }
245
246 async fn run_migrations(&self) -> Result<()> {
248 info!("๐ Running database migrations...");
249
250 let config: PgConfig = self.database_url.parse()?;
251 let (client, connection) = config.connect(NoTls).await?;
252
253 tokio::spawn(async move {
254 if let Err(e) = connection.await {
255 error!("Database connection error: {}", e);
256 }
257 });
258
259 let tracking_exists = client
261 .query(
262 "SELECT 1 FROM information_schema.tables WHERE table_name = 'migration_history'",
263 &[],
264 )
265 .await?;
266
267 if tracking_exists.is_empty() {
268 info!("๐ Creating migration tracking table...");
269 client
270 .execute(
271 r#"
272 CREATE TABLE migration_history (
273 id SERIAL PRIMARY KEY,
274 migration_name VARCHAR(255) NOT NULL UNIQUE,
275 applied_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
276 )
277 "#,
278 &[],
279 )
280 .await?;
281 }
282
283 let tables_exist = client
285 .query(
286 "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name IN ('memories', 'memory_tiers')",
287 &[],
288 )
289 .await?;
290
291 if tables_exist.len() < 2 {
292 info!("๐ Creating main schema tables...");
293 self.create_main_schema(&client).await?;
294 } else {
295 info!("โ
Main schema tables already exist");
296 }
297
298 Ok(())
299 }
300
301 async fn create_main_schema(&self, client: &tokio_postgres::Client) -> Result<()> {
303 info!("๐ Creating main database schema...");
304
305 client
307 .execute(
308 r#"
309 CREATE TABLE IF NOT EXISTS memories (
310 id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
311 content TEXT NOT NULL,
312 embedding VECTOR(768),
313 metadata JSONB DEFAULT '{}',
314 tier VARCHAR(20) NOT NULL DEFAULT 'working',
315 importance_score FLOAT DEFAULT 0.0,
316 access_count INTEGER DEFAULT 0,
317 last_accessed TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
318 created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
319 updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
320 )
321 "#,
322 &[],
323 )
324 .await
325 .context("Failed to create memories table")?;
326
327 client
329 .execute(
330 "CREATE INDEX IF NOT EXISTS memories_embedding_idx ON memories USING hnsw (embedding vector_cosine_ops)",
331 &[],
332 )
333 .await
334 .context("Failed to create embedding index")?;
335
336 client
338 .execute(
339 "CREATE INDEX IF NOT EXISTS memories_tier_idx ON memories (tier)",
340 &[],
341 )
342 .await
343 .context("Failed to create tier index")?;
344
345 client
346 .execute(
347 "CREATE INDEX IF NOT EXISTS memories_last_accessed_idx ON memories (last_accessed DESC)",
348 &[],
349 )
350 .await
351 .context("Failed to create last_accessed index")?;
352
353 client
354 .execute(
355 "CREATE INDEX IF NOT EXISTS memories_importance_idx ON memories (importance_score DESC)",
356 &[],
357 )
358 .await
359 .context("Failed to create importance index")?;
360
361 client
363 .execute(
364 r#"
365 CREATE TABLE IF NOT EXISTS memory_tiers (
366 id SERIAL PRIMARY KEY,
367 tier_name VARCHAR(20) NOT NULL UNIQUE,
368 max_capacity INTEGER,
369 current_count INTEGER DEFAULT 0,
370 retention_days INTEGER,
371 created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
372 )
373 "#,
374 &[],
375 )
376 .await
377 .context("Failed to create memory_tiers table")?;
378
379 client
381 .execute(
382 r#"
383 INSERT INTO memory_tiers (tier_name, max_capacity, retention_days)
384 VALUES
385 ('working', 1000, 7),
386 ('warm', 10000, 30),
387 ('cold', NULL, NULL)
388 ON CONFLICT (tier_name) DO NOTHING
389 "#,
390 &[],
391 )
392 .await
393 .context("Failed to insert default tiers")?;
394
395 client
397 .execute(
398 "INSERT INTO migration_history (migration_name) VALUES ('001_initial_schema') ON CONFLICT (migration_name) DO NOTHING",
399 &[],
400 )
401 .await?;
402
403 info!("โ
Main database schema created successfully");
404 Ok(())
405 }
406
407 async fn verify_setup(&self) -> Result<()> {
409 info!("๐ Verifying database setup...");
410
411 let config: PgConfig = self.database_url.parse()?;
412 let (client, connection) = config.connect(NoTls).await?;
413
414 tokio::spawn(async move {
415 if let Err(e) = connection.await {
416 error!("Database connection error: {}", e);
417 }
418 });
419
420 let required_tables = ["memories", "memory_tiers", "migration_history"];
422 for table in &required_tables {
423 let rows = client
424 .query(
425 "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
426 &[table],
427 )
428 .await?;
429
430 if rows.is_empty() {
431 return Err(anyhow::anyhow!("Required table '{}' not found", table));
432 }
433 }
434
435 client
437 .query("SELECT vector_dims('[1,2,3]'::vector)", &[])
438 .await
439 .context("pgvector extension not working")?;
440
441 info!("๐งช Testing vector operations...");
443
444 let test_vector = vec![0.1f32; 768].iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",");
446 client
447 .execute(
448 &format!("INSERT INTO memories (content, embedding) VALUES ($1, '[{}]'::vector) ON CONFLICT DO NOTHING", test_vector),
449 &[&"Setup test memory"],
450 )
451 .await
452 .context("Failed to insert test memory")?;
453
454 let query_vector = vec![0.1f32; 768].iter().map(|f| f.to_string()).collect::<Vec<_>>().join(",");
456 client
457 .query(
458 &format!("SELECT content FROM memories ORDER BY embedding <-> '[{}]'::vector LIMIT 1", query_vector),
459 &[],
460 )
461 .await
462 .context("Failed to perform vector similarity search")?;
463
464 client
466 .execute(
467 "DELETE FROM memories WHERE content = 'Setup test memory'",
468 &[],
469 )
470 .await?;
471
472 info!("โ
Database setup verification passed");
473 Ok(())
474 }
475
476 pub async fn health_check(&self) -> Result<DatabaseHealth> {
478 let config: PgConfig = self.database_url.parse()?;
479 let (client, connection) = config.connect(NoTls).await?;
480
481 tokio::spawn(async move {
482 if let Err(e) = connection.await {
483 error!("Database connection error: {}", e);
484 }
485 });
486
487 let mut health = DatabaseHealth::default();
488
489 match client.query("SELECT 1", &[]).await {
491 Ok(_) => health.connectivity = true,
492 Err(e) => {
493 health.connectivity = false;
494 health.issues.push(format!("Connectivity failed: {}", e));
495 }
496 }
497
498 match client.query("SELECT 1 FROM pg_extension WHERE extname = 'vector'", &[]).await {
500 Ok(rows) => {
501 health.pgvector_installed = !rows.is_empty();
502 if !health.pgvector_installed {
503 health.issues.push("pgvector extension not installed".to_string());
504 }
505 }
506 Err(e) => {
507 health.issues.push(format!("Failed to check pgvector: {}", e));
508 }
509 }
510
511 let required_tables = ["memories", "memory_tiers"];
513 let mut tables_found = 0;
514 for table in &required_tables {
515 match client
516 .query(
517 "SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1",
518 &[table],
519 )
520 .await
521 {
522 Ok(rows) => {
523 if rows.is_empty() {
524 health.issues.push(format!("Table '{}' missing", table));
525 } else {
526 tables_found += 1;
527 }
528 }
529 Err(e) => {
530 health.issues.push(format!("Failed to check table {}: {}", table, e));
531 }
532 }
533 }
534 health.schema_ready = tables_found == required_tables.len();
535
536 match client.query("SELECT COUNT(*) FROM memories", &[]).await {
538 Ok(rows) => {
539 let count: i64 = rows[0].get(0);
540 health.memory_count = count as usize;
541 }
542 Err(e) => {
543 health.issues.push(format!("Failed to get memory count: {}", e));
544 }
545 }
546
547 Ok(health)
548 }
549}
550
551#[derive(Debug)]
553struct DatabaseInfo {
554 host: String,
555 port: u16,
556 username: String,
557 password: String,
558 database: String,
559}
560
561#[derive(Debug, Default)]
563pub struct DatabaseHealth {
564 pub connectivity: bool,
565 pub pgvector_installed: bool,
566 pub schema_ready: bool,
567 pub memory_count: usize,
568 pub issues: Vec<String>,
569}
570
571impl DatabaseHealth {
572 pub fn is_healthy(&self) -> bool {
573 self.connectivity && self.pgvector_installed && self.schema_ready && self.issues.is_empty()
574 }
575
576 pub fn status_summary(&self) -> String {
577 if self.is_healthy() {
578 format!("โ
Healthy ({} memories)", self.memory_count)
579 } else {
580 format!("โ Issues: {}", self.issues.join(", "))
581 }
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_parse_database_url() {
591 let setup = DatabaseSetup::new(
592 "postgresql://user:pass@localhost:5432/testdb".to_string()
593 );
594
595 let info = setup.parse_database_url().unwrap();
596 assert_eq!(info.host, "localhost");
597 assert_eq!(info.port, 5432);
598 assert_eq!(info.username, "user");
599 assert_eq!(info.password, "pass");
600 assert_eq!(info.database, "testdb");
601 }
602
603 #[test]
604 fn test_parse_database_url_default_port() {
605 let setup = DatabaseSetup::new(
606 "postgresql://user:pass@localhost/testdb".to_string()
607 );
608
609 let info = setup.parse_database_url().unwrap();
610 assert_eq!(info.port, 5432); }
612
613 #[test]
614 fn test_parse_invalid_database_url() {
615 let setup = DatabaseSetup::new("invalid-url".to_string());
616 assert!(setup.parse_database_url().is_err());
617
618 let setup = DatabaseSetup::new("http://localhost/db".to_string());
619 assert!(setup.parse_database_url().is_err());
620 }
621}