1use crate::schema::{Config, ConfigError, LogOutput};
19
20pub(crate) struct Document {
24 lines: Vec<Line>,
25 trailing_newline: bool,
29}
30
31struct Line {
32 raw: String,
33 kind: LineKind,
34}
35
36enum LineKind {
37 BlankOrComment,
38 Section(String),
39 Pair {
40 section: Option<String>,
41 key: String,
42 value_start: usize,
43 value_end: usize,
44 },
45}
46
47impl Document {
48 pub(crate) fn parse(src: &str) -> Result<Self, ConfigError> {
49 let trailing_newline = src.ends_with('\n');
50 let mut lines = Vec::new();
51 let mut current: Option<String> = None;
52 for (idx, raw_with_nl) in src.split_inclusive('\n').enumerate() {
53 let line_no = idx + 1;
54 let raw = raw_with_nl.strip_suffix('\n').unwrap_or(raw_with_nl);
55 let kind = classify_line(raw, ¤t, line_no)?;
56 if let LineKind::Section(name) = &kind {
57 current = Some(name.clone());
58 }
59 lines.push(Line { raw: raw.to_string(), kind });
60 }
61 Ok(Document { lines, trailing_newline })
62 }
63}
64
65impl Config {
66 pub fn to_toml_string_preserving(
73 &self,
74 original_source: &str,
75 ) -> Result<String, ConfigError> {
76 let doc = Document::parse(original_source)?;
77 let pairs = canonical_pairs(self);
78 let last_idx = last_line_per_known_section(&doc);
79 let mut emitted = vec![false; pairs.len()];
80 let mut out = String::with_capacity(original_source.len() + 256);
81 for (i, line) in doc.lines.iter().enumerate() {
82 emit_line(line, &pairs, &mut emitted, &mut out);
83 for (section_name, idx) in &last_idx {
84 if *idx == i {
85 inline_flush_section(section_name, &pairs, &mut emitted, &mut out);
86 }
87 }
88 }
89 append_orphan_sections(&pairs, &mut emitted, &mut out, doc.trailing_newline);
90 Ok(out)
91 }
92}
93
94fn last_line_per_known_section(doc: &Document) -> Vec<(String, usize)> {
99 let mut acc: Vec<(String, usize)> = Vec::new();
100 for (i, line) in doc.lines.iter().enumerate() {
101 let name = match &line.kind {
102 LineKind::Section(n) => Some(n.clone()),
103 LineKind::Pair { section: Some(s), .. } => Some(s.clone()),
104 _ => None,
105 };
106 if let Some(n) = name {
107 if let Some(slot) = acc.iter_mut().find(|(k, _)| *k == n) {
108 slot.1 = i;
109 } else {
110 acc.push((n, i));
111 }
112 }
113 }
114 acc
115}
116
117fn inline_flush_section(
118 section: &str,
119 pairs: &[CanonicalPair],
120 emitted: &mut [bool],
121 out: &mut String,
122) {
123 for (j, p) in pairs.iter().enumerate() {
124 if !emitted[j] && p.section == section {
125 out.push_str(p.key);
126 out.push_str(" = ");
127 out.push_str(&p.value);
128 out.push('\n');
129 emitted[j] = true;
130 }
131 }
132}
133
134fn append_orphan_sections(
135 pairs: &[CanonicalPair],
136 emitted: &mut [bool],
137 out: &mut String,
138 src_had_trailing_newline: bool,
139) {
140 let mut any_appended = false;
141 let mut current_section: Option<&'static str> = None;
142 for (i, p) in pairs.iter().enumerate() {
143 if emitted[i] {
144 continue;
145 }
146 if !any_appended {
147 if !out.is_empty() {
148 if !out.ends_with('\n') {
149 out.push('\n');
150 }
151 if !out.ends_with("\n\n") {
152 out.push('\n');
153 }
154 }
155 any_appended = true;
156 }
157 if current_section != Some(p.section) {
158 if current_section.is_some() {
159 out.push('\n');
160 }
161 out.push('[');
162 out.push_str(p.section);
163 out.push_str("]\n");
164 current_section = Some(p.section);
165 }
166 out.push_str(p.key);
167 out.push_str(" = ");
168 out.push_str(&p.value);
169 out.push('\n');
170 emitted[i] = true;
171 }
172 if !any_appended && !src_had_trailing_newline && out.ends_with('\n') {
173 out.pop();
174 }
175}
176
177fn emit_line(
178 line: &Line,
179 pairs: &[CanonicalPair],
180 emitted: &mut [bool],
181 out: &mut String,
182) {
183 match &line.kind {
184 LineKind::BlankOrComment | LineKind::Section(_) => {
185 out.push_str(&line.raw);
186 out.push('\n');
187 }
188 LineKind::Pair { section, key, value_start, value_end } => {
189 let canonical = pairs.iter().enumerate().find(|(_, p)| {
190 p.section == section.as_deref().unwrap_or("") && p.key == key
191 });
192 match canonical {
193 Some((idx, p)) => {
194 out.push_str(&line.raw[..*value_start]);
195 out.push_str(&p.value);
196 out.push_str(&line.raw[*value_end..]);
197 out.push('\n');
198 emitted[idx] = true;
199 }
200 None => {
201 out.push_str(&line.raw);
206 out.push('\n');
207 }
208 }
209 }
210 }
211}
212
213struct CanonicalPair {
216 section: &'static str,
217 key: &'static str,
218 value: String,
219}
220
221fn canonical_pairs(cfg: &Config) -> Vec<CanonicalPair> {
222 let mut v = Vec::with_capacity(22);
223 push_server(&mut v, cfg);
224 push_persistence(&mut v, cfg);
225 push_memory(&mut v, cfg);
226 push_expiry(&mut v, cfg);
227 push_log(&mut v, cfg);
228 push_notification(&mut v, cfg);
229 push_advanced(&mut v, cfg);
230 push_slowlog(&mut v, cfg);
231 v
232}
233
234fn push_server(v: &mut Vec<CanonicalPair>, cfg: &Config) {
235 let [a, b, c, d] = cfg.server.bind;
236 push(v, "server", "bind", format!("\"{a}.{b}.{c}.{d}\""));
237 push(v, "server", "port", cfg.server.port.to_string());
238 push(v, "server", "threads", cfg.server.threads.to_string());
239 push(
240 v,
241 "server",
242 "data_dir",
243 toml_string(&cfg.server.data_dir.display().to_string()),
244 );
245}
246
247fn push_persistence(v: &mut Vec<CanonicalPair>, cfg: &Config) {
248 let p = &cfg.persistence;
249 push(v, "persistence", "aof", p.aof.to_string());
250 push(v, "persistence", "appendfsync", toml_string(p.appendfsync.as_str()));
251 push(
252 v,
253 "persistence",
254 "auto_aof_rewrite_percentage",
255 p.auto_aof_rewrite_percentage.to_string(),
256 );
257 push(
258 v,
259 "persistence",
260 "auto_aof_rewrite_min_size",
261 p.auto_aof_rewrite_min_size.to_string(),
262 );
263}
264
265fn push_memory(v: &mut Vec<CanonicalPair>, cfg: &Config) {
266 push(v, "memory", "maxmemory", cfg.memory.maxmemory.to_string());
267 push(
268 v,
269 "memory",
270 "maxmemory_policy",
271 toml_string(cfg.memory.maxmemory_policy.as_str()),
272 );
273}
274
275fn push_expiry(v: &mut Vec<CanonicalPair>, cfg: &Config) {
276 push(v, "expiry", "hz", cfg.expiry.hz.to_string());
277 push(v, "expiry", "sample", cfg.expiry.sample.to_string());
278}
279
280fn push_log(v: &mut Vec<CanonicalPair>, cfg: &Config) {
281 push(v, "log", "level", toml_string(cfg.log.level.as_str()));
282 push(v, "log", "output", toml_string(&log_output_str(&cfg.log.output)));
283}
284
285fn push_notification(v: &mut Vec<CanonicalPair>, cfg: &Config) {
286 push(
287 v,
288 "notification",
289 "notify_keyspace_events",
290 toml_string(&cfg.notification.notify_keyspace_events),
291 );
292}
293
294fn push_advanced(v: &mut Vec<CanonicalPair>, cfg: &Config) {
295 let a = &cfg.advanced;
296 push(v, "advanced", "spin_limit", a.spin_limit.to_string());
297 push(v, "advanced", "park_timeout_ms", a.park_timeout_ms.to_string());
298 push(v, "advanced", "tick_check_every", a.tick_check_every.to_string());
299 push(v, "advanced", "ring_capacity", a.ring_capacity.to_string());
300}
301
302fn push_slowlog(v: &mut Vec<CanonicalPair>, cfg: &Config) {
303 push(
304 v,
305 "slowlog",
306 "slower_than_micros",
307 cfg.slowlog.slower_than_micros.to_string(),
308 );
309 push(v, "slowlog", "max_len", cfg.slowlog.max_len.to_string());
310}
311
312fn push(v: &mut Vec<CanonicalPair>, section: &'static str, key: &'static str, value: String) {
313 v.push(CanonicalPair { section, key, value });
314}
315
316fn log_output_str(o: &LogOutput) -> String {
317 o.as_str().into_owned()
318}
319
320fn toml_string(s: &str) -> String {
321 let mut out = String::with_capacity(s.len() + 2);
322 out.push('"');
323 for c in s.chars() {
324 match c {
325 '\\' => out.push_str("\\\\"),
326 '"' => out.push_str("\\\""),
327 other => out.push(other),
328 }
329 }
330 out.push('"');
331 out
332}
333
334fn classify_line(
337 raw: &str,
338 section_ctx: &Option<String>,
339 line_no: usize,
340) -> Result<LineKind, ConfigError> {
341 let bytes = raw.as_bytes();
342 let Some(i) = first_nonws(bytes) else {
343 return Ok(LineKind::BlankOrComment);
344 };
345 let first = bytes[i];
346 if first == b'#' {
347 return Ok(LineKind::BlankOrComment);
348 }
349 if first == b'[' {
350 return parse_section_line(bytes, i, line_no);
351 }
352 parse_pair_line(bytes, i, section_ctx, line_no)
353}
354
355fn parse_section_line(bytes: &[u8], i: usize, line_no: usize) -> Result<LineKind, ConfigError> {
356 let rest = &bytes[i + 1..];
357 let end = rest
358 .iter()
359 .position(|&b| b == b']')
360 .ok_or_else(|| parse_err(line_no, i + 2, "expected ']' in section header"))?;
361 let name = std::str::from_utf8(&rest[..end])
362 .map_err(|_| parse_err(line_no, i + 2, "section name not UTF-8"))?
363 .trim();
364 if name.is_empty() {
365 return Err(parse_err(line_no, i + 2, "empty section name"));
366 }
367 check_trailing_or_comment(&rest[end + 1..], line_no, i + end + 2)?;
368 Ok(LineKind::Section(name.to_string()))
369}
370
371fn parse_pair_line(
372 bytes: &[u8],
373 key_start: usize,
374 section_ctx: &Option<String>,
375 line_no: usize,
376) -> Result<LineKind, ConfigError> {
377 let mut j = key_start;
378 while j < bytes.len() && is_ident_char(bytes[j]) {
379 j += 1;
380 }
381 if j == key_start {
382 return Err(parse_err(line_no, key_start + 1, "expected key identifier"));
383 }
384 let key = std::str::from_utf8(&bytes[key_start..j])
385 .map_err(|_| parse_err(line_no, key_start + 1, "key not UTF-8"))?
386 .to_string();
387 j = skip_ws(bytes, j);
388 if j >= bytes.len() || bytes[j] != b'=' {
389 return Err(parse_err(line_no, j + 1, "expected '='"));
390 }
391 j += 1;
392 j = skip_ws(bytes, j);
393 let value_start = j;
394 let value_end = scan_value_end(bytes, j, line_no)?;
395 check_trailing_or_comment(&bytes[value_end..], line_no, value_end + 1)?;
396 Ok(LineKind::Pair {
397 section: section_ctx.clone(),
398 key,
399 value_start,
400 value_end,
401 })
402}
403
404fn scan_value_end(bytes: &[u8], start: usize, line_no: usize) -> Result<usize, ConfigError> {
405 if start >= bytes.len() {
406 return Err(parse_err(line_no, start + 1, "expected value"));
407 }
408 let first = bytes[start];
409 if first == b'"' || first == b'\'' {
410 let mut k = start + 1;
411 while k < bytes.len() {
412 let b = bytes[k];
413 if b == first {
414 return Ok(k + 1);
415 }
416 if b == b'\\' && first == b'"' && k + 1 < bytes.len() {
417 k += 2;
418 continue;
419 }
420 k += 1;
421 }
422 return Err(parse_err(line_no, start + 1, "unterminated string"));
423 }
424 let mut k = start;
425 while k < bytes.len() {
426 let b = bytes[k];
427 if b == b' ' || b == b'\t' || b == b'\r' || b == b'#' {
428 break;
429 }
430 k += 1;
431 }
432 Ok(k)
433}
434
435fn check_trailing_or_comment(
436 rest: &[u8],
437 line_no: usize,
438 col_base: usize,
439) -> Result<(), ConfigError> {
440 let mut k = 0;
441 while k < rest.len() {
442 let b = rest[k];
443 if b == b' ' || b == b'\t' || b == b'\r' {
444 k += 1;
445 continue;
446 }
447 if b == b'#' {
448 return Ok(());
449 }
450 return Err(parse_err(
451 line_no,
452 col_base + k,
453 format!("unexpected trailing content {:?}", b as char),
454 ));
455 }
456 Ok(())
457}
458
459fn first_nonws(bytes: &[u8]) -> Option<usize> {
460 bytes
461 .iter()
462 .position(|&b| b != b' ' && b != b'\t' && b != b'\r')
463}
464
465fn skip_ws(bytes: &[u8], mut k: usize) -> usize {
466 while k < bytes.len() && (bytes[k] == b' ' || bytes[k] == b'\t') {
467 k += 1;
468 }
469 k
470}
471
472fn is_ident_char(b: u8) -> bool {
473 b.is_ascii_alphanumeric() || b == b'_' || b == b'-'
474}
475
476fn parse_err(line: usize, col: usize, msg: impl Into<String>) -> ConfigError {
477 ConfigError::Parse {
478 line,
479 col,
480 msg: msg.into(),
481 }
482}
483