codex_memory/
setup.rs

1use anyhow::{Context, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use tokio_postgres::{Config as PgConfig, NoTls};
6use tracing::{error, info, warn};
7
8use crate::config::Config;
9use crate::embedding::SimpleEmbedder;
10
11/// Available embedding models with their configurations
12#[derive(Debug, Clone)]
13pub struct EmbeddingModelInfo {
14    pub name: String,
15    pub dimensions: usize,
16    pub description: String,
17    pub preferred: bool,
18}
19
20/// Setup manager for the Agentic Memory System
21pub struct SetupManager {
22    client: Client,
23    config: Config,
24}
25
26/// Ollama API response structures
27#[derive(Debug, Deserialize)]
28struct OllamaModel {
29    name: String,
30    #[allow(dead_code)]
31    size: u64,
32    #[serde(default)]
33    #[allow(dead_code)]
34    family: String,
35}
36
37#[derive(Debug, Deserialize)]
38struct OllamaModelsResponse {
39    models: Vec<OllamaModel>,
40}
41
42#[derive(Debug, Serialize)]
43struct OllamaPullRequest {
44    name: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OllamaPullResponse {
49    status: String,
50    #[serde(default)]
51    completed: Option<u64>,
52    #[serde(default)]
53    total: Option<u64>,
54}
55
56impl SetupManager {
57    pub fn new(config: Config) -> Self {
58        let client = Client::builder()
59            .timeout(Duration::from_secs(120))
60            .build()
61            .expect("Failed to create HTTP client");
62
63        Self { client, config }
64    }
65
66    /// Run complete setup process
67    pub async fn run_setup(&self) -> Result<()> {
68        info!("๐Ÿš€ Starting Agentic Memory System setup...");
69
70        // 1. Check Ollama connectivity
71        self.check_ollama_connectivity().await?;
72
73        // 2. Detect and pull embedding models
74        let available_models = self.detect_embedding_models().await?;
75        let selected_model = self.ensure_embedding_model(available_models).await?;
76
77        // 3. Update configuration with selected model
78        let mut updated_config = self.config.clone();
79        updated_config.embedding.model = selected_model.name.clone();
80        
81        // 4. Test embedding generation
82        self.test_embedding_generation(&updated_config).await?;
83
84        // 5. Setup database
85        self.setup_database().await?;
86
87        // 6. Run comprehensive health checks
88        self.run_health_checks(&updated_config).await?;
89
90        info!("โœ… Setup completed successfully!");
91        info!("Selected embedding model: {} ({}D)", selected_model.name, selected_model.dimensions);
92        
93        Ok(())
94    }
95
96    /// Check if Ollama is running and accessible
97    async fn check_ollama_connectivity(&self) -> Result<()> {
98        info!("๐Ÿ” Checking Ollama connectivity at {}", self.config.embedding.base_url);
99
100        let response = self
101            .client
102            .get(&format!("{}/api/tags", self.config.embedding.base_url))
103            .send()
104            .await
105            .context("Failed to connect to Ollama. Is it running and accessible?")?;
106
107        if !response.status().is_success() {
108            return Err(anyhow::anyhow!(
109                "Ollama returned error status: {}",
110                response.status()
111            ));
112        }
113
114        info!("โœ… Ollama is running and accessible");
115        Ok(())
116    }
117
118    /// Detect available embedding models on Ollama
119    async fn detect_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>> {
120        info!("๐Ÿ” Detecting available embedding models...");
121
122        let response = self
123            .client
124            .get(&format!("{}/api/tags", self.config.embedding.base_url))
125            .send()
126            .await?;
127
128        let models_response: OllamaModelsResponse = response.json().await?;
129        
130        let mut embedding_models = Vec::new();
131        
132        for model in models_response.models {
133            if let Some(model_info) = self.classify_embedding_model(&model.name) {
134                embedding_models.push(model_info);
135            }
136        }
137
138        if embedding_models.is_empty() {
139            warn!("No embedding models found on Ollama");
140        } else {
141            info!("Found {} embedding models:", embedding_models.len());
142            for model in &embedding_models {
143                info!("  - {} ({}D) {}", 
144                    model.name, 
145                    model.dimensions,
146                    if model.preferred { "โญ RECOMMENDED" } else { "" }
147                );
148            }
149        }
150
151        Ok(embedding_models)
152    }
153
154    /// Classify a model name as an embedding model and return its info
155    fn classify_embedding_model(&self, model_name: &str) -> Option<EmbeddingModelInfo> {
156        let name_lower = model_name.to_lowercase();
157        
158        // Define known embedding models with their properties
159        let known_models = [
160            ("nomic-embed-text", 768, "High-quality text embeddings", true),
161            ("mxbai-embed-large", 1024, "Large multilingual embeddings", true),
162            ("all-minilm", 384, "Compact sentence embeddings", false),
163            ("all-mpnet-base-v2", 768, "Sentence transformer embeddings", false),
164            ("bge-small-en", 384, "BGE small English embeddings", false),
165            ("bge-base-en", 768, "BGE base English embeddings", false),
166            ("bge-large-en", 1024, "BGE large English embeddings", false),
167            ("e5-small", 384, "E5 small embeddings", false),
168            ("e5-base", 768, "E5 base embeddings", false),
169            ("e5-large", 1024, "E5 large embeddings", false),
170        ];
171
172        for (pattern, dimensions, description, preferred) in known_models {
173            if name_lower.contains(pattern) || model_name.contains(pattern) {
174                return Some(EmbeddingModelInfo {
175                    name: model_name.to_string(),
176                    dimensions,
177                    description: description.to_string(),
178                    preferred,
179                });
180            }
181        }
182
183        // Check if it's likely an embedding model based on common patterns
184        if name_lower.contains("embed") || 
185           name_lower.contains("sentence") || 
186           name_lower.contains("vector") {
187            return Some(EmbeddingModelInfo {
188                name: model_name.to_string(),
189                dimensions: 768, // Default assumption
190                description: "Detected embedding model".to_string(),
191                preferred: false,
192            });
193        }
194
195        None
196    }
197
198    /// Ensure a suitable embedding model is available, pulling if necessary
199    async fn ensure_embedding_model(&self, available_models: Vec<EmbeddingModelInfo>) -> Result<EmbeddingModelInfo> {
200        info!("๐ŸŽฏ Selecting embedding model...");
201
202        // If we have a preferred model available, use it
203        if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
204            info!("โœ… Using preferred model: {}", preferred.name);
205            return Ok(preferred.clone());
206        }
207
208        // If we have any available model, use the first one
209        if !available_models.is_empty() {
210            let selected = available_models[0].clone();
211            info!("โœ… Using available model: {}", selected.name);
212            return Ok(selected);
213        }
214
215        // No embedding models available, try to pull recommended ones
216        info!("๐Ÿ“ฅ No embedding models found. Attempting to pull recommended models...");
217
218        let recommended_models = [
219            ("nomic-embed-text", 768, "High-quality text embeddings"),
220            ("mxbai-embed-large", 1024, "Large multilingual embeddings"),
221            ("all-minilm", 384, "Compact sentence embeddings"),
222        ];
223
224        for (model_name, dimensions, description) in recommended_models {
225            info!("๐Ÿ“ฅ Attempting to pull model: {}", model_name);
226            
227            match self.pull_model(model_name).await {
228                Ok(_) => {
229                    info!("โœ… Successfully pulled model: {}", model_name);
230                    return Ok(EmbeddingModelInfo {
231                        name: model_name.to_string(),
232                        dimensions,
233                        description: description.to_string(),
234                        preferred: true,
235                    });
236                }
237                Err(e) => {
238                    warn!("Failed to pull model {}: {}", model_name, e);
239                    continue;
240                }
241            }
242        }
243
244        Err(anyhow::anyhow!(
245            "Failed to find or pull any suitable embedding models. Please manually pull an embedding model using 'ollama pull nomic-embed-text'"
246        ))
247    }
248
249    /// Pull a model from Ollama
250    async fn pull_model(&self, model_name: &str) -> Result<()> {
251        info!("๐Ÿ“ฅ Pulling model: {}", model_name);
252
253        let request = OllamaPullRequest {
254            name: model_name.to_string(),
255        };
256
257        let response = self
258            .client
259            .post(&format!("{}/api/pull", self.config.embedding.base_url))
260            .json(&request)
261            .send()
262            .await?;
263
264        if !response.status().is_success() {
265            let status = response.status();
266            let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
267            return Err(anyhow::anyhow!(
268                "Failed to pull model {}: HTTP {} - {}",
269                model_name,
270                status,
271                error_text
272            ));
273        }
274
275        // Stream the response to show progress
276        let lines = response.text().await?;
277        
278        // Ollama returns JSONL (JSON Lines) for streaming responses
279        for line in lines.lines() {
280            if line.trim().is_empty() {
281                continue;
282            }
283            
284            match serde_json::from_str::<OllamaPullResponse>(line) {
285                Ok(pull_response) => {
286                    match pull_response.status.as_str() {
287                        "downloading" => {
288                            if let (Some(completed), Some(total)) = (pull_response.completed, pull_response.total) {
289                                let progress = (completed as f64 / total as f64) * 100.0;
290                                info!("  ๐Ÿ“Š Downloading: {:.1}% ({}/{})", progress, completed, total);
291                            }
292                        }
293                        "verifying sha256" => {
294                            info!("  ๐Ÿ” Verifying checksum...");
295                        }
296                        "success" => {
297                            info!("  โœ… Pull completed successfully");
298                            return Ok(());
299                        }
300                        status => {
301                            info!("  ๐Ÿ“ฆ Status: {}", status);
302                        }
303                    }
304                }
305                Err(_) => {
306                    // Sometimes Ollama sends non-JSON status lines
307                    if line.contains("success") {
308                        info!("  โœ… Pull completed successfully");
309                        return Ok(());
310                    }
311                    info!("  ๐Ÿ“ฆ {}", line);
312                }
313            }
314        }
315
316        Ok(())
317    }
318
319    /// Test embedding generation with the selected model
320    async fn test_embedding_generation(&self, config: &Config) -> Result<()> {
321        info!("๐Ÿงช Testing embedding generation...");
322
323        let embedder = SimpleEmbedder::new_ollama(
324            config.embedding.base_url.clone(),
325            config.embedding.model.clone(),
326        );
327
328        let test_text = "This is a test sentence for embedding generation.";
329        
330        match embedder.generate_embedding(test_text).await {
331            Ok(embedding) => {
332                info!("โœ… Embedding generation successful!");
333                info!("  ๐Ÿ“Š Embedding dimensions: {}", embedding.len());
334                info!("  ๐Ÿ“Š Sample values: [{:.4}, {:.4}, {:.4}, ...]", 
335                    embedding.get(0).unwrap_or(&0.0),
336                    embedding.get(1).unwrap_or(&0.0),
337                    embedding.get(2).unwrap_or(&0.0)
338                );
339                Ok(())
340            }
341            Err(e) => {
342                error!("โŒ Embedding generation failed: {}", e);
343                Err(e)
344            }
345        }
346    }
347
348    /// Setup database with required extensions and tables
349    async fn setup_database(&self) -> Result<()> {
350        info!("๐Ÿ—„๏ธ  Setting up database...");
351
352        // Parse the database URL
353        let db_config: PgConfig = self.config.database_url.parse()
354            .context("Invalid database URL")?;
355
356        // Connect to the database
357        let (client, connection) = db_config.connect(NoTls).await
358            .context("Failed to connect to database")?;
359
360        // Spawn the connection
361        tokio::spawn(async move {
362            if let Err(e) = connection.await {
363                error!("Database connection error: {}", e);
364            }
365        });
366
367        // Check if pgvector extension is available
368        info!("๐Ÿ” Checking for pgvector extension...");
369        
370        let extension_check = client
371            .query("SELECT 1 FROM pg_available_extensions WHERE name = 'vector'", &[])
372            .await?;
373
374        if extension_check.is_empty() {
375            warn!("โš ๏ธ  pgvector extension is not available in this PostgreSQL instance");
376            warn!("   Please install pgvector: https://github.com/pgvector/pgvector");
377            return Err(anyhow::anyhow!("pgvector extension not available"));
378        }
379
380        // Enable pgvector extension
381        info!("๐Ÿ”ง Enabling pgvector extension...");
382        client
383            .execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
384            .await
385            .context("Failed to enable pgvector extension")?;
386
387        // Check if our tables exist
388        info!("๐Ÿ” Checking database schema...");
389        let table_check = client
390            .query(
391                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'memories'",
392                &[]
393            )
394            .await?;
395
396        if table_check.is_empty() {
397            info!("๐Ÿ“‹ Running database migrations...");
398            // Run migrations using the migration crate
399            // This would typically be done through the migration module
400            warn!("โš ๏ธ  Please run database migrations: cargo run --bin migration");
401        } else {
402            info!("โœ… Database schema is ready");
403        }
404
405        Ok(())
406    }
407
408    /// Run comprehensive health checks
409    pub async fn run_health_checks(&self, config: &Config) -> Result<()> {
410        info!("๐Ÿฉบ Running comprehensive health checks...");
411
412        let mut checks_passed = 0;
413        let mut total_checks = 0;
414
415        // Check 1: Ollama connectivity
416        total_checks += 1;
417        match self.check_ollama_connectivity().await {
418            Ok(_) => {
419                info!("  โœ… Ollama connectivity");
420                checks_passed += 1;
421            }
422            Err(e) => {
423                error!("  โŒ Ollama connectivity: {}", e);
424            }
425        }
426
427        // Check 2: Embedding model availability
428        total_checks += 1;
429        let embedder = SimpleEmbedder::new_ollama(
430            config.embedding.base_url.clone(),
431            config.embedding.model.clone(),
432        );
433
434        match embedder.generate_embedding("health check").await {
435            Ok(_) => {
436                info!("  โœ… Embedding generation");
437                checks_passed += 1;
438            }
439            Err(e) => {
440                error!("  โŒ Embedding generation: {}", e);
441            }
442        }
443
444        // Check 3: Database connectivity
445        total_checks += 1;
446        match self.check_database_connectivity().await {
447            Ok(_) => {
448                info!("  โœ… Database connectivity");
449                checks_passed += 1;
450            }
451            Err(e) => {
452                error!("  โŒ Database connectivity: {}", e);
453            }
454        }
455
456        // Check 4: pgvector extension
457        total_checks += 1;
458        match self.check_pgvector_extension().await {
459            Ok(_) => {
460                info!("  โœ… pgvector extension");
461                checks_passed += 1;
462            }
463            Err(e) => {
464                error!("  โŒ pgvector extension: {}", e);
465            }
466        }
467
468        // Summary
469        info!("๐Ÿ“Š Health check summary: {}/{} checks passed", checks_passed, total_checks);
470
471        if checks_passed == total_checks {
472            info!("๐ŸŽ‰ All health checks passed! System is ready.");
473            Ok(())
474        } else {
475            Err(anyhow::anyhow!(
476                "Some health checks failed. Please address the issues above."
477            ))
478        }
479    }
480
481    /// Check database connectivity
482    async fn check_database_connectivity(&self) -> Result<()> {
483        let db_config: PgConfig = self.config.database_url.parse()?;
484        let (client, connection) = db_config.connect(NoTls).await?;
485
486        tokio::spawn(async move {
487            if let Err(e) = connection.await {
488                error!("Database connection error: {}", e);
489            }
490        });
491
492        // Simple connectivity test
493        client.query("SELECT 1", &[]).await?;
494        Ok(())
495    }
496
497    /// Check pgvector extension
498    async fn check_pgvector_extension(&self) -> Result<()> {
499        let db_config: PgConfig = self.config.database_url.parse()?;
500        let (client, connection) = db_config.connect(NoTls).await?;
501
502        tokio::spawn(async move {
503            if let Err(e) = connection.await {
504                error!("Database connection error: {}", e);
505            }
506        });
507
508        // Check if pgvector is installed and functional
509        client
510            .query("SELECT vector_dims(vector '[1,2,3]')", &[])
511            .await
512            .context("pgvector extension not available or not working")?;
513
514        Ok(())
515    }
516
517    /// List available models for user selection
518    pub async fn list_available_models(&self) -> Result<()> {
519        info!("๐Ÿ“‹ Available embedding models:");
520
521        let available_models = self.detect_embedding_models().await?;
522
523        if available_models.is_empty() {
524            info!("  No embedding models currently available");
525            info!("  Recommended models to pull:");
526            info!("    ollama pull nomic-embed-text");
527            info!("    ollama pull mxbai-embed-large");
528            info!("    ollama pull all-minilm");
529        } else {
530            for model in available_models {
531                let icon = if model.preferred { "โญ" } else { "  " };
532                info!("{} {} ({}D) - {}", icon, model.name, model.dimensions, model.description);
533            }
534        }
535
536        Ok(())
537    }
538
539    /// Quick health check without setup
540    pub async fn quick_health_check(&self) -> Result<()> {
541        info!("๐Ÿฅ Running quick health check...");
542
543        // Check Ollama
544        match self.check_ollama_connectivity().await {
545            Ok(_) => info!("โœ… Ollama: Running"),
546            Err(_) => info!("โŒ Ollama: Not accessible"),
547        }
548
549        // Check database
550        match self.check_database_connectivity().await {
551            Ok(_) => info!("โœ… Database: Connected"),
552            Err(_) => info!("โŒ Database: Connection failed"),
553        }
554
555        // Check embedding model
556        let embedder = SimpleEmbedder::new_ollama(
557            self.config.embedding.base_url.clone(),
558            self.config.embedding.model.clone(),
559        );
560
561        match embedder.generate_embedding("test").await {
562            Ok(_) => info!("โœ… Embeddings: Working"),
563            Err(_) => info!("โŒ Embeddings: Failed"),
564        }
565
566        Ok(())
567    }
568}
569
570/// Create a sample .env file with default configuration
571pub fn create_sample_env_file() -> Result<()> {
572    let env_content = r#"# Agentic Memory System Configuration
573
574# Database Configuration
575DATABASE_URL=postgresql://postgres:postgres@localhost:5432/codex_memory
576
577# Embedding Configuration
578EMBEDDING_PROVIDER=ollama
579EMBEDDING_MODEL=nomic-embed-text
580EMBEDDING_BASE_URL=http://192.168.1.110:11434
581EMBEDDING_TIMEOUT_SECONDS=60
582
583# Server Configuration
584HTTP_PORT=8080
585LOG_LEVEL=info
586
587# Memory Tier Configuration
588WORKING_TIER_LIMIT=1000
589WARM_TIER_LIMIT=10000
590WORKING_TO_WARM_DAYS=7
591WARM_TO_COLD_DAYS=30
592IMPORTANCE_THRESHOLD=0.7
593
594# Operational Configuration
595MAX_DB_CONNECTIONS=10
596REQUEST_TIMEOUT_SECONDS=30
597ENABLE_METRICS=true
598"#;
599
600    std::fs::write(".env.example", env_content)
601        .context("Failed to create .env.example file")?;
602
603    info!("๐Ÿ“‹ Created .env.example file with default configuration");
604    info!("   Copy this to .env and modify as needed");
605    
606    Ok(())
607}
608
609#[cfg(test)]
610mod tests {
611    use super::*;
612
613    #[test]
614    fn test_classify_embedding_model() {
615        let setup = SetupManager::new(Config::default());
616
617        // Known models
618        let nomic = setup.classify_embedding_model("nomic-embed-text").unwrap();
619        assert_eq!(nomic.dimensions, 768);
620        assert!(nomic.preferred);
621
622        let mxbai = setup.classify_embedding_model("mxbai-embed-large").unwrap();
623        assert_eq!(mxbai.dimensions, 1024);
624        assert!(mxbai.preferred);
625
626        // Unknown embedding model
627        let unknown = setup.classify_embedding_model("custom-embed-model").unwrap();
628        assert_eq!(unknown.dimensions, 768); // Default
629        assert!(!unknown.preferred);
630
631        // Non-embedding model
632        let non_embed = setup.classify_embedding_model("llama2");
633        assert!(non_embed.is_none());
634    }
635
636    #[test]
637    fn test_known_models_classification() {
638        let setup = SetupManager::new(Config::default());
639
640        let test_cases = [
641            ("nomic-embed-text", true, 768),
642            ("all-minilm", false, 384),
643            ("bge-base-en", false, 768),
644            ("e5-large", false, 1024),
645        ];
646
647        for (model_name, expected_preferred, expected_dims) in test_cases {
648            let result = setup.classify_embedding_model(model_name);
649            assert!(result.is_some(), "Should classify {} as embedding model", model_name);
650            
651            let info = result.unwrap();
652            assert_eq!(info.preferred, expected_preferred);
653            assert_eq!(info.dimensions, expected_dims);
654        }
655    }
656}