use std::fs;
use std::path::PathBuf;
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Config {
pub models: ModelsConfig,
pub ollama: OllamaConfig,
pub index: IndexConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ModelsConfig {
pub embedding_model: String,
pub llm_model: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OllamaConfig {
pub url: String,
#[serde(default = "default_generate_timeout")]
pub generate_timeout_secs: u64,
#[serde(default = "default_embedding_timeout")]
pub embedding_timeout_secs: u64,
}
const fn default_generate_timeout() -> u64 {
120
}
const fn default_embedding_timeout() -> u64 {
60
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct IndexConfig {
pub embedding_dimension: Option<u32>,
pub last_embedding_model: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
models: ModelsConfig {
embedding_model: "nomic-embed-text".to_string(),
llm_model: "llama3.2:3b".to_string(),
},
ollama: OllamaConfig {
url: "http://localhost:11434".to_string(),
generate_timeout_secs: default_generate_timeout(),
embedding_timeout_secs: default_embedding_timeout(),
},
index: IndexConfig {
embedding_dimension: None,
last_embedding_model: None,
},
}
}
}
impl Config {
#[must_use]
pub fn embedding_model(&self) -> &str {
&self.models.embedding_model
}
#[must_use]
pub fn llm_model(&self) -> &str {
&self.models.llm_model
}
#[must_use]
pub fn ollama_url(&self) -> &str {
&self.ollama.url
}
#[must_use]
pub const fn generate_timeout_secs(&self) -> u64 {
self.ollama.generate_timeout_secs
}
#[must_use]
pub const fn embedding_timeout_secs(&self) -> u64 {
self.ollama.embedding_timeout_secs
}
pub fn update_index_metadata(&mut self, dimension: u32) {
self.index.embedding_dimension = Some(dimension);
self.index.last_embedding_model = Some(self.models.embedding_model.clone());
}
#[must_use]
pub fn needs_index_rebuild(&self) -> bool {
match &self.index.last_embedding_model {
Some(last_model) => last_model != &self.models.embedding_model,
None => false, }
}
#[must_use]
pub const fn index_dimension(&self) -> Option<u32> {
self.index.embedding_dimension
}
}
#[derive(Debug, Deserialize)]
#[allow(clippy::missing_docs_in_private_items)]
struct LegacyConfig {
model_name: String,
ollama_url: String,
}
pub fn get_config_path() -> Result<PathBuf> {
let dirs = directories::ProjectDirs::from("", "", "ulm")
.context("Could not determine config directory")?;
let config_dir = dirs.config_dir();
fs::create_dir_all(config_dir).with_context(|| {
format!(
"Failed to create config directory: {}",
config_dir.display()
)
})?;
Ok(config_dir.join("config.toml"))
}
pub fn load_config() -> Result<Config> {
let config_path = get_config_path()?;
if !config_path.exists() {
return Ok(Config::default());
}
let content = fs::read_to_string(&config_path)
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
if let Ok(config) = toml::from_str::<Config>(&content) {
return Ok(config);
}
if let Ok(legacy) = toml::from_str::<LegacyConfig>(&content) {
tracing::info!("Migrating legacy config to new multi-model format");
let config = Config {
models: ModelsConfig {
embedding_model: legacy.model_name.clone(),
llm_model: legacy.model_name,
},
ollama: OllamaConfig {
url: legacy.ollama_url,
generate_timeout_secs: default_generate_timeout(),
embedding_timeout_secs: default_embedding_timeout(),
},
index: IndexConfig {
embedding_dimension: None,
last_embedding_model: None,
},
};
save_config(&config)?;
tracing::info!("Config migrated successfully");
return Ok(config);
}
anyhow::bail!("Failed to parse config file: {}", config_path.display())
}
pub fn save_config(config: &Config) -> Result<()> {
let config_path = get_config_path()?;
let content = toml::to_string_pretty(config).context("Failed to serialize config")?;
fs::write(&config_path, &content)
.with_context(|| format!("Failed to write config file: {}", config_path.display()))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let permissions = fs::Permissions::from_mode(0o600);
fs::set_permissions(&config_path, permissions)
.with_context(|| format!("Failed to set permissions on: {}", config_path.display()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.models.embedding_model, "nomic-embed-text");
assert_eq!(config.models.llm_model, "llama3.2:3b");
assert_eq!(config.ollama.url, "http://localhost:11434");
assert_eq!(config.ollama.generate_timeout_secs, 120);
assert_eq!(config.ollama.embedding_timeout_secs, 60);
assert_eq!(config.index.embedding_dimension, None);
}
#[test]
fn test_config_serialization() {
let config = Config {
models: ModelsConfig {
embedding_model: "nomic-embed-text".to_string(),
llm_model: "mistral:7b".to_string(),
},
ollama: OllamaConfig {
url: "http://localhost:11434".to_string(),
generate_timeout_secs: 180,
embedding_timeout_secs: 90,
},
index: IndexConfig {
embedding_dimension: Some(768),
last_embedding_model: Some("nomic-embed-text".to_string()),
},
};
let toml_str = toml::to_string(&config).unwrap();
assert!(toml_str.contains("embedding_model = \"nomic-embed-text\""));
assert!(toml_str.contains("llm_model = \"mistral:7b\""));
assert!(toml_str.contains("embedding_dimension = 768"));
assert!(toml_str.contains("generate_timeout_secs = 180"));
assert!(toml_str.contains("embedding_timeout_secs = 90"));
}
#[test]
fn test_config_deserialization() {
let toml_str = r#"
[models]
embedding_model = "mxbai-embed-large"
llm_model = "llama3.1:8b"
[ollama]
url = "http://localhost:11434"
generate_timeout_secs = 300
embedding_timeout_secs = 120
[index]
embedding_dimension = 1024
last_embedding_model = "mxbai-embed-large"
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(config.models.embedding_model, "mxbai-embed-large");
assert_eq!(config.models.llm_model, "llama3.1:8b");
assert_eq!(config.index.embedding_dimension, Some(1024));
assert_eq!(config.ollama.generate_timeout_secs, 300);
assert_eq!(config.ollama.embedding_timeout_secs, 120);
}
#[test]
fn test_config_deserialization_with_defaults() {
let toml_str = r#"
[models]
embedding_model = "nomic-embed-text"
llm_model = "llama3"
[ollama]
url = "http://localhost:11434"
[index]
"#;
let config: Config = toml::from_str(toml_str).unwrap();
assert_eq!(config.ollama.generate_timeout_secs, 120);
assert_eq!(config.ollama.embedding_timeout_secs, 60);
}
#[test]
fn test_legacy_config_migration() {
let legacy_toml = r#"
model_name = "llama3.1:8b"
ollama_url = "http://localhost:11434"
"#;
let legacy: LegacyConfig = toml::from_str(legacy_toml).unwrap();
let config = Config {
models: ModelsConfig {
embedding_model: legacy.model_name.clone(),
llm_model: legacy.model_name,
},
ollama: OllamaConfig {
url: legacy.ollama_url,
generate_timeout_secs: default_generate_timeout(),
embedding_timeout_secs: default_embedding_timeout(),
},
index: IndexConfig {
embedding_dimension: None,
last_embedding_model: None,
},
};
assert_eq!(config.models.embedding_model, "llama3.1:8b");
assert_eq!(config.models.llm_model, "llama3.1:8b");
}
#[test]
fn test_config_roundtrip() {
let original = Config {
models: ModelsConfig {
embedding_model: "nomic-embed-text".to_string(),
llm_model: "phi3:mini".to_string(),
},
ollama: OllamaConfig {
url: "http://127.0.0.1:11434".to_string(),
generate_timeout_secs: 240,
embedding_timeout_secs: 90,
},
index: IndexConfig {
embedding_dimension: Some(768),
last_embedding_model: Some("nomic-embed-text".to_string()),
},
};
let toml_str = toml::to_string(&original).unwrap();
let restored: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(original, restored);
}
#[test]
fn test_get_config_path_returns_toml_file() {
let result = get_config_path();
assert!(result.is_ok());
let path = result.unwrap();
assert_eq!(path.file_name().unwrap(), "config.toml");
}
#[test]
fn test_save_and_load_config() {
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let config = Config {
models: ModelsConfig {
embedding_model: "test-embed".to_string(),
llm_model: "test-llm".to_string(),
},
ollama: OllamaConfig {
url: "http://test:11434".to_string(),
generate_timeout_secs: 200,
embedding_timeout_secs: 100,
},
index: IndexConfig {
embedding_dimension: Some(512),
last_embedding_model: Some("test-embed".to_string()),
},
};
let content = toml::to_string_pretty(&config).unwrap();
fs::write(&config_path, &content).unwrap();
let loaded_content = fs::read_to_string(&config_path).unwrap();
let loaded_config: Config = toml::from_str(&loaded_content).unwrap();
assert_eq!(config, loaded_config);
}
#[test]
fn test_needs_index_rebuild() {
let mut config = Config::default();
assert!(!config.needs_index_rebuild());
config.update_index_metadata(768);
assert!(!config.needs_index_rebuild());
config.models.embedding_model = "different-model".to_string();
assert!(config.needs_index_rebuild());
}
#[test]
fn test_update_index_metadata() {
let mut config = Config::default();
config.update_index_metadata(1024);
assert_eq!(config.index.embedding_dimension, Some(1024));
assert_eq!(
config.index.last_embedding_model,
Some("nomic-embed-text".to_string())
);
}
#[cfg(unix)]
#[test]
fn test_file_permissions() {
use std::os::unix::fs::PermissionsExt;
let temp_dir = TempDir::new().unwrap();
let config_path = temp_dir.path().join("config.toml");
let config = Config::default();
let content = toml::to_string_pretty(&config).unwrap();
fs::write(&config_path, &content).unwrap();
let permissions = fs::Permissions::from_mode(0o600);
fs::set_permissions(&config_path, permissions).unwrap();
let metadata = fs::metadata(&config_path).unwrap();
let mode = metadata.permissions().mode() & 0o777;
assert_eq!(mode, 0o600);
}
}