cargo_futhark/
generator.rs1use crate::{manifest::Manifest, template, Target};
2use bindgen::callbacks::ParseCallbacks;
3use enumflags2::BitFlags;
4use eyre::{bail, ensure, Context, Result};
5use rerun_except::rerun_except;
6use std::{
7 env,
8 fs::{self, File},
9 io::{BufWriter, Write},
10 path::{Path, PathBuf},
11 process::Command,
12};
13
14mod names {
15 pub const TARGET_DIR: &str = "futhark";
16 pub const RAW_TARGET_DIR: &str = "futhark_raw";
17
18 pub const LIBRARY: &str = "futhark_lib";
19 pub const MANIFEST: &str = "futhark_lib.json";
20
21 pub const H_FILE: &str = "futhark_lib.h";
22 pub const C_FILE: &str = "futhark_lib.c";
23 pub const RS_FILE: &str = "futhark_lib.rs";
24}
25
26fn cargo_out_dir() -> Result<PathBuf> {
27 env::var("OUT_DIR")
28 .wrap_err("OUT_DIR is undefined.")
29 .map(PathBuf::from)
30}
31
32fn cargo_manifest_dir() -> Result<PathBuf> {
33 env::var("CARGO_MANIFEST_DIR")
34 .wrap_err("CARGO_MANIFEST_DIR is undefined.")
35 .map(PathBuf::from)
36}
37
38pub struct Generator {
63 source: PathBuf,
64 watch: bool,
65 cuda_home: Option<PathBuf>,
66 targets: BitFlags<Target>,
67}
68
69impl Generator {
70 pub fn new(source: impl Into<PathBuf>) -> Self {
80 Generator {
81 source: source.into(),
82 cuda_home: None,
83 watch: true,
84 targets: BitFlags::empty(),
85 }
86 }
87
88 pub fn watch_sources(&mut self, watch: bool) -> &mut Self {
92 self.watch = watch;
93 self
94 }
95
96 pub fn with_cuda_home(&mut self, cuda_home: impl Into<PathBuf>) -> Result<&mut Self> {
102 let cuda_home = cuda_home.into();
103 ensure!(
104 cuda_home.to_str().is_some(),
105 "cuda_home must be representable using UTF8."
106 );
107 self.cuda_home = Some(cuda_home);
108 Ok(self)
109 }
110
111 pub fn with_target(&mut self, target: Target) -> &mut Self {
113 self.targets |= target;
114 self
115 }
116
117 pub fn with_target_if(&mut self, target: Target, condition: bool) -> &mut Self {
121 if condition {
122 self.targets |= target;
123 }
124 self
125 }
126
127 pub fn run(&mut self) -> Result<()> {
129 ensure!(self.source.is_file(), "Futhark source file does not exist.");
130
131 ensure!(
132 !self.targets.is_empty(),
133 "At least one target must be built."
134 );
135
136 self.build_targets().wrap_err("Failed to build targets.")?;
137
138 self.generate_library()
139 .wrap_err("Failed to generate Rust library.")?;
140
141 Ok(())
142 }
143}
144
145impl Generator {
146 fn generate_library(&mut self) -> Result<(), eyre::ErrReport> {
147 let any_target = self.targets.iter().next().unwrap();
148 let manifest_path = cargo_out_dir()?
149 .join(names::TARGET_DIR)
150 .join(any_target.name())
151 .join(names::MANIFEST);
152 let manifest = Manifest::from_json_file(&manifest_path).wrap_err_with(|| {
153 format!(
154 "Failed to load manifest file at {}.",
155 manifest_path.display()
156 )
157 })?;
158
159 let rust_lib = template::combined(&manifest, self.targets).to_string();
160 let rust_lib_path = cargo_out_dir()?
161 .join(names::TARGET_DIR)
162 .join(names::RS_FILE);
163
164 let mut rust_lib_file = fs::File::create(&rust_lib_path)
165 .wrap_err("Failed to create generated Rust library file.")?;
166 writeln!(rust_lib_file, "{}", rust_lib)
167 .wrap_err("Failed to write generated Rust library.")?;
168 rust_lib_file.flush().wrap_err("Failed to flush.")?;
169
170 let rustfmt_status = Command::new("rustfmt")
171 .arg(rust_lib_path)
172 .status()
173 .wrap_err("Failed to run rustfmt.")?
174 .success();
175
176 if !rustfmt_status {
177 bail!("Failed to format generated Rust library.");
178 };
179
180 Ok(())
181 }
182
183 fn build_targets(&self) -> Result<()> {
184 if self.watch {
185 watch_source(&self.source).wrap_err("Failed to watch source files for changes.")?;
186 }
187
188 if self.targets.contains(Target::C) {
189 self.build_target(Target::C)
190 .wrap_err("Failed to build C target.")?;
191 }
192
193 if self.targets.contains(Target::MultiCore) {
194 self.build_target(Target::MultiCore)
195 .wrap_err("Failed to build Multi-Core target.")?;
196 }
197
198 if self.targets.contains(Target::OpenCL) {
199 self.build_target(Target::OpenCL)
200 .wrap_err("Failed to build OpenCL target.")?;
201
202 println!("cargo:rustc-link-lib=OpenCL");
203 }
204
205 if self.targets.contains(Target::Cuda) {
206 self.build_target(Target::Cuda)
207 .wrap_err("Failed to build Cuda target.")?;
208
209 println!("cargo:rustc-link-lib=cuda");
210 println!("cargo:rustc-link-lib=cudart");
211 println!("cargo:rustc-link-lib=nvrtc");
212
213 if let Some(cuda_home) = &self.cuda_home {
214 let cuda_lib64 = cuda_home.join("lib64");
215
216 println!("cargo:rustc-link-search={}", cuda_lib64.to_str().unwrap());
217 }
218 }
219
220 Ok(())
221 }
222
223 fn build_target(&self, target: Target) -> Result<()> {
224 let out_dir = cargo_out_dir()?;
225 let target_dir = out_dir.join(names::TARGET_DIR).join(target.name());
226 fs::create_dir_all(&target_dir).wrap_err("Could not create target dir.")?;
227
228 let raw_target_dir = out_dir.join(names::RAW_TARGET_DIR).join(target.name());
229 fs::create_dir_all(&raw_target_dir).wrap_err("Could not create raw target dir.")?;
230
231 let futhark_status = Command::new("futhark")
232 .args([target.name(), "--library", "-o"])
233 .arg(raw_target_dir.join(names::LIBRARY))
234 .arg(self.source.as_os_str())
235 .status()
236 .wrap_err("Failed to run Futhark compiler.")?
237 .success();
238
239 if !futhark_status {
240 bail!("Failed to compile Futhark code.");
241 }
242
243 fs::copy(
244 raw_target_dir.join(names::MANIFEST),
245 target_dir.join(names::MANIFEST),
246 )
247 .wrap_err("Failed to copy manifest file")?;
248
249 let prefix = format!("futhark_{target}_");
250
251 prefix_items(
252 &prefix,
253 raw_target_dir.join(names::H_FILE),
254 target_dir.join(names::H_FILE),
255 )
256 .wrap_err("Failed to prefix header file items.")?;
257
258 prefix_items(
259 &prefix,
260 raw_target_dir.join(names::C_FILE),
261 target_dir.join(names::C_FILE),
262 )
263 .wrap_err("Failed to prefix source file items.")?;
264
265 let cuda_include_path = match (target, &self.cuda_home) {
266 (Target::Cuda, Some(cuda_home)) => Some(cuda_home.join("include")),
267 _ => None,
268 };
269
270 let cuda_include_flag = cuda_include_path
271 .as_ref()
272 .map(|path| format!("-I{}", path.to_str().unwrap()));
273
274 bindgen::Builder::default()
275 .clang_args(cuda_include_flag)
276 .header(target_dir.join(names::H_FILE).to_string_lossy())
277 .allowlist_function("free")
278 .allowlist_function("futhark_.*")
279 .allowlist_type("futhark_.*")
280 .parse_callbacks(Box::new(bindgen::CargoCallbacks))
281 .parse_callbacks(PrefixRemover::new(prefix))
282 .generate()
283 .wrap_err("Failed to generate bindings.")?
284 .write_to_file(target_dir.join(names::RS_FILE))
285 .wrap_err("Failed to write bindings to file.")?;
286
287 cc::Build::new()
288 .file(target_dir.join(names::C_FILE))
289 .includes(cuda_include_path)
290 .static_flag(true)
291 .warnings(false)
292 .try_compile(&format!("futhark-lib-{compiler}", compiler = target))
293 .wrap_err("Failed to compile the generated c code.")?;
294
295 Ok(())
296 }
297}
298
299fn watch_source(source: &Path) -> Result<()> {
300 let old_manifest_dir = cargo_manifest_dir()?;
301
302 env::set_var("CARGO_MANIFEST_DIR", source.parent().unwrap().as_os_str());
303
304 rerun_except(&[])
305 .map_err(|err| eyre::eyre!("{}", err))
306 .wrap_err("Failed to watch files.")?;
307
308 env::set_var("CARGO_MANIFEST_DIR", old_manifest_dir);
309
310 Ok(())
311}
312
313fn prefix_items(prefix: &str, input: impl AsRef<Path>, output: impl AsRef<Path>) -> Result<()> {
314 let mut out = BufWriter::new(File::create(output).wrap_err("Failed to open output file.")?);
315
316 let memblock_prefix = &format!("{prefix}_memblock_");
317 let lexical_realloc_error_prefix = &format!("{prefix}_lexical_realloc_error");
318
319 for line in fs::read_to_string(input)?.lines() {
320 let new_line = line
321 .replace("memblock_", memblock_prefix)
322 .replace("lexical_realloc_error", lexical_realloc_error_prefix)
323 .replace("futhark_", prefix);
324
325 writeln!(out, "{}", new_line).wrap_err("Failed to write line to output file.")?;
326 }
327
328 out.flush().wrap_err("Failed to flush output file.")?;
329
330 Ok(())
331}
332
333#[derive(Debug)]
334struct PrefixRemover {
335 prefix: String,
336}
337
338impl PrefixRemover {
339 #[allow(clippy::new_ret_no_self)]
340 pub fn new(prefix: impl ToOwned<Owned = String>) -> Box<dyn ParseCallbacks> {
341 Box::new(PrefixRemover {
342 prefix: prefix.to_owned(),
343 })
344 }
345}
346
347impl ParseCallbacks for PrefixRemover {
348 fn item_name(&self, original_item_name: &str) -> Option<String> {
349 if original_item_name.contains(&self.prefix) {
350 return Some(original_item_name.replace(&self.prefix, "futhark_"));
351 }
352
353 None
354 }
355}