use super::error::HybridResult;
pub trait HybridWorkload: Send + Sync {
type Result;
fn workload_size(&self) -> usize;
fn execute_cpu(&self) -> Self::Result;
fn execute_gpu(&self) -> HybridResult<Self::Result>;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
fn supports_gpu(&self) -> bool {
true
}
fn memory_estimate(&self) -> usize {
0
}
}
#[allow(dead_code)]
pub struct FnWorkload<F, R>
where
F: FnOnce() -> R + Send + Sync,
{
cpu_fn: Option<F>,
size: usize,
_marker: std::marker::PhantomData<R>,
}
#[allow(dead_code)]
impl<F, R> FnWorkload<F, R>
where
F: FnOnce() -> R + Send + Sync,
{
pub fn cpu_only(f: F, size: usize) -> Self {
Self {
cpu_fn: Some(f),
size,
_marker: std::marker::PhantomData,
}
}
}
#[allow(dead_code)]
pub type BoxedWorkload<R> = Box<dyn HybridWorkload<Result = R>>;
#[cfg(test)]
mod tests {
use super::*;
struct TestWorkload {
data: Vec<f32>,
}
impl HybridWorkload for TestWorkload {
type Result = f32;
fn workload_size(&self) -> usize {
self.data.len()
}
fn execute_cpu(&self) -> Self::Result {
self.data.iter().sum()
}
fn execute_gpu(&self) -> HybridResult<Self::Result> {
Ok(self.data.iter().sum())
}
fn name(&self) -> &str {
"TestWorkload"
}
}
#[test]
fn test_workload_cpu() {
let workload = TestWorkload {
data: vec![1.0, 2.0, 3.0, 4.0],
};
assert_eq!(workload.workload_size(), 4);
assert!((workload.execute_cpu() - 10.0).abs() < f32::EPSILON);
}
#[test]
fn test_workload_gpu() {
let workload = TestWorkload {
data: vec![1.0, 2.0, 3.0, 4.0],
};
let result = workload.execute_gpu().unwrap();
assert!((result - 10.0).abs() < f32::EPSILON);
}
#[test]
fn test_workload_name() {
let workload = TestWorkload { data: vec![] };
assert_eq!(workload.name(), "TestWorkload");
}
}