Skip to main content

sc/embeddings/
config.rs

1//! Embedding configuration management.
2//!
3//! Loads and saves embedding settings from `~/.savecontext/config.json`,
4//! maintaining compatibility with the TypeScript MCP server.
5
6use crate::error::{Error, Result};
7use std::fs;
8use std::path::PathBuf;
9
10use super::types::{EmbeddingSettings, SaveContextConfig};
11
12/// Get the config file path.
13fn config_path() -> Result<PathBuf> {
14    directories::BaseDirs::new()
15        .map(|b| b.home_dir().join(".savecontext").join("config.json"))
16        .ok_or(Error::Config("Could not determine home directory".into()))
17}
18
19/// Load the full SaveContext configuration.
20pub fn load_config() -> Result<SaveContextConfig> {
21    let path = config_path()?;
22
23    if !path.exists() {
24        return Ok(SaveContextConfig::default());
25    }
26
27    let content = fs::read_to_string(&path).map_err(|e| {
28        Error::Config(format!("Failed to read config file: {e}"))
29    })?;
30
31    serde_json::from_str(&content).map_err(|e| {
32        Error::Config(format!("Failed to parse config file: {e}"))
33    })
34}
35
36/// Save the full SaveContext configuration.
37pub fn save_config(config: &SaveContextConfig) -> Result<()> {
38    let path = config_path()?;
39
40    // Ensure directory exists
41    if let Some(parent) = path.parent() {
42        fs::create_dir_all(parent).map_err(|e| {
43            Error::Config(format!("Failed to create config directory: {e}"))
44        })?;
45    }
46
47    let content = serde_json::to_string_pretty(config).map_err(|e| {
48        Error::Config(format!("Failed to serialize config: {e}"))
49    })?;
50
51    fs::write(&path, content).map_err(|e| {
52        Error::Config(format!("Failed to write config file: {e}"))
53    })?;
54
55    Ok(())
56}
57
58/// Get embedding settings from config file.
59pub fn get_embedding_settings() -> Result<Option<EmbeddingSettings>> {
60    let config = load_config()?;
61    Ok(config.embeddings)
62}
63
64/// Save embedding settings (merges with existing config).
65pub fn save_embedding_settings(settings: &EmbeddingSettings) -> Result<()> {
66    let mut config = load_config()?;
67
68    // Merge with existing settings
69    let existing = config.embeddings.unwrap_or_default();
70    config.embeddings = Some(EmbeddingSettings {
71        enabled: settings.enabled.or(existing.enabled),
72        provider: settings.provider.or(existing.provider),
73        HF_TOKEN: settings.HF_TOKEN.clone().or(existing.HF_TOKEN),
74        HF_MODEL: settings.HF_MODEL.clone().or(existing.HF_MODEL),
75        HF_ENDPOINT: settings.HF_ENDPOINT.clone().or(existing.HF_ENDPOINT),
76        OLLAMA_ENDPOINT: settings.OLLAMA_ENDPOINT.clone().or(existing.OLLAMA_ENDPOINT),
77        OLLAMA_MODEL: settings.OLLAMA_MODEL.clone().or(existing.OLLAMA_MODEL),
78        TRANSFORMERS_MODEL: settings.TRANSFORMERS_MODEL.clone().or(existing.TRANSFORMERS_MODEL),
79    });
80
81    save_config(&config)
82}
83
84/// Reset embedding settings (removes from config).
85pub fn reset_embedding_settings() -> Result<()> {
86    let mut config = load_config()?;
87    config.embeddings = None;
88    save_config(&config)
89}
90
91/// Resolve Ollama endpoint from config or environment.
92pub fn resolve_ollama_endpoint() -> String {
93    // Priority: env var > config > default
94    if let Ok(endpoint) = std::env::var("OLLAMA_ENDPOINT") {
95        if !endpoint.is_empty() {
96            return endpoint;
97        }
98    }
99
100    if let Ok(Some(settings)) = get_embedding_settings() {
101        if let Some(endpoint) = settings.OLLAMA_ENDPOINT {
102            return endpoint;
103        }
104    }
105
106    "http://localhost:11434".to_string()
107}
108
109/// Resolve Ollama model from config or environment.
110pub fn resolve_ollama_model() -> String {
111    // Priority: env var > config > default
112    if let Ok(model) = std::env::var("OLLAMA_MODEL") {
113        if !model.is_empty() {
114            return model;
115        }
116    }
117
118    if let Ok(Some(settings)) = get_embedding_settings() {
119        if let Some(model) = settings.OLLAMA_MODEL {
120            return model;
121        }
122    }
123
124    "nomic-embed-text".to_string()
125}
126
127/// Resolve HuggingFace token from config or environment.
128pub fn resolve_hf_token() -> Option<String> {
129    // Priority: env var > config
130    if let Ok(token) = std::env::var("HF_TOKEN") {
131        if !token.is_empty() {
132            return Some(token);
133        }
134    }
135
136    if let Ok(Some(settings)) = get_embedding_settings() {
137        return settings.HF_TOKEN;
138    }
139
140    None
141}
142
143/// Resolve HuggingFace model from config or environment.
144pub fn resolve_hf_model() -> String {
145    // Priority: env var > config > default
146    if let Ok(model) = std::env::var("HF_MODEL") {
147        if !model.is_empty() {
148            return model;
149        }
150    }
151
152    if let Ok(Some(settings)) = get_embedding_settings() {
153        if let Some(model) = settings.HF_MODEL {
154            return model;
155        }
156    }
157
158    "sentence-transformers/all-MiniLM-L6-v2".to_string()
159}
160
161/// Resolve HuggingFace endpoint from config or environment.
162pub fn resolve_hf_endpoint() -> String {
163    // Priority: env var > config > default
164    if let Ok(endpoint) = std::env::var("HF_ENDPOINT") {
165        if !endpoint.is_empty() {
166            return endpoint;
167        }
168    }
169
170    if let Ok(Some(settings)) = get_embedding_settings() {
171        if let Some(endpoint) = settings.HF_ENDPOINT {
172            return endpoint;
173        }
174    }
175
176    "https://router.huggingface.co/hf-inference".to_string()
177}
178
179/// Check if embeddings are enabled.
180pub fn is_embeddings_enabled() -> bool {
181    // Check env var first
182    if let Ok(enabled) = std::env::var("SAVECONTEXT_EMBEDDINGS_ENABLED") {
183        return enabled != "false" && enabled != "0";
184    }
185
186    // Check config
187    if let Ok(Some(settings)) = get_embedding_settings() {
188        return settings.enabled.unwrap_or(true);
189    }
190
191    true // Enabled by default
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197
198    #[test]
199    fn test_default_ollama_endpoint() {
200        let endpoint = resolve_ollama_endpoint();
201        assert!(endpoint.contains("localhost:11434") || !endpoint.is_empty());
202    }
203
204    #[test]
205    fn test_default_ollama_model() {
206        let model = resolve_ollama_model();
207        assert!(!model.is_empty());
208    }
209
210    #[test]
211    fn test_embeddings_enabled_by_default() {
212        // Without any config, embeddings should be enabled
213        let enabled = is_embeddings_enabled();
214        assert!(enabled);
215    }
216}