use std::sync::Arc;
use snafu::ResultExt;
use svod_device::device::Device;
use svod_device::registry::DeviceRegistry;
use svod_ir::DeviceSpec;
use svod_runtime::CpuBackend;
use svod_schedule::OptimizerConfig;
use crate::error::{DeviceFactorySnafu, DeviceSnafu};
pub(crate) trait DeviceResolver: Send + Sync {
fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>>;
}
struct EnvResolver;
impl DeviceResolver for EnvResolver {
fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
svod_runtime::DEVICE_FACTORIES.device(spec, registry).context(DeviceFactorySnafu)
}
}
struct CpuBackendResolver(CpuBackend);
impl DeviceResolver for CpuBackendResolver {
fn resolve(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
match spec {
DeviceSpec::Cpu => {
Ok(Arc::new(svod_runtime::create_cpu_device_with_backend(registry, self.0).context(DeviceSnafu)?))
}
_ => svod_runtime::DEVICE_FACTORIES.device(spec, registry).context(DeviceFactorySnafu),
}
}
}
#[allow(rustdoc::private_intra_doc_links)]
pub struct PrepareConfig {
pub optimizer: OptimizerConfig,
pub(crate) resolver: Arc<dyn DeviceResolver>,
pub disable_schedule_cache: bool,
}
impl std::fmt::Debug for PrepareConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PrepareConfig")
.field("optimizer", &self.optimizer)
.field("disable_schedule_cache", &self.disable_schedule_cache)
.finish_non_exhaustive()
}
}
impl Default for PrepareConfig {
fn default() -> Self {
Self { optimizer: OptimizerConfig::default(), resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
}
}
impl PrepareConfig {
pub fn from_env() -> Self {
Self { optimizer: OptimizerConfig::from_env(), resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
}
pub fn for_cpu_backend(backend: CpuBackend) -> Self {
Self {
optimizer: OptimizerConfig::from_env(),
resolver: Arc::new(CpuBackendResolver(backend)),
disable_schedule_cache: false,
}
}
pub(crate) fn resolve_device(&self, spec: &DeviceSpec, registry: &DeviceRegistry) -> crate::Result<Arc<Device>> {
self.resolver.resolve(spec, registry)
}
}
impl From<OptimizerConfig> for PrepareConfig {
fn from(optimizer: OptimizerConfig) -> Self {
Self { optimizer, resolver: Arc::new(EnvResolver), disable_schedule_cache: false }
}
}
#[macro_export]
macro_rules! codegen_tests {
() => {};
($(#[$meta:meta])* fn $name:ident($config:ident) $body:block $($rest:tt)*) => {
mod $name {
#[allow(unused_imports)]
use super::*;
#[test]
$(#[$meta])*
fn clang() {
::svod_schedule::testing::setup_test_tracing();
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
$body
}
#[test]
$(#[$meta])*
fn llvm() {
::svod_schedule::testing::setup_test_tracing();
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
$body
}
}
$crate::codegen_tests!($($rest)*);
};
(#[proptest_config($($pc:tt)*)] $(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident in $strategy:expr),+ $(,)?) $body:block $($rest:tt)*) => {
$crate::codegen_tests!(@proptest $name, $config, [$($param in $strategy),+], $body,
::proptest::test_runner::TestRunner::new($($pc)*), [$(#[$meta])*]);
$crate::codegen_tests!($($rest)*);
};
($(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident in $strategy:expr),+ $(,)?) $body:block $($rest:tt)*) => {
$crate::codegen_tests!(@proptest $name, $config, [$($param in $strategy),+], $body,
::proptest::test_runner::TestRunner::default(), [$(#[$meta])*]);
$crate::codegen_tests!($($rest)*);
};
(@proptest $name:ident, $config:ident, [$($param:ident in $strategy:expr),+], $body:block, $runner:expr, [$(#[$meta:meta])*]) => {
mod $name {
#[allow(unused_imports)]
use super::*;
#[test]
#[allow(unused_parens)]
$(#[$meta])*
fn clang() {
::svod_schedule::testing::setup_test_tracing();
let mut runner = $runner;
runner.run(&($($strategy),+), |($($param),+)| {
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
$body
Ok(())
}).unwrap();
}
#[test]
#[allow(unused_parens)]
$(#[$meta])*
fn llvm() {
::svod_schedule::testing::setup_test_tracing();
let mut runner = $runner;
runner.run(&($($strategy),+), |($($param),+)| {
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
$body
Ok(())
}).unwrap();
}
}
};
($(#[$meta:meta])* fn $name:ident($config:ident, $($param:ident: $ty:ty),+ $(,)?) $body:block $($rest:tt)*) => {
mod $name {
mod clang {
#[allow(unused_imports)]
use super::super::*;
use ::test_case::test_case;
$(#[$meta])*
fn $name($($param: $ty),+) {
::svod_schedule::testing::setup_test_tracing();
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Clang);
$body
}
}
mod llvm {
#[allow(unused_imports)]
use super::super::*;
use ::test_case::test_case;
$(#[$meta])*
fn $name($($param: $ty),+) {
::svod_schedule::testing::setup_test_tracing();
let $config = $crate::PrepareConfig::for_cpu_backend($crate::CpuBackend::Llvm);
$body
}
}
}
$crate::codegen_tests!($($rest)*);
};
}