1use std::io;
2use std::path::PathBuf;
3
4pub const VALID_MODELS: &[&str] = &["embedding-gemma-300m"];
5
6#[derive(Debug, thiserror::Error)]
7pub enum ConfigError {
8 #[error("IO error: {0}")]
9 Io(#[from] io::Error),
10 #[error("JSON error: {0}")]
11 Json(#[from] serde_json::Error),
12 #[error("unknown model: {0}")]
13 UnknownModel(String),
14}
15
16#[derive(Debug, serde::Serialize, serde::Deserialize)]
17pub struct PapersConfig {
18 pub embedding_model: String,
19}
20
21impl Default for PapersConfig {
22 fn default() -> Self {
23 Self {
24 embedding_model: "embedding-gemma-300m".to_string(),
25 }
26 }
27}
28
29impl PapersConfig {
30 pub fn config_path() -> PathBuf {
32 dirs::config_dir()
33 .unwrap_or_else(|| PathBuf::from("."))
34 .join(".papers")
35 .join("config.json")
36 }
37
38 pub fn load() -> Result<Self, ConfigError> {
40 let path = Self::config_path();
41 if !path.exists() {
42 return Ok(Self::default());
43 }
44 let bytes = std::fs::read(&path)?;
45 let cfg = serde_json::from_slice(&bytes)?;
46 Ok(cfg)
47 }
48
49 pub fn save(&self) -> Result<(), ConfigError> {
51 let path = Self::config_path();
52 if let Some(parent) = path.parent() {
53 std::fs::create_dir_all(parent)?;
54 }
55 let json = serde_json::to_vec_pretty(self)?;
56 std::fs::write(&path, json)?;
57 Ok(())
58 }
59
60 pub fn validate_model(name: &str) -> Result<(), ConfigError> {
62 if VALID_MODELS.contains(&name) {
63 Ok(())
64 } else {
65 Err(ConfigError::UnknownModel(name.to_string()))
66 }
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 use tempfile::TempDir;
74
75 #[test]
76 fn test_config_default() {
77 let cfg = PapersConfig::default();
78 assert_eq!(cfg.embedding_model, "embedding-gemma-300m");
79 }
80
81 #[test]
82 fn test_config_roundtrip() {
83 let dir = TempDir::new().unwrap();
84 let path = dir.path().join("config.json");
85 let cfg = PapersConfig {
86 embedding_model: "embedding-gemma-300m".to_string(),
87 };
88 let json = serde_json::to_vec_pretty(&cfg).unwrap();
89 std::fs::write(&path, &json).unwrap();
90 let bytes = std::fs::read(&path).unwrap();
91 let loaded: PapersConfig = serde_json::from_slice(&bytes).unwrap();
92 assert_eq!(loaded.embedding_model, cfg.embedding_model);
93 }
94
95 #[test]
96 fn test_config_missing_file_returns_default() {
97 let dir = TempDir::new().unwrap();
99 let path = dir.path().join("nonexistent.json");
100 assert!(!path.exists());
101 let cfg = if path.exists() {
103 let bytes = std::fs::read(&path).unwrap();
104 serde_json::from_slice::<PapersConfig>(&bytes).unwrap()
105 } else {
106 PapersConfig::default()
107 };
108 assert_eq!(cfg.embedding_model, "embedding-gemma-300m");
109 }
110
111 #[test]
112 fn test_config_invalid_json() {
113 let dir = TempDir::new().unwrap();
114 let path = dir.path().join("config.json");
115 std::fs::write(&path, b"not valid json{{{").unwrap();
116 let bytes = std::fs::read(&path).unwrap();
117 let result: Result<PapersConfig, _> = serde_json::from_slice(&bytes);
118 assert!(result.is_err());
119 }
120
121 #[test]
122 fn test_valid_model_accepted() {
123 assert!(PapersConfig::validate_model("embedding-gemma-300m").is_ok());
124 }
125
126 #[test]
127 fn test_invalid_model_rejected() {
128 let err = PapersConfig::validate_model("gpt-4").unwrap_err();
129 assert!(matches!(err, ConfigError::UnknownModel(ref s) if s == "gpt-4"));
130 }
131
132 #[test]
133 fn test_config_path_is_platform_appropriate() {
134 let path = PapersConfig::config_path();
135 let s = path.to_string_lossy();
136 assert!(s.contains(".papers"));
137 assert!(s.ends_with("config.json"));
138 }
139}