1use anyhow::{Context, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use tokio_postgres::{Config as PgConfig, NoTls};
6use tracing::{error, info, warn};
7
8use crate::config::Config;
9use crate::embedding::SimpleEmbedder;
10
11#[derive(Debug, Clone)]
13pub struct EmbeddingModelInfo {
14 pub name: String,
15 pub dimensions: usize,
16 pub description: String,
17 pub preferred: bool,
18}
19
20pub struct SetupManager {
22 client: Client,
23 config: Config,
24}
25
26#[derive(Debug, Deserialize)]
28struct OllamaModel {
29 name: String,
30 #[allow(dead_code)]
31 size: u64,
32 #[serde(default)]
33 #[allow(dead_code)]
34 family: String,
35}
36
37#[derive(Debug, Deserialize)]
38struct OllamaModelsResponse {
39 models: Vec<OllamaModel>,
40}
41
42#[derive(Debug, Serialize)]
43struct OllamaPullRequest {
44 name: String,
45}
46
47#[derive(Debug, Deserialize)]
48struct OllamaPullResponse {
49 status: String,
50 #[serde(default)]
51 completed: Option<u64>,
52 #[serde(default)]
53 total: Option<u64>,
54}
55
56impl SetupManager {
57 pub fn new(config: Config) -> Self {
58 let client = Client::builder()
59 .timeout(Duration::from_secs(120))
60 .build()
61 .expect("Failed to create HTTP client");
62
63 Self { client, config }
64 }
65
66 pub async fn run_setup(&self) -> Result<()> {
68 info!("๐ Starting Agentic Memory System setup...");
69
70 self.check_ollama_connectivity().await?;
72
73 let available_models = self.detect_embedding_models().await?;
75 let selected_model = self.ensure_embedding_model(available_models).await?;
76
77 let mut updated_config = self.config.clone();
79 updated_config.embedding.model = selected_model.name.clone();
80
81 self.test_embedding_generation(&updated_config).await?;
83
84 self.setup_database().await?;
86
87 self.run_health_checks(&updated_config).await?;
89
90 info!("โ
Setup completed successfully!");
91 info!("Selected embedding model: {} ({}D)", selected_model.name, selected_model.dimensions);
92
93 Ok(())
94 }
95
96 async fn check_ollama_connectivity(&self) -> Result<()> {
98 info!("๐ Checking Ollama connectivity at {}", self.config.embedding.base_url);
99
100 let response = self
101 .client
102 .get(&format!("{}/api/tags", self.config.embedding.base_url))
103 .send()
104 .await
105 .context("Failed to connect to Ollama. Is it running and accessible?")?;
106
107 if !response.status().is_success() {
108 return Err(anyhow::anyhow!(
109 "Ollama returned error status: {}",
110 response.status()
111 ));
112 }
113
114 info!("โ
Ollama is running and accessible");
115 Ok(())
116 }
117
118 async fn detect_embedding_models(&self) -> Result<Vec<EmbeddingModelInfo>> {
120 info!("๐ Detecting available embedding models...");
121
122 let response = self
123 .client
124 .get(&format!("{}/api/tags", self.config.embedding.base_url))
125 .send()
126 .await?;
127
128 let models_response: OllamaModelsResponse = response.json().await?;
129
130 let mut embedding_models = Vec::new();
131
132 for model in models_response.models {
133 if let Some(model_info) = self.classify_embedding_model(&model.name) {
134 embedding_models.push(model_info);
135 }
136 }
137
138 if embedding_models.is_empty() {
139 warn!("No embedding models found on Ollama");
140 } else {
141 info!("Found {} embedding models:", embedding_models.len());
142 for model in &embedding_models {
143 info!(" - {} ({}D) {}",
144 model.name,
145 model.dimensions,
146 if model.preferred { "โญ RECOMMENDED" } else { "" }
147 );
148 }
149 }
150
151 Ok(embedding_models)
152 }
153
154 fn classify_embedding_model(&self, model_name: &str) -> Option<EmbeddingModelInfo> {
156 let name_lower = model_name.to_lowercase();
157
158 let known_models = [
160 ("nomic-embed-text", 768, "High-quality text embeddings", true),
161 ("mxbai-embed-large", 1024, "Large multilingual embeddings", true),
162 ("all-minilm", 384, "Compact sentence embeddings", false),
163 ("all-mpnet-base-v2", 768, "Sentence transformer embeddings", false),
164 ("bge-small-en", 384, "BGE small English embeddings", false),
165 ("bge-base-en", 768, "BGE base English embeddings", false),
166 ("bge-large-en", 1024, "BGE large English embeddings", false),
167 ("e5-small", 384, "E5 small embeddings", false),
168 ("e5-base", 768, "E5 base embeddings", false),
169 ("e5-large", 1024, "E5 large embeddings", false),
170 ];
171
172 for (pattern, dimensions, description, preferred) in known_models {
173 if name_lower.contains(pattern) || model_name.contains(pattern) {
174 return Some(EmbeddingModelInfo {
175 name: model_name.to_string(),
176 dimensions,
177 description: description.to_string(),
178 preferred,
179 });
180 }
181 }
182
183 if name_lower.contains("embed") ||
185 name_lower.contains("sentence") ||
186 name_lower.contains("vector") {
187 return Some(EmbeddingModelInfo {
188 name: model_name.to_string(),
189 dimensions: 768, description: "Detected embedding model".to_string(),
191 preferred: false,
192 });
193 }
194
195 None
196 }
197
198 async fn ensure_embedding_model(&self, available_models: Vec<EmbeddingModelInfo>) -> Result<EmbeddingModelInfo> {
200 info!("๐ฏ Selecting embedding model...");
201
202 if let Some(preferred) = available_models.iter().find(|m| m.preferred) {
204 info!("โ
Using preferred model: {}", preferred.name);
205 return Ok(preferred.clone());
206 }
207
208 if !available_models.is_empty() {
210 let selected = available_models[0].clone();
211 info!("โ
Using available model: {}", selected.name);
212 return Ok(selected);
213 }
214
215 info!("๐ฅ No embedding models found. Attempting to pull recommended models...");
217
218 let recommended_models = [
219 ("nomic-embed-text", 768, "High-quality text embeddings"),
220 ("mxbai-embed-large", 1024, "Large multilingual embeddings"),
221 ("all-minilm", 384, "Compact sentence embeddings"),
222 ];
223
224 for (model_name, dimensions, description) in recommended_models {
225 info!("๐ฅ Attempting to pull model: {}", model_name);
226
227 match self.pull_model(model_name).await {
228 Ok(_) => {
229 info!("โ
Successfully pulled model: {}", model_name);
230 return Ok(EmbeddingModelInfo {
231 name: model_name.to_string(),
232 dimensions,
233 description: description.to_string(),
234 preferred: true,
235 });
236 }
237 Err(e) => {
238 warn!("Failed to pull model {}: {}", model_name, e);
239 continue;
240 }
241 }
242 }
243
244 Err(anyhow::anyhow!(
245 "Failed to find or pull any suitable embedding models. Please manually pull an embedding model using 'ollama pull nomic-embed-text'"
246 ))
247 }
248
249 async fn pull_model(&self, model_name: &str) -> Result<()> {
251 info!("๐ฅ Pulling model: {}", model_name);
252
253 let request = OllamaPullRequest {
254 name: model_name.to_string(),
255 };
256
257 let response = self
258 .client
259 .post(&format!("{}/api/pull", self.config.embedding.base_url))
260 .json(&request)
261 .send()
262 .await?;
263
264 if !response.status().is_success() {
265 let status = response.status();
266 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
267 return Err(anyhow::anyhow!(
268 "Failed to pull model {}: HTTP {} - {}",
269 model_name,
270 status,
271 error_text
272 ));
273 }
274
275 let lines = response.text().await?;
277
278 for line in lines.lines() {
280 if line.trim().is_empty() {
281 continue;
282 }
283
284 match serde_json::from_str::<OllamaPullResponse>(line) {
285 Ok(pull_response) => {
286 match pull_response.status.as_str() {
287 "downloading" => {
288 if let (Some(completed), Some(total)) = (pull_response.completed, pull_response.total) {
289 let progress = (completed as f64 / total as f64) * 100.0;
290 info!(" ๐ Downloading: {:.1}% ({}/{})", progress, completed, total);
291 }
292 }
293 "verifying sha256" => {
294 info!(" ๐ Verifying checksum...");
295 }
296 "success" => {
297 info!(" โ
Pull completed successfully");
298 return Ok(());
299 }
300 status => {
301 info!(" ๐ฆ Status: {}", status);
302 }
303 }
304 }
305 Err(_) => {
306 if line.contains("success") {
308 info!(" โ
Pull completed successfully");
309 return Ok(());
310 }
311 info!(" ๐ฆ {}", line);
312 }
313 }
314 }
315
316 Ok(())
317 }
318
319 async fn test_embedding_generation(&self, config: &Config) -> Result<()> {
321 info!("๐งช Testing embedding generation...");
322
323 let embedder = SimpleEmbedder::new_ollama(
324 config.embedding.base_url.clone(),
325 config.embedding.model.clone(),
326 );
327
328 let test_text = "This is a test sentence for embedding generation.";
329
330 match embedder.generate_embedding(test_text).await {
331 Ok(embedding) => {
332 info!("โ
Embedding generation successful!");
333 info!(" ๐ Embedding dimensions: {}", embedding.len());
334 info!(" ๐ Sample values: [{:.4}, {:.4}, {:.4}, ...]",
335 embedding.get(0).unwrap_or(&0.0),
336 embedding.get(1).unwrap_or(&0.0),
337 embedding.get(2).unwrap_or(&0.0)
338 );
339 Ok(())
340 }
341 Err(e) => {
342 error!("โ Embedding generation failed: {}", e);
343 Err(e)
344 }
345 }
346 }
347
348 async fn setup_database(&self) -> Result<()> {
350 info!("๐๏ธ Setting up database...");
351
352 let db_config: PgConfig = self.config.database_url.parse()
354 .context("Invalid database URL")?;
355
356 let (client, connection) = db_config.connect(NoTls).await
358 .context("Failed to connect to database")?;
359
360 tokio::spawn(async move {
362 if let Err(e) = connection.await {
363 error!("Database connection error: {}", e);
364 }
365 });
366
367 info!("๐ Checking for pgvector extension...");
369
370 let extension_check = client
371 .query("SELECT 1 FROM pg_available_extensions WHERE name = 'vector'", &[])
372 .await?;
373
374 if extension_check.is_empty() {
375 warn!("โ ๏ธ pgvector extension is not available in this PostgreSQL instance");
376 warn!(" Please install pgvector: https://github.com/pgvector/pgvector");
377 return Err(anyhow::anyhow!("pgvector extension not available"));
378 }
379
380 info!("๐ง Enabling pgvector extension...");
382 client
383 .execute("CREATE EXTENSION IF NOT EXISTS vector", &[])
384 .await
385 .context("Failed to enable pgvector extension")?;
386
387 info!("๐ Checking database schema...");
389 let table_check = client
390 .query(
391 "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name = 'memories'",
392 &[]
393 )
394 .await?;
395
396 if table_check.is_empty() {
397 info!("๐ Running database migrations...");
398 warn!("โ ๏ธ Please run database migrations: cargo run --bin migration");
401 } else {
402 info!("โ
Database schema is ready");
403 }
404
405 Ok(())
406 }
407
408 pub async fn run_health_checks(&self, config: &Config) -> Result<()> {
410 info!("๐ฉบ Running comprehensive health checks...");
411
412 let mut checks_passed = 0;
413 let mut total_checks = 0;
414
415 total_checks += 1;
417 match self.check_ollama_connectivity().await {
418 Ok(_) => {
419 info!(" โ
Ollama connectivity");
420 checks_passed += 1;
421 }
422 Err(e) => {
423 error!(" โ Ollama connectivity: {}", e);
424 }
425 }
426
427 total_checks += 1;
429 let embedder = SimpleEmbedder::new_ollama(
430 config.embedding.base_url.clone(),
431 config.embedding.model.clone(),
432 );
433
434 match embedder.generate_embedding("health check").await {
435 Ok(_) => {
436 info!(" โ
Embedding generation");
437 checks_passed += 1;
438 }
439 Err(e) => {
440 error!(" โ Embedding generation: {}", e);
441 }
442 }
443
444 total_checks += 1;
446 match self.check_database_connectivity().await {
447 Ok(_) => {
448 info!(" โ
Database connectivity");
449 checks_passed += 1;
450 }
451 Err(e) => {
452 error!(" โ Database connectivity: {}", e);
453 }
454 }
455
456 total_checks += 1;
458 match self.check_pgvector_extension().await {
459 Ok(_) => {
460 info!(" โ
pgvector extension");
461 checks_passed += 1;
462 }
463 Err(e) => {
464 error!(" โ pgvector extension: {}", e);
465 }
466 }
467
468 info!("๐ Health check summary: {}/{} checks passed", checks_passed, total_checks);
470
471 if checks_passed == total_checks {
472 info!("๐ All health checks passed! System is ready.");
473 Ok(())
474 } else {
475 Err(anyhow::anyhow!(
476 "Some health checks failed. Please address the issues above."
477 ))
478 }
479 }
480
481 async fn check_database_connectivity(&self) -> Result<()> {
483 let db_config: PgConfig = self.config.database_url.parse()?;
484 let (client, connection) = db_config.connect(NoTls).await?;
485
486 tokio::spawn(async move {
487 if let Err(e) = connection.await {
488 error!("Database connection error: {}", e);
489 }
490 });
491
492 client.query("SELECT 1", &[]).await?;
494 Ok(())
495 }
496
497 async fn check_pgvector_extension(&self) -> Result<()> {
499 let db_config: PgConfig = self.config.database_url.parse()?;
500 let (client, connection) = db_config.connect(NoTls).await?;
501
502 tokio::spawn(async move {
503 if let Err(e) = connection.await {
504 error!("Database connection error: {}", e);
505 }
506 });
507
508 client
510 .query("SELECT vector_dims(vector '[1,2,3]')", &[])
511 .await
512 .context("pgvector extension not available or not working")?;
513
514 Ok(())
515 }
516
517 pub async fn list_available_models(&self) -> Result<()> {
519 info!("๐ Available embedding models:");
520
521 let available_models = self.detect_embedding_models().await?;
522
523 if available_models.is_empty() {
524 info!(" No embedding models currently available");
525 info!(" Recommended models to pull:");
526 info!(" ollama pull nomic-embed-text");
527 info!(" ollama pull mxbai-embed-large");
528 info!(" ollama pull all-minilm");
529 } else {
530 for model in available_models {
531 let icon = if model.preferred { "โญ" } else { " " };
532 info!("{} {} ({}D) - {}", icon, model.name, model.dimensions, model.description);
533 }
534 }
535
536 Ok(())
537 }
538
539 pub async fn quick_health_check(&self) -> Result<()> {
541 info!("๐ฅ Running quick health check...");
542
543 match self.check_ollama_connectivity().await {
545 Ok(_) => info!("โ
Ollama: Running"),
546 Err(_) => info!("โ Ollama: Not accessible"),
547 }
548
549 match self.check_database_connectivity().await {
551 Ok(_) => info!("โ
Database: Connected"),
552 Err(_) => info!("โ Database: Connection failed"),
553 }
554
555 let embedder = SimpleEmbedder::new_ollama(
557 self.config.embedding.base_url.clone(),
558 self.config.embedding.model.clone(),
559 );
560
561 match embedder.generate_embedding("test").await {
562 Ok(_) => info!("โ
Embeddings: Working"),
563 Err(_) => info!("โ Embeddings: Failed"),
564 }
565
566 Ok(())
567 }
568}
569
570pub fn create_sample_env_file() -> Result<()> {
572 let env_content = r#"# Agentic Memory System Configuration
573
574# Database Configuration
575DATABASE_URL=postgresql://postgres:postgres@localhost:5432/codex_memory
576
577# Embedding Configuration
578EMBEDDING_PROVIDER=ollama
579EMBEDDING_MODEL=nomic-embed-text
580EMBEDDING_BASE_URL=http://192.168.1.110:11434
581EMBEDDING_TIMEOUT_SECONDS=60
582
583# Server Configuration
584HTTP_PORT=8080
585LOG_LEVEL=info
586
587# Memory Tier Configuration
588WORKING_TIER_LIMIT=1000
589WARM_TIER_LIMIT=10000
590WORKING_TO_WARM_DAYS=7
591WARM_TO_COLD_DAYS=30
592IMPORTANCE_THRESHOLD=0.7
593
594# Operational Configuration
595MAX_DB_CONNECTIONS=10
596REQUEST_TIMEOUT_SECONDS=30
597ENABLE_METRICS=true
598"#;
599
600 std::fs::write(".env.example", env_content)
601 .context("Failed to create .env.example file")?;
602
603 info!("๐ Created .env.example file with default configuration");
604 info!(" Copy this to .env and modify as needed");
605
606 Ok(())
607}
608
609#[cfg(test)]
610mod tests {
611 use super::*;
612
613 #[test]
614 fn test_classify_embedding_model() {
615 let setup = SetupManager::new(Config::default());
616
617 let nomic = setup.classify_embedding_model("nomic-embed-text").unwrap();
619 assert_eq!(nomic.dimensions, 768);
620 assert!(nomic.preferred);
621
622 let mxbai = setup.classify_embedding_model("mxbai-embed-large").unwrap();
623 assert_eq!(mxbai.dimensions, 1024);
624 assert!(mxbai.preferred);
625
626 let unknown = setup.classify_embedding_model("custom-embed-model").unwrap();
628 assert_eq!(unknown.dimensions, 768); assert!(!unknown.preferred);
630
631 let non_embed = setup.classify_embedding_model("llama2");
633 assert!(non_embed.is_none());
634 }
635
636 #[test]
637 fn test_known_models_classification() {
638 let setup = SetupManager::new(Config::default());
639
640 let test_cases = [
641 ("nomic-embed-text", true, 768),
642 ("all-minilm", false, 384),
643 ("bge-base-en", false, 768),
644 ("e5-large", false, 1024),
645 ];
646
647 for (model_name, expected_preferred, expected_dims) in test_cases {
648 let result = setup.classify_embedding_model(model_name);
649 assert!(result.is_some(), "Should classify {} as embedding model", model_name);
650
651 let info = result.unwrap();
652 assert_eq!(info.preferred, expected_preferred);
653 assert_eq!(info.dimensions, expected_dims);
654 }
655 }
656}