nvbit_build/
lib.rs

1use std::ffi::OsStr;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Output};
4
5/// Get the nvbit include dir.
6///
7/// **Note**: This function is intended to be used the build.rs context.
8///
9/// This can be useful when your crate uses nvbit and requires access to
10/// the nvbit header files.
11///
12/// # Panics
13/// When the `DEP_NVBIT_INCLUDE` environment variable is not set.
14#[inline]
15#[must_use]
16pub fn nvbit_include() -> PathBuf {
17    PathBuf::from(std::env::var("DEP_NVBIT_INCLUDE").expect("nvbit include path"))
18        .canonicalize()
19        .expect("canonicalize path")
20}
21
22/// Get the cargo output directory.
23///
24/// **Note**: This function is intended to be used the build.rs context.
25///
26/// # Panics
27/// When the `OUT_DIR` environment variable is not set.
28#[inline]
29#[must_use]
30pub fn output_path() -> PathBuf {
31    PathBuf::from(std::env::var("OUT_DIR").expect("cargo out dir"))
32        .canonicalize()
33        .expect("canonicalize path")
34}
35
36/// Get the cargo manifest directory.
37///
38/// **Note**: This function is intended to be used the build.rs context.
39///
40/// # Panics
41/// When the `CARGO_MANIFEST_DIR` environment variable is not set.
42#[inline]
43#[must_use]
44pub fn manifest_path() -> PathBuf {
45    PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").expect("cargo manifest dir"))
46        .canonicalize()
47        .expect("canonicalize path")
48}
49
50#[derive(thiserror::Error, Debug)]
51pub enum Error {
52    #[error(transparent)]
53    Io(#[from] std::io::Error),
54    #[error("Command failed")]
55    Command(Output),
56}
57
58#[derive(Debug, Clone)]
59pub struct Build {
60    include_directories: Vec<PathBuf>,
61    objects: Vec<PathBuf>,
62    sources: Vec<PathBuf>,
63    instrumentation_sources: Vec<PathBuf>,
64    compiler_flags: Vec<String>,
65    host_compiler: Option<PathBuf>,
66    nvcc_compiler: Option<PathBuf>,
67    warnings: bool,
68    warnings_as_errors: bool,
69}
70
71impl Default for Build {
72    fn default() -> Self {
73        Self::new()
74    }
75}
76
77impl Build {
78    #[inline]
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            include_directories: Vec::new(),
83            objects: Vec::new(),
84            sources: Vec::new(),
85            instrumentation_sources: Vec::new(),
86            compiler_flags: Vec::new(),
87            host_compiler: None,
88            nvcc_compiler: None,
89            warnings: false,
90            warnings_as_errors: false,
91        }
92    }
93
94    fn compile_instrumentation_functions(
95        &self,
96        nvcc_compiler: &Path,
97        include_args: &[String],
98        compiler_flags: &[&str],
99        objects: &mut Vec<PathBuf>,
100    ) -> Result<(), Error> {
101        for (i, src) in self.instrumentation_sources.iter().enumerate() {
102            let default_name = format!("instr_src_{i}");
103            let obj = output_path()
104                .join(
105                    src.file_name()
106                        .and_then(OsStr::to_str)
107                        .unwrap_or(&default_name),
108                )
109                .with_extension("o");
110            let mut cmd = Command::new(nvcc_compiler);
111            if let Some(host_compiler) = &self.host_compiler {
112                cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
113            }
114            cmd.args(include_args);
115            cmd.args([
116                "-maxrregcount=24",
117                "-arch=sm_35",
118                "-Xptxas",
119                "-astoolspatch",
120                "--keep-device-functions",
121            ])
122            .args(compiler_flags)
123            .args(&self.compiler_flags)
124            .arg("-c")
125            .arg(src)
126            .arg("-o")
127            .arg(&*obj.to_string_lossy());
128
129            println!("cargo:warning={cmd:?}");
130            let result = cmd.output()?;
131            if !result.status.success() {
132                return Err(Error::Command(result));
133            }
134            objects.push(obj);
135        }
136        Ok(())
137    }
138
139    fn compile_sources(
140        &self,
141        nvcc_compiler: &Path,
142        include_args: &[String],
143        compiler_flags: &[&str],
144        objects: &mut Vec<PathBuf>,
145    ) -> Result<(), Error> {
146        for (i, src) in self.sources.iter().enumerate() {
147            let default_name = format!("src_{i}");
148            let obj = output_path()
149                .join(
150                    src.file_name()
151                        .and_then(OsStr::to_str)
152                        .unwrap_or(&default_name),
153                )
154                .with_extension("o");
155            let mut cmd = Command::new(nvcc_compiler);
156            if let Some(host_compiler) = &self.host_compiler {
157                cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
158            }
159            cmd.args(include_args)
160                .args(compiler_flags)
161                .args(&self.compiler_flags)
162                .args(["-dc", "-c"])
163                .arg(src)
164                .arg("-o")
165                .arg(&*obj.to_string_lossy());
166            println!("cargo:warning={cmd:?}");
167            let result = cmd.output()?;
168            if !result.status.success() {
169                return Err(Error::Command(result));
170            }
171            objects.push(obj);
172        }
173        Ok(())
174    }
175
176    /// Compile and link static library with given name from inputs.
177    ///
178    /// # Errors
179    /// When compilation fails, an error is returned.
180    pub fn compile<O: AsRef<str>>(&self, output: O) -> Result<(), Error> {
181        let mut objects = self.objects.clone();
182        let include_args: Vec<_> = self
183            .include_directories
184            .iter()
185            .map(|d| format!("-I{}", &d.to_string_lossy()))
186            .collect();
187
188        let mut compiler_flags = vec!["-arch=sm_35", "-Xcompiler", "-fPIC"];
189        if self.warnings {
190            compiler_flags.extend(["-Xcompiler", "-Wall"]);
191        }
192        if self.warnings_as_errors {
193            compiler_flags.extend(["-Xcompiler", "-Werror"]);
194        }
195
196        let default_nvcc_compiler = PathBuf::from("nvcc");
197        let nvcc_compiler = self
198            .nvcc_compiler
199            .as_ref()
200            .unwrap_or(&default_nvcc_compiler);
201
202        // compile instrumentation functions
203        self.compile_instrumentation_functions(
204            nvcc_compiler,
205            &include_args,
206            &compiler_flags,
207            &mut objects,
208        )?;
209
210        // compile sources
211        self.compile_sources(nvcc_compiler, &include_args, &compiler_flags, &mut objects)?;
212
213        // link device functions
214        let dev_link_obj = output_path().join("dev_link.o");
215        let mut cmd = Command::new(nvcc_compiler);
216        if let Some(host_compiler) = &self.host_compiler {
217            cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
218        }
219
220        cmd.args(&include_args)
221            .args(&compiler_flags)
222            .args(&self.compiler_flags)
223            .arg("-dlink")
224            .args(&objects)
225            .arg("-o")
226            .arg(&*dev_link_obj.to_string_lossy());
227        println!("cargo:warning={cmd:?}");
228        let result = cmd.output()?;
229        if !result.status.success() {
230            return Err(Error::Command(result));
231        }
232        objects.push(dev_link_obj);
233
234        // link everything together
235        let mut cmd = Command::new("ar");
236        cmd.args([
237            "cru",
238            &output_path()
239                .join(format!("lib{}.a", output.as_ref()))
240                .to_string_lossy(),
241        ])
242        .args(&objects);
243        println!("cargo:warning={cmd:?}");
244        let result = cmd.output()?;
245        if !result.status.success() {
246            return Err(Error::Command(result));
247        }
248
249        println!("cargo:rustc-link-search=native={}", output_path().display());
250        println!(
251            "cargo:rustc-link-lib=static:+whole-archive={}",
252            output.as_ref()
253        );
254        Ok(())
255    }
256
257    /// Configures the host compiler to be used to produce output.
258    pub fn host_compiler<P: Into<PathBuf>>(&mut self, compiler: P) -> &mut Self {
259        self.host_compiler = Some(compiler.into());
260        self
261    }
262
263    /// Configures the host compiler to be used to produce output.
264    pub fn nvcc_compiler<P: Into<PathBuf>>(&mut self, compiler: P) -> &mut Self {
265        self.nvcc_compiler = Some(compiler.into());
266        self
267    }
268
269    pub fn object<P: Into<PathBuf>>(&mut self, obj: P) -> &mut Self {
270        self.objects.push(obj.into());
271        self
272    }
273
274    pub fn objects<P>(&mut self, objects: P) -> &mut Self
275    where
276        P: IntoIterator,
277        P::Item: Into<PathBuf>,
278    {
279        for obj in objects {
280            self.object(obj);
281        }
282        self
283    }
284
285    pub fn instrumentation_source<P: Into<PathBuf>>(&mut self, src: P) -> &mut Self {
286        self.instrumentation_sources.push(src.into());
287        self
288    }
289
290    pub fn instrumentation_sources<P>(&mut self, sources: P) -> &mut Self
291    where
292        P: IntoIterator,
293        P::Item: Into<PathBuf>,
294    {
295        for src in sources {
296            self.instrumentation_source(src);
297        }
298        self
299    }
300
301    pub fn source<P: Into<PathBuf>>(&mut self, dir: P) -> &mut Self {
302        self.sources.push(dir.into());
303        self
304    }
305
306    pub fn sources<P>(&mut self, sources: P) -> &mut Self
307    where
308        P: IntoIterator,
309        P::Item: Into<PathBuf>,
310    {
311        for src in sources {
312            self.source(src);
313        }
314        self
315    }
316
317    /// Add an arbitrary flag to the invocation of nvcc.
318    pub fn nvcc_flag<F: Into<String>>(&mut self, flag: F) -> &mut Build {
319        self.compiler_flags.push(flag.into());
320        self
321    }
322
323    /// Add an arbitrary flag to the invocation of the host compiler.
324    pub fn host_compiler_flag<F: Into<String>>(&mut self, flag: F) -> &mut Build {
325        self.compiler_flags
326            .extend(["-Xcompiler".to_string(), flag.into()]);
327        self
328    }
329
330    /// Add arbitrary flags to the invocation of nvcc.
331    pub fn nvcc_flags<I>(&mut self, flags: I) -> &mut Self
332    where
333        I: IntoIterator,
334        I::Item: Into<String>,
335    {
336        for flag in flags {
337            self.nvcc_flag(flag);
338        }
339        self
340    }
341
342    /// Add arbitrary flags to the invocation of the host compiler.
343    pub fn host_compiler_flags<I>(&mut self, flags: I) -> &mut Self
344    where
345        I: IntoIterator,
346        I::Item: Into<String>,
347    {
348        for flag in flags {
349            self.host_compiler_flag(flag);
350        }
351        self
352    }
353
354    pub fn include<P: Into<PathBuf>>(&mut self, dir: P) -> &mut Self {
355        self.include_directories.push(dir.into());
356        self
357    }
358
359    pub fn includes<P>(&mut self, dirs: P) -> &mut Self
360    where
361        P: IntoIterator,
362        P::Item: Into<PathBuf>,
363    {
364        for dir in dirs {
365            self.include(dir);
366        }
367        self
368    }
369
370    pub fn warnings(&mut self, enable: bool) -> &mut Self {
371        self.warnings = enable;
372        self
373    }
374
375    pub fn warnings_as_errors(&mut self, enable: bool) -> &mut Self {
376        self.warnings_as_errors = enable;
377        self
378    }
379}