Skip to main content

pilota_build/codegen/
workspace.rs

1use std::{borrow::Cow, path::PathBuf, process::Command, sync::Arc};
2
3use ahash::AHashMap;
4use anyhow::bail;
5use faststr::FastStr;
6use itertools::Itertools;
7use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
8use rustc_hash::FxHashMap;
9
10use super::CodegenItem;
11use crate::{
12    Codegen, CodegenBackend, Context, DefId, fmt::fmt_file, middle::context::DefLocation,
13    rir::ItemPath, symbol::ModPath,
14};
15
16#[derive(Clone)]
17pub struct Workspace<B> {
18    base_dir: Arc<std::path::Path>,
19    cg: Codegen<B>,
20}
21
22fn run_cmd(cmd: &mut Command) -> Result<(), anyhow::Error> {
23    let status = cmd.status()?;
24
25    if !status.success() {
26        bail!("run cmd {:?} failed", cmd)
27    }
28
29    Ok(())
30}
31
32struct CrateInfo {
33    name: FastStr,
34    main_mod_path: Option<ItemPath>,
35    deps: Vec<FastStr>,
36    workspace_deps: Vec<FastStr>,
37    mod_items: AHashMap<ModPath, Vec<DefId>>,
38    re_pubs: AHashMap<ModPath, Vec<DefId>>,
39    user_gen: Option<String>,
40}
41
42impl<B> Workspace<B>
43where
44    B: CodegenBackend + Send,
45{
46    fn cx(&self) -> &Context {
47        &self.cg
48    }
49
50    pub fn new(base_dir: PathBuf, cg: Codegen<B>) -> Self {
51        Workspace {
52            base_dir: Arc::from(base_dir),
53            cg,
54        }
55    }
56
57    pub fn group_defs(&self, entry_def_ids: &[DefId]) -> Result<(), anyhow::Error> {
58        let location_map = self.collect_def_ids(entry_def_ids, None);
59        let entry_map = location_map.iter().into_group_map_by(|item| item.1);
60
61        let entry_deps = entry_map
62            .iter()
63            .map(|(k, v)| {
64                let def_ids = v.iter().map(|i| i.0).copied().collect_vec();
65                let deps = self
66                    .collect_def_ids(&def_ids, Some(&location_map))
67                    .into_iter()
68                    .collect_vec();
69                (k, deps)
70            })
71            .collect::<FxHashMap<_, _>>();
72
73        if !self.base_dir.exists() {
74            std::fs::create_dir_all(&*self.base_dir).unwrap();
75        }
76
77        let this = self.clone();
78
79        let members = entry_map
80            .keys()
81            .map(|k| {
82                let name = self.cx().crate_name(k);
83                format!("    \"{name}\"")
84            })
85            .dedup()
86            .sorted()
87            .join(",\n");
88
89        let mut cargo_toml = toml::from_str::<toml::Value>(&unsafe {
90            String::from_utf8_unchecked(std::fs::read(self.base_dir.join("Cargo.toml")).unwrap())
91        })
92        .unwrap();
93
94        let reflect_dep = if self.cg.config.with_descriptor {
95            r#"pilota-thrift-reflect = "*""#
96        } else {
97            r#""#
98        };
99
100        let fieldmask_dep = if self.cg.config.with_field_mask {
101            r#"pilota-thrift-fieldmask = "*""#
102        } else {
103            r#""#
104        };
105
106        crate::codegen::toml::merge_tomls(
107            &mut cargo_toml,
108            toml::from_str::<toml::Value>(&format!(
109                r#"[workspace]
110    members = [
111    {members}
112    ]
113    edition = "2024"
114    resolver = "3"
115
116    [workspace.dependencies]
117    pilota = "*"
118    {reflect_dep}
119    {fieldmask_dep}
120    anyhow = "1"
121    volo = "*"
122    volo-{} = "*""#,
123                if B::PROTOCOL == "thrift" {
124                    "thrift"
125                } else if B::PROTOCOL == "protobuf" {
126                    "grpc"
127                } else {
128                    panic!("unknown protocol")
129                }
130            ))
131            .unwrap(),
132        );
133
134        let workspace_deps = cargo_toml
135            .get("workspace")
136            .unwrap()
137            .get("dependencies")
138            .unwrap()
139            .as_table()
140            .unwrap()
141            .keys()
142            .map(FastStr::new)
143            .collect_vec();
144
145        std::fs::write(
146            self.base_dir.join("Cargo.toml"),
147            toml::to_string_pretty(&cargo_toml).unwrap(),
148        )?;
149
150        entry_deps
151            .par_iter()
152            .try_for_each_with(this, |this, (k, deps)| {
153                let name = this.cx().crate_name(k);
154                let deps = deps.iter().filter(|dep| dep.1 != ***k).collect_vec();
155                let (main_mod_path, re_pubs, deps) = match k {
156                    DefLocation::Fixed(_, path) => (
157                        Some(path.clone()),
158                        deps.iter()
159                            .map(|v| (this.cg.mod_index(v.0), v.0))
160                            .into_group_map_by(|(mod_path, _)| mod_path.clone())
161                            .into_iter()
162                            .map(|(mod_path, items)| {
163                                (
164                                    mod_path,
165                                    items.iter().map(|(_, def_id)| *def_id).collect_vec(),
166                                )
167                            })
168                            .collect::<AHashMap<_, _>>(),
169                        deps.iter()
170                            .map(|dep| this.cx().crate_name(&dep.1))
171                            .sorted()
172                            .dedup()
173                            .collect_vec(),
174                    ),
175                    DefLocation::Dynamic => (None, AHashMap::default(), vec![]),
176                };
177
178                let mod_items = entry_map[*k]
179                    .iter()
180                    .map(|(k, _)| (this.cg.mod_index(**k), **k))
181                    .into_group_map_by(|(mod_path, _)| mod_path.clone())
182                    .into_iter()
183                    .map(|(mod_path, items)| {
184                        (
185                            mod_path,
186                            items.iter().map(|(_, def_id)| *def_id).collect_vec(),
187                        )
188                    })
189                    .collect::<AHashMap<_, _>>();
190
191                this.create_crate(
192                    &this.base_dir,
193                    CrateInfo {
194                        main_mod_path,
195                        workspace_deps: workspace_deps.clone(),
196                        name,
197                        re_pubs,
198                        mod_items,
199                        deps,
200                        user_gen: this.cx().cache.plugin_gen.get(k).map(|v| v.value().clone()),
201                    },
202                )
203            })?;
204
205        Ok(())
206    }
207
208    fn collect_def_ids(
209        &self,
210        input: &[DefId],
211        locations: Option<&FxHashMap<DefId, DefLocation>>,
212    ) -> FxHashMap<DefId, DefLocation> {
213        self.cg.db.collect_def_ids(input, locations)
214    }
215
216    fn create_crate(
217        &self,
218        base_dir: impl AsRef<std::path::Path>,
219        info: CrateInfo,
220    ) -> anyhow::Result<()> {
221        if !base_dir.as_ref().join(&*info.name).exists() {
222            run_cmd(
223                Command::new("cargo")
224                    .arg("init")
225                    .arg("--lib")
226                    .arg("--vcs")
227                    .arg("none")
228                    .current_dir(base_dir.as_ref())
229                    .arg(&*info.name),
230            )?;
231        };
232
233        let cargo_toml_path = base_dir.as_ref().join(&*info.name).join("Cargo.toml");
234
235        let mut cargo_toml = toml::from_str::<toml::Value>(&unsafe {
236            String::from_utf8_unchecked(std::fs::read(&cargo_toml_path)?)
237        })
238        .unwrap();
239
240        let deps = info
241            .deps
242            .iter()
243            .map(|s| Cow::from(format!(r#"{s} = {{ path = "../{s}" }}"#)))
244            .chain(
245                info.workspace_deps
246                    .iter()
247                    .map(|s| Cow::from(format!(r#"{s}.workspace = true"#))),
248            )
249            .join("\n");
250
251        super::toml::merge_tomls(
252            &mut cargo_toml,
253            toml::from_str::<toml::Value>(&format!("[dependencies]\n{deps}")).unwrap(),
254        );
255
256        std::fs::write(
257            &cargo_toml_path,
258            toml::to_string_pretty(&cargo_toml).unwrap(),
259        )?;
260
261        let mut lib_rs_stream = String::default();
262        lib_rs_stream.push_str("include!(\"gen.rs\");\n");
263        lib_rs_stream.push_str("pub use r#gen::*;\n\n");
264
265        if let Some(user_gen) = info.user_gen {
266            if !user_gen.is_empty() {
267                lib_rs_stream.push_str("include!(\"custom.rs\");\n");
268
269                let mut custom_rs_stream = String::default();
270                custom_rs_stream.push_str(&user_gen);
271
272                let custom_rs = base_dir.as_ref().join(&*info.name).join("src/custom.rs");
273
274                std::fs::write(&custom_rs, custom_rs_stream)?;
275
276                fmt_file(custom_rs);
277            }
278        }
279
280        let mut gen_rs_stream = String::default();
281
282        let mut mod_items = self.cg.collect_direct_codegen_items(&info.mod_items);
283
284        for (mod_path, def_ids) in info.re_pubs.iter() {
285            mod_items
286                .entry(mod_path.clone())
287                .or_default()
288                .extend(def_ids.iter().map(|&def_id| CodegenItem {
289                    def_id,
290                    kind: super::CodegenKind::RePub,
291                }));
292        }
293
294        self.cg.write_items(
295            &mut gen_rs_stream,
296            mod_items,
297            base_dir.as_ref().join(&*info.name).join("src").as_path(),
298        );
299        if let Some(main_mod_path) = info.main_mod_path {
300            gen_rs_stream.push_str(&format!(
301                "pub use {}::*;",
302                main_mod_path.iter().map(|item| item.to_string()).join("::")
303            ));
304        }
305        gen_rs_stream = format! {r#"pub mod r#gen {{
306            #![allow(warnings, clippy::all)]
307            {gen_rs_stream}
308        }}"#};
309
310        let lib_rs_stream = lib_rs_stream.lines().map(|s| s.trim_end()).join("\n");
311        let gen_rs_stream = gen_rs_stream.lines().map(|s| s.trim_end()).join("\n");
312
313        let lib_rs = base_dir.as_ref().join(&*info.name).join("src/lib.rs");
314        let gen_rs = base_dir.as_ref().join(&*info.name).join("src/gen.rs");
315
316        std::fs::write(&lib_rs, lib_rs_stream)?;
317        std::fs::write(&gen_rs, gen_rs_stream)?;
318
319        fmt_file(lib_rs);
320        fmt_file(gen_rs);
321
322        Ok(())
323    }
324
325    pub(crate) fn write_crates(self) -> anyhow::Result<()> {
326        self.group_defs(&self.cx().cache.codegen_items)
327    }
328}