1use std::collections::BTreeMap;
2
3use super::Data;
4use super::Inline;
5use super::Position;
6
7pub(crate) fn get() -> std::sync::MutexGuard<'static, Runtime> {
8 static RT: std::sync::Mutex<Runtime> = std::sync::Mutex::new(Runtime::new());
9 RT.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
10}
11
12#[derive(Default)]
13pub(crate) struct Runtime {
14 per_file: Vec<SourceFileRuntime>,
15 path_count: Vec<PathRuntime>,
16}
17
18impl Runtime {
19 const fn new() -> Self {
20 Self {
21 per_file: Vec::new(),
22 path_count: Vec::new(),
23 }
24 }
25
26 pub(crate) fn count(&mut self, path_prefix: &str) -> usize {
27 if let Some(entry) = self
28 .path_count
29 .iter_mut()
30 .find(|entry| entry.is(path_prefix))
31 {
32 entry.next()
33 } else {
34 let entry = PathRuntime::new(path_prefix);
35 let next = entry.count();
36 self.path_count.push(entry);
37 next
38 }
39 }
40
41 pub(crate) fn write(&mut self, actual: &Data, inline: &Inline) -> std::io::Result<()> {
42 let actual = actual.render().expect("`actual` must be UTF-8");
43 if let Some(entry) = self
44 .per_file
45 .iter_mut()
46 .find(|f| f.path == inline.position.file)
47 {
48 entry.update(&actual, inline)?;
49 } else {
50 let mut entry = SourceFileRuntime::new(inline)?;
51 entry.update(&actual, inline)?;
52 self.per_file.push(entry);
53 }
54
55 Ok(())
56 }
57}
58
59struct SourceFileRuntime {
60 path: std::path::PathBuf,
61 original_text: String,
62 patchwork: Patchwork,
63}
64
65impl SourceFileRuntime {
66 fn new(inline: &Inline) -> std::io::Result<SourceFileRuntime> {
67 let path = inline.position.file.clone();
68 let original_text = std::fs::read_to_string(&path)?;
69 let patchwork = Patchwork::new(original_text.clone());
70 Ok(SourceFileRuntime {
71 path,
72 original_text,
73 patchwork,
74 })
75 }
76 fn update(&mut self, actual: &str, inline: &Inline) -> std::io::Result<()> {
77 let span = Span::from_pos(&inline.position, &self.original_text);
78 let patch = format_patch(actual);
79 self.patchwork.patch(span.literal_range, &patch)?;
80 std::fs::write(&inline.position.file, &self.patchwork.text)
81 }
82}
83
84#[derive(Debug)]
85struct Patchwork {
86 text: String,
87 indels: BTreeMap<OrdRange, (usize, String)>,
88}
89
90impl Patchwork {
91 fn new(text: String) -> Patchwork {
92 Patchwork {
93 text,
94 indels: BTreeMap::new(),
95 }
96 }
97 fn patch(&mut self, mut range: std::ops::Range<usize>, patch: &str) -> std::io::Result<()> {
98 let key: OrdRange = range.clone().into();
99 match self.indels.entry(key) {
100 std::collections::btree_map::Entry::Vacant(entry) => {
101 entry.insert((patch.len(), patch.to_owned()));
102 }
103 std::collections::btree_map::Entry::Occupied(entry) => {
104 if entry.get().1 == patch {
105 return Ok(());
106 } else {
107 return Err(std::io::Error::other(
108 "cannot update as it was already modified",
109 ));
110 }
111 }
112 }
113
114 let (delete, insert) = self
115 .indels
116 .iter()
117 .take_while(|(delete, _)| delete.start < range.start)
118 .map(|(delete, (insert, _))| (delete.end - delete.start, insert))
119 .fold((0usize, 0usize), |(x1, y1), (x2, y2)| (x1 + x2, y1 + y2));
120
121 for pos in &mut [&mut range.start, &mut range.end] {
122 **pos -= delete;
123 **pos += insert;
124 }
125
126 self.text.replace_range(range, patch);
127 Ok(())
128 }
129}
130
131#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
132struct OrdRange {
133 start: usize,
134 end: usize,
135}
136
137impl From<std::ops::Range<usize>> for OrdRange {
138 fn from(other: std::ops::Range<usize>) -> Self {
139 Self {
140 start: other.start,
141 end: other.end,
142 }
143 }
144}
145
146fn lit_kind_for_patch(patch: &str) -> StrLitKind {
147 let has_dquote = patch.chars().any(|c| c == '"');
148 if !has_dquote {
149 let has_bslash_or_newline = patch.chars().any(|c| matches!(c, '\\' | '\n'));
150 return if has_bslash_or_newline {
151 StrLitKind::Raw(1)
152 } else {
153 StrLitKind::Normal
154 };
155 }
156
157 let leading_hashes = |s: &str| s.chars().take_while(|&c| c == '#').count();
160 let max_hashes = patch.split('"').map(leading_hashes).max().unwrap();
161 StrLitKind::Raw(max_hashes + 1)
162}
163
164fn format_patch(patch: &str) -> String {
165 let lit_kind = lit_kind_for_patch(patch);
166 let is_multiline = patch.contains('\n');
167
168 let mut buf = String::new();
169 if matches!(lit_kind, StrLitKind::Raw(_)) {
170 buf.push('[');
171 }
172 lit_kind.write_start(&mut buf).unwrap();
173 if is_multiline {
174 buf.push('\n');
175 }
176 buf.push_str(patch);
177 if is_multiline {
178 buf.push('\n');
179 }
180 lit_kind.write_end(&mut buf).unwrap();
181 if matches!(lit_kind, StrLitKind::Raw(_)) {
182 buf.push(']');
183 }
184 buf
185}
186
187#[derive(Clone, Debug)]
188struct Span {
189 literal_range: std::ops::Range<usize>,
191}
192
193impl Span {
194 fn from_pos(pos: &Position, file: &str) -> Span {
195 let mut target_line = None;
196 let mut line_start = 0;
197 for (i, line) in crate::utils::LinesWithTerminator::new(file).enumerate() {
198 if i == pos.line as usize - 1 {
199 #[allow(clippy::skip_while_next)]
208 let byte_offset = line
209 .char_indices()
210 .skip((pos.column - 1).try_into().unwrap())
211 .skip_while(|&(_, c)| c != '!')
212 .skip(1) .skip_while(|&(_, c)| c.is_whitespace())
214 .skip(1) .skip_while(|&(_, c)| c.is_whitespace())
216 .next()
217 .expect("Failed to parse macro invocation")
218 .0;
219
220 let literal_start = line_start + byte_offset;
221 target_line = Some(literal_start);
222 break;
223 }
224 line_start += line.len();
225 }
226 let literal_start = target_line.unwrap();
227
228 let lit_to_eof = &file[literal_start..];
229 let lit_to_eof_trimmed = lit_to_eof.trim_start();
230
231 let literal_start = literal_start + (lit_to_eof.len() - lit_to_eof_trimmed.len());
232
233 let literal_len =
234 locate_end(lit_to_eof_trimmed).expect("Couldn't find closing delimiter for `expect!`.");
235 let literal_range = literal_start..literal_start + literal_len;
236 Span { literal_range }
237 }
238}
239
240fn locate_end(arg_start_to_eof: &str) -> Option<usize> {
241 match arg_start_to_eof.chars().next()? {
242 c if c.is_whitespace() => panic!("skip whitespace before calling `locate_end`"),
243
244 '[' => {
246 let str_start_to_eof = arg_start_to_eof[1..].trim_start();
247 let str_len = find_str_lit_len(str_start_to_eof)?;
248 let str_end_to_eof = &str_start_to_eof[str_len..];
249 let closing_brace_offset = str_end_to_eof.find(']')?;
250 Some((arg_start_to_eof.len() - str_end_to_eof.len()) + closing_brace_offset + 1)
251 }
252
253 ']' | '}' | ')' => Some(0),
255
256 _ => find_str_lit_len(arg_start_to_eof),
258 }
259}
260
261fn find_str_lit_len(str_lit_to_eof: &str) -> Option<usize> {
264 fn try_find_n_hashes(
265 s: &mut impl Iterator<Item = char>,
266 desired_hashes: usize,
267 ) -> Option<(usize, Option<char>)> {
268 let mut n = 0;
269 loop {
270 match s.next()? {
271 '#' => n += 1,
272 c => return Some((n, Some(c))),
273 }
274
275 if n == desired_hashes {
276 return Some((n, None));
277 }
278 }
279 }
280
281 let mut s = str_lit_to_eof.chars();
282 let kind = match s.next()? {
283 '"' => StrLitKind::Normal,
284 'r' => {
285 let (n, c) = try_find_n_hashes(&mut s, usize::MAX)?;
286 if c != Some('"') {
287 return None;
288 }
289 StrLitKind::Raw(n)
290 }
291 _ => return None,
292 };
293
294 let mut oldc = None;
295 loop {
296 let c = oldc.take().or_else(|| s.next())?;
297 match (c, kind) {
298 ('\\', StrLitKind::Normal) => {
299 let _escaped = s.next()?;
300 }
301 ('"', StrLitKind::Normal) => break,
302 ('"', StrLitKind::Raw(0)) => break,
303 ('"', StrLitKind::Raw(n)) => {
304 let (seen, c) = try_find_n_hashes(&mut s, n)?;
305 if seen == n {
306 break;
307 }
308 oldc = c;
309 }
310 _ => {}
311 }
312 }
313
314 Some(str_lit_to_eof.len() - s.as_str().len())
315}
316
317#[derive(Copy, Clone)]
318enum StrLitKind {
319 Normal,
320 Raw(usize),
321}
322
323impl StrLitKind {
324 fn write_start(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
325 match self {
326 Self::Normal => write!(w, "\""),
327 Self::Raw(n) => {
328 write!(w, "r")?;
329 for _ in 0..n {
330 write!(w, "#")?;
331 }
332 write!(w, "\"")
333 }
334 }
335 }
336
337 fn write_end(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
338 match self {
339 Self::Normal => write!(w, "\""),
340 Self::Raw(n) => {
341 write!(w, "\"")?;
342 for _ in 0..n {
343 write!(w, "#")?;
344 }
345 Ok(())
346 }
347 }
348 }
349}
350
351#[derive(Clone)]
352struct PathRuntime {
353 path_prefix: String,
354 count: usize,
355}
356
357impl PathRuntime {
358 fn new(path_prefix: &str) -> Self {
359 Self {
360 path_prefix: path_prefix.to_owned(),
361 count: 0,
362 }
363 }
364
365 fn is(&self, path_prefix: &str) -> bool {
366 self.path_prefix == path_prefix
367 }
368
369 fn next(&mut self) -> usize {
370 self.count += 1;
371 self.count
372 }
373
374 fn count(&self) -> usize {
375 self.count
376 }
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::assert_data_eq;
383 use crate::prelude::*;
384 use crate::str;
385
386 #[test]
387 fn test_format_patch() {
388 let patch = format_patch("hello\nworld\n");
389
390 assert_data_eq!(
391 patch,
392 str![[r##"
393[r#"
394hello
395world
396
397"#]
398"##]],
399 );
400
401 let patch = format_patch(r"hello\tworld");
402 assert_data_eq!(patch, str![[r##"[r#"hello\tworld"#]"##]].raw());
403
404 let patch = format_patch("{\"foo\": 42}");
405 assert_data_eq!(patch, str![[r##"[r#"{"foo": 42}"#]"##]]);
406 }
407
408 #[test]
409 fn test_patchwork() {
410 let mut patchwork = Patchwork::new("one two three".to_owned());
411 patchwork.patch(4..7, "zwei").unwrap();
412 patchwork.patch(0..3, "один").unwrap();
413 patchwork.patch(8..13, "3").unwrap();
414 assert_data_eq!(
415 patchwork.to_debug(),
416 str![[r#"
417Patchwork {
418 text: "один zwei 3",
419 indels: {
420 OrdRange {
421 start: 0,
422 end: 3,
423 }: (
424 8,
425 "один",
426 ),
427 OrdRange {
428 start: 4,
429 end: 7,
430 }: (
431 4,
432 "zwei",
433 ),
434 OrdRange {
435 start: 8,
436 end: 13,
437 }: (
438 1,
439 "3",
440 ),
441 },
442}
443
444"#]],
445 );
446 }
447
448 #[test]
449 fn test_patchwork_overlap_diverge() {
450 let mut patchwork = Patchwork::new("one two three".to_owned());
451 patchwork.patch(4..7, "zwei").unwrap();
452 patchwork.patch(4..7, "abcd").unwrap_err();
453 assert_data_eq!(
454 patchwork.to_debug(),
455 str![[r#"
456Patchwork {
457 text: "one zwei three",
458 indels: {
459 OrdRange {
460 start: 4,
461 end: 7,
462 }: (
463 4,
464 "zwei",
465 ),
466 },
467}
468
469"#]],
470 );
471 }
472
473 #[test]
474 fn test_patchwork_overlap_converge() {
475 let mut patchwork = Patchwork::new("one two three".to_owned());
476 patchwork.patch(4..7, "zwei").unwrap();
477 patchwork.patch(4..7, "zwei").unwrap();
478 assert_data_eq!(
479 patchwork.to_debug(),
480 str![[r#"
481Patchwork {
482 text: "one zwei three",
483 indels: {
484 OrdRange {
485 start: 4,
486 end: 7,
487 }: (
488 4,
489 "zwei",
490 ),
491 },
492}
493
494"#]],
495 );
496 }
497
498 #[test]
499 fn test_locate() {
500 macro_rules! check_locate {
501 ($( [[$s:literal]] ),* $(,)?) => {$({
502 let lit = stringify!($s);
503 let with_trailer = format!("{} \t]]\n", lit);
504 assert_eq!(locate_end(&with_trailer), Some(lit.len()));
505 })*};
506 }
507
508 check_locate!(
510 [[r#"{ arr: [[1, 2], [3, 4]], other: "foo" } "#]],
511 [["]]"]],
512 [["\"]]"]],
513 [[r#""]]"#]],
514 );
515
516 assert_eq!(locate_end("]]"), Some(0));
518 }
519
520 #[test]
521 fn test_find_str_lit_len() {
522 macro_rules! check_str_lit_len {
523 ($( $s:literal ),* $(,)?) => {$({
524 let lit = stringify!($s);
525 assert_eq!(find_str_lit_len(lit), Some(lit.len()));
526 })*}
527 }
528
529 check_str_lit_len![
530 r##"foa\""#"##,
531 r##"
532
533 asdf][]]""""#
534 "##,
535 "",
536 "\"",
537 "\"\"",
538 "#\"#\"#",
539 ];
540 }
541}