mastermind_cli/configs/
config.rs

1use dotenv::dotenv;
2use std::fs;
3use std::path::Path;
4use toml_edit::{value, DocumentMut, Item, Table};
5
6use crate::configs::config_error::ConfigError;
7
8pub struct Config {
9    document: DocumentMut,
10}
11
12impl Config {
13    pub fn new() -> Result<Self, ConfigError> {
14        dotenv().ok();
15
16        // Get the user's home directory
17        let Some(config_dir) = dirs::config_dir() else {
18            return Err(ConfigError::FileNotFound(
19                "No config directory found".to_string(),
20            ));
21        };
22
23        // Define the config folder
24        let mastermind_dir = config_dir.join("mastermind");
25        if !mastermind_dir.exists() {
26            match fs::create_dir_all(&mastermind_dir) {
27                Ok(()) => println!("Config directory created at {mastermind_dir:?}"),
28                Err(e) => {
29                    return Err(ConfigError::FileNotFound(format!(
30                        "Failed to create folder: {e}"
31                    )))
32                }
33            }
34        }
35
36        // Define config file path
37        let config_file = mastermind_dir.join("config.toml");
38
39        // Read or create a document
40        let document = match fs::read_to_string(&config_file) {
41            Ok(content) if !content.is_empty() => content.parse::<DocumentMut>()?,
42            _ => {
43                let mut doc = DocumentMut::new();
44
45                // Make .toml file in table-like format
46                doc["api"] = Item::Table(Table::new());
47                doc["api"]["base-url"] = value("");
48                doc["api"]["key"] = value("");
49
50                doc["model"] = Item::Table(Table::new());
51                doc["model"]["default"] = value("");
52
53                // Write the document to the config file
54                println!(
55                    "Looks like it's your first run\n\
56                          Creating a config file at {}\n\
57                          Make sure to modify it first or use the proper environment variables\n\
58                          See: https://github.com/theoforger/mastermind?tab=readme-ov-file#%EF%B8%8F-configure",
59                    config_file.display()
60                );
61                fs::write(&config_file, doc.to_string())?;
62
63                doc
64            }
65        };
66
67        Ok(Config { document })
68    }
69
70    pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), ConfigError> {
71        fs::write(&path, self.document.to_string())?;
72        Ok(())
73    }
74
75    pub fn get_base_url(&self) -> Option<&str> {
76        self.document["api"]["base-url"]
77            .as_str()
78            .filter(|s| !s.is_empty())
79    }
80
81    pub fn get_api_key(&self) -> Option<&str> {
82        self.document["api"]["key"]
83            .as_str()
84            .filter(|s| !s.is_empty())
85    }
86
87    pub fn get_default_model(&self) -> Option<&str> {
88        self.document["model"]["default"]
89            .as_str()
90            .filter(|s| !s.is_empty())
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use tempfile::tempdir;
98
99    #[tokio::test]
100    async fn test_new() {
101        // Create a temporary directory
102        let temp_dir = tempdir().unwrap();
103        let config_dir = temp_dir.path().join("mastermind");
104        assert!(!config_dir.exists());
105
106        // Override config home
107        std::env::set_var("XDG_CONFIG_HOME", temp_dir.path().to_str().unwrap());
108
109        // Create a config
110        let config_result = Config::new();
111        assert!(config_result.is_ok());
112        assert!(config_dir.exists());
113
114        // Check if config.toml exists
115        let config_file = config_dir.join("config.toml");
116        assert!(config_file.exists());
117
118        // Check the content
119        let content = fs::read_to_string(config_file).unwrap();
120        assert!(content.contains("[api]"));
121        assert!(content.contains("base-url"));
122        assert!(content.contains("key"));
123        assert!(content.contains("[model]"));
124        assert!(content.contains("default"));
125    }
126}