grpc_build/
tree.rs

1//! Contains a [`Tree`] type that is used to process the dotted package names into
2//! directory structured files.
3
4use std::{
5    collections::{BTreeSet, HashMap},
6    ffi::{OsStr, OsString},
7    fmt::{Debug, Display},
8    iter::FromIterator,
9    path::{Path, PathBuf},
10};
11
12use anyhow::{bail, Context, Result};
13use fs_err::OpenOptions;
14
15#[derive(Default, Debug, PartialEq)]
16pub struct Tree(pub(crate) HashMap<PathBuf, Tree>);
17
18impl Extend<PathBuf> for Tree {
19    fn extend<T: IntoIterator<Item = PathBuf>>(&mut self, iter: T) {
20        for path in iter {
21            self.insert_path(path)
22        }
23    }
24}
25
26impl FromIterator<PathBuf> for Tree {
27    fn from_iter<T: IntoIterator<Item = PathBuf>>(iter: T) -> Self {
28        let mut tree = Tree::default();
29        tree.extend(iter);
30        tree
31    }
32}
33
34impl Tree {
35    /// Given a file path that is `.` separated, it loads it into the tree.
36    pub fn insert_path(mut self: &mut Self, path: PathBuf) {
37        for comp in path.file_stem().unwrap().to_str().unwrap().split('.') {
38            self = self.0.entry(PathBuf::from(comp)).or_default()
39        }
40    }
41
42    /// Generates the module at the root level of the tree
43    pub fn generate_module(&self) -> String {
44        let mut module = String::from("// Module generated with `grpc_build`\n");
45        let sorted: BTreeSet<_> = self.0.keys().collect();
46        for k in sorted {
47            module.push_str(&format!("pub mod {};\n", k.display()));
48        }
49
50        module.push('\n');
51        module
52    }
53
54    /// Loop through the tree, determining where all the files should be
55    /// and moving them there
56    pub fn move_paths(&self, root: &Path, filename: OsString, output: PathBuf) -> Result<()> {
57        if self.0.is_empty() {
58            fs_err::create_dir_all(root.join(&output).parent().unwrap())
59                .with_context(|| format!("could not create dir for file {}", output.display()))?;
60
61            let from = root.join(filename.add("rs"));
62            let to = root.join(output.with_extension("rs"));
63            fs_err::rename(&from, &to).with_context(|| {
64                format!("could not move {} to {}", from.display(), to.display())
65            })?;
66        } else {
67            for (k, tree) in &self.0 {
68                tree.move_paths(root, filename.add(k), output.join(k))?;
69            }
70
71            if !filename.is_empty() {
72                self.create_module_file(root, filename, output)?;
73            }
74        }
75        Ok(())
76    }
77
78    fn create_module_file(
79        &self,
80        root: &Path,
81        filename: OsString,
82        output: PathBuf,
83    ) -> Result<(), anyhow::Error> {
84        let maybe_proto_file_name = root.join(filename.add("rs"));
85        let dest_tmp_file_name = root.join(output.with_extension("tmp"));
86        let final_dest_name = root.join(output.with_extension("rs"));
87
88        // Write a temporary file with the module contents
89        let modules = self.generate_module();
90        fs_err::write(&dest_tmp_file_name, modules)
91            .with_context(|| format!("could not write to file {}", final_dest_name.display()))?;
92
93        // If there is a proto file in this directory, we append its contents to the already written temporary module file
94        if fs_err::metadata(&maybe_proto_file_name)
95            .map(|m| m.is_file())
96            .unwrap_or(false)
97        {
98            merge_file_into(&maybe_proto_file_name, &dest_tmp_file_name)?;
99        }
100
101        // Finally, move the temporary file to the final destination
102        fs_err::rename(&dest_tmp_file_name, &final_dest_name).with_context(|| {
103            format!(
104                "could not move {} to {}",
105                dest_tmp_file_name.display(),
106                final_dest_name.display()
107            )
108        })?;
109
110        Ok(())
111    }
112}
113
114fn merge_file_into(from: &PathBuf, to: &PathBuf) -> Result<(), anyhow::Error> {
115    if from == to {
116        bail!("Merging files, source and destination files are the same");
117    }
118
119    let mut source = OpenOptions::new()
120        .read(true)
121        .open(from)
122        .with_context(|| format!("Failed to open not source file {}", to.display()))?;
123
124    let mut dest = OpenOptions::new()
125        .create_new(false)
126        .write(true)
127        .append(true)
128        .open(to)
129        .with_context(|| format!("Failed to open the destination file {}", from.display()))?;
130
131    std::io::copy(&mut source, &mut dest).with_context(|| {
132        format!(
133            "could not copy contents from {} to {}",
134            from.display(),
135            to.display()
136        )
137    })?;
138
139    fs_err::remove_file(from)
140        .with_context(|| format!("could not remove file {}", from.display()))?;
141    Ok(())
142}
143
144// private helper trait
145trait OsStrExt {
146    fn add(&self, add: impl AsRef<OsStr>) -> OsString;
147}
148
149impl OsStrExt for OsStr {
150    /// Adds `add` to the [`OsStr`], returning a new [`OsString`]. If there already exists data in the string,
151    /// this puts a `.` separator inbetween
152    fn add(&self, add: impl AsRef<OsStr>) -> OsString {
153        let mut _self = self.to_owned();
154        if !_self.is_empty() {
155            _self.push(".");
156        }
157        _self.push(add);
158        _self
159    }
160}
161
162impl Display for Tree {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        for (k, tree) in &self.0 {
165            write!(f, "pub mod {}", k.display())?;
166            if tree.0.is_empty() {
167                write!(f, ";")?;
168            } else {
169                write!(f, "{{{}}}", tree)?;
170            }
171        }
172        Ok(())
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use std::collections::HashMap;
179    use std::path::PathBuf;
180
181    use super::Tree;
182
183    macro_rules! tree {
184        ($($key:literal : $val:expr,)*) => {
185            Tree(HashMap::from_iter([
186                $(
187                    (PathBuf::from($key), $val)
188                ),*
189            ]))
190        };
191    }
192
193    #[test]
194    fn build_tree() {
195        let tree: Tree = [
196            "grpc_build.client.helloworld.rs",
197            "grpc_build.request.helloworld.rs",
198            "grpc_build.response.helloworld.rs",
199            "google.protobuf.foo.rs",
200            "google.protobuf.bar.rs",
201        ]
202        .into_iter()
203        .map(PathBuf::from)
204        .collect();
205
206        let expected = tree! {
207            "grpc_build": tree! {
208                "client": tree! {
209                    "helloworld": tree!{},
210                },
211                "request": tree! {
212                    "helloworld": tree!{},
213                },
214                "response": tree! {
215                    "helloworld": tree!{},
216                },
217            },
218            "google": tree! {
219                "protobuf": tree! {
220                    "foo": tree!{},
221                    "bar": tree!{},
222                },
223            },
224        };
225
226        assert_eq!(tree, expected);
227    }
228
229    #[test]
230    fn generate_module_returns_at_current_level() {
231        let tree: Tree = [
232            "grpc_build.client.helloworld.rs",
233            "grpc_build.request.helloworld.rs",
234            "grpc_build.response.helloworld.rs",
235            "google.protobuf.foo.rs",
236            "google.protobuf.bar.rs",
237            "alphabet.foo.rs",
238            "hello.rs",
239        ]
240        .into_iter()
241        .map(PathBuf::from)
242        .collect();
243
244        let expected = "// Module generated with `grpc_build`
245pub mod alphabet;
246pub mod google;
247pub mod grpc_build;
248pub mod hello;
249
250";
251
252        assert_eq!(tree.generate_module(), expected);
253    }
254
255    #[test]
256    fn generate_module_returns_at_current_level_nested() {
257        let tree: Tree = [
258            "grpc_build.client.helloworld.rs",
259            "grpc_build.request.helloworld.rs",
260            "grpc_build.response.helloworld.rs",
261            "google.protobuf.foo.rs",
262            "google.protobuf.bar.rs",
263            "alphabet.foo.rs",
264            "hello.rs",
265        ]
266        .into_iter()
267        .map(PathBuf::from)
268        .collect();
269
270        let inner_tree = tree.0.get(&PathBuf::from("grpc_build")).unwrap();
271        let expected = "// Module generated with `grpc_build`
272pub mod client;
273pub mod request;
274pub mod response;
275
276";
277
278        assert_eq!(inner_tree.generate_module(), expected);
279    }
280}