Skip to main content

papers_core/
config.rs

1use std::io;
2use std::path::PathBuf;
3
4pub const VALID_MODELS: &[&str] = &["embedding-gemma-300m"];
5
6#[derive(Debug, thiserror::Error)]
7pub enum ConfigError {
8    #[error("IO error: {0}")]
9    Io(#[from] io::Error),
10    #[error("JSON error: {0}")]
11    Json(#[from] serde_json::Error),
12    #[error("unknown model: {0}")]
13    UnknownModel(String),
14}
15
16#[derive(Debug, serde::Serialize, serde::Deserialize)]
17pub struct PapersConfig {
18    pub embedding_model: String,
19}
20
21impl Default for PapersConfig {
22    fn default() -> Self {
23        Self {
24            embedding_model: "embedding-gemma-300m".to_string(),
25        }
26    }
27}
28
29impl PapersConfig {
30    /// Returns `<config_dir>/.papers/config.json`.
31    pub fn config_path() -> PathBuf {
32        dirs::config_dir()
33            .unwrap_or_else(|| PathBuf::from("."))
34            .join(".papers")
35            .join("config.json")
36    }
37
38    /// Loads from disk. Returns `Default` if file missing. Errors on bad JSON or I/O failure.
39    pub fn load() -> Result<Self, ConfigError> {
40        let path = Self::config_path();
41        if !path.exists() {
42            return Ok(Self::default());
43        }
44        let bytes = std::fs::read(&path)?;
45        let cfg = serde_json::from_slice(&bytes)?;
46        Ok(cfg)
47    }
48
49    /// Writes to disk, creating parent directories as needed.
50    pub fn save(&self) -> Result<(), ConfigError> {
51        let path = Self::config_path();
52        if let Some(parent) = path.parent() {
53            std::fs::create_dir_all(parent)?;
54        }
55        let json = serde_json::to_vec_pretty(self)?;
56        std::fs::write(&path, json)?;
57        Ok(())
58    }
59
60    /// Returns `Err(ConfigError::UnknownModel)` if `name` is not in `VALID_MODELS`.
61    pub fn validate_model(name: &str) -> Result<(), ConfigError> {
62        if VALID_MODELS.contains(&name) {
63            Ok(())
64        } else {
65            Err(ConfigError::UnknownModel(name.to_string()))
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    use tempfile::TempDir;
74
75    #[test]
76    fn test_config_default() {
77        let cfg = PapersConfig::default();
78        assert_eq!(cfg.embedding_model, "embedding-gemma-300m");
79    }
80
81    #[test]
82    fn test_config_roundtrip() {
83        let dir = TempDir::new().unwrap();
84        let path = dir.path().join("config.json");
85        let cfg = PapersConfig {
86            embedding_model: "embedding-gemma-300m".to_string(),
87        };
88        let json = serde_json::to_vec_pretty(&cfg).unwrap();
89        std::fs::write(&path, &json).unwrap();
90        let bytes = std::fs::read(&path).unwrap();
91        let loaded: PapersConfig = serde_json::from_slice(&bytes).unwrap();
92        assert_eq!(loaded.embedding_model, cfg.embedding_model);
93    }
94
95    #[test]
96    fn test_config_missing_file_returns_default() {
97        // Test that a missing path leads to the default being used (simulating load behavior).
98        let dir = TempDir::new().unwrap();
99        let path = dir.path().join("nonexistent.json");
100        assert!(!path.exists());
101        // Simulate load logic: if path doesn't exist, return default.
102        let cfg = if path.exists() {
103            let bytes = std::fs::read(&path).unwrap();
104            serde_json::from_slice::<PapersConfig>(&bytes).unwrap()
105        } else {
106            PapersConfig::default()
107        };
108        assert_eq!(cfg.embedding_model, "embedding-gemma-300m");
109    }
110
111    #[test]
112    fn test_config_invalid_json() {
113        let dir = TempDir::new().unwrap();
114        let path = dir.path().join("config.json");
115        std::fs::write(&path, b"not valid json{{{").unwrap();
116        let bytes = std::fs::read(&path).unwrap();
117        let result: Result<PapersConfig, _> = serde_json::from_slice(&bytes);
118        assert!(result.is_err());
119    }
120
121    #[test]
122    fn test_valid_model_accepted() {
123        assert!(PapersConfig::validate_model("embedding-gemma-300m").is_ok());
124    }
125
126    #[test]
127    fn test_invalid_model_rejected() {
128        let err = PapersConfig::validate_model("gpt-4").unwrap_err();
129        assert!(matches!(err, ConfigError::UnknownModel(ref s) if s == "gpt-4"));
130    }
131
132    #[test]
133    fn test_config_path_is_platform_appropriate() {
134        let path = PapersConfig::config_path();
135        let s = path.to_string_lossy();
136        assert!(s.contains(".papers"));
137        assert!(s.ends_with("config.json"));
138    }
139}