1use serde::{Deserialize, Serialize};
2
3use koi_certmesh::roster::ProxyConfigEntry;
4use koi_common::paths;
5
6use crate::ProxyError;
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, utoipa::ToSchema)]
9pub struct ProxyEntry {
10 pub name: String,
11 pub listen_port: u16,
12 pub backend: String,
13 #[serde(default)]
14 pub allow_remote: bool,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize, Default)]
18struct ProxySection {
19 #[serde(default)]
20 entries: Vec<ProxyEntry>,
21}
22
23pub fn config_path() -> std::path::PathBuf {
24 paths::koi_data_dir().join("config.toml")
25}
26
27pub fn load_entries() -> Result<Vec<ProxyEntry>, ProxyError> {
28 let mut entries = load_entries_from_config()?;
29
30 match load_entries_from_roster() {
31 Ok(roster_entries) => {
32 merge_entries(&mut entries, roster_entries);
33 }
34 Err(e) => {
35 tracing::debug!(error = %e, "Failed to load proxy entries from roster");
36 }
37 }
38
39 entries.sort_by(|a, b| a.name.cmp(&b.name));
40 Ok(entries)
41}
42
43fn load_entries_from_config() -> Result<Vec<ProxyEntry>, ProxyError> {
44 let path = config_path();
45 if !path.exists() {
46 return Ok(Vec::new());
47 }
48 let raw = std::fs::read_to_string(&path).map_err(|e| ProxyError::Io(e.to_string()))?;
49 let value: toml::Value = raw
50 .parse()
51 .map_err(|e| ProxyError::Config(format!("Invalid config.toml: {e}")))?;
52 let proxy = value
53 .get("proxy")
54 .cloned()
55 .unwrap_or_else(|| toml::Value::Table(toml::map::Map::new()));
56 let proxy: ProxySection = proxy
57 .try_into()
58 .map_err(|e| ProxyError::Config(format!("Invalid proxy section: {e}")))?;
59 Ok(proxy.entries)
60}
61
62fn load_entries_from_roster() -> Result<Vec<ProxyEntry>, ProxyError> {
63 let roster_path = koi_certmesh::ca::roster_path();
64 if !roster_path.exists() {
65 return Ok(Vec::new());
66 }
67
68 let roster = koi_certmesh::roster::load_roster(&roster_path)
69 .map_err(|e| ProxyError::Io(e.to_string()))?;
70 let hostname = hostname::get()
71 .map_err(|e| ProxyError::Io(e.to_string()))?
72 .to_string_lossy()
73 .to_string();
74
75 let Some(member) = roster.find_member(&hostname) else {
76 return Ok(Vec::new());
77 };
78
79 Ok(member
80 .proxy_entries
81 .iter()
82 .map(|entry| ProxyEntry {
83 name: entry.name.clone(),
84 listen_port: entry.listen_port,
85 backend: entry.backend.clone(),
86 allow_remote: entry.allow_remote,
87 })
88 .collect())
89}
90
91pub fn save_entries(entries: &[ProxyEntry]) -> Result<(), ProxyError> {
92 let path = config_path();
93 if let Some(parent) = path.parent() {
94 std::fs::create_dir_all(parent).map_err(|e| ProxyError::Io(e.to_string()))?;
95 }
96
97 let mut root = if path.exists() {
98 let raw = std::fs::read_to_string(&path).map_err(|e| ProxyError::Io(e.to_string()))?;
99 raw.parse::<toml::Value>()
100 .unwrap_or_else(|_| toml::Value::Table(toml::map::Map::new()))
101 } else {
102 toml::Value::Table(toml::map::Map::new())
103 };
104
105 let proxy = ProxySection {
106 entries: entries.to_vec(),
107 };
108 let proxy_value = toml::Value::try_from(proxy)
109 .map_err(|e| ProxyError::Config(format!("Proxy config serialize error: {e}")))?;
110
111 if let toml::Value::Table(table) = &mut root {
112 table.insert("proxy".to_string(), proxy_value);
113 }
114
115 let raw = toml::to_string_pretty(&root)
116 .map_err(|e| ProxyError::Config(format!("Config serialize error: {e}")))?;
117 std::fs::write(&path, raw).map_err(|e| ProxyError::Io(e.to_string()))?;
118 Ok(())
119}
120
121pub fn upsert_entry(entry: ProxyEntry) -> Result<Vec<ProxyEntry>, ProxyError> {
122 let mut entries = load_entries_from_config()?;
123 if let Some(existing) = entries.iter_mut().find(|e| e.name == entry.name) {
124 *existing = entry;
125 } else {
126 entries.push(entry);
127 }
128 entries.sort_by(|a, b| a.name.cmp(&b.name));
129 save_entries(&entries)?;
130 sync_roster(&entries)?;
131 Ok(entries)
132}
133
134pub fn remove_entry(name: &str) -> Result<Vec<ProxyEntry>, ProxyError> {
135 let mut entries = load_entries_from_config()?;
136 let before = entries.len();
137 entries.retain(|e| e.name != name);
138 if entries.len() == before {
139 return Err(ProxyError::NotFound(name.to_string()));
140 }
141 save_entries(&entries)?;
142 sync_roster(&entries)?;
143 Ok(entries)
144}
145
146fn sync_roster(entries: &[ProxyEntry]) -> Result<(), ProxyError> {
147 let roster_path = koi_certmesh::ca::roster_path();
148 if !roster_path.exists() {
149 return Ok(());
150 }
151
152 let mut roster = koi_certmesh::roster::load_roster(&roster_path)
153 .map_err(|e| ProxyError::Io(e.to_string()))?;
154 let hostname = hostname::get()
155 .map_err(|e| ProxyError::Io(e.to_string()))?
156 .to_string_lossy()
157 .to_string();
158
159 let Some(member) = roster.find_member_mut(&hostname) else {
160 return Ok(());
161 };
162
163 member.proxy_entries = entries
164 .iter()
165 .map(|entry| ProxyConfigEntry {
166 name: entry.name.clone(),
167 listen_port: entry.listen_port,
168 backend: entry.backend.clone(),
169 allow_remote: entry.allow_remote,
170 })
171 .collect();
172
173 koi_certmesh::roster::save_roster(&roster, &roster_path)
174 .map_err(|e| ProxyError::Io(e.to_string()))?;
175 Ok(())
176}
177
178fn merge_entries(entries: &mut Vec<ProxyEntry>, roster_entries: Vec<ProxyEntry>) {
179 let mut map: std::collections::BTreeMap<String, ProxyEntry> = std::collections::BTreeMap::new();
180 for entry in roster_entries {
181 map.insert(entry.name.clone(), entry);
182 }
183 for entry in entries.drain(..) {
184 map.insert(entry.name.clone(), entry);
185 }
186 *entries = map.into_values().collect();
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn config_path_is_under_data_dir() {
195 let path = config_path();
196 assert!(path.ends_with("config.toml"));
197 }
198
199 #[test]
200 fn proxy_entry_round_trip() {
201 let entry = ProxyEntry {
202 name: "grafana".to_string(),
203 listen_port: 443,
204 backend: "http://localhost:3000".to_string(),
205 allow_remote: false,
206 };
207 let proxy = ProxySection {
208 entries: vec![entry.clone()],
209 };
210 let value = toml::Value::try_from(proxy).unwrap();
211 let decoded: ProxySection = value.try_into().unwrap();
212 assert_eq!(decoded.entries[0], entry);
213 }
214}