1use std::{borrow::Cow, path::PathBuf, process::Command, sync::Arc};
2
3use ahash::AHashMap;
4use anyhow::bail;
5use faststr::FastStr;
6use itertools::Itertools;
7use rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
8use rustc_hash::FxHashMap;
9
10use super::CodegenItem;
11use crate::{
12 Codegen, CodegenBackend, Context, DefId, fmt::fmt_file, middle::context::DefLocation,
13 rir::ItemPath, symbol::ModPath,
14};
15
16#[derive(Clone)]
17pub struct Workspace<B> {
18 base_dir: Arc<std::path::Path>,
19 cg: Codegen<B>,
20}
21
22fn run_cmd(cmd: &mut Command) -> Result<(), anyhow::Error> {
23 let status = cmd.status()?;
24
25 if !status.success() {
26 bail!("run cmd {:?} failed", cmd)
27 }
28
29 Ok(())
30}
31
32struct CrateInfo {
33 name: FastStr,
34 main_mod_path: Option<ItemPath>,
35 deps: Vec<FastStr>,
36 workspace_deps: Vec<FastStr>,
37 mod_items: AHashMap<ModPath, Vec<DefId>>,
38 re_pubs: AHashMap<ModPath, Vec<DefId>>,
39 user_gen: Option<String>,
40}
41
42impl<B> Workspace<B>
43where
44 B: CodegenBackend + Send,
45{
46 fn cx(&self) -> &Context {
47 &self.cg
48 }
49
50 pub fn new(base_dir: PathBuf, cg: Codegen<B>) -> Self {
51 Workspace {
52 base_dir: Arc::from(base_dir),
53 cg,
54 }
55 }
56
57 pub fn group_defs(&self, entry_def_ids: &[DefId]) -> Result<(), anyhow::Error> {
58 let location_map = self.collect_def_ids(entry_def_ids, None);
59 let entry_map = location_map.iter().into_group_map_by(|item| item.1);
60
61 let entry_deps = entry_map
62 .iter()
63 .map(|(k, v)| {
64 let def_ids = v.iter().map(|i| i.0).copied().collect_vec();
65 let deps = self
66 .collect_def_ids(&def_ids, Some(&location_map))
67 .into_iter()
68 .collect_vec();
69 (k, deps)
70 })
71 .collect::<FxHashMap<_, _>>();
72
73 if !self.base_dir.exists() {
74 std::fs::create_dir_all(&*self.base_dir).unwrap();
75 }
76
77 let this = self.clone();
78
79 let members = entry_map
80 .keys()
81 .map(|k| {
82 let name = self.cx().crate_name(k);
83 format!(" \"{name}\"")
84 })
85 .dedup()
86 .sorted()
87 .join(",\n");
88
89 let mut cargo_toml = toml::from_str::<toml::Value>(&unsafe {
90 String::from_utf8_unchecked(std::fs::read(self.base_dir.join("Cargo.toml")).unwrap())
91 })
92 .unwrap();
93
94 let reflect_dep = if self.cg.config.with_descriptor {
95 r#"pilota-thrift-reflect = "*""#
96 } else {
97 r#""#
98 };
99
100 let fieldmask_dep = if self.cg.config.with_field_mask {
101 r#"pilota-thrift-fieldmask = "*""#
102 } else {
103 r#""#
104 };
105
106 crate::codegen::toml::merge_tomls(
107 &mut cargo_toml,
108 toml::from_str::<toml::Value>(&format!(
109 r#"[workspace]
110 members = [
111 {members}
112 ]
113 edition = "2024"
114 resolver = "3"
115
116 [workspace.dependencies]
117 pilota = "*"
118 {reflect_dep}
119 {fieldmask_dep}
120 anyhow = "1"
121 volo = "*"
122 volo-{} = "*""#,
123 if B::PROTOCOL == "thrift" {
124 "thrift"
125 } else if B::PROTOCOL == "protobuf" {
126 "grpc"
127 } else {
128 panic!("unknown protocol")
129 }
130 ))
131 .unwrap(),
132 );
133
134 let workspace_deps = cargo_toml
135 .get("workspace")
136 .unwrap()
137 .get("dependencies")
138 .unwrap()
139 .as_table()
140 .unwrap()
141 .keys()
142 .map(FastStr::new)
143 .collect_vec();
144
145 std::fs::write(
146 self.base_dir.join("Cargo.toml"),
147 toml::to_string_pretty(&cargo_toml).unwrap(),
148 )?;
149
150 entry_deps
151 .par_iter()
152 .try_for_each_with(this, |this, (k, deps)| {
153 let name = this.cx().crate_name(k);
154 let deps = deps.iter().filter(|dep| dep.1 != ***k).collect_vec();
155 let (main_mod_path, re_pubs, deps) = match k {
156 DefLocation::Fixed(_, path) => (
157 Some(path.clone()),
158 deps.iter()
159 .map(|v| (this.cg.mod_index(v.0), v.0))
160 .into_group_map_by(|(mod_path, _)| mod_path.clone())
161 .into_iter()
162 .map(|(mod_path, items)| {
163 (
164 mod_path,
165 items.iter().map(|(_, def_id)| *def_id).collect_vec(),
166 )
167 })
168 .collect::<AHashMap<_, _>>(),
169 deps.iter()
170 .map(|dep| this.cx().crate_name(&dep.1))
171 .sorted()
172 .dedup()
173 .collect_vec(),
174 ),
175 DefLocation::Dynamic => (None, AHashMap::default(), vec![]),
176 };
177
178 let mod_items = entry_map[*k]
179 .iter()
180 .map(|(k, _)| (this.cg.mod_index(**k), **k))
181 .into_group_map_by(|(mod_path, _)| mod_path.clone())
182 .into_iter()
183 .map(|(mod_path, items)| {
184 (
185 mod_path,
186 items.iter().map(|(_, def_id)| *def_id).collect_vec(),
187 )
188 })
189 .collect::<AHashMap<_, _>>();
190
191 this.create_crate(
192 &this.base_dir,
193 CrateInfo {
194 main_mod_path,
195 workspace_deps: workspace_deps.clone(),
196 name,
197 re_pubs,
198 mod_items,
199 deps,
200 user_gen: this.cx().cache.plugin_gen.get(k).map(|v| v.value().clone()),
201 },
202 )
203 })?;
204
205 Ok(())
206 }
207
208 fn collect_def_ids(
209 &self,
210 input: &[DefId],
211 locations: Option<&FxHashMap<DefId, DefLocation>>,
212 ) -> FxHashMap<DefId, DefLocation> {
213 self.cg.db.collect_def_ids(input, locations)
214 }
215
216 fn create_crate(
217 &self,
218 base_dir: impl AsRef<std::path::Path>,
219 info: CrateInfo,
220 ) -> anyhow::Result<()> {
221 if !base_dir.as_ref().join(&*info.name).exists() {
222 run_cmd(
223 Command::new("cargo")
224 .arg("init")
225 .arg("--lib")
226 .arg("--vcs")
227 .arg("none")
228 .current_dir(base_dir.as_ref())
229 .arg(&*info.name),
230 )?;
231 };
232
233 let cargo_toml_path = base_dir.as_ref().join(&*info.name).join("Cargo.toml");
234
235 let mut cargo_toml = toml::from_str::<toml::Value>(&unsafe {
236 String::from_utf8_unchecked(std::fs::read(&cargo_toml_path)?)
237 })
238 .unwrap();
239
240 let deps = info
241 .deps
242 .iter()
243 .map(|s| Cow::from(format!(r#"{s} = {{ path = "../{s}" }}"#)))
244 .chain(
245 info.workspace_deps
246 .iter()
247 .map(|s| Cow::from(format!(r#"{s}.workspace = true"#))),
248 )
249 .join("\n");
250
251 super::toml::merge_tomls(
252 &mut cargo_toml,
253 toml::from_str::<toml::Value>(&format!("[dependencies]\n{deps}")).unwrap(),
254 );
255
256 std::fs::write(
257 &cargo_toml_path,
258 toml::to_string_pretty(&cargo_toml).unwrap(),
259 )?;
260
261 let mut lib_rs_stream = String::default();
262 lib_rs_stream.push_str("include!(\"gen.rs\");\n");
263 lib_rs_stream.push_str("pub use r#gen::*;\n\n");
264
265 if let Some(user_gen) = info.user_gen {
266 if !user_gen.is_empty() {
267 lib_rs_stream.push_str("include!(\"custom.rs\");\n");
268
269 let mut custom_rs_stream = String::default();
270 custom_rs_stream.push_str(&user_gen);
271
272 let custom_rs = base_dir.as_ref().join(&*info.name).join("src/custom.rs");
273
274 std::fs::write(&custom_rs, custom_rs_stream)?;
275
276 fmt_file(custom_rs);
277 }
278 }
279
280 let mut gen_rs_stream = String::default();
281
282 let mut mod_items = self.cg.collect_direct_codegen_items(&info.mod_items);
283
284 for (mod_path, def_ids) in info.re_pubs.iter() {
285 mod_items
286 .entry(mod_path.clone())
287 .or_default()
288 .extend(def_ids.iter().map(|&def_id| CodegenItem {
289 def_id,
290 kind: super::CodegenKind::RePub,
291 }));
292 }
293
294 self.cg.write_items(
295 &mut gen_rs_stream,
296 mod_items,
297 base_dir.as_ref().join(&*info.name).join("src").as_path(),
298 );
299 if let Some(main_mod_path) = info.main_mod_path {
300 gen_rs_stream.push_str(&format!(
301 "pub use {}::*;",
302 main_mod_path.iter().map(|item| item.to_string()).join("::")
303 ));
304 }
305 gen_rs_stream = format! {r#"pub mod r#gen {{
306 #![allow(warnings, clippy::all)]
307 {gen_rs_stream}
308 }}"#};
309
310 let lib_rs_stream = lib_rs_stream.lines().map(|s| s.trim_end()).join("\n");
311 let gen_rs_stream = gen_rs_stream.lines().map(|s| s.trim_end()).join("\n");
312
313 let lib_rs = base_dir.as_ref().join(&*info.name).join("src/lib.rs");
314 let gen_rs = base_dir.as_ref().join(&*info.name).join("src/gen.rs");
315
316 std::fs::write(&lib_rs, lib_rs_stream)?;
317 std::fs::write(&gen_rs, gen_rs_stream)?;
318
319 fmt_file(lib_rs);
320 fmt_file(gen_rs);
321
322 Ok(())
323 }
324
325 pub(crate) fn write_crates(self) -> anyhow::Result<()> {
326 self.group_defs(&self.cx().cache.codegen_items)
327 }
328}