use super::{result, sys};
use core::ffi::{c_char, CStr};
use std::ffi::CString;
use std::{borrow::ToOwned, path::PathBuf, string::String, vec::Vec};
#[derive(Debug, Clone)]
pub struct Ptx(pub(crate) PtxKind);
impl Ptx {
pub fn from_file<P: Into<PathBuf>>(path: P) -> Self {
Self(PtxKind::File(path.into()))
}
pub fn from_src<S: Into<String>>(src: S) -> Self {
Self(PtxKind::Src(src.into()))
}
}
impl<S: Into<String>> From<S> for Ptx {
fn from(value: S) -> Self {
Self::from_src(value)
}
}
#[derive(Debug, Clone)]
pub(crate) enum PtxKind {
Image(Vec<c_char>),
Src(String),
File(PathBuf),
}
pub fn compile_ptx<S: AsRef<str>>(src: S) -> Result<Ptx, CompileError> {
compile_ptx_with_opts(src, Default::default())
}
pub fn compile_ptx_with_opts<S: AsRef<str>>(
src: S,
opts: CompileOptions,
) -> Result<Ptx, CompileError> {
let prog = Program::create(src)?;
prog.compile(opts)
}
pub(crate) struct Program {
prog: sys::nvrtcProgram,
}
impl Program {
pub(crate) fn create<S: AsRef<str>>(src: S) -> Result<Self, CompileError> {
let prog = result::create_program(src).map_err(CompileError::CreationError)?;
Ok(Self { prog })
}
pub(crate) fn compile(self, opts: CompileOptions) -> Result<Ptx, CompileError> {
let options = opts.build();
unsafe { result::compile_program(self.prog, &options) }.map_err(|e| {
let log_raw = unsafe { result::get_program_log(self.prog) }.unwrap();
let log_ptr = log_raw.as_ptr();
let log = unsafe { CStr::from_ptr(log_ptr) }.to_owned();
CompileError::CompileError {
nvrtc: e,
options,
log,
}
})?;
let image = unsafe { result::get_ptx(self.prog) }.map_err(CompileError::GetPtxError)?;
Ok(Ptx(PtxKind::Image(image)))
}
}
impl Drop for Program {
fn drop(&mut self) {
let prog = std::mem::replace(&mut self.prog, std::ptr::null_mut());
if !prog.is_null() {
unsafe { result::destroy_program(prog) }.unwrap()
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompileError {
CreationError(result::NvrtcError),
CompileError {
nvrtc: result::NvrtcError,
options: Vec<String>,
log: CString,
},
GetLogError(result::NvrtcError),
GetPtxError(result::NvrtcError),
DestroyError(result::NvrtcError),
}
#[cfg(feature = "std")]
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg(feature = "std")]
impl std::error::Error for CompileError {}
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)]
pub struct CompileOptions {
pub ftz: Option<bool>,
pub prec_sqrt: Option<bool>,
pub prec_div: Option<bool>,
pub fmad: Option<bool>,
pub use_fast_math: Option<bool>,
pub maxrregcount: Option<usize>,
pub include_paths: Vec<String>,
pub arch: Option<&'static str>,
}
impl CompileOptions {
pub(crate) fn build(self) -> Vec<String> {
let mut options: Vec<String> = Vec::new();
if let Some(v) = self.ftz {
options.push(std::format!("--ftz={v}"));
}
if let Some(v) = self.prec_sqrt {
options.push(std::format!("--prec-sqrt={v}"));
}
if let Some(v) = self.prec_div {
options.push(std::format!("--prec-div={v}"));
}
if let Some(v) = self.fmad {
options.push(std::format!("--fmad={v}"));
}
if let Some(true) = self.use_fast_math {
options.push("--fmad=true".into());
}
if let Some(count) = self.maxrregcount {
options.push(std::format!("--maxrregcount={count}"));
}
for path in self.include_paths {
options.push(std::format!("--include-path={path}"));
}
if let Some(arch) = self.arch {
options.push(std::format!("--gpu-architecture={arch}"))
}
options
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compile_no_opts() {
const SRC: &str =
"extern \"C\" __global__ void sin_kernel(float *out, const float *inp, int numel) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < numel) {
out[i] = sin(inp[i]);
}
}";
compile_ptx_with_opts(SRC, Default::default()).unwrap();
}
#[test]
fn test_compile_options_build_none() {
let opts: CompileOptions = Default::default();
assert!(opts.build().is_empty());
}
#[test]
fn test_compile_options_build_ftz() {
let opts = CompileOptions {
ftz: Some(true),
..Default::default()
};
assert_eq!(&opts.build(), &["--ftz=true"]);
}
#[test]
fn test_compile_options_build_multi() {
let opts = CompileOptions {
prec_div: Some(false),
maxrregcount: Some(60),
..Default::default()
};
assert_eq!(&opts.build(), &["--prec-div=false", "--maxrregcount=60"]);
}
}