use super::Error;
use slang::Downcast;
use std::ffi::CString;
pub use slang::{CompileTarget, OptimizationLevel, Stage};
#[derive(Debug)]
pub enum SlangError {
GlobalInit,
SessionCreate(String),
LoadModule(String),
EntryPointNotFound(String),
Composite(String),
Link(String),
EntryPointCode(String),
MalformedSpirv(String),
}
impl std::fmt::Display for SlangError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::GlobalInit => write!(f, "failed to initialize Slang global session"),
Self::SessionCreate(s) => write!(f, "Slang session creation failed: {s}"),
Self::LoadModule(s) => write!(f, "Slang module compilation failed:\n{s}"),
Self::EntryPointNotFound(s) => write!(f, "Slang entry point not found: {s}"),
Self::Composite(s) => write!(f, "Slang component composition failed: {s}"),
Self::Link(s) => write!(f, "Slang linking failed: {s}"),
Self::EntryPointCode(s) => write!(f, "Slang entry-point code retrieval failed: {s}"),
Self::MalformedSpirv(s) => write!(f, "Slang produced malformed SPIR-V: {s}"),
}
}
}
impl std::error::Error for SlangError {}
impl From<SlangError> for Error {
fn from(e: SlangError) -> Self {
Error::SlangCompile(e.to_string())
}
}
pub struct SlangSession {
_global: slang::GlobalSession,
session: slang::Session,
_search_paths: Vec<CString>,
}
impl SlangSession {
pub fn new() -> Result<Self, SlangError> {
Self::with_search_paths(&[])
}
pub fn with_search_paths(paths: &[&str]) -> Result<Self, SlangError> {
let global = slang::GlobalSession::new().ok_or(SlangError::GlobalInit)?;
let search_paths: Vec<CString> = paths
.iter()
.map(|p| CString::new(*p).expect("search path contains NUL byte"))
.collect();
let search_path_ptrs: Vec<*const i8> =
search_paths.iter().map(|s| s.as_ptr()).collect();
let target_desc = slang::TargetDesc::default()
.format(slang::CompileTarget::Spirv)
.profile(global.find_profile("glsl_450"));
let targets = [target_desc];
let session_desc = slang::SessionDesc::default()
.targets(&targets)
.search_paths(&search_path_ptrs);
let session = global
.create_session(&session_desc)
.ok_or_else(|| SlangError::SessionCreate("create_session returned None".into()))?;
Ok(Self {
_global: global,
session,
_search_paths: search_paths,
})
}
pub fn load_file(&self, name: &str) -> Result<SlangModule<'_>, SlangError> {
let module = self
.session
.load_module(name)
.map_err(|e| SlangError::LoadModule(e.to_string()))?;
Ok(SlangModule {
module,
session: &self.session,
})
}
}
pub struct SlangModule<'sess> {
module: slang::Module,
session: &'sess slang::Session,
}
impl<'sess> SlangModule<'sess> {
pub fn compile_entry_point(&self, name: &str) -> Result<Vec<u32>, SlangError> {
let entry_point = self
.module
.find_entry_point_by_name(name)
.ok_or_else(|| SlangError::EntryPointNotFound(name.to_owned()))?;
let program = self
.session
.create_composite_component_type(&[
self.module.downcast().clone(),
entry_point.downcast().clone(),
])
.map_err(|e| SlangError::Composite(e.to_string()))?;
let linked = program
.link()
.map_err(|e| SlangError::Link(e.to_string()))?;
let blob = linked
.entry_point_code(0, 0)
.map_err(|e| SlangError::EntryPointCode(e.to_string()))?;
let bytes = blob.as_slice();
if bytes.len() % 4 != 0 {
return Err(SlangError::MalformedSpirv(format!(
"blob length {} is not a multiple of 4",
bytes.len()
)));
}
Ok(bytes
.chunks_exact(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
}
pub fn compile_slang_file(
search_dir: &str,
module_name: &str,
entry_point: &str,
) -> Result<Vec<u32>, SlangError> {
let session = SlangSession::with_search_paths(&[search_dir])?;
let module = session.load_file(module_name)?;
module.compile_entry_point(entry_point)
}