cargo_futhark/
generator.rs

1use crate::{manifest::Manifest, template, Target};
2use bindgen::callbacks::ParseCallbacks;
3use enumflags2::BitFlags;
4use eyre::{bail, ensure, Context, Result};
5use rerun_except::rerun_except;
6use std::{
7    env,
8    fs::{self, File},
9    io::{BufWriter, Write},
10    path::{Path, PathBuf},
11    process::Command,
12};
13
14mod names {
15    pub const TARGET_DIR: &str = "futhark";
16    pub const RAW_TARGET_DIR: &str = "futhark_raw";
17
18    pub const LIBRARY: &str = "futhark_lib";
19    pub const MANIFEST: &str = "futhark_lib.json";
20
21    pub const H_FILE: &str = "futhark_lib.h";
22    pub const C_FILE: &str = "futhark_lib.c";
23    pub const RS_FILE: &str = "futhark_lib.rs";
24}
25
26fn cargo_out_dir() -> Result<PathBuf> {
27    env::var("OUT_DIR")
28        .wrap_err("OUT_DIR is undefined.")
29        .map(PathBuf::from)
30}
31
32fn cargo_manifest_dir() -> Result<PathBuf> {
33    env::var("CARGO_MANIFEST_DIR")
34        .wrap_err("CARGO_MANIFEST_DIR is undefined.")
35        .map(PathBuf::from)
36}
37
38/// Bindings generator.
39///
40/// This does:
41/// - Compile Futhark code to C code for each target.
42/// - Generate unsafe Rust bindings for each target.
43/// - Generate a single safe wrapper around all targets.
44/// - Compile and link generated C code.
45///
46/// # Usage
47///
48/// In your `build.rs` file:
49/// ```no_run
50/// use cargo_futhark::{Generator, Result, Target};
51///
52/// fn main() -> Result<()> {
53///     Generator::new("src/lib.fut")
54///         .with_target_if(Target::C, cfg!(feature = "c"))
55///         .with_target_if(Target::MultiCore, cfg!(feature = "multicore"))
56///         .with_target_if(Target::OpenCL, cfg!(feature = "opencl"))
57///         .with_target_if(Target::Cuda, cfg!(feature = "cuda"))
58///         .with_target_if(Target::ISPC, cfg!(feature = "ispc"))
59///         .run()
60/// }
61/// ```
62pub struct Generator {
63    source: PathBuf,
64    watch: bool,
65    cuda_home: Option<PathBuf>,
66    targets: BitFlags<Target>,
67}
68
69impl Generator {
70    /// Returns a new [`Generator`] with default settings.
71    ///
72    /// The `source` should be the `.fut` file containing the `entry` functions.
73    ///
74    /// The defaults are:
75    /// - `watch_sources = true`
76    /// - `targets = EMPTY`
77    ///
78    /// You must add at least on [`Target`] before you call [`Generator::run`].
79    pub fn new(source: impl Into<PathBuf>) -> Self {
80        Generator {
81            source: source.into(),
82            cuda_home: None,
83            watch: true,
84            targets: BitFlags::empty(),
85        }
86    }
87
88    /// Watch Futhark source file for changes.
89    ///
90    /// Enabled by default.
91    pub fn watch_sources(&mut self, watch: bool) -> &mut Self {
92        self.watch = watch;
93        self
94    }
95
96    /// Specify a custom CUDA home path.
97    ///
98    /// This will add the following:
99    /// - `$cuda_home/include` to include path
100    /// - `$cuda_home/lib64` to link path
101    pub fn with_cuda_home(&mut self, cuda_home: impl Into<PathBuf>) -> Result<&mut Self> {
102        let cuda_home = cuda_home.into();
103        ensure!(
104            cuda_home.to_str().is_some(),
105            "cuda_home must be representable using UTF8."
106        );
107        self.cuda_home = Some(cuda_home);
108        Ok(self)
109    }
110
111    /// Enable the given [Target].
112    pub fn with_target(&mut self, target: Target) -> &mut Self {
113        self.targets |= target;
114        self
115    }
116
117    /// Enable the given [Target] conditionally.
118    ///
119    /// This is especially useful with the [`cfg!`] macro.
120    pub fn with_target_if(&mut self, target: Target, condition: bool) -> &mut Self {
121        if condition {
122            self.targets |= target;
123        }
124        self
125    }
126
127    /// Run the generator.
128    pub fn run(&mut self) -> Result<()> {
129        ensure!(self.source.is_file(), "Futhark source file does not exist.");
130
131        ensure!(
132            !self.targets.is_empty(),
133            "At least one target must be built."
134        );
135
136        self.build_targets().wrap_err("Failed to build targets.")?;
137
138        self.generate_library()
139            .wrap_err("Failed to generate Rust library.")?;
140
141        Ok(())
142    }
143}
144
145impl Generator {
146    fn generate_library(&mut self) -> Result<(), eyre::ErrReport> {
147        let any_target = self.targets.iter().next().unwrap();
148        let manifest_path = cargo_out_dir()?
149            .join(names::TARGET_DIR)
150            .join(any_target.name())
151            .join(names::MANIFEST);
152        let manifest = Manifest::from_json_file(&manifest_path).wrap_err_with(|| {
153            format!(
154                "Failed to load manifest file at {}.",
155                manifest_path.display()
156            )
157        })?;
158
159        let rust_lib = template::combined(&manifest, self.targets).to_string();
160        let rust_lib_path = cargo_out_dir()?
161            .join(names::TARGET_DIR)
162            .join(names::RS_FILE);
163
164        let mut rust_lib_file = fs::File::create(&rust_lib_path)
165            .wrap_err("Failed to create generated Rust library file.")?;
166        writeln!(rust_lib_file, "{}", rust_lib)
167            .wrap_err("Failed to write generated Rust library.")?;
168        rust_lib_file.flush().wrap_err("Failed to flush.")?;
169
170        let rustfmt_status = Command::new("rustfmt")
171            .arg(rust_lib_path)
172            .status()
173            .wrap_err("Failed to run rustfmt.")?
174            .success();
175
176        if !rustfmt_status {
177            bail!("Failed to format generated Rust library.");
178        };
179
180        Ok(())
181    }
182
183    fn build_targets(&self) -> Result<()> {
184        if self.watch {
185            watch_source(&self.source).wrap_err("Failed to watch source files for changes.")?;
186        }
187
188        if self.targets.contains(Target::C) {
189            self.build_target(Target::C)
190                .wrap_err("Failed to build C target.")?;
191        }
192
193        if self.targets.contains(Target::MultiCore) {
194            self.build_target(Target::MultiCore)
195                .wrap_err("Failed to build Multi-Core target.")?;
196        }
197
198        if self.targets.contains(Target::OpenCL) {
199            self.build_target(Target::OpenCL)
200                .wrap_err("Failed to build OpenCL target.")?;
201
202            println!("cargo:rustc-link-lib=OpenCL");
203        }
204
205        if self.targets.contains(Target::Cuda) {
206            self.build_target(Target::Cuda)
207                .wrap_err("Failed to build Cuda target.")?;
208
209            println!("cargo:rustc-link-lib=cuda");
210            println!("cargo:rustc-link-lib=cudart");
211            println!("cargo:rustc-link-lib=nvrtc");
212
213            if let Some(cuda_home) = &self.cuda_home {
214                let cuda_lib64 = cuda_home.join("lib64");
215
216                println!("cargo:rustc-link-search={}", cuda_lib64.to_str().unwrap());
217            }
218        }
219
220        Ok(())
221    }
222
223    fn build_target(&self, target: Target) -> Result<()> {
224        let out_dir = cargo_out_dir()?;
225        let target_dir = out_dir.join(names::TARGET_DIR).join(target.name());
226        fs::create_dir_all(&target_dir).wrap_err("Could not create target dir.")?;
227
228        let raw_target_dir = out_dir.join(names::RAW_TARGET_DIR).join(target.name());
229        fs::create_dir_all(&raw_target_dir).wrap_err("Could not create raw target dir.")?;
230
231        let futhark_status = Command::new("futhark")
232            .args([target.name(), "--library", "-o"])
233            .arg(raw_target_dir.join(names::LIBRARY))
234            .arg(self.source.as_os_str())
235            .status()
236            .wrap_err("Failed to run Futhark compiler.")?
237            .success();
238
239        if !futhark_status {
240            bail!("Failed to compile Futhark code.");
241        }
242
243        fs::copy(
244            raw_target_dir.join(names::MANIFEST),
245            target_dir.join(names::MANIFEST),
246        )
247        .wrap_err("Failed to copy manifest file")?;
248
249        let prefix = format!("futhark_{target}_");
250
251        prefix_items(
252            &prefix,
253            raw_target_dir.join(names::H_FILE),
254            target_dir.join(names::H_FILE),
255        )
256        .wrap_err("Failed to prefix header file items.")?;
257
258        prefix_items(
259            &prefix,
260            raw_target_dir.join(names::C_FILE),
261            target_dir.join(names::C_FILE),
262        )
263        .wrap_err("Failed to prefix source file items.")?;
264
265        let cuda_include_path = match (target, &self.cuda_home) {
266            (Target::Cuda, Some(cuda_home)) => Some(cuda_home.join("include")),
267            _ => None,
268        };
269
270        let cuda_include_flag = cuda_include_path
271            .as_ref()
272            .map(|path| format!("-I{}", path.to_str().unwrap()));
273
274        bindgen::Builder::default()
275            .clang_args(cuda_include_flag)
276            .header(target_dir.join(names::H_FILE).to_string_lossy())
277            .allowlist_function("free")
278            .allowlist_function("futhark_.*")
279            .allowlist_type("futhark_.*")
280            .parse_callbacks(Box::new(bindgen::CargoCallbacks))
281            .parse_callbacks(PrefixRemover::new(prefix))
282            .generate()
283            .wrap_err("Failed to generate bindings.")?
284            .write_to_file(target_dir.join(names::RS_FILE))
285            .wrap_err("Failed to write bindings to file.")?;
286
287        cc::Build::new()
288            .file(target_dir.join(names::C_FILE))
289            .includes(cuda_include_path)
290            .static_flag(true)
291            .warnings(false)
292            .try_compile(&format!("futhark-lib-{compiler}", compiler = target))
293            .wrap_err("Failed to compile the generated c code.")?;
294
295        Ok(())
296    }
297}
298
299fn watch_source(source: &Path) -> Result<()> {
300    let old_manifest_dir = cargo_manifest_dir()?;
301
302    env::set_var("CARGO_MANIFEST_DIR", source.parent().unwrap().as_os_str());
303
304    rerun_except(&[])
305        .map_err(|err| eyre::eyre!("{}", err))
306        .wrap_err("Failed to watch files.")?;
307
308    env::set_var("CARGO_MANIFEST_DIR", old_manifest_dir);
309
310    Ok(())
311}
312
313fn prefix_items(prefix: &str, input: impl AsRef<Path>, output: impl AsRef<Path>) -> Result<()> {
314    let mut out = BufWriter::new(File::create(output).wrap_err("Failed to open output file.")?);
315
316    let memblock_prefix = &format!("{prefix}_memblock_");
317    let lexical_realloc_error_prefix = &format!("{prefix}_lexical_realloc_error");
318
319    for line in fs::read_to_string(input)?.lines() {
320        let new_line = line
321            .replace("memblock_", memblock_prefix)
322            .replace("lexical_realloc_error", lexical_realloc_error_prefix)
323            .replace("futhark_", prefix);
324
325        writeln!(out, "{}", new_line).wrap_err("Failed to write line to output file.")?;
326    }
327
328    out.flush().wrap_err("Failed to flush output file.")?;
329
330    Ok(())
331}
332
333#[derive(Debug)]
334struct PrefixRemover {
335    prefix: String,
336}
337
338impl PrefixRemover {
339    #[allow(clippy::new_ret_no_self)]
340    pub fn new(prefix: impl ToOwned<Owned = String>) -> Box<dyn ParseCallbacks> {
341        Box::new(PrefixRemover {
342            prefix: prefix.to_owned(),
343        })
344    }
345}
346
347impl ParseCallbacks for PrefixRemover {
348    fn item_name(&self, original_item_name: &str) -> Option<String> {
349        if original_item_name.contains(&self.prefix) {
350            return Some(original_item_name.replace(&self.prefix, "futhark_"));
351        }
352
353        None
354    }
355}