1use anyhow::{Context, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use tokio_postgres::{Config as PgConfig, NoTls};
6use tracing::{error, info, warn};
7
8use crate::config::Config;
9use crate::embedding::SimpleEmbedder;
10
11#[derive(Debug, Clone)]
13pub struct EmbeddingModelInfo {
14 pub name: String,
15 pub dimensions: usize,
16 pub description: String,
17 pub preferred: bool,
18}
19
20pub struct SetupManager {
22 client: Client,
23 config: Config,
24}
25
26#[derive(Debug, Deserialize)]
28struct OllamaModel {
29 name: String,
30 #[allow(dead_code)]
31 size: u64,
32 #[serde(default)]
33 #[allow(dead_code)]
34 family: String,
35}
36
37#[derive(Debug, Deserialize)]
38struct OllamaModelsResponse {
39 models: Vec<OllamaModel>,
40}
41
42#[derive(Debug, Serialize)]
43struct OllamaPullRequest {
44 name: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OllamaPullResponse {
49 status: String,
50 #[serde(default)]
51 completed: Option<u64>,
52 #[serde(default)]
53 total: Option<u64>,
54}
55
56impl SetupManager {
57 pub fn new(config: Config) -> Self {
58 let client = Client::builder()
59 .timeout(Duration::from_secs(120))
60 .build()
61 .expect("Failed to create HTTP client");
62
63 Self { client, config }
64 }
65
66 pub async fn run_setup(&self) -> Result<()> {
68 info!("๐ Starting Agentic Memory System setup...");
69
70 self.check_ollama_connectivity().await?;
72
73 let available_models = self.detect_embedding_models().await?;
75 let selected_model = self.ensure_embedding_model(available_models).await?;
76
77 let mut updated_config = self.config.clone();
79 updated_config.embedding.model = selected_model.name.clone();
80
81 self.test_embedding_generation(&updated_config).await?;
83
84 self.setup_database().await?;
86
87 self.run_health_checks(&updated_config).await?;
89
90 info!("โ
Setup completed successfully!");
91 info!(
92 "Selected embedding model: {} ({}D)",
93 selected_model.name, selected_model.dimensions
94 );
95
96 Ok(())
97 }
98
99 async fn check_ollama_connectivity(&self) -> Result<()> {
101 info!(
102 "๐ Checking Ollama connectivity at {}",
103 self.config.embedding.base_url
104 );
105
106 let response = self
107 .client
108 .get(format!("{}/api/tags", self.config.embedding.base_url))
109 .send()
110 .await
111 .context("Failed to connect to Ollama. Is it running and accessible?")?;
112
113 if !response.status().is_success() {
114 return Err(anyhow::anyhow!(
115 "Ollama returned error status: {}",
116 response.status()
117 ));
118 }
119
120 info!("โ
Ollama is running and accessible");
121 Ok(())
122 }
123
124 async fn detect_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>> {
126 info!("๐ Detecting available embedding models...");
127
128 let response = self
129 .client
130 .get(format!("{}/api/tags", self.config.embedding.base_url))
131 .send()
132 .await?;
133
134 let models_response: OllamaModelsResponse = response.json().await?;
135
136 let mut embedding_models = Vec::new();
137
138 for model in models_response.models {
139 if let Some(model_info) = self.classify_embedding_model(&model.name) {
140 embedding_models.push(model_info);
141 }
142 }
143
144 if embedding_models.is_empty() {
145 warn!("No embedding models found on Ollama");
146 } else {
147 info!("Found {} embedding models:", embedding_models.len());
148 for model in &embedding_models {
149 info!(
150 " - {} ({}D) {}",
151 model.name,
152 model.dimensions,
153 if model.preferred {
154 "โญ RECOMMENDED"
155 } else {
156 ""
157 }
158 );
159 }
160 }
161
162 Ok(embedding_models)
163 }
164
165 fn classify_embedding_model(&self, model_name: &str) -> Option<EmbeddingModelInfo> {
167 let name_lower = model_name.to_lowercase();
168
169 let known_models = [
171 (
172 "nomic-embed-text",
173 768,
174 "High-quality text embeddings",
175 true,
176 ),
177 (
178 "mxbai-embed-large",
179 1024,
180 "Large multilingual embeddings",
181 true,
182 ),
183 ("all-minilm", 384, "Compact sentence embeddings", false),
184 (
185 "all-mpnet-base-v2",
186 768,
187 "Sentence transformer embeddings",
188 false,
189 ),
190 ("bge-small-en", 384, "BGE small English embeddings", false),
191 ("bge-base-en", 768, "BGE base English embeddings", false),
192 ("bge-large-en", 1024, "BGE large English embeddings", false),
193 ("e5-small", 384, "E5 small embeddings", false),
194 ("e5-base", 768, "E5 base embeddings", false),
195 ("e5-large", 1024, "E5 large embeddings", false),
196 ];
197
198 for (pattern, dimensions, description, preferred) in known_models {
199 if name_lower.contains(pattern) || model_name.contains(pattern) {
200 return Some(EmbeddingModelInfo {
201 name: model_name.to_string(),
202 dimensions,
203 description: description.to_string(),
204 preferred,
205 });
206 }
207 }
208
209 if name_lower.contains("embed")
211 || name_lower.contains("sentence")
212 || name_lower.contains("vector")
213 {
214 return Some(EmbeddingModelInfo {
215 name: model_name.to_string(),
216 dimensions: 768, description: "Detected embedding model".to_string(),
218 preferred: false,
219 });
220 }
221
222 None
223 }
224
225 async fn ensure_embedding_model(
227 &self,
228 available_models: Vec<EmbeddingModelInfo>,
229 ) -> Result<EmbeddingModelInfo> {
230 info!("๐ฏ Selecting embedding model...");
231
232 if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
234 info!("โ
Using preferred model: {}", preferred.name);
235 return Ok(preferred.clone());
236 }
237
238 if !available_models.is_empty() {
240 let selected = available_models[0].clone();
241 info!("โ
Using available model: {}", selected.name);
242 return Ok(selected);
243 }
244
245 info!("๐ฅ No embedding models found. Attempting to pull recommended models...");
247
248 let recommended_models = [
249 ("nomic-embed-text", 768, "High-quality text embeddings"),
250 ("mxbai-embed-large", 1024, "Large multilingual embeddings"),
251 ("all-minilm", 384, "Compact sentence embeddings"),
252 ];
253
254 for (model_name, dimensions, description) in recommended_models {
255 info!("๐ฅ Attempting to pull model: {}", model_name);
256
257 match self.pull_model(model_name).await {
258 Ok(_) => {
259 info!("โ
Successfully pulled model: {}", model_name);
260 return Ok(EmbeddingModelInfo {
261 name: model_name.to_string(),
262 dimensions,
263 description: description.to_string(),
264 preferred: true,
265 });
266 }
267 Err(e) => {
268 warn!("Failed to pull model {}: {}", model_name, e);
269 continue;
270 }
271 }
272 }
273
274 Err(anyhow::anyhow!(
275 "Failed to find or pull any suitable embedding models. Please manually pull an embedding model using 'ollama pull nomic-embed-text'"
276 ))
277 }
278
279 async fn pull_model(&self, model_name: &str) -> Result<()> {
281 info!("๐ฅ Pulling model: {}", model_name);
282
283 let request = OllamaPullRequest {
284 name: model_name.to_string(),
285 };
286
287 let response = self
288 .client
289 .post(format!("{}/api/pull", self.config.embedding.base_url))
290 .json(&request)
291 .send()
292 .await?;
293
294 if !response.status().is_success() {
295 let status = response.status();
296 let error_text = response
297 .text()
298 .await
299 .unwrap_or_else(|_| "Unknown error".to_string());
300 return Err(anyhow::anyhow!(
301 "Failed to pull model {}: HTTP {} - {}",
302 model_name,
303 status,
304 error_text
305 ));
306 }
307
308 let lines = response.text().await?;
310
311 for line in lines.lines() {
313 if line.trim().is_empty() {
314 continue;
315 }
316
317 match serde_json::from_str::<OllamaPullResponse>(line) {
318 Ok(pull_response) => match pull_response.status.as_str() {
319 "downloading" => {
320 if let (Some(completed), Some(total)) =
321 (pull_response.completed, pull_response.total)
322 {
323 let progress = (completed as f64 / total as f64) * 100.0;
324 info!(
325 " ๐ Downloading: {:.1}% ({}/{})",
326 progress, completed, total
327 );
328 }
329 }
330 "verifying sha256" => {
331 info!(" ๐ Verifying checksum...");
332 }
333 "success" => {
334 info!(" โ
Pull completed successfully");
335 return Ok(());
336 }
337 status => {
338 info!(" ๐ฆ Status: {}", status);
339 }
340 },
341 Err(_) => {
342 if line.contains("success") {
344 info!(" โ
Pull completed successfully");
345 return Ok(());
346 }
347 info!(" ๐ฆ {}", line);
348 }
349 }
350 }
351
352 Ok(())
353 }
354
355 async fn test_embedding_generation(&self, config: &Config) -> Result<()> {
357 info!("๐งช Testing embedding generation...");
358
359 let embedder = SimpleEmbedder::new_ollama(
360 config.embedding.base_url.clone(),
361 config.embedding.model.clone(),
362 );
363
364 let test_text = "This is a test sentence for embedding generation.";
365
366 match embedder.generate_embedding(test_text).await {
367 Ok(embedding) => {
368 info!("โ
Embedding generation successful!");
369 info!(" ๐ Embedding dimensions: {}", embedding.len());
370 info!(
371 " ๐ Sample values: [{:.4}, {:.4}, {:.4}, ...]",
372 embedding.first().unwrap_or(&0.0),
373 embedding.get(1).unwrap_or(&0.0),
374 embedding.get(2).unwrap_or(&0.0)
375 );
376 Ok(())
377 }
378 Err(e) => {
379 error!("โ Embedding generation failed: {}", e);
380 Err(e)
381 }
382 }
383 }
384
385 async fn setup_database(&self) -> Result<()> {
387 info!("๐๏ธ Setting up database...");
388
389 let db_config: PgConfig = self
391 .config
392 .database_url
393 .parse()
394 .context("Invalid database URL")?;
395
396 let (client, connection) = db_config
398 .connect(NoTls)
399 .await
400 .context("Failed to connect to database")?;
401
402 tokio::spawn(async move {
404 if let Err(e) = connection.await {
405 error!("Database connection error: {}", e);
406 }
407 });
408
409 info!("๐ Checking for pgvector extension...");
411
412 let extension_check = client
413 .query(
414 "SELECT 1 FROM pg_available_extensions WHERE name = 'vector'",
415 &[],
416 )
417 .await?;
418
419 if extension_check.is_empty() {
420 warn!("โ ๏ธ pgvector extension is not available in this PostgreSQL instance");
421 warn!(" Please install pgvector: https://github.com/pgvector/pgvector");
422 return Err(anyhow::anyhow!("pgvector extension not available"));
423 }
424
425 info!("๐ง Enabling pgvector extension...");
427 client
428 .execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
429 .await
430 .context("Failed to enable pgvector extension")?;
431
432 info!("๐ Checking database schema...");
434 let table_check = client
435 .query(
436 "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'memories'",
437 &[]
438 )
439 .await?;
440
441 if table_check.is_empty() {
442 info!("๐ Running database migrations...");
443 warn!("โ ๏ธ Please run database migrations: cargo run --bin migration");
446 } else {
447 info!("โ
Database schema is ready");
448 }
449
450 Ok(())
451 }
452
453 pub async fn run_health_checks(&self, config: &Config) -> Result<()> {
455 info!("๐ฉบ Running comprehensive health checks...");
456
457 let mut checks_passed = 0;
458 let mut total_checks = 0;
459
460 total_checks += 1;
462 match self.check_ollama_connectivity().await {
463 Ok(_) => {
464 info!(" โ
Ollama connectivity");
465 checks_passed += 1;
466 }
467 Err(e) => {
468 error!(" โ Ollama connectivity: {}", e);
469 }
470 }
471
472 total_checks += 1;
474 let embedder = SimpleEmbedder::new_ollama(
475 config.embedding.base_url.clone(),
476 config.embedding.model.clone(),
477 );
478
479 match embedder.generate_embedding("health check").await {
480 Ok(_) => {
481 info!(" โ
Embedding generation");
482 checks_passed += 1;
483 }
484 Err(e) => {
485 error!(" โ Embedding generation: {}", e);
486 }
487 }
488
489 total_checks += 1;
491 match self.check_database_connectivity().await {
492 Ok(_) => {
493 info!(" โ
Database connectivity");
494 checks_passed += 1;
495 }
496 Err(e) => {
497 error!(" โ Database connectivity: {}", e);
498 }
499 }
500
501 total_checks += 1;
503 match self.check_pgvector_extension().await {
504 Ok(_) => {
505 info!(" โ
pgvector extension");
506 checks_passed += 1;
507 }
508 Err(e) => {
509 error!(" โ pgvector extension: {}", e);
510 }
511 }
512
513 info!(
515 "๐ Health check summary: {}/{} checks passed",
516 checks_passed, total_checks
517 );
518
519 if checks_passed == total_checks {
520 info!("๐ All health checks passed! System is ready.");
521 Ok(())
522 } else {
523 Err(anyhow::anyhow!(
524 "Some health checks failed. Please address the issues above."
525 ))
526 }
527 }
528
529 async fn check_database_connectivity(&self) -> Result<()> {
531 let db_config: PgConfig = self.config.database_url.parse()?;
532 let (client, connection) = db_config.connect(NoTls).await?;
533
534 tokio::spawn(async move {
535 if let Err(e) = connection.await {
536 error!("Database connection error: {}", e);
537 }
538 });
539
540 client.query("SELECT 1", &[]).await?;
542 Ok(())
543 }
544
545 async fn check_pgvector_extension(&self) -> Result<()> {
547 let db_config: PgConfig = self.config.database_url.parse()?;
548 let (client, connection) = db_config.connect(NoTls).await?;
549
550 tokio::spawn(async move {
551 if let Err(e) = connection.await {
552 error!("Database connection error: {}", e);
553 }
554 });
555
556 client
558 .query("SELECT vector_dims(vector '[1,2,3]')", &[])
559 .await
560 .context("pgvector extension not available or not working")?;
561
562 Ok(())
563 }
564
565 pub async fn list_available_models(&self) -> Result<()> {
567 info!("๐ Available embedding models:");
568
569 let available_models = self.detect_embedding_models().await?;
570
571 if available_models.is_empty() {
572 info!(" No embedding models currently available");
573 info!(" Recommended models to pull:");
574 info!(" ollama pull nomic-embed-text");
575 info!(" ollama pull mxbai-embed-large");
576 info!(" ollama pull all-minilm");
577 } else {
578 for model in available_models {
579 let icon = if model.preferred { "โญ" } else { " " };
580 info!(
581 "{} {} ({}D) - {}",
582 icon, model.name, model.dimensions, model.description
583 );
584 }
585 }
586
587 Ok(())
588 }
589
590 pub async fn quick_health_check(&self) -> Result<()> {
592 info!("๐ฅ Running quick health check...");
593
594 match self.check_ollama_connectivity().await {
596 Ok(_) => info!("โ
Ollama: Running"),
597 Err(_) => info!("โ Ollama: Not accessible"),
598 }
599
600 match self.check_database_connectivity().await {
602 Ok(_) => info!("โ
Database: Connected"),
603 Err(_) => info!("โ Database: Connection failed"),
604 }
605
606 let embedder = SimpleEmbedder::new_ollama(
608 self.config.embedding.base_url.clone(),
609 self.config.embedding.model.clone(),
610 );
611
612 match embedder.generate_embedding("test").await {
613 Ok(_) => info!("โ
Embeddings: Working"),
614 Err(_) => info!("โ Embeddings: Failed"),
615 }
616
617 Ok(())
618 }
619}
620
621pub fn create_sample_env_file() -> Result<()> {
623 let env_content = r#"# Agentic Memory System Configuration
624
625# Database Configuration
626DATABASE_URL=postgresql://postgres:postgres@localhost:5432/codex_memory
627
628# Embedding Configuration
629EMBEDDING_PROVIDER=ollama
630EMBEDDING_MODEL=nomic-embed-text
631EMBEDDING_BASE_URL=http://192.168.1.110:11434
632EMBEDDING_TIMEOUT_SECONDS=60
633
634# Server Configuration
635HTTP_PORT=8080
636LOG_LEVEL=info
637
638# Memory Tier Configuration
639WORKING_TIER_LIMIT=1000
640WARM_TIER_LIMIT=10000
641WORKING_TO_WARM_DAYS=7
642WARM_TO_COLD_DAYS=30
643IMPORTANCE_THRESHOLD=0.7
644
645# Operational Configuration
646MAX_DB_CONNECTIONS=10
647REQUEST_TIMEOUT_SECONDS=30
648ENABLE_METRICS=true
649"#;
650
651 std::fs::write(".env.example", env_content).context("Failed to create .env.example file")?;
652
653 info!("๐ Created .env.example file with default configuration");
654 info!(" Copy this to .env and modify as needed");
655
656 Ok(())
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662
663 #[test]
664 fn test_classify_embedding_model() {
665 let setup = SetupManager::new(Config::default());
666
667 let nomic = setup.classify_embedding_model("nomic-embed-text").unwrap();
669 assert_eq!(nomic.dimensions, 768);
670 assert!(nomic.preferred);
671
672 let mxbai = setup.classify_embedding_model("mxbai-embed-large").unwrap();
673 assert_eq!(mxbai.dimensions, 1024);
674 assert!(mxbai.preferred);
675
676 let unknown = setup
678 .classify_embedding_model("custom-embed-model")
679 .unwrap();
680 assert_eq!(unknown.dimensions, 768); assert!(!unknown.preferred);
682
683 let non_embed = setup.classify_embedding_model("llama2");
685 assert!(non_embed.is_none());
686 }
687
688 #[test]
689 fn test_known_models_classification() {
690 let setup = SetupManager::new(Config::default());
691
692 let test_cases = [
693 ("nomic-embed-text", true, 768),
694 ("all-minilm", false, 384),
695 ("bge-base-en", false, 768),
696 ("e5-large", false, 1024),
697 ];
698
699 for (model_name, expected_preferred, expected_dims) in test_cases {
700 let result = setup.classify_embedding_model(model_name);
701 assert!(
702 result.is_some(),
703 "Should classify {model_name} as embedding model"
704 );
705
706 let info = result.unwrap();
707 assert_eq!(info.preferred, expected_preferred);
708 assert_eq!(info.dimensions, expected_dims);
709 }
710 }
711}