use std::ffi::{CString, c_void};
use crate::error::{CudaError, CudaResult};
use crate::ffi::{CUfunction, CUjit_option, CUmodule};
use crate::loader::try_driver;
#[derive(Debug, Clone)]
pub struct JitOptions {
pub max_registers: u32,
pub optimization_level: u32,
pub generate_debug_info: bool,
pub target_from_context: bool,
}
impl Default for JitOptions {
fn default() -> Self {
Self {
max_registers: 0,
optimization_level: 4,
generate_debug_info: false,
target_from_context: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum JitSeverity {
Fatal,
Error,
Warning,
Info,
}
impl std::fmt::Display for JitSeverity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Fatal => f.write_str("fatal"),
Self::Error => f.write_str("error"),
Self::Warning => f.write_str("warning"),
Self::Info => f.write_str("info"),
}
}
}
#[derive(Debug, Clone)]
pub struct JitDiagnostic {
pub severity: JitSeverity,
pub kernel: Option<String>,
pub line: Option<u32>,
pub message: String,
}
#[derive(Debug, Clone, Default)]
pub struct JitLog {
pub info: String,
pub error: String,
}
impl JitLog {
#[must_use]
pub fn is_empty(&self) -> bool {
self.info.is_empty() && self.error.is_empty()
}
#[must_use]
pub fn has_errors(&self) -> bool {
!self.error.is_empty()
}
#[must_use]
pub fn parse_diagnostics(&self) -> Vec<JitDiagnostic> {
let mut out = Vec::new();
for line in self.error.lines().chain(self.info.lines()) {
if let Some(d) = parse_ptxas_line(line) {
out.push(d);
}
}
out
}
#[must_use]
pub fn errors(&self) -> Vec<JitDiagnostic> {
self.parse_diagnostics()
.into_iter()
.filter(|d| matches!(d.severity, JitSeverity::Error | JitSeverity::Fatal))
.collect()
}
#[must_use]
pub fn warnings(&self) -> Vec<JitDiagnostic> {
self.parse_diagnostics()
.into_iter()
.filter(|d| matches!(d.severity, JitSeverity::Warning))
.collect()
}
}
fn parse_ptxas_line(line: &str) -> Option<JitDiagnostic> {
let line = line.trim();
if line.is_empty() {
return None;
}
let rest = line.strip_prefix("ptxas ")?;
let (sev_str, after_sev) = split_first_word(rest.trim_start());
let severity = match sev_str.to_ascii_lowercase().trim_end_matches(':') {
"fatal" => JitSeverity::Fatal,
"error" => JitSeverity::Error,
"warning" => JitSeverity::Warning,
"info" => JitSeverity::Info,
_ => JitSeverity::Info,
};
let body = skip_colon(after_sev.trim_start());
let (kernel, after_kernel) = extract_kernel_name(body);
let (line_no, after_line) = extract_line_number(after_kernel);
let message = extract_message(after_line.trim());
Some(JitDiagnostic {
severity,
kernel,
line: line_no,
message: message.to_string(),
})
}
fn split_first_word(s: &str) -> (&str, &str) {
match s.find(|c: char| c.is_whitespace()) {
Some(pos) => (&s[..pos], &s[pos..]),
None => (s, ""),
}
}
fn skip_colon(s: &str) -> &str {
if let Some(pos) = s.find(':') {
s[pos + 1..].trim_start()
} else {
s
}
}
fn extract_kernel_name(s: &str) -> (Option<String>, &str) {
let s = s.trim_start();
if !s.starts_with('\'') {
return (None, s);
}
let inner = &s[1..];
if let Some(end) = inner.find('\'') {
let name = inner[..end].to_string();
let after = &inner[end + 1..];
(Some(name), after)
} else {
(None, s)
}
}
fn extract_line_number(s: &str) -> (Option<u32>, &str) {
let s_trim = s.trim_start_matches([',', ' ', ';']);
let lower = s_trim.to_ascii_lowercase();
if !lower.starts_with("line ") {
return (None, s);
}
let after_line = &s_trim[5..]; let (num_str, rest) = split_first_word(after_line.trim_start());
let num_clean: String = num_str.chars().filter(|c| c.is_ascii_digit()).collect();
if let Ok(n) = num_clean.parse::<u32>() {
(Some(n), rest)
} else {
(None, s)
}
}
fn extract_message(s: &str) -> &str {
let (word, rest) = split_first_word(s);
let word_clean = word.trim_end_matches(':');
if matches!(
word_clean.to_ascii_lowercase().as_str(),
"error" | "warning" | "info" | "fatal"
) {
skip_colon(rest.trim_start())
} else {
s
}
}
pub struct Module {
raw: CUmodule,
}
unsafe impl Send for Module {}
const JIT_LOG_BUFFER_SIZE: usize = 4096;
impl Module {
pub fn from_ptx(ptx: &str) -> CudaResult<Self> {
let api = try_driver()?;
let c_ptx = CString::new(ptx).map_err(|_| CudaError::InvalidValue)?;
let mut raw = CUmodule::default();
crate::cuda_call!((api.cu_module_load_data)(
&mut raw,
c_ptx.as_ptr().cast::<c_void>()
))?;
Ok(Self { raw })
}
pub fn from_ptx_with_options(ptx: &str, options: &JitOptions) -> CudaResult<(Self, JitLog)> {
let api = try_driver()?;
let c_ptx = CString::new(ptx).map_err(|_| CudaError::InvalidValue)?;
let mut info_buf: Vec<u8> = vec![0u8; JIT_LOG_BUFFER_SIZE];
let mut error_buf: Vec<u8> = vec![0u8; JIT_LOG_BUFFER_SIZE];
let mut opt_keys: Vec<CUjit_option> = Vec::with_capacity(8);
let mut opt_vals: Vec<*mut c_void> = Vec::with_capacity(8);
opt_keys.push(CUjit_option::InfoLogBuffer);
opt_vals.push(info_buf.as_mut_ptr().cast::<c_void>());
opt_keys.push(CUjit_option::InfoLogBufferSizeBytes);
opt_vals.push(JIT_LOG_BUFFER_SIZE as *mut c_void);
opt_keys.push(CUjit_option::ErrorLogBuffer);
opt_vals.push(error_buf.as_mut_ptr().cast::<c_void>());
opt_keys.push(CUjit_option::ErrorLogBufferSizeBytes);
opt_vals.push(JIT_LOG_BUFFER_SIZE as *mut c_void);
opt_keys.push(CUjit_option::OptimizationLevel);
opt_vals.push(options.optimization_level as *mut c_void);
if options.max_registers > 0 {
opt_keys.push(CUjit_option::MaxRegisters);
opt_vals.push(options.max_registers as *mut c_void);
}
if options.generate_debug_info {
opt_keys.push(CUjit_option::GenerateDebugInfo);
opt_vals.push(core::ptr::without_provenance_mut::<c_void>(1));
}
if options.target_from_context {
opt_keys.push(CUjit_option::TargetFromCuContext);
opt_vals.push(core::ptr::without_provenance_mut::<c_void>(1));
}
let num_options = opt_keys.len() as u32;
let mut raw = CUmodule::default();
let result = crate::cuda_call!((api.cu_module_load_data_ex)(
&mut raw,
c_ptx.as_ptr().cast::<c_void>(),
num_options,
opt_keys.as_mut_ptr(),
opt_vals.as_mut_ptr(),
));
let log = JitLog {
info: buf_to_string(&info_buf),
error: buf_to_string(&error_buf),
};
result?;
Ok((Self { raw }, log))
}
pub fn get_function(&self, name: &str) -> CudaResult<Function> {
let api = try_driver()?;
let c_name = CString::new(name).map_err(|_| CudaError::InvalidValue)?;
let mut raw = CUfunction::default();
crate::cuda_call!((api.cu_module_get_function)(
&mut raw,
self.raw,
c_name.as_ptr()
))?;
Ok(Function { raw })
}
#[inline]
pub fn raw(&self) -> CUmodule {
self.raw
}
}
impl Drop for Module {
fn drop(&mut self) {
if let Ok(api) = try_driver() {
let rc = unsafe { (api.cu_module_unload)(self.raw) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
module = ?self.raw,
"cuModuleUnload failed during drop"
);
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Function {
raw: CUfunction,
}
impl Function {
#[inline]
pub fn raw(&self) -> CUfunction {
self.raw
}
}
fn buf_to_string(buf: &[u8]) -> String {
let len = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
String::from_utf8_lossy(&buf[..len]).trim().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_blank_line_returns_none() {
assert!(parse_ptxas_line("").is_none());
assert!(parse_ptxas_line(" ").is_none());
}
#[test]
fn parse_non_ptxas_line_returns_none() {
assert!(parse_ptxas_line("nvcc error: something").is_none());
assert!(parse_ptxas_line(" error: foo").is_none());
}
#[test]
fn parse_standard_error_with_kernel_and_line() {
let line = "ptxas error : 'vec_add', line 42; error : Unknown instruction 'xyz.f32'";
let d = parse_ptxas_line(line).expect("should parse");
assert_eq!(d.severity, JitSeverity::Error);
assert_eq!(d.kernel.as_deref(), Some("vec_add"));
assert_eq!(d.line, Some(42));
assert!(
d.message.contains("Unknown instruction"),
"msg: {}",
d.message
);
}
#[test]
fn parse_warning_with_kernel_and_line() {
let line = "ptxas warning : 'my_kernel', line 7; warning : Double-precision instructions will be slow";
let d = parse_ptxas_line(line).expect("should parse");
assert_eq!(d.severity, JitSeverity::Warning);
assert_eq!(d.kernel.as_deref(), Some("my_kernel"));
assert_eq!(d.line, Some(7));
assert!(d.message.contains("Double-precision"), "msg: {}", d.message);
}
#[test]
fn parse_info_register_usage() {
let line =
"ptxas info : 'reduce_kernel' used 32 registers, 0 bytes smem, 0 bytes cmem[0]";
let d = parse_ptxas_line(line).expect("should parse");
assert_eq!(d.severity, JitSeverity::Info);
assert_eq!(d.kernel.as_deref(), Some("reduce_kernel"));
assert!(d.message.contains("32 registers"), "msg: {}", d.message);
assert!(d.line.is_none());
}
#[test]
fn parse_fatal_no_kernel() {
let line = "ptxas fatal : Unresolved extern function 'missing_func'";
let d = parse_ptxas_line(line).expect("should parse");
assert_eq!(d.severity, JitSeverity::Fatal);
assert!(d.kernel.is_none());
assert!(d.message.contains("Unresolved"), "msg: {}", d.message);
}
#[test]
fn parse_error_no_kernel_no_line() {
let line = "ptxas error : syntax error near token ';'";
let d = parse_ptxas_line(line).expect("should parse");
assert_eq!(d.severity, JitSeverity::Error);
assert!(d.kernel.is_none());
assert!(d.line.is_none());
assert!(d.message.contains("syntax error"), "msg: {}", d.message);
}
#[test]
fn jitlog_is_empty_for_default() {
let log = JitLog::default();
assert!(log.is_empty());
assert!(!log.has_errors());
}
#[test]
fn jitlog_has_errors_when_error_buf_nonempty() {
let log = JitLog {
info: String::new(),
error: "ptxas error : something went wrong".to_string(),
};
assert!(log.has_errors());
assert!(!log.is_empty());
}
#[test]
fn jitlog_parse_diagnostics_multiline() {
let log = JitLog {
error: concat!(
"ptxas error : 'k1', line 5; error : bad opcode\n",
"ptxas warning : 'k1', line 8; warning : slow path\n",
)
.to_string(),
info: "ptxas info : 'k1' used 8 registers, 0 bytes smem\n".to_string(),
};
let diags = log.parse_diagnostics();
assert_eq!(diags.len(), 3);
assert_eq!(diags[0].severity, JitSeverity::Error);
assert_eq!(diags[1].severity, JitSeverity::Warning);
assert_eq!(diags[2].severity, JitSeverity::Info);
}
#[test]
fn jitlog_errors_filter() {
let log = JitLog {
error: concat!(
"ptxas error : 'k', line 1; error : bad\n",
"ptxas warning : 'k', line 2; warning : slow\n",
)
.to_string(),
info: "ptxas info : 'k' used 4 registers\n".to_string(),
};
let errs = log.errors();
assert_eq!(errs.len(), 1);
assert_eq!(errs[0].severity, JitSeverity::Error);
}
#[test]
fn jitlog_warnings_filter() {
let log = JitLog {
error: "ptxas warning : 'k', line 3; warning : something slow\n".to_string(),
info: String::new(),
};
let warns = log.warnings();
assert_eq!(warns.len(), 1);
assert_eq!(warns[0].severity, JitSeverity::Warning);
assert_eq!(warns[0].line, Some(3));
}
#[test]
fn buf_to_string_null_terminated() {
let mut buf = b"hello\0\0\0".to_vec();
buf.extend_from_slice(&[0u8; 100]);
assert_eq!(buf_to_string(&buf), "hello");
}
#[test]
fn buf_to_string_empty() {
assert_eq!(buf_to_string(&[0u8; 10]), "");
}
#[test]
fn buf_to_string_no_null() {
let buf = b"abc".to_vec();
assert_eq!(buf_to_string(&buf), "abc");
}
#[test]
fn jit_severity_display() {
assert_eq!(JitSeverity::Fatal.to_string(), "fatal");
assert_eq!(JitSeverity::Error.to_string(), "error");
assert_eq!(JitSeverity::Warning.to_string(), "warning");
assert_eq!(JitSeverity::Info.to_string(), "info");
}
}