Skip to main content

ringkernel_core/hybrid/
traits.rs

1//! Traits for hybrid CPU-GPU workloads.
2
3use super::error::HybridResult;
4
5/// Trait for workloads that can be executed on CPU or GPU.
6///
7/// Implementors provide both CPU and GPU execution paths, allowing the
8/// `HybridDispatcher` to choose the optimal backend based on workload size
9/// and runtime measurements.
10///
11/// # Example
12///
13/// ```ignore
14/// use ringkernel_core::hybrid::{HybridWorkload, HybridResult};
15///
16/// struct VectorAdd {
17///     a: Vec<f32>,
18///     b: Vec<f32>,
19/// }
20///
21/// impl HybridWorkload for VectorAdd {
22///     type Result = Vec<f32>;
23///
24///     fn workload_size(&self) -> usize {
25///         self.a.len()
26///     }
27///
28///     fn execute_cpu(&self) -> Self::Result {
29///         self.a.iter().zip(&self.b).map(|(a, b)| a + b).collect()
30///     }
31///
32///     fn execute_gpu(&self) -> HybridResult<Self::Result> {
33///         // GPU implementation
34///         todo!("GPU kernel execution")
35///     }
36/// }
37/// ```
38pub trait HybridWorkload: Send + Sync {
39    /// The result type produced by the workload.
40    type Result;
41
42    /// Returns the size of the workload (number of elements to process).
43    ///
44    /// This is used by the dispatcher to decide between CPU and GPU execution.
45    fn workload_size(&self) -> usize;
46
47    /// Executes the workload on CPU.
48    ///
49    /// This should typically use Rayon or similar for parallel CPU execution.
50    fn execute_cpu(&self) -> Self::Result;
51
52    /// Executes the workload on GPU.
53    ///
54    /// Returns an error if GPU execution fails.
55    fn execute_gpu(&self) -> HybridResult<Self::Result>;
56
57    /// Returns the name of the workload (for logging/metrics).
58    fn name(&self) -> &str {
59        std::any::type_name::<Self>()
60    }
61
62    /// Returns whether GPU execution is supported.
63    ///
64    /// Override to return `false` if this workload doesn't have a GPU implementation.
65    fn supports_gpu(&self) -> bool {
66        true
67    }
68
69    /// Returns an estimate of memory bytes required for this workload.
70    ///
71    /// Used by the resource guard to prevent OOM situations.
72    fn memory_estimate(&self) -> usize {
73        0
74    }
75}
76
77/// A wrapper to execute any `FnOnce` as a hybrid workload.
78#[allow(dead_code)]
79pub struct FnWorkload<F, R>
80where
81    F: FnOnce() -> R + Send + Sync,
82{
83    cpu_fn: Option<F>,
84    size: usize,
85    _marker: std::marker::PhantomData<R>,
86}
87
88#[allow(dead_code)]
89impl<F, R> FnWorkload<F, R>
90where
91    F: FnOnce() -> R + Send + Sync,
92{
93    /// Creates a CPU-only workload from a function.
94    pub fn cpu_only(f: F, size: usize) -> Self {
95        Self {
96            cpu_fn: Some(f),
97            size,
98            _marker: std::marker::PhantomData,
99        }
100    }
101}
102
103/// A boxed hybrid workload for dynamic dispatch.
104#[allow(dead_code)]
105pub type BoxedWorkload<R> = Box<dyn HybridWorkload<Result = R>>;
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    struct TestWorkload {
112        data: Vec<f32>,
113    }
114
115    impl HybridWorkload for TestWorkload {
116        type Result = f32;
117
118        fn workload_size(&self) -> usize {
119            self.data.len()
120        }
121
122        fn execute_cpu(&self) -> Self::Result {
123            self.data.iter().sum()
124        }
125
126        fn execute_gpu(&self) -> HybridResult<Self::Result> {
127            // Simulate GPU execution
128            Ok(self.data.iter().sum())
129        }
130
131        fn name(&self) -> &str {
132            "TestWorkload"
133        }
134    }
135
136    #[test]
137    fn test_workload_cpu() {
138        let workload = TestWorkload {
139            data: vec![1.0, 2.0, 3.0, 4.0],
140        };
141
142        assert_eq!(workload.workload_size(), 4);
143        assert!((workload.execute_cpu() - 10.0).abs() < f32::EPSILON);
144    }
145
146    #[test]
147    fn test_workload_gpu() {
148        let workload = TestWorkload {
149            data: vec![1.0, 2.0, 3.0, 4.0],
150        };
151
152        let result = workload.execute_gpu().unwrap();
153        assert!((result - 10.0).abs() < f32::EPSILON);
154    }
155
156    #[test]
157    fn test_workload_name() {
158        let workload = TestWorkload { data: vec![] };
159        assert_eq!(workload.name(), "TestWorkload");
160    }
161}