Skip to main content

snapbox/data/
runtime.rs

1use std::collections::BTreeMap;
2
3use super::Data;
4use super::Inline;
5use super::Position;
6
7pub(crate) fn get() -> std::sync::MutexGuard<'static, Runtime> {
8    static RT: std::sync::Mutex<Runtime> = std::sync::Mutex::new(Runtime::new());
9    RT.lock().unwrap_or_else(|poisoned| poisoned.into_inner())
10}
11
12#[derive(Default)]
13pub(crate) struct Runtime {
14    per_file: Vec<SourceFileRuntime>,
15    path_count: Vec<PathRuntime>,
16}
17
18impl Runtime {
19    const fn new() -> Self {
20        Self {
21            per_file: Vec::new(),
22            path_count: Vec::new(),
23        }
24    }
25
26    pub(crate) fn count(&mut self, path_prefix: &str) -> usize {
27        if let Some(entry) = self
28            .path_count
29            .iter_mut()
30            .find(|entry| entry.is(path_prefix))
31        {
32            entry.next()
33        } else {
34            let entry = PathRuntime::new(path_prefix);
35            let next = entry.count();
36            self.path_count.push(entry);
37            next
38        }
39    }
40
41    pub(crate) fn write(&mut self, actual: &Data, inline: &Inline) -> std::io::Result<()> {
42        let actual = actual.render().expect("`actual` must be UTF-8");
43        if let Some(entry) = self
44            .per_file
45            .iter_mut()
46            .find(|f| f.path == inline.position.file)
47        {
48            entry.update(&actual, inline)?;
49        } else {
50            let mut entry = SourceFileRuntime::new(inline)?;
51            entry.update(&actual, inline)?;
52            self.per_file.push(entry);
53        }
54
55        Ok(())
56    }
57}
58
59struct SourceFileRuntime {
60    path: std::path::PathBuf,
61    original_text: String,
62    patchwork: Patchwork,
63}
64
65impl SourceFileRuntime {
66    fn new(inline: &Inline) -> std::io::Result<SourceFileRuntime> {
67        let path = inline.position.file.clone();
68        let original_text = std::fs::read_to_string(&path)?;
69        let patchwork = Patchwork::new(original_text.clone());
70        Ok(SourceFileRuntime {
71            path,
72            original_text,
73            patchwork,
74        })
75    }
76    fn update(&mut self, actual: &str, inline: &Inline) -> std::io::Result<()> {
77        let span = Span::from_pos(&inline.position, &self.original_text);
78        let patch = format_patch(actual);
79        self.patchwork.patch(span.literal_range, &patch)?;
80        std::fs::write(&inline.position.file, &self.patchwork.text)
81    }
82}
83
84#[derive(Debug)]
85struct Patchwork {
86    text: String,
87    indels: BTreeMap<OrdRange, (usize, String)>,
88}
89
90impl Patchwork {
91    fn new(text: String) -> Patchwork {
92        Patchwork {
93            text,
94            indels: BTreeMap::new(),
95        }
96    }
97    fn patch(&mut self, mut range: std::ops::Range<usize>, patch: &str) -> std::io::Result<()> {
98        let key: OrdRange = range.clone().into();
99        match self.indels.entry(key) {
100            std::collections::btree_map::Entry::Vacant(entry) => {
101                entry.insert((patch.len(), patch.to_owned()));
102            }
103            std::collections::btree_map::Entry::Occupied(entry) => {
104                if entry.get().1 == patch {
105                    return Ok(());
106                } else {
107                    return Err(std::io::Error::other(
108                        "cannot update as it was already modified",
109                    ));
110                }
111            }
112        }
113
114        let (delete, insert) = self
115            .indels
116            .iter()
117            .take_while(|(delete, _)| delete.start < range.start)
118            .map(|(delete, (insert, _))| (delete.end - delete.start, insert))
119            .fold((0usize, 0usize), |(x1, y1), (x2, y2)| (x1 + x2, y1 + y2));
120
121        for pos in &mut [&mut range.start, &mut range.end] {
122            **pos -= delete;
123            **pos += insert;
124        }
125
126        self.text.replace_range(range, patch);
127        Ok(())
128    }
129}
130
131#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
132struct OrdRange {
133    start: usize,
134    end: usize,
135}
136
137impl From<std::ops::Range<usize>> for OrdRange {
138    fn from(other: std::ops::Range<usize>) -> Self {
139        Self {
140            start: other.start,
141            end: other.end,
142        }
143    }
144}
145
146fn lit_kind_for_patch(patch: &str) -> StrLitKind {
147    let has_dquote = patch.chars().any(|c| c == '"');
148    if !has_dquote {
149        let has_bslash_or_newline = patch.chars().any(|c| matches!(c, '\\' | '\n'));
150        return if has_bslash_or_newline {
151            StrLitKind::Raw(1)
152        } else {
153            StrLitKind::Normal
154        };
155    }
156
157    // Find the maximum number of hashes that follow a double quote in the string.
158    // We need to use one more than that to delimit the string.
159    let leading_hashes = |s: &str| s.chars().take_while(|&c| c == '#').count();
160    let max_hashes = patch.split('"').map(leading_hashes).max().unwrap();
161    StrLitKind::Raw(max_hashes + 1)
162}
163
164fn format_patch(patch: &str) -> String {
165    let lit_kind = lit_kind_for_patch(patch);
166    let is_multiline = patch.contains('\n');
167
168    let mut buf = String::new();
169    if matches!(lit_kind, StrLitKind::Raw(_)) {
170        buf.push('[');
171    }
172    lit_kind.write_start(&mut buf).unwrap();
173    if is_multiline {
174        buf.push('\n');
175    }
176    buf.push_str(patch);
177    if is_multiline {
178        buf.push('\n');
179    }
180    lit_kind.write_end(&mut buf).unwrap();
181    if matches!(lit_kind, StrLitKind::Raw(_)) {
182        buf.push(']');
183    }
184    buf
185}
186
187#[derive(Clone, Debug)]
188struct Span {
189    /// The byte range of the argument to `expect!`, including the inner `[]` if it exists.
190    literal_range: std::ops::Range<usize>,
191}
192
193impl Span {
194    fn from_pos(pos: &Position, file: &str) -> Span {
195        let mut target_line = None;
196        let mut line_start = 0;
197        for (i, line) in crate::utils::LinesWithTerminator::new(file).enumerate() {
198            if i == pos.line as usize - 1 {
199                // `column` points to the first character of the macro invocation:
200                //
201                //    expect![[r#""#]]        expect![""]
202                //    ^       ^               ^       ^
203                //  column   offset                 offset
204                //
205                // Seek past the exclam, then skip any whitespace and
206                // the macro delimiter to get to our argument.
207                #[allow(clippy::skip_while_next)]
208                let byte_offset = line
209                    .char_indices()
210                    .skip((pos.column - 1).try_into().unwrap())
211                    .skip_while(|&(_, c)| c != '!')
212                    .skip(1) // !
213                    .skip_while(|&(_, c)| c.is_whitespace())
214                    .skip(1) // [({
215                    .skip_while(|&(_, c)| c.is_whitespace())
216                    .next()
217                    .expect("Failed to parse macro invocation")
218                    .0;
219
220                let literal_start = line_start + byte_offset;
221                target_line = Some(literal_start);
222                break;
223            }
224            line_start += line.len();
225        }
226        let literal_start = target_line.unwrap();
227
228        let lit_to_eof = &file[literal_start..];
229        let lit_to_eof_trimmed = lit_to_eof.trim_start();
230
231        let literal_start = literal_start + (lit_to_eof.len() - lit_to_eof_trimmed.len());
232
233        let literal_len =
234            locate_end(lit_to_eof_trimmed).expect("Couldn't find closing delimiter for `expect!`.");
235        let literal_range = literal_start..literal_start + literal_len;
236        Span { literal_range }
237    }
238}
239
240fn locate_end(arg_start_to_eof: &str) -> Option<usize> {
241    match arg_start_to_eof.chars().next()? {
242        c if c.is_whitespace() => panic!("skip whitespace before calling `locate_end`"),
243
244        // expect![[]]
245        '[' => {
246            let str_start_to_eof = arg_start_to_eof[1..].trim_start();
247            let str_len = find_str_lit_len(str_start_to_eof)?;
248            let str_end_to_eof = &str_start_to_eof[str_len..];
249            let closing_brace_offset = str_end_to_eof.find(']')?;
250            Some((arg_start_to_eof.len() - str_end_to_eof.len()) + closing_brace_offset + 1)
251        }
252
253        // expect![] | expect!{} | expect!()
254        ']' | '}' | ')' => Some(0),
255
256        // expect!["..."] | expect![r#"..."#]
257        _ => find_str_lit_len(arg_start_to_eof),
258    }
259}
260
261/// Parses a string literal, returning the byte index of its last character
262/// (either a quote or a hash).
263fn find_str_lit_len(str_lit_to_eof: &str) -> Option<usize> {
264    fn try_find_n_hashes(
265        s: &mut impl Iterator<Item = char>,
266        desired_hashes: usize,
267    ) -> Option<(usize, Option<char>)> {
268        let mut n = 0;
269        loop {
270            match s.next()? {
271                '#' => n += 1,
272                c => return Some((n, Some(c))),
273            }
274
275            if n == desired_hashes {
276                return Some((n, None));
277            }
278        }
279    }
280
281    let mut s = str_lit_to_eof.chars();
282    let kind = match s.next()? {
283        '"' => StrLitKind::Normal,
284        'r' => {
285            let (n, c) = try_find_n_hashes(&mut s, usize::MAX)?;
286            if c != Some('"') {
287                return None;
288            }
289            StrLitKind::Raw(n)
290        }
291        _ => return None,
292    };
293
294    let mut oldc = None;
295    loop {
296        let c = oldc.take().or_else(|| s.next())?;
297        match (c, kind) {
298            ('\\', StrLitKind::Normal) => {
299                let _escaped = s.next()?;
300            }
301            ('"', StrLitKind::Normal) => break,
302            ('"', StrLitKind::Raw(0)) => break,
303            ('"', StrLitKind::Raw(n)) => {
304                let (seen, c) = try_find_n_hashes(&mut s, n)?;
305                if seen == n {
306                    break;
307                }
308                oldc = c;
309            }
310            _ => {}
311        }
312    }
313
314    Some(str_lit_to_eof.len() - s.as_str().len())
315}
316
317#[derive(Copy, Clone)]
318enum StrLitKind {
319    Normal,
320    Raw(usize),
321}
322
323impl StrLitKind {
324    fn write_start(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
325        match self {
326            Self::Normal => write!(w, "\""),
327            Self::Raw(n) => {
328                write!(w, "r")?;
329                for _ in 0..n {
330                    write!(w, "#")?;
331                }
332                write!(w, "\"")
333            }
334        }
335    }
336
337    fn write_end(self, w: &mut impl std::fmt::Write) -> std::fmt::Result {
338        match self {
339            Self::Normal => write!(w, "\""),
340            Self::Raw(n) => {
341                write!(w, "\"")?;
342                for _ in 0..n {
343                    write!(w, "#")?;
344                }
345                Ok(())
346            }
347        }
348    }
349}
350
351#[derive(Clone)]
352struct PathRuntime {
353    path_prefix: String,
354    count: usize,
355}
356
357impl PathRuntime {
358    fn new(path_prefix: &str) -> Self {
359        Self {
360            path_prefix: path_prefix.to_owned(),
361            count: 0,
362        }
363    }
364
365    fn is(&self, path_prefix: &str) -> bool {
366        self.path_prefix == path_prefix
367    }
368
369    fn next(&mut self) -> usize {
370        self.count += 1;
371        self.count
372    }
373
374    fn count(&self) -> usize {
375        self.count
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::assert_data_eq;
383    use crate::prelude::*;
384    use crate::str;
385
386    #[test]
387    fn test_format_patch() {
388        let patch = format_patch("hello\nworld\n");
389
390        assert_data_eq!(
391            patch,
392            str![[r##"
393[r#"
394hello
395world
396
397"#]
398"##]],
399        );
400
401        let patch = format_patch(r"hello\tworld");
402        assert_data_eq!(patch, str![[r##"[r#"hello\tworld"#]"##]].raw());
403
404        let patch = format_patch("{\"foo\": 42}");
405        assert_data_eq!(patch, str![[r##"[r#"{"foo": 42}"#]"##]]);
406    }
407
408    #[test]
409    fn test_patchwork() {
410        let mut patchwork = Patchwork::new("one two three".to_owned());
411        patchwork.patch(4..7, "zwei").unwrap();
412        patchwork.patch(0..3, "один").unwrap();
413        patchwork.patch(8..13, "3").unwrap();
414        assert_data_eq!(
415            patchwork.to_debug(),
416            str![[r#"
417Patchwork {
418    text: "один zwei 3",
419    indels: {
420        OrdRange {
421            start: 0,
422            end: 3,
423        }: (
424            8,
425            "один",
426        ),
427        OrdRange {
428            start: 4,
429            end: 7,
430        }: (
431            4,
432            "zwei",
433        ),
434        OrdRange {
435            start: 8,
436            end: 13,
437        }: (
438            1,
439            "3",
440        ),
441    },
442}
443
444"#]],
445        );
446    }
447
448    #[test]
449    fn test_patchwork_overlap_diverge() {
450        let mut patchwork = Patchwork::new("one two three".to_owned());
451        patchwork.patch(4..7, "zwei").unwrap();
452        patchwork.patch(4..7, "abcd").unwrap_err();
453        assert_data_eq!(
454            patchwork.to_debug(),
455            str![[r#"
456Patchwork {
457    text: "one zwei three",
458    indels: {
459        OrdRange {
460            start: 4,
461            end: 7,
462        }: (
463            4,
464            "zwei",
465        ),
466    },
467}
468
469"#]],
470        );
471    }
472
473    #[test]
474    fn test_patchwork_overlap_converge() {
475        let mut patchwork = Patchwork::new("one two three".to_owned());
476        patchwork.patch(4..7, "zwei").unwrap();
477        patchwork.patch(4..7, "zwei").unwrap();
478        assert_data_eq!(
479            patchwork.to_debug(),
480            str![[r#"
481Patchwork {
482    text: "one zwei three",
483    indels: {
484        OrdRange {
485            start: 4,
486            end: 7,
487        }: (
488            4,
489            "zwei",
490        ),
491    },
492}
493
494"#]],
495        );
496    }
497
498    #[test]
499    fn test_locate() {
500        macro_rules! check_locate {
501            ($( [[$s:literal]] ),* $(,)?) => {$({
502                let lit = stringify!($s);
503                let with_trailer = format!("{} \t]]\n", lit);
504                assert_eq!(locate_end(&with_trailer), Some(lit.len()));
505            })*};
506        }
507
508        // Check that we handle string literals containing "]]" correctly.
509        check_locate!(
510            [[r#"{ arr: [[1, 2], [3, 4]], other: "foo" } "#]],
511            [["]]"]],
512            [["\"]]"]],
513            [[r#""]]"#]],
514        );
515
516        // Check `str![[  ]]` as well.
517        assert_eq!(locate_end("]]"), Some(0));
518    }
519
520    #[test]
521    fn test_find_str_lit_len() {
522        macro_rules! check_str_lit_len {
523            ($( $s:literal ),* $(,)?) => {$({
524                let lit = stringify!($s);
525                assert_eq!(find_str_lit_len(lit), Some(lit.len()));
526            })*}
527        }
528
529        check_str_lit_len![
530            r##"foa\""#"##,
531            r##"
532
533                asdf][]]""""#
534            "##,
535            "",
536            "\"",
537            "\"\"",
538            "#\"#\"#",
539        ];
540    }
541}