Skip to main content

koi_proxy/
config.rs

1use serde::{Deserialize, Serialize};
2
3use koi_common::integration::CertmeshSnapshot;
4use koi_common::paths;
5
6use crate::ProxyError;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)]
9pub struct ProxyEntry {
10    pub name: String,
11    pub listen_port: u16,
12    pub backend: String,
13    #[serde(default)]
14    pub allow_remote: bool,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18struct ProxySection {
19    #[serde(default)]
20    entries: Vec<ProxyEntry>,
21}
22
23pub fn config_path() -> std::path::PathBuf {
24    paths::koi_data_dir().join("config.toml")
25}
26
27pub fn config_path_with_override(data_dir: Option<&std::path::Path>) -> std::path::PathBuf {
28    paths::koi_data_dir_with_override(data_dir).join("config.toml")
29}
30
31pub fn load_entries() -> Result<Vec<ProxyEntry>, ProxyError> {
32    load_entries_from(&config_path())
33}
34
35pub fn load_entries_with_data_dir(
36    data_dir: Option<&std::path::Path>,
37) -> Result<Vec<ProxyEntry>, ProxyError> {
38    load_entries_from(&config_path_with_override(data_dir))
39}
40
41/// Load entries from config file and merge with roster entries from certmesh.
42pub fn load_entries_with_certmesh(
43    certmesh: Option<&dyn CertmeshSnapshot>,
44) -> Result<Vec<ProxyEntry>, ProxyError> {
45    let mut entries = load_entries_from(&config_path())?;
46
47    if let Some(cm) = certmesh {
48        let hostname = hostname::get()
49            .map_err(|e| ProxyError::Io(e.to_string()))?
50            .to_string_lossy()
51            .to_string();
52
53        let members = cm.active_members();
54        if let Some(member) = members.iter().find(|m| m.hostname == hostname) {
55            let roster_entries: Vec<ProxyEntry> = member
56                .proxy_entries
57                .iter()
58                .map(|entry| ProxyEntry {
59                    name: entry.name.clone(),
60                    listen_port: entry.listen_port,
61                    backend: entry.backend.clone(),
62                    allow_remote: entry.allow_remote,
63                })
64                .collect();
65            merge_entries(&mut entries, roster_entries);
66        }
67    }
68
69    entries.sort_by(|a, b| a.name.cmp(&b.name));
70    Ok(entries)
71}
72
73fn load_entries_from(path: &std::path::Path) -> Result<Vec<ProxyEntry>, ProxyError> {
74    if !path.exists() {
75        return Ok(Vec::new());
76    }
77    let raw = std::fs::read_to_string(path).map_err(|e| ProxyError::Io(e.to_string()))?;
78    let value: toml::Value = raw
79        .parse()
80        .map_err(|e| ProxyError::Config(format!("Invalid config.toml: {e}")))?;
81    let proxy = value
82        .get("proxy")
83        .cloned()
84        .unwrap_or_else(|| toml::Value::Table(toml::map::Map::new()));
85    let proxy: ProxySection = proxy
86        .try_into()
87        .map_err(|e| ProxyError::Config(format!("Invalid proxy section: {e}")))?;
88    Ok(proxy.entries)
89}
90
91pub fn save_entries(entries: &[ProxyEntry]) -> Result<(), ProxyError> {
92    save_entries_to(entries, &config_path())
93}
94
95fn save_entries_to(entries: &[ProxyEntry], path: &std::path::Path) -> Result<(), ProxyError> {
96    if let Some(parent) = path.parent() {
97        std::fs::create_dir_all(parent).map_err(|e| ProxyError::Io(e.to_string()))?;
98    }
99
100    let mut root = if path.exists() {
101        let raw = std::fs::read_to_string(path).map_err(|e| ProxyError::Io(e.to_string()))?;
102        raw.parse::<toml::Value>()
103            .unwrap_or_else(|_| toml::Value::Table(toml::map::Map::new()))
104    } else {
105        toml::Value::Table(toml::map::Map::new())
106    };
107
108    let proxy = ProxySection {
109        entries: entries.to_vec(),
110    };
111    let proxy_value = toml::Value::try_from(proxy)
112        .map_err(|e| ProxyError::Config(format!("Proxy config serialize error: {e}")))?;
113
114    if let toml::Value::Table(table) = &mut root {
115        table.insert("proxy".to_string(), proxy_value);
116    }
117
118    let raw = toml::to_string_pretty(&root)
119        .map_err(|e| ProxyError::Config(format!("Config serialize error: {e}")))?;
120    std::fs::write(path, raw).map_err(|e| ProxyError::Io(e.to_string()))?;
121    Ok(())
122}
123
124pub fn upsert_entry(entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
125    upsert_entry_with_data_dir(entry, None)
126}
127
128pub fn upsert_entry_with_data_dir(
129    entry: ProxyEntry,
130    data_dir: Option<&std::path::Path>,
131) -> Result<Vec<ProxyEntry>, ProxyError> {
132    let path = config_path_with_override(data_dir);
133    let mut entries = load_entries_from(&path)?;
134    if let Some(existing) = entries.iter_mut().find(|e| e.name == entry.name) {
135        *existing = entry;
136    } else {
137        entries.push(entry);
138    }
139    entries.sort_by(|a, b| a.name.cmp(&b.name));
140    save_entries_to(&entries, &path)?;
141    Ok(entries)
142}
143
144pub fn remove_entry(name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
145    remove_entry_with_data_dir(name, None)
146}
147
148pub fn remove_entry_with_data_dir(
149    name: &str,
150    data_dir: Option<&std::path::Path>,
151) -> Result<Vec<ProxyEntry>, ProxyError> {
152    let path = config_path_with_override(data_dir);
153    let mut entries = load_entries_from(&path)?;
154    let before = entries.len();
155    entries.retain(|e| e.name != name);
156    if entries.len() == before {
157        return Err(ProxyError::NotFound(name.to_string()));
158    }
159    save_entries_to(&entries, &path)?;
160    Ok(entries)
161}
162
163fn merge_entries(entries: &mut Vec<ProxyEntry>, roster_entries: Vec<ProxyEntry>) {
164    let mut map: std::collections::BTreeMap<String, ProxyEntry> = std::collections::BTreeMap::new();
165    for entry in roster_entries {
166        map.insert(entry.name.clone(), entry);
167    }
168    for entry in entries.drain(..) {
169        map.insert(entry.name.clone(), entry);
170    }
171    *entries = map.into_values().collect();
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn config_path_is_under_data_dir() {
180        let _ = koi_common::test::ensure_data_dir("koi-proxy-config-tests");
181        let path = config_path();
182        assert!(path.ends_with("config.toml"));
183    }
184
185    #[test]
186    fn proxy_entry_round_trip() {
187        let entry = ProxyEntry {
188            name: "grafana".to_string(),
189            listen_port: 443,
190            backend: "http://localhost:3000".to_string(),
191            allow_remote: false,
192        };
193        let proxy = ProxySection {
194            entries: vec![entry.clone()],
195        };
196        let value = toml::Value::try_from(proxy).unwrap();
197        let decoded: ProxySection = value.try_into().unwrap();
198        assert_eq!(decoded.entries[0], entry);
199    }
200}