codex_memory/
setup.rs

1use anyhow::{Context, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use tokio_postgres::{Config as PgConfig, NoTls};
6use tracing::{error, info, warn};
7
8use crate::config::Config;
9use crate::embedding::SimpleEmbedder;
10
11/// Available embedding models with their configurations
12#[derive(Debug, Clone)]
13pub struct EmbeddingModelInfo {
14    pub name: String,
15    pub dimensions: usize,
16    pub description: String,
17    pub preferred: bool,
18}
19
20/// Setup manager for the Agentic Memory System
21pub struct SetupManager {
22    client: Client,
23    config: Config,
24}
25
26/// Ollama API response structures
27#[derive(Debug, Deserialize)]
28struct OllamaModel {
29    name: String,
30    #[allow(dead_code)]
31    size: u64,
32    #[serde(default)]
33    #[allow(dead_code)]
34    family: String,
35}
36
37#[derive(Debug, Deserialize)]
38struct OllamaModelsResponse {
39    models: Vec<OllamaModel>,
40}
41
42#[derive(Debug, Serialize)]
43struct OllamaPullRequest {
44    name: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OllamaPullResponse {
49    status: String,
50    #[serde(default)]
51    completed: Option<u64>,
52    #[serde(default)]
53    total: Option<u64>,
54}
55
56impl SetupManager {
57    pub fn new(config: Config) -> Self {
58        let client = Client::builder()
59            .timeout(Duration::from_secs(120))
60            .build()
61            .expect("Failed to create HTTP client");
62
63        Self { client, config }
64    }
65
66    /// Run complete setup process
67    pub async fn run_setup(&self) -> Result<()> {
68        info!("๐Ÿš€ Starting Agentic Memory System setup...");
69
70        // 1. Check Ollama connectivity
71        self.check_ollama_connectivity().await?;
72
73        // 2. Detect and pull embedding models
74        let available_models = self.detect_embedding_models().await?;
75        let selected_model = self.ensure_embedding_model(available_models).await?;
76
77        // 3. Update configuration with selected model
78        let mut updated_config = self.config.clone();
79        updated_config.embedding.model = selected_model.name.clone();
80
81        // 4. Test embedding generation
82        self.test_embedding_generation(&updated_config).await?;
83
84        // 5. Setup database
85        self.setup_database().await?;
86
87        // 6. Run comprehensive health checks
88        self.run_health_checks(&updated_config).await?;
89
90        info!("โœ… Setup completed successfully!");
91        info!(
92            "Selected embedding model: {} ({}D)",
93            selected_model.name, selected_model.dimensions
94        );
95
96        Ok(())
97    }
98
99    /// Check if Ollama is running and accessible
100    async fn check_ollama_connectivity(&self) -> Result<()> {
101        info!(
102            "๐Ÿ” Checking Ollama connectivity at {}",
103            self.config.embedding.base_url
104        );
105
106        let response = self
107            .client
108            .get(format!("{}/api/tags", self.config.embedding.base_url))
109            .send()
110            .await
111            .context("Failed to connect to Ollama. Is it running and accessible?")?;
112
113        if !response.status().is_success() {
114            return Err(anyhow::anyhow!(
115                "Ollama returned error status: {}",
116                response.status()
117            ));
118        }
119
120        info!("โœ… Ollama is running and accessible");
121        Ok(())
122    }
123
124    /// Detect available embedding models on Ollama
125    async fn detect_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>> {
126        info!("๐Ÿ” Detecting available embedding models...");
127
128        let response = self
129            .client
130            .get(format!("{}/api/tags", self.config.embedding.base_url))
131            .send()
132            .await?;
133
134        let models_response: OllamaModelsResponse = response.json().await?;
135
136        let mut embedding_models = Vec::new();
137
138        for model in models_response.models {
139            if let Some(model_info) = self.classify_embedding_model(&model.name) {
140                embedding_models.push(model_info);
141            }
142        }
143
144        if embedding_models.is_empty() {
145            warn!("No embedding models found on Ollama");
146        } else {
147            info!("Found {} embedding models:", embedding_models.len());
148            for model in &embedding_models {
149                info!(
150                    "  - {} ({}D) {}",
151                    model.name,
152                    model.dimensions,
153                    if model.preferred {
154                        "โญ RECOMMENDED"
155                    } else {
156                        ""
157                    }
158                );
159            }
160        }
161
162        Ok(embedding_models)
163    }
164
165    /// Classify a model name as an embedding model and return its info
166    fn classify_embedding_model(&self, model_name: &str) -> Option<EmbeddingModelInfo> {
167        let name_lower = model_name.to_lowercase();
168
169        // Define known embedding models with their properties
170        let known_models = [
171            (
172                "nomic-embed-text",
173                768,
174                "High-quality text embeddings",
175                true,
176            ),
177            (
178                "mxbai-embed-large",
179                1024,
180                "Large multilingual embeddings",
181                true,
182            ),
183            ("all-minilm", 384, "Compact sentence embeddings", false),
184            (
185                "all-mpnet-base-v2",
186                768,
187                "Sentence transformer embeddings",
188                false,
189            ),
190            ("bge-small-en", 384, "BGE small English embeddings", false),
191            ("bge-base-en", 768, "BGE base English embeddings", false),
192            ("bge-large-en", 1024, "BGE large English embeddings", false),
193            ("e5-small", 384, "E5 small embeddings", false),
194            ("e5-base", 768, "E5 base embeddings", false),
195            ("e5-large", 1024, "E5 large embeddings", false),
196        ];
197
198        for (pattern, dimensions, description, preferred) in known_models {
199            if name_lower.contains(pattern) || model_name.contains(pattern) {
200                return Some(EmbeddingModelInfo {
201                    name: model_name.to_string(),
202                    dimensions,
203                    description: description.to_string(),
204                    preferred,
205                });
206            }
207        }
208
209        // Check if it's likely an embedding model based on common patterns
210        if name_lower.contains("embed")
211            || name_lower.contains("sentence")
212            || name_lower.contains("vector")
213        {
214            return Some(EmbeddingModelInfo {
215                name: model_name.to_string(),
216                dimensions: 768, // Default assumption
217                description: "Detected embedding model".to_string(),
218                preferred: false,
219            });
220        }
221
222        None
223    }
224
225    /// Ensure a suitable embedding model is available, pulling if necessary
226    async fn ensure_embedding_model(
227        &self,
228        available_models: Vec<EmbeddingModelInfo>,
229    ) -> Result<EmbeddingModelInfo> {
230        info!("๐ŸŽฏ Selecting embedding model...");
231
232        // If we have a preferred model available, use it
233        if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
234            info!("โœ… Using preferred model: {}", preferred.name);
235            return Ok(preferred.clone());
236        }
237
238        // If we have any available model, use the first one
239        if !available_models.is_empty() {
240            let selected = available_models[0].clone();
241            info!("โœ… Using available model: {}", selected.name);
242            return Ok(selected);
243        }
244
245        // No embedding models available, try to pull recommended ones
246        info!("๐Ÿ“ฅ No embedding models found. Attempting to pull recommended models...");
247
248        let recommended_models = [
249            ("nomic-embed-text", 768, "High-quality text embeddings"),
250            ("mxbai-embed-large", 1024, "Large multilingual embeddings"),
251            ("all-minilm", 384, "Compact sentence embeddings"),
252        ];
253
254        for (model_name, dimensions, description) in recommended_models {
255            info!("๐Ÿ“ฅ Attempting to pull model: {}", model_name);
256
257            match self.pull_model(model_name).await {
258                Ok(_) => {
259                    info!("โœ… Successfully pulled model: {}", model_name);
260                    return Ok(EmbeddingModelInfo {
261                        name: model_name.to_string(),
262                        dimensions,
263                        description: description.to_string(),
264                        preferred: true,
265                    });
266                }
267                Err(e) => {
268                    warn!("Failed to pull model {}: {}", model_name, e);
269                    continue;
270                }
271            }
272        }
273
274        Err(anyhow::anyhow!(
275            "Failed to find or pull any suitable embedding models. Please manually pull an embedding model using 'ollama pull nomic-embed-text'"
276        ))
277    }
278
279    /// Pull a model from Ollama
280    async fn pull_model(&self, model_name: &str) -> Result<()> {
281        info!("๐Ÿ“ฅ Pulling model: {}", model_name);
282
283        let request = OllamaPullRequest {
284            name: model_name.to_string(),
285        };
286
287        let response = self
288            .client
289            .post(format!("{}/api/pull", self.config.embedding.base_url))
290            .json(&request)
291            .send()
292            .await?;
293
294        if !response.status().is_success() {
295            let status = response.status();
296            let error_text = response
297                .text()
298                .await
299                .unwrap_or_else(|_| "Unknown error".to_string());
300            return Err(anyhow::anyhow!(
301                "Failed to pull model {}: HTTP {} - {}",
302                model_name,
303                status,
304                error_text
305            ));
306        }
307
308        // Stream the response to show progress
309        let lines = response.text().await?;
310
311        // Ollama returns JSONL (JSON Lines) for streaming responses
312        for line in lines.lines() {
313            if line.trim().is_empty() {
314                continue;
315            }
316
317            match serde_json::from_str::<OllamaPullResponse>(line) {
318                Ok(pull_response) => match pull_response.status.as_str() {
319                    "downloading" => {
320                        if let (Some(completed), Some(total)) =
321                            (pull_response.completed, pull_response.total)
322                        {
323                            let progress = (completed as f64 / total as f64) * 100.0;
324                            info!(
325                                "  ๐Ÿ“Š Downloading: {:.1}% ({}/{})",
326                                progress, completed, total
327                            );
328                        }
329                    }
330                    "verifying sha256" => {
331                        info!("  ๐Ÿ” Verifying checksum...");
332                    }
333                    "success" => {
334                        info!("  โœ… Pull completed successfully");
335                        return Ok(());
336                    }
337                    status => {
338                        info!("  ๐Ÿ“ฆ Status: {}", status);
339                    }
340                },
341                Err(_) => {
342                    // Sometimes Ollama sends non-JSON status lines
343                    if line.contains("success") {
344                        info!("  โœ… Pull completed successfully");
345                        return Ok(());
346                    }
347                    info!("  ๐Ÿ“ฆ {}", line);
348                }
349            }
350        }
351
352        Ok(())
353    }
354
355    /// Test embedding generation with the selected model
356    async fn test_embedding_generation(&self, config: &Config) -> Result<()> {
357        info!("๐Ÿงช Testing embedding generation...");
358
359        let embedder = SimpleEmbedder::new_ollama(
360            config.embedding.base_url.clone(),
361            config.embedding.model.clone(),
362        );
363
364        let test_text = "This is a test sentence for embedding generation.";
365
366        match embedder.generate_embedding(test_text).await {
367            Ok(embedding) => {
368                info!("โœ… Embedding generation successful!");
369                info!("  ๐Ÿ“Š Embedding dimensions: {}", embedding.len());
370                info!(
371                    "  ๐Ÿ“Š Sample values: [{:.4}, {:.4}, {:.4}, ...]",
372                    embedding.first().unwrap_or(&0.0),
373                    embedding.get(1).unwrap_or(&0.0),
374                    embedding.get(2).unwrap_or(&0.0)
375                );
376                Ok(())
377            }
378            Err(e) => {
379                error!("โŒ Embedding generation failed: {}", e);
380                Err(e)
381            }
382        }
383    }
384
385    /// Setup database with required extensions and tables
386    async fn setup_database(&self) -> Result<()> {
387        info!("๐Ÿ—„๏ธ  Setting up database...");
388
389        // Parse the database URL
390        let db_config: PgConfig = self
391            .config
392            .database_url
393            .parse()
394            .context("Invalid database URL")?;
395
396        // Connect to the database
397        let (client, connection) = db_config
398            .connect(NoTls)
399            .await
400            .context("Failed to connect to database")?;
401
402        // Spawn the connection
403        tokio::spawn(async move {
404            if let Err(e) = connection.await {
405                error!("Database connection error: {}", e);
406            }
407        });
408
409        // Check if pgvector extension is available
410        info!("๐Ÿ” Checking for pgvector extension...");
411
412        let extension_check = client
413            .query(
414                "SELECT 1 FROM pg_available_extensions WHERE name = 'vector'",
415                &[],
416            )
417            .await?;
418
419        if extension_check.is_empty() {
420            warn!("โš ๏ธ  pgvector extension is not available in this PostgreSQL instance");
421            warn!("   Please install pgvector: https://github.com/pgvector/pgvector");
422            return Err(anyhow::anyhow!("pgvector extension not available"));
423        }
424
425        // Enable pgvector extension
426        info!("๐Ÿ”ง Enabling pgvector extension...");
427        client
428            .execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
429            .await
430            .context("Failed to enable pgvector extension")?;
431
432        // Check if our tables exist
433        info!("๐Ÿ” Checking database schema...");
434        let table_check = client
435            .query(
436                "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'memories'",
437                &[]
438            )
439            .await?;
440
441        if table_check.is_empty() {
442            info!("๐Ÿ“‹ Running database migrations...");
443            // Run migrations using the migration crate
444            // This would typically be done through the migration module
445            warn!("โš ๏ธ  Please run database migrations: cargo run --bin migration");
446        } else {
447            info!("โœ… Database schema is ready");
448        }
449
450        Ok(())
451    }
452
453    /// Run comprehensive health checks
454    pub async fn run_health_checks(&self, config: &Config) -> Result<()> {
455        info!("๐Ÿฉบ Running comprehensive health checks...");
456
457        let mut checks_passed = 0;
458        let mut total_checks = 0;
459
460        // Check 1: Ollama connectivity
461        total_checks += 1;
462        match self.check_ollama_connectivity().await {
463            Ok(_) => {
464                info!("  โœ… Ollama connectivity");
465                checks_passed += 1;
466            }
467            Err(e) => {
468                error!("  โŒ Ollama connectivity: {}", e);
469            }
470        }
471
472        // Check 2: Embedding model availability
473        total_checks += 1;
474        let embedder = SimpleEmbedder::new_ollama(
475            config.embedding.base_url.clone(),
476            config.embedding.model.clone(),
477        );
478
479        match embedder.generate_embedding("health check").await {
480            Ok(_) => {
481                info!("  โœ… Embedding generation");
482                checks_passed += 1;
483            }
484            Err(e) => {
485                error!("  โŒ Embedding generation: {}", e);
486            }
487        }
488
489        // Check 3: Database connectivity
490        total_checks += 1;
491        match self.check_database_connectivity().await {
492            Ok(_) => {
493                info!("  โœ… Database connectivity");
494                checks_passed += 1;
495            }
496            Err(e) => {
497                error!("  โŒ Database connectivity: {}", e);
498            }
499        }
500
501        // Check 4: pgvector extension
502        total_checks += 1;
503        match self.check_pgvector_extension().await {
504            Ok(_) => {
505                info!("  โœ… pgvector extension");
506                checks_passed += 1;
507            }
508            Err(e) => {
509                error!("  โŒ pgvector extension: {}", e);
510            }
511        }
512
513        // Summary
514        info!(
515            "๐Ÿ“Š Health check summary: {}/{} checks passed",
516            checks_passed, total_checks
517        );
518
519        if checks_passed == total_checks {
520            info!("๐ŸŽ‰ All health checks passed! System is ready.");
521            Ok(())
522        } else {
523            Err(anyhow::anyhow!(
524                "Some health checks failed. Please address the issues above."
525            ))
526        }
527    }
528
529    /// Check database connectivity
530    async fn check_database_connectivity(&self) -> Result<()> {
531        let db_config: PgConfig = self.config.database_url.parse()?;
532        let (client, connection) = db_config.connect(NoTls).await?;
533
534        tokio::spawn(async move {
535            if let Err(e) = connection.await {
536                error!("Database connection error: {}", e);
537            }
538        });
539
540        // Simple connectivity test
541        client.query("SELECT 1", &[]).await?;
542        Ok(())
543    }
544
545    /// Check pgvector extension
546    async fn check_pgvector_extension(&self) -> Result<()> {
547        let db_config: PgConfig = self.config.database_url.parse()?;
548        let (client, connection) = db_config.connect(NoTls).await?;
549
550        tokio::spawn(async move {
551            if let Err(e) = connection.await {
552                error!("Database connection error: {}", e);
553            }
554        });
555
556        // Check if pgvector is installed and functional
557        client
558            .query("SELECT vector_dims(vector '[1,2,3]')", &[])
559            .await
560            .context("pgvector extension not available or not working")?;
561
562        Ok(())
563    }
564
565    /// List available models for user selection
566    pub async fn list_available_models(&self) -> Result<()> {
567        info!("๐Ÿ“‹ Available embedding models:");
568
569        let available_models = self.detect_embedding_models().await?;
570
571        if available_models.is_empty() {
572            info!("  No embedding models currently available");
573            info!("  Recommended models to pull:");
574            info!("    ollama pull nomic-embed-text");
575            info!("    ollama pull mxbai-embed-large");
576            info!("    ollama pull all-minilm");
577        } else {
578            for model in available_models {
579                let icon = if model.preferred { "โญ" } else { "  " };
580                info!(
581                    "{} {} ({}D) - {}",
582                    icon, model.name, model.dimensions, model.description
583                );
584            }
585        }
586
587        Ok(())
588    }
589
590    /// Quick health check without setup
591    pub async fn quick_health_check(&self) -> Result<()> {
592        info!("๐Ÿฅ Running quick health check...");
593
594        // Check Ollama
595        match self.check_ollama_connectivity().await {
596            Ok(_) => info!("โœ… Ollama: Running"),
597            Err(_) => info!("โŒ Ollama: Not accessible"),
598        }
599
600        // Check database
601        match self.check_database_connectivity().await {
602            Ok(_) => info!("โœ… Database: Connected"),
603            Err(_) => info!("โŒ Database: Connection failed"),
604        }
605
606        // Check embedding model
607        let embedder = SimpleEmbedder::new_ollama(
608            self.config.embedding.base_url.clone(),
609            self.config.embedding.model.clone(),
610        );
611
612        match embedder.generate_embedding("test").await {
613            Ok(_) => info!("โœ… Embeddings: Working"),
614            Err(_) => info!("โŒ Embeddings: Failed"),
615        }
616
617        Ok(())
618    }
619}
620
621/// Create a sample .env file with default configuration
622pub fn create_sample_env_file() -> Result<()> {
623    let env_content = r#"# Agentic Memory System Configuration
624
625# Database Configuration
626DATABASE_URL=postgresql://postgres:postgres@localhost:5432/codex_memory
627
628# Embedding Configuration
629EMBEDDING_PROVIDER=ollama
630EMBEDDING_MODEL=nomic-embed-text
631EMBEDDING_BASE_URL=http://192.168.1.110:11434
632EMBEDDING_TIMEOUT_SECONDS=60
633
634# Server Configuration
635HTTP_PORT=8080
636LOG_LEVEL=info
637
638# Memory Tier Configuration
639WORKING_TIER_LIMIT=1000
640WARM_TIER_LIMIT=10000
641WORKING_TO_WARM_DAYS=7
642WARM_TO_COLD_DAYS=30
643IMPORTANCE_THRESHOLD=0.7
644
645# Operational Configuration
646MAX_DB_CONNECTIONS=10
647REQUEST_TIMEOUT_SECONDS=30
648ENABLE_METRICS=true
649"#;
650
651    std::fs::write(".env.example", env_content).context("Failed to create .env.example file")?;
652
653    info!("๐Ÿ“‹ Created .env.example file with default configuration");
654    info!("   Copy this to .env and modify as needed");
655
656    Ok(())
657}
658
659#[cfg(test)]
660mod tests {
661    use super::*;
662
663    #[test]
664    fn test_classify_embedding_model() {
665        let setup = SetupManager::new(Config::default());
666
667        // Known models
668        let nomic = setup.classify_embedding_model("nomic-embed-text").unwrap();
669        assert_eq!(nomic.dimensions, 768);
670        assert!(nomic.preferred);
671
672        let mxbai = setup.classify_embedding_model("mxbai-embed-large").unwrap();
673        assert_eq!(mxbai.dimensions, 1024);
674        assert!(mxbai.preferred);
675
676        // Unknown embedding model
677        let unknown = setup
678            .classify_embedding_model("custom-embed-model")
679            .unwrap();
680        assert_eq!(unknown.dimensions, 768); // Default
681        assert!(!unknown.preferred);
682
683        // Non-embedding model
684        let non_embed = setup.classify_embedding_model("llama2");
685        assert!(non_embed.is_none());
686    }
687
688    #[test]
689    fn test_known_models_classification() {
690        let setup = SetupManager::new(Config::default());
691
692        let test_cases = [
693            ("nomic-embed-text", true, 768),
694            ("all-minilm", false, 384),
695            ("bge-base-en", false, 768),
696            ("e5-large", false, 1024),
697        ];
698
699        for (model_name, expected_preferred, expected_dims) in test_cases {
700            let result = setup.classify_embedding_model(model_name);
701            assert!(
702                result.is_some(),
703                "Should classify {model_name} as embedding model"
704            );
705
706            let info = result.unwrap();
707            assert_eq!(info.preferred, expected_preferred);
708            assert_eq!(info.dimensions, expected_dims);
709        }
710    }
711}