next_plaid_cli/
config.rs

1//! User configuration persistence
2//!
3//! Stores user preferences (like default model) in the plaid data directory.
4
5use std::fs;
6use std::path::PathBuf;
7
8use anyhow::{Context, Result};
9use serde::{Deserialize, Serialize};
10
11use crate::index::paths::get_plaid_data_dir;
12
13const CONFIG_FILE: &str = "config.json";
14
15/// User configuration stored in the plaid data directory
16#[derive(Debug, Clone, Serialize, Deserialize, Default)]
17pub struct Config {
18    /// Default model to use (HuggingFace model ID or local path)
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub default_model: Option<String>,
21}
22
23impl Config {
24    /// Load config from the plaid data directory
25    /// Returns default config if file doesn't exist
26    pub fn load() -> Result<Self> {
27        let path = get_config_path()?;
28        if !path.exists() {
29            return Ok(Self::default());
30        }
31
32        let content = fs::read_to_string(&path)
33            .with_context(|| format!("Failed to read config from {}", path.display()))?;
34        let config: Config = serde_json::from_str(&content)
35            .with_context(|| format!("Failed to parse config from {}", path.display()))?;
36        Ok(config)
37    }
38
39    /// Save config to the plaid data directory
40    pub fn save(&self) -> Result<()> {
41        let path = get_config_path()?;
42
43        // Ensure parent directory exists
44        if let Some(parent) = path.parent() {
45            fs::create_dir_all(parent)?;
46        }
47
48        let content = serde_json::to_string_pretty(self)?;
49        fs::write(&path, content)?;
50        Ok(())
51    }
52
53    /// Get the default model, if set
54    pub fn get_default_model(&self) -> Option<&str> {
55        self.default_model.as_deref()
56    }
57
58    /// Set the default model
59    pub fn set_default_model(&mut self, model: impl Into<String>) {
60        self.default_model = Some(model.into());
61    }
62}
63
64/// Get the path to the config file
65pub fn get_config_path() -> Result<PathBuf> {
66    let data_dir = get_plaid_data_dir()?;
67    // Go up one level from indices directory
68    let parent = data_dir
69        .parent()
70        .context("Could not determine config directory")?;
71    Ok(parent.join(CONFIG_FILE))
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77
78    #[test]
79    fn test_config_default() {
80        let config = Config::default();
81        assert!(config.default_model.is_none());
82        assert!(config.get_default_model().is_none());
83    }
84
85    #[test]
86    fn test_config_set_default_model() {
87        let mut config = Config::default();
88        config.set_default_model("test-model");
89        assert_eq!(config.get_default_model(), Some("test-model"));
90    }
91
92    #[test]
93    fn test_config_set_default_model_string() {
94        let mut config = Config::default();
95        config.set_default_model(String::from("another-model"));
96        assert_eq!(config.get_default_model(), Some("another-model"));
97    }
98
99    #[test]
100    fn test_config_serialization() {
101        let mut config = Config::default();
102        config.set_default_model("lightonai/GTE-ModernColBERT-v1-onnx");
103
104        let json = serde_json::to_string(&config).unwrap();
105        assert!(json.contains("lightonai/GTE-ModernColBERT-v1-onnx"));
106
107        let deserialized: Config = serde_json::from_str(&json).unwrap();
108        assert_eq!(
109            deserialized.get_default_model(),
110            Some("lightonai/GTE-ModernColBERT-v1-onnx")
111        );
112    }
113
114    #[test]
115    fn test_config_serialization_empty() {
116        let config = Config::default();
117        let json = serde_json::to_string(&config).unwrap();
118        // Should not contain default_model key when None (skip_serializing_if)
119        assert!(!json.contains("default_model"));
120
121        let deserialized: Config = serde_json::from_str(&json).unwrap();
122        assert!(deserialized.get_default_model().is_none());
123    }
124
125    #[test]
126    fn test_config_deserialization_missing_field() {
127        // Config should deserialize even if default_model is missing
128        let json = "{}";
129        let config: Config = serde_json::from_str(json).unwrap();
130        assert!(config.get_default_model().is_none());
131    }
132
133    #[test]
134    fn test_config_deserialization_null_field() {
135        // Config should handle explicit null
136        let json = r#"{"default_model": null}"#;
137        let config: Config = serde_json::from_str(json).unwrap();
138        assert!(config.get_default_model().is_none());
139    }
140
141    #[test]
142    fn test_config_path_exists() {
143        // Just verify the function doesn't panic
144        let result = get_config_path();
145        assert!(result.is_ok());
146        let path = result.unwrap();
147        assert!(path.to_string_lossy().contains("config.json"));
148    }
149}