1use std::fs::{self, File};
2use std::io::Write;
3use std::io::{self, ErrorKind};
4use std::path::{Path, PathBuf};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7#[derive(Clone, Copy, Debug, Eq, PartialEq)]
8pub enum FileType {
9 TOML,
10 JSON,
11}
12
13#[derive(Clone, Copy, Debug, Eq, PartialEq)]
14enum FileFormat {
15 Toml,
16 Json,
17 Jsonc,
18}
19
20impl FileFormat {
21 fn from_path(path: &Path) -> io::Result<Self> {
22 match path.extension().and_then(|suffix| suffix.to_str()) {
23 Some("toml") => Ok(Self::Toml),
24 Some("json") => Ok(Self::Json),
25 Some("jsonc") => Ok(Self::Jsonc),
26 _ => Err(io::Error::new(
27 ErrorKind::InvalidInput,
28 format!(
29 "config file type not supported for {} (expected .toml, .json, or .jsonc)",
30 path.display()
31 ),
32 )),
33 }
34 }
35
36 fn write_type(self) -> FileType {
37 match self {
38 Self::Toml => FileType::TOML,
39 Self::Json | Self::Jsonc => FileType::JSON,
40 }
41 }
42}
43
44pub(crate) fn infer_file_type(path: &Path) -> io::Result<FileType> {
45 Ok(FileFormat::from_path(path)?.write_type())
46}
47
48impl FileType {
49 fn as_str(self) -> &'static str {
50 match self {
51 Self::TOML => "TOML",
52 Self::JSON => "JSON",
53 }
54 }
55}
56
57fn serialize_config<T>(config: &T, file_type: FileType, path: &Path) -> io::Result<String>
58where
59 T: serde::Serialize + ?Sized,
60{
61 match file_type {
62 FileType::TOML => toml::to_string_pretty(config).map_err(|e| {
63 io::Error::new(
64 ErrorKind::InvalidData,
65 format!(
66 "failed to serialize TOML config for {}: {e}",
67 path.display()
68 ),
69 )
70 }),
71 FileType::JSON => serde_json::to_string_pretty(config).map_err(|e| {
72 io::Error::new(
73 ErrorKind::InvalidData,
74 format!(
75 "failed to serialize JSON config for {}: {e}",
76 path.display()
77 ),
78 )
79 }),
80 }
81}
82
83pub(crate) fn write_config<T>(path: &Path, config: &T, file_type: FileType) -> io::Result<()>
84where
85 T: serde::Serialize + ?Sized,
86{
87 let inferred_type = infer_file_type(path)?;
88 if inferred_type != file_type {
89 return Err(io::Error::new(
90 ErrorKind::InvalidInput,
91 format!(
92 "refusing to write {} config to {} because its extension expects {}",
93 file_type.as_str(),
94 path.display(),
95 inferred_type.as_str()
96 ),
97 ));
98 }
99
100 let content = serialize_config(config, file_type, path)?;
101 atomic_write(path, &content)
102}
103
104pub(crate) fn write_config_inferred<T>(path: &Path, config: &T) -> io::Result<()>
105where
106 T: serde::Serialize + ?Sized,
107{
108 let file_type = infer_file_type(path)?;
109 let content = serialize_config(config, file_type, path)?;
110 atomic_write(path, &content)
111}
112
113fn atomic_write(path: &Path, content: &str) -> io::Result<()> {
114 if path.is_dir() {
115 return Err(io::Error::new(
116 ErrorKind::InvalidInput,
117 format!("refusing to write config to directory {}", path.display()),
118 ));
119 }
120
121 if let Some(dir) = path.parent().filter(|dir| !dir.as_os_str().is_empty()) {
122 fs::create_dir_all(dir).map_err(|err| {
123 io::Error::new(
124 err.kind(),
125 format!(
126 "failed to create config directory {} for {}: {err}",
127 dir.display(),
128 path.display()
129 ),
130 )
131 })?;
132 }
133
134 let temp_path = temporary_path_for(path);
135 let write_result = (|| -> io::Result<()> {
136 let mut file = File::create(&temp_path).map_err(|err| {
137 io::Error::new(
138 err.kind(),
139 format!(
140 "failed to create temporary config file {} for {}: {err}",
141 temp_path.display(),
142 path.display()
143 ),
144 )
145 })?;
146 file.write_all(content.as_bytes()).map_err(|err| {
147 io::Error::new(
148 err.kind(),
149 format!(
150 "failed to write temporary config file {} for {}: {err}",
151 temp_path.display(),
152 path.display()
153 ),
154 )
155 })?;
156 file.sync_all().map_err(|err| {
157 io::Error::new(
158 err.kind(),
159 format!(
160 "failed to sync temporary config file {} for {}: {err}",
161 temp_path.display(),
162 path.display()
163 ),
164 )
165 })?;
166 Ok(())
167 })();
168
169 if let Err(err) = write_result {
170 let _ = fs::remove_file(&temp_path);
171 return Err(err);
172 }
173
174 if let Err(err) = fs::rename(&temp_path, path) {
175 let _ = fs::remove_file(&temp_path);
176 return Err(io::Error::new(
177 err.kind(),
178 format!(
179 "failed to replace config file {} with {}: {err}",
180 path.display(),
181 temp_path.display()
182 ),
183 ));
184 }
185
186 Ok(())
187}
188
189fn temporary_path_for(path: &Path) -> PathBuf {
190 let unique = SystemTime::now()
191 .duration_since(UNIX_EPOCH)
192 .map(|duration| duration.as_nanos())
193 .unwrap_or(0);
194 let file_name = path
195 .file_name()
196 .and_then(|name| name.to_str())
197 .unwrap_or("config");
198 let temp_name = format!(".{file_name}.{}.{}.tmp", std::process::id(), unique);
199
200 match path.parent() {
201 Some(parent) if !parent.as_os_str().is_empty() => parent.join(temp_name),
202 _ => PathBuf::from(temp_name),
203 }
204}
205
206fn strip_jsonc_comments(content: &str) -> String {
207 #[derive(Clone, Copy)]
208 enum State {
209 Normal,
210 InString,
211 Escaped,
212 LineComment,
213 BlockComment,
214 }
215
216 let mut output = String::with_capacity(content.len());
217 let mut state = State::Normal;
218 let mut chars = content.chars().peekable();
219
220 while let Some(ch) = chars.next() {
221 match state {
222 State::Normal => {
223 if ch == '"' {
224 output.push(ch);
225 state = State::InString;
226 } else if ch == '/' && matches!(chars.peek(), Some('/')) {
227 output.push(' ');
228 output.push(' ');
229 chars.next();
230 state = State::LineComment;
231 } else if ch == '/' && matches!(chars.peek(), Some('*')) {
232 output.push(' ');
233 output.push(' ');
234 chars.next();
235 state = State::BlockComment;
236 } else {
237 output.push(ch);
238 }
239 }
240 State::InString => {
241 output.push(ch);
242 if ch == '\\' {
243 state = State::Escaped;
244 } else if ch == '"' {
245 state = State::Normal;
246 }
247 }
248 State::Escaped => {
249 output.push(ch);
250 state = State::InString;
251 }
252 State::LineComment => {
253 if ch == '\n' {
254 output.push('\n');
255 state = State::Normal;
256 } else {
257 output.push(' ');
258 }
259 }
260 State::BlockComment => {
261 if ch == '*' && matches!(chars.peek(), Some('/')) {
262 output.push(' ');
263 output.push(' ');
264 chars.next();
265 state = State::Normal;
266 } else if ch == '\n' {
267 output.push('\n');
268 } else {
269 output.push(' ');
270 }
271 }
272 }
273 }
274
275 output
276}
277
278pub(crate) fn read_config<T>(path: &Path) -> Result<T, io::Error>
279where
280 T: serde::de::DeserializeOwned,
281{
282 let format = FileFormat::from_path(path)?;
283 let content = fs::read_to_string(path).map_err(|err| {
284 io::Error::new(
285 err.kind(),
286 format!("failed to read config {}: {err}", path.display()),
287 )
288 })?;
289
290 match format {
291 FileFormat::Toml => toml::from_str(&content).map_err(|e| {
292 io::Error::new(
293 ErrorKind::InvalidData,
294 format!("failed to parse TOML config {}: {e}", path.display()),
295 )
296 }),
297 FileFormat::Json => serde_json::from_str(&content).map_err(|e| {
298 io::Error::new(
299 ErrorKind::InvalidData,
300 format!("failed to parse JSON config {}: {e}", path.display()),
301 )
302 }),
303 FileFormat::Jsonc => {
304 let json_content = strip_jsonc_comments(&content);
305 serde_json::from_str(&json_content).map_err(|e| {
306 io::Error::new(
307 ErrorKind::InvalidData,
308 format!("failed to parse JSONC config {}: {e}", path.display()),
309 )
310 })
311 }
312 }
313}