1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
//! Data-driven kernel module loading.
//!
//! Replaces the 812-line hand-unrolled `new()` constructor with a loop over
//! `KERNEL_MODULES`, reducing ~800 LOC to ~30.
use std::sync::Arc;
use std::time::Instant;
use cudarc::nvrtc::Ptx;
use xlog_core::{Result, XlogError};
use super::{CudaKernelProvider, KernelModuleSource, PtxLoadProfile};
use crate::kernel_manifest_data::KERNEL_MODULES;
use crate::CudaDevice;
impl CudaKernelProvider {
/// Load every kernel module listed in `KERNEL_MODULES` into `device`.
///
/// Returns `Some(PtxLoadProfile)` when `profiling` is true, `None` otherwise.
pub(crate) fn load_all_kernel_modules(
device: &Arc<CudaDevice>,
profiling: bool,
) -> Result<Option<PtxLoadProfile>> {
let cc = super::detect_compute_capability(device)?;
let mut profile = PtxLoadProfile::default();
for spec in KERNEL_MODULES {
let t0 = if profiling {
Some(Instant::now())
} else {
None
};
let source = super::load_module_source(spec.cu_name, cc)?;
let is_cubin = matches!(source, KernelModuleSource::File { is_cubin: true, .. });
match source {
KernelModuleSource::File { path, .. } => {
device
.inner()
.load_file(&path, spec.module_name, spec.kernels)
.map_err(|e| {
XlogError::Kernel(format!(
"Failed to load {} module from {}: {}",
spec.cu_name,
path.display(),
e
))
})?;
}
KernelModuleSource::EmbeddedPortablePtx { ptx } => {
device
.inner()
.load_ptx(Ptx::from_src(ptx), spec.module_name, spec.kernels)
.map_err(|e| {
XlogError::Kernel(format!(
"Failed to load embedded {} portable PTX: {}",
spec.cu_name, e
))
})?;
}
}
if let Some(t0) = t0 {
if profiling {
device.inner().synchronize().map_err(|e| {
XlogError::Kernel(format!("sync after {} load: {}", spec.cu_name, e))
})?;
}
let elapsed = t0.elapsed().as_secs_f64();
profile
.per_module_sec
.push((spec.cu_name.to_string(), elapsed));
profile.total_sec += elapsed;
if is_cubin {
profile.cubin_loaded += 1;
} else {
profile.ptx_fallback += 1;
}
}
}
Ok(if profiling { Some(profile) } else { None })
}
}