Skip to main content

koi_proxy/
config.rs

1use serde::{Deserialize, Serialize};
2
3use koi_certmesh::roster::ProxyConfigEntry;
4use koi_common::paths;
5
6use crate::ProxyError;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)]
9pub struct ProxyEntry {
10    pub name: String,
11    pub listen_port: u16,
12    pub backend: String,
13    #[serde(default)]
14    pub allow_remote: bool,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18struct ProxySection {
19    #[serde(default)]
20    entries: Vec<ProxyEntry>,
21}
22
23pub fn config_path() -> std::path::PathBuf {
24    paths::koi_data_dir().join("config.toml")
25}
26
27pub fn load_entries() -> Result<Vec<ProxyEntry>, ProxyError> {
28    let mut entries = load_entries_from_config()?;
29
30    match load_entries_from_roster() {
31        Ok(roster_entries) => {
32            merge_entries(&mut entries, roster_entries);
33        }
34        Err(e) => {
35            tracing::debug!(error = %e, "Failed to load proxy entries from roster");
36        }
37    }
38
39    entries.sort_by(|a, b| a.name.cmp(&b.name));
40    Ok(entries)
41}
42
43fn load_entries_from_config() -> Result<Vec<ProxyEntry>, ProxyError> {
44    let path = config_path();
45    if !path.exists() {
46        return Ok(Vec::new());
47    }
48    let raw = std::fs::read_to_string(&path).map_err(|e| ProxyError::Io(e.to_string()))?;
49    let value: toml::Value = raw
50        .parse()
51        .map_err(|e| ProxyError::Config(format!("Invalid config.toml: {e}")))?;
52    let proxy = value
53        .get("proxy")
54        .cloned()
55        .unwrap_or_else(|| toml::Value::Table(toml::map::Map::new()));
56    let proxy: ProxySection = proxy
57        .try_into()
58        .map_err(|e| ProxyError::Config(format!("Invalid proxy section: {e}")))?;
59    Ok(proxy.entries)
60}
61
62fn load_entries_from_roster() -> Result<Vec<ProxyEntry>, ProxyError> {
63    let roster_path = koi_certmesh::ca::roster_path();
64    if !roster_path.exists() {
65        return Ok(Vec::new());
66    }
67
68    let roster = koi_certmesh::roster::load_roster(&roster_path)
69        .map_err(|e| ProxyError::Io(e.to_string()))?;
70    let hostname = hostname::get()
71        .map_err(|e| ProxyError::Io(e.to_string()))?
72        .to_string_lossy()
73        .to_string();
74
75    let Some(member) = roster.find_member(&hostname) else {
76        return Ok(Vec::new());
77    };
78
79    Ok(member
80        .proxy_entries
81        .iter()
82        .map(|entry| ProxyEntry {
83            name: entry.name.clone(),
84            listen_port: entry.listen_port,
85            backend: entry.backend.clone(),
86            allow_remote: entry.allow_remote,
87        })
88        .collect())
89}
90
91pub fn save_entries(entries: &[ProxyEntry]) -> Result<(), ProxyError> {
92    let path = config_path();
93    if let Some(parent) = path.parent() {
94        std::fs::create_dir_all(parent).map_err(|e| ProxyError::Io(e.to_string()))?;
95    }
96
97    let mut root = if path.exists() {
98        let raw = std::fs::read_to_string(&path).map_err(|e| ProxyError::Io(e.to_string()))?;
99        raw.parse::<toml::Value>()
100            .unwrap_or_else(|_| toml::Value::Table(toml::map::Map::new()))
101    } else {
102        toml::Value::Table(toml::map::Map::new())
103    };
104
105    let proxy = ProxySection {
106        entries: entries.to_vec(),
107    };
108    let proxy_value = toml::Value::try_from(proxy)
109        .map_err(|e| ProxyError::Config(format!("Proxy config serialize error: {e}")))?;
110
111    if let toml::Value::Table(table) = &mut root {
112        table.insert("proxy".to_string(), proxy_value);
113    }
114
115    let raw = toml::to_string_pretty(&root)
116        .map_err(|e| ProxyError::Config(format!("Config serialize error: {e}")))?;
117    std::fs::write(&path, raw).map_err(|e| ProxyError::Io(e.to_string()))?;
118    Ok(())
119}
120
121pub fn upsert_entry(entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
122    let mut entries = load_entries_from_config()?;
123    if let Some(existing) = entries.iter_mut().find(|e| e.name == entry.name) {
124        *existing = entry;
125    } else {
126        entries.push(entry);
127    }
128    entries.sort_by(|a, b| a.name.cmp(&b.name));
129    save_entries(&entries)?;
130    sync_roster(&entries)?;
131    Ok(entries)
132}
133
134pub fn remove_entry(name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
135    let mut entries = load_entries_from_config()?;
136    let before = entries.len();
137    entries.retain(|e| e.name != name);
138    if entries.len() == before {
139        return Err(ProxyError::NotFound(name.to_string()));
140    }
141    save_entries(&entries)?;
142    sync_roster(&entries)?;
143    Ok(entries)
144}
145
146fn sync_roster(entries: &[ProxyEntry]) -> Result<(), ProxyError> {
147    let roster_path = koi_certmesh::ca::roster_path();
148    if !roster_path.exists() {
149        return Ok(());
150    }
151
152    let mut roster = koi_certmesh::roster::load_roster(&roster_path)
153        .map_err(|e| ProxyError::Io(e.to_string()))?;
154    let hostname = hostname::get()
155        .map_err(|e| ProxyError::Io(e.to_string()))?
156        .to_string_lossy()
157        .to_string();
158
159    let Some(member) = roster.find_member_mut(&hostname) else {
160        return Ok(());
161    };
162
163    member.proxy_entries = entries
164        .iter()
165        .map(|entry| ProxyConfigEntry {
166            name: entry.name.clone(),
167            listen_port: entry.listen_port,
168            backend: entry.backend.clone(),
169            allow_remote: entry.allow_remote,
170        })
171        .collect();
172
173    koi_certmesh::roster::save_roster(&roster, &roster_path)
174        .map_err(|e| ProxyError::Io(e.to_string()))?;
175    Ok(())
176}
177
178fn merge_entries(entries: &mut Vec<ProxyEntry>, roster_entries: Vec<ProxyEntry>) {
179    let mut map: std::collections::BTreeMap<String, ProxyEntry> = std::collections::BTreeMap::new();
180    for entry in roster_entries {
181        map.insert(entry.name.clone(), entry);
182    }
183    for entry in entries.drain(..) {
184        map.insert(entry.name.clone(), entry);
185    }
186    *entries = map.into_values().collect();
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn config_path_is_under_data_dir() {
195        let path = config_path();
196        assert!(path.ends_with("config.toml"));
197    }
198
199    #[test]
200    fn proxy_entry_round_trip() {
201        let entry = ProxyEntry {
202            name: "grafana".to_string(),
203            listen_port: 443,
204            backend: "http://localhost:3000".to_string(),
205            allow_remote: false,
206        };
207        let proxy = ProxySection {
208            entries: vec![entry.clone()],
209        };
210        let value = toml::Value::try_from(proxy).unwrap();
211        let decoded: ProxySection = value.try_into().unwrap();
212        assert_eq!(decoded.entries[0], entry);
213    }
214}