1use std::fs;
2use std::io;
3use std::path::Path;
4
5use thiserror::Error;
6use toml::de::Error as TomlDeError;
7use toml::ser::Error as TomlSerError;
8
9use super::Trace;
10
11#[derive(Debug, Error)]
12pub enum TraceIoError {
13 #[error(transparent)]
14 Io(#[from] io::Error),
15
16 #[error(transparent)]
17 Parse(#[from] TomlDeError),
18
19 #[error(transparent)]
20 Serialize(#[from] TomlSerError),
21}
22
23impl Trace {
24 pub fn read_from_file(path: &Path) -> Result<Self, TraceIoError> {
25 let content = fs::read_to_string(path)?;
26 let trace = toml::from_str(&content)?;
27 Ok(trace)
28 }
29
30 pub fn write_to_file(&self, path: &Path) -> Result<(), TraceIoError> {
31 let content = toml::to_string(self)?;
32 fs::write(path, content)?;
33 Ok(())
34 }
35}
36
37#[cfg(test)]
38mod tests {
39 use std::collections::BTreeMap;
40 use std::io::Write;
41
42 use googletest::prelude::*;
43 use tempfile::NamedTempFile;
44
45 use super::*;
46 use crate::trace::{PkgTrace, Trace};
47
48 #[gtest]
49 fn write_and_read_trace() {
50 let trace = Trace {
51 packages: BTreeMap::from([
52 (
53 "pkg1".to_string(),
54 PkgTrace {
55 directory: "dir1".to_string(),
56 maps: BTreeMap::from([
57 ("src1".to_string(), "dst1".to_string()),
58 ("src2".to_string(), "dst2".to_string()),
59 ]),
60 },
61 ),
62 (
63 "pkg2".to_string(),
64 PkgTrace {
65 directory: "dir2".to_string(),
66 maps: BTreeMap::from([("src3".to_string(), "dst3".to_string())]),
67 },
68 ),
69 ]),
70 };
71
72 let file = NamedTempFile::new().unwrap();
73 trace.write_to_file(file.path()).unwrap();
74
75 assert_eq!(trace, Trace::read_from_file(file.path()).unwrap());
76 }
77
78 #[gtest]
79 fn read_non_existent_file() {
80 let result = Trace::read_from_file(Path::new("no_such_file.toml")).unwrap_err();
81 assert_that!(result, pat!(TraceIoError::Io(_)));
82 }
83
84 #[test]
85 fn read_invalid_toml() {
86 let mut file = NamedTempFile::new().unwrap();
87 writeln!(file, "invalid = [this is not toml]").unwrap();
88
89 let result = Trace::read_from_file(file.path()).unwrap_err();
90 assert_that!(result, pat!(TraceIoError::Parse(_)));
91 }
92}