1use std::ffi::OsStr;
2use std::path::{Path, PathBuf};
3use std::process::{Command, Output};
4
5#[inline]
15#[must_use]
16pub fn nvbit_include() -> PathBuf {
17 PathBuf::from(std::env::var("DEP_NVBIT_INCLUDE").expect("nvbit include path"))
18 .canonicalize()
19 .expect("canonicalize path")
20}
21
22#[inline]
29#[must_use]
30pub fn output_path() -> PathBuf {
31 PathBuf::from(std::env::var("OUT_DIR").expect("cargo out dir"))
32 .canonicalize()
33 .expect("canonicalize path")
34}
35
36#[inline]
43#[must_use]
44pub fn manifest_path() -> PathBuf {
45 PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").expect("cargo manifest dir"))
46 .canonicalize()
47 .expect("canonicalize path")
48}
49
50#[derive(thiserror::Error, Debug)]
51pub enum Error {
52 #[error(transparent)]
53 Io(#[from] std::io::Error),
54 #[error("Command failed")]
55 Command(Output),
56}
57
58#[derive(Debug, Clone)]
59pub struct Build {
60 include_directories: Vec<PathBuf>,
61 objects: Vec<PathBuf>,
62 sources: Vec<PathBuf>,
63 instrumentation_sources: Vec<PathBuf>,
64 compiler_flags: Vec<String>,
65 host_compiler: Option<PathBuf>,
66 nvcc_compiler: Option<PathBuf>,
67 warnings: bool,
68 warnings_as_errors: bool,
69}
70
71impl Default for Build {
72 fn default() -> Self {
73 Self::new()
74 }
75}
76
77impl Build {
78 #[inline]
79 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 include_directories: Vec::new(),
83 objects: Vec::new(),
84 sources: Vec::new(),
85 instrumentation_sources: Vec::new(),
86 compiler_flags: Vec::new(),
87 host_compiler: None,
88 nvcc_compiler: None,
89 warnings: false,
90 warnings_as_errors: false,
91 }
92 }
93
94 fn compile_instrumentation_functions(
95 &self,
96 nvcc_compiler: &Path,
97 include_args: &[String],
98 compiler_flags: &[&str],
99 objects: &mut Vec<PathBuf>,
100 ) -> Result<(), Error> {
101 for (i, src) in self.instrumentation_sources.iter().enumerate() {
102 let default_name = format!("instr_src_{i}");
103 let obj = output_path()
104 .join(
105 src.file_name()
106 .and_then(OsStr::to_str)
107 .unwrap_or(&default_name),
108 )
109 .with_extension("o");
110 let mut cmd = Command::new(nvcc_compiler);
111 if let Some(host_compiler) = &self.host_compiler {
112 cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
113 }
114 cmd.args(include_args);
115 cmd.args([
116 "-maxrregcount=24",
117 "-arch=sm_35",
118 "-Xptxas",
119 "-astoolspatch",
120 "--keep-device-functions",
121 ])
122 .args(compiler_flags)
123 .args(&self.compiler_flags)
124 .arg("-c")
125 .arg(src)
126 .arg("-o")
127 .arg(&*obj.to_string_lossy());
128
129 println!("cargo:warning={cmd:?}");
130 let result = cmd.output()?;
131 if !result.status.success() {
132 return Err(Error::Command(result));
133 }
134 objects.push(obj);
135 }
136 Ok(())
137 }
138
139 fn compile_sources(
140 &self,
141 nvcc_compiler: &Path,
142 include_args: &[String],
143 compiler_flags: &[&str],
144 objects: &mut Vec<PathBuf>,
145 ) -> Result<(), Error> {
146 for (i, src) in self.sources.iter().enumerate() {
147 let default_name = format!("src_{i}");
148 let obj = output_path()
149 .join(
150 src.file_name()
151 .and_then(OsStr::to_str)
152 .unwrap_or(&default_name),
153 )
154 .with_extension("o");
155 let mut cmd = Command::new(nvcc_compiler);
156 if let Some(host_compiler) = &self.host_compiler {
157 cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
158 }
159 cmd.args(include_args)
160 .args(compiler_flags)
161 .args(&self.compiler_flags)
162 .args(["-dc", "-c"])
163 .arg(src)
164 .arg("-o")
165 .arg(&*obj.to_string_lossy());
166 println!("cargo:warning={cmd:?}");
167 let result = cmd.output()?;
168 if !result.status.success() {
169 return Err(Error::Command(result));
170 }
171 objects.push(obj);
172 }
173 Ok(())
174 }
175
176 pub fn compile<O: AsRef<str>>(&self, output: O) -> Result<(), Error> {
181 let mut objects = self.objects.clone();
182 let include_args: Vec<_> = self
183 .include_directories
184 .iter()
185 .map(|d| format!("-I{}", &d.to_string_lossy()))
186 .collect();
187
188 let mut compiler_flags = vec!["-arch=sm_35", "-Xcompiler", "-fPIC"];
189 if self.warnings {
190 compiler_flags.extend(["-Xcompiler", "-Wall"]);
191 }
192 if self.warnings_as_errors {
193 compiler_flags.extend(["-Xcompiler", "-Werror"]);
194 }
195
196 let default_nvcc_compiler = PathBuf::from("nvcc");
197 let nvcc_compiler = self
198 .nvcc_compiler
199 .as_ref()
200 .unwrap_or(&default_nvcc_compiler);
201
202 self.compile_instrumentation_functions(
204 nvcc_compiler,
205 &include_args,
206 &compiler_flags,
207 &mut objects,
208 )?;
209
210 self.compile_sources(nvcc_compiler, &include_args, &compiler_flags, &mut objects)?;
212
213 let dev_link_obj = output_path().join("dev_link.o");
215 let mut cmd = Command::new(nvcc_compiler);
216 if let Some(host_compiler) = &self.host_compiler {
217 cmd.args(["-ccbin", &*host_compiler.to_string_lossy()]);
218 }
219
220 cmd.args(&include_args)
221 .args(&compiler_flags)
222 .args(&self.compiler_flags)
223 .arg("-dlink")
224 .args(&objects)
225 .arg("-o")
226 .arg(&*dev_link_obj.to_string_lossy());
227 println!("cargo:warning={cmd:?}");
228 let result = cmd.output()?;
229 if !result.status.success() {
230 return Err(Error::Command(result));
231 }
232 objects.push(dev_link_obj);
233
234 let mut cmd = Command::new("ar");
236 cmd.args([
237 "cru",
238 &output_path()
239 .join(format!("lib{}.a", output.as_ref()))
240 .to_string_lossy(),
241 ])
242 .args(&objects);
243 println!("cargo:warning={cmd:?}");
244 let result = cmd.output()?;
245 if !result.status.success() {
246 return Err(Error::Command(result));
247 }
248
249 println!("cargo:rustc-link-search=native={}", output_path().display());
250 println!(
251 "cargo:rustc-link-lib=static:+whole-archive={}",
252 output.as_ref()
253 );
254 Ok(())
255 }
256
257 pub fn host_compiler<P: Into<PathBuf>>(&mut self, compiler: P) -> &mut Self {
259 self.host_compiler = Some(compiler.into());
260 self
261 }
262
263 pub fn nvcc_compiler<P: Into<PathBuf>>(&mut self, compiler: P) -> &mut Self {
265 self.nvcc_compiler = Some(compiler.into());
266 self
267 }
268
269 pub fn object<P: Into<PathBuf>>(&mut self, obj: P) -> &mut Self {
270 self.objects.push(obj.into());
271 self
272 }
273
274 pub fn objects<P>(&mut self, objects: P) -> &mut Self
275 where
276 P: IntoIterator,
277 P::Item: Into<PathBuf>,
278 {
279 for obj in objects {
280 self.object(obj);
281 }
282 self
283 }
284
285 pub fn instrumentation_source<P: Into<PathBuf>>(&mut self, src: P) -> &mut Self {
286 self.instrumentation_sources.push(src.into());
287 self
288 }
289
290 pub fn instrumentation_sources<P>(&mut self, sources: P) -> &mut Self
291 where
292 P: IntoIterator,
293 P::Item: Into<PathBuf>,
294 {
295 for src in sources {
296 self.instrumentation_source(src);
297 }
298 self
299 }
300
301 pub fn source<P: Into<PathBuf>>(&mut self, dir: P) -> &mut Self {
302 self.sources.push(dir.into());
303 self
304 }
305
306 pub fn sources<P>(&mut self, sources: P) -> &mut Self
307 where
308 P: IntoIterator,
309 P::Item: Into<PathBuf>,
310 {
311 for src in sources {
312 self.source(src);
313 }
314 self
315 }
316
317 pub fn nvcc_flag<F: Into<String>>(&mut self, flag: F) -> &mut Build {
319 self.compiler_flags.push(flag.into());
320 self
321 }
322
323 pub fn host_compiler_flag<F: Into<String>>(&mut self, flag: F) -> &mut Build {
325 self.compiler_flags
326 .extend(["-Xcompiler".to_string(), flag.into()]);
327 self
328 }
329
330 pub fn nvcc_flags<I>(&mut self, flags: I) -> &mut Self
332 where
333 I: IntoIterator,
334 I::Item: Into<String>,
335 {
336 for flag in flags {
337 self.nvcc_flag(flag);
338 }
339 self
340 }
341
342 pub fn host_compiler_flags<I>(&mut self, flags: I) -> &mut Self
344 where
345 I: IntoIterator,
346 I::Item: Into<String>,
347 {
348 for flag in flags {
349 self.host_compiler_flag(flag);
350 }
351 self
352 }
353
354 pub fn include<P: Into<PathBuf>>(&mut self, dir: P) -> &mut Self {
355 self.include_directories.push(dir.into());
356 self
357 }
358
359 pub fn includes<P>(&mut self, dirs: P) -> &mut Self
360 where
361 P: IntoIterator,
362 P::Item: Into<PathBuf>,
363 {
364 for dir in dirs {
365 self.include(dir);
366 }
367 self
368 }
369
370 pub fn warnings(&mut self, enable: bool) -> &mut Self {
371 self.warnings = enable;
372 self
373 }
374
375 pub fn warnings_as_errors(&mut self, enable: bool) -> &mut Self {
376 self.warnings_as_errors = enable;
377 self
378 }
379}