1use std::{
5 collections::{BTreeSet, HashMap},
6 ffi::{OsStr, OsString},
7 fmt::{Debug, Display},
8 iter::FromIterator,
9 path::{Path, PathBuf},
10};
11
12use anyhow::{bail, Context, Result};
13use fs_err::OpenOptions;
14
15#[derive(Default, Debug, PartialEq)]
16pub struct Tree(pub(crate) HashMap<PathBuf, Tree>);
17
18impl Extend<PathBuf> for Tree {
19 fn extend<T: IntoIterator<Item = PathBuf>>(&mut self, iter: T) {
20 for path in iter {
21 self.insert_path(path)
22 }
23 }
24}
25
26impl FromIterator<PathBuf> for Tree {
27 fn from_iter<T: IntoIterator<Item = PathBuf>>(iter: T) -> Self {
28 let mut tree = Tree::default();
29 tree.extend(iter);
30 tree
31 }
32}
33
34impl Tree {
35 pub fn insert_path(mut self: &mut Self, path: PathBuf) {
37 for comp in path.file_stem().unwrap().to_str().unwrap().split('.') {
38 self = self.0.entry(PathBuf::from(comp)).or_default()
39 }
40 }
41
42 pub fn generate_module(&self) -> String {
44 let mut module = String::from("// Module generated with `grpc_build`\n");
45 let sorted: BTreeSet<_> = self.0.keys().collect();
46 for k in sorted {
47 module.push_str(&format!("pub mod {};\n", k.display()));
48 }
49
50 module.push('\n');
51 module
52 }
53
54 pub fn move_paths(&self, root: &Path, filename: OsString, output: PathBuf) -> Result<()> {
57 if self.0.is_empty() {
58 fs_err::create_dir_all(root.join(&output).parent().unwrap())
59 .with_context(|| format!("could not create dir for file {}", output.display()))?;
60
61 let from = root.join(filename.add("rs"));
62 let to = root.join(output.with_extension("rs"));
63 fs_err::rename(&from, &to).with_context(|| {
64 format!("could not move {} to {}", from.display(), to.display())
65 })?;
66 } else {
67 for (k, tree) in &self.0 {
68 tree.move_paths(root, filename.add(k), output.join(k))?;
69 }
70
71 if !filename.is_empty() {
72 self.create_module_file(root, filename, output)?;
73 }
74 }
75 Ok(())
76 }
77
78 fn create_module_file(
79 &self,
80 root: &Path,
81 filename: OsString,
82 output: PathBuf,
83 ) -> Result<(), anyhow::Error> {
84 let maybe_proto_file_name = root.join(filename.add("rs"));
85 let dest_tmp_file_name = root.join(output.with_extension("tmp"));
86 let final_dest_name = root.join(output.with_extension("rs"));
87
88 let modules = self.generate_module();
90 fs_err::write(&dest_tmp_file_name, modules)
91 .with_context(|| format!("could not write to file {}", final_dest_name.display()))?;
92
93 if fs_err::metadata(&maybe_proto_file_name)
95 .map(|m| m.is_file())
96 .unwrap_or(false)
97 {
98 merge_file_into(&maybe_proto_file_name, &dest_tmp_file_name)?;
99 }
100
101 fs_err::rename(&dest_tmp_file_name, &final_dest_name).with_context(|| {
103 format!(
104 "could not move {} to {}",
105 dest_tmp_file_name.display(),
106 final_dest_name.display()
107 )
108 })?;
109
110 Ok(())
111 }
112}
113
114fn merge_file_into(from: &PathBuf, to: &PathBuf) -> Result<(), anyhow::Error> {
115 if from == to {
116 bail!("Merging files, source and destination files are the same");
117 }
118
119 let mut source = OpenOptions::new()
120 .read(true)
121 .open(from)
122 .with_context(|| format!("Failed to open not source file {}", to.display()))?;
123
124 let mut dest = OpenOptions::new()
125 .create_new(false)
126 .write(true)
127 .append(true)
128 .open(to)
129 .with_context(|| format!("Failed to open the destination file {}", from.display()))?;
130
131 std::io::copy(&mut source, &mut dest).with_context(|| {
132 format!(
133 "could not copy contents from {} to {}",
134 from.display(),
135 to.display()
136 )
137 })?;
138
139 fs_err::remove_file(from)
140 .with_context(|| format!("could not remove file {}", from.display()))?;
141 Ok(())
142}
143
144trait OsStrExt {
146 fn add(&self, add: impl AsRef<OsStr>) -> OsString;
147}
148
149impl OsStrExt for OsStr {
150 fn add(&self, add: impl AsRef<OsStr>) -> OsString {
153 let mut _self = self.to_owned();
154 if !_self.is_empty() {
155 _self.push(".");
156 }
157 _self.push(add);
158 _self
159 }
160}
161
162impl Display for Tree {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 for (k, tree) in &self.0 {
165 write!(f, "pub mod {}", k.display())?;
166 if tree.0.is_empty() {
167 write!(f, ";")?;
168 } else {
169 write!(f, "{{{}}}", tree)?;
170 }
171 }
172 Ok(())
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use std::collections::HashMap;
179 use std::path::PathBuf;
180
181 use super::Tree;
182
183 macro_rules! tree {
184 ($($key:literal : $val:expr,)*) => {
185 Tree(HashMap::from_iter([
186 $(
187 (PathBuf::from($key), $val)
188 ),*
189 ]))
190 };
191 }
192
193 #[test]
194 fn build_tree() {
195 let tree: Tree = [
196 "grpc_build.client.helloworld.rs",
197 "grpc_build.request.helloworld.rs",
198 "grpc_build.response.helloworld.rs",
199 "google.protobuf.foo.rs",
200 "google.protobuf.bar.rs",
201 ]
202 .into_iter()
203 .map(PathBuf::from)
204 .collect();
205
206 let expected = tree! {
207 "grpc_build": tree! {
208 "client": tree! {
209 "helloworld": tree!{},
210 },
211 "request": tree! {
212 "helloworld": tree!{},
213 },
214 "response": tree! {
215 "helloworld": tree!{},
216 },
217 },
218 "google": tree! {
219 "protobuf": tree! {
220 "foo": tree!{},
221 "bar": tree!{},
222 },
223 },
224 };
225
226 assert_eq!(tree, expected);
227 }
228
229 #[test]
230 fn generate_module_returns_at_current_level() {
231 let tree: Tree = [
232 "grpc_build.client.helloworld.rs",
233 "grpc_build.request.helloworld.rs",
234 "grpc_build.response.helloworld.rs",
235 "google.protobuf.foo.rs",
236 "google.protobuf.bar.rs",
237 "alphabet.foo.rs",
238 "hello.rs",
239 ]
240 .into_iter()
241 .map(PathBuf::from)
242 .collect();
243
244 let expected = "// Module generated with `grpc_build`
245pub mod alphabet;
246pub mod google;
247pub mod grpc_build;
248pub mod hello;
249
250";
251
252 assert_eq!(tree.generate_module(), expected);
253 }
254
255 #[test]
256 fn generate_module_returns_at_current_level_nested() {
257 let tree: Tree = [
258 "grpc_build.client.helloworld.rs",
259 "grpc_build.request.helloworld.rs",
260 "grpc_build.response.helloworld.rs",
261 "google.protobuf.foo.rs",
262 "google.protobuf.bar.rs",
263 "alphabet.foo.rs",
264 "hello.rs",
265 ]
266 .into_iter()
267 .map(PathBuf::from)
268 .collect();
269
270 let inner_tree = tree.0.get(&PathBuf::from("grpc_build")).unwrap();
271 let expected = "// Module generated with `grpc_build`
272pub mod client;
273pub mod request;
274pub mod response;
275
276";
277
278 assert_eq!(inner_tree.generate_module(), expected);
279 }
280}