use super::GpuBackend;
use super::compute::{ComputeManager, TransformParams};
use super::memory::GpuMemoryPool;
use crate::core::{PlottingError, Result};
use crate::data::{Data1D, PooledVec, SharedMemoryPool};
use crate::render::backend::Renderer;
use crate::render::pooled::PooledRenderer;
use bytemuck::{Pod, Zeroable};
use std::sync::Arc;
pub struct GpuRenderer {
gpu_backend: Arc<GpuBackend>,
cpu_fallback: PooledRenderer,
compute_manager: ComputeManager,
gpu_memory_pool: GpuMemoryPool,
gpu_threshold: usize,
stats: GpuRendererStats,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, Pod, Zeroable)]
pub struct GpuVertex {
position: [f32; 2],
color: u32,
size: f32,
}
#[derive(Debug, Clone, Default)]
pub struct GpuRendererStats {
pub gpu_points_processed: u64,
pub cpu_points_processed: u64,
pub gpu_operations: u64,
pub cpu_operations: u64,
pub avg_gpu_time: f32,
pub avg_cpu_time: f32,
pub gpu_memory_used: u64,
pub cpu_memory_used: u64,
}
impl GpuRenderer {
pub async fn new() -> Result<Self> {
let config = super::GpuConfig::default();
Self::with_config(config).await
}
pub async fn with_config(config: super::GpuConfig) -> Result<Self> {
let gpu_backend = Arc::new(GpuBackend::with_config(config).await.map_err(|e| {
PlottingError::GpuInitError {
backend: "wgpu".to_string(),
error: format!("{}", e),
}
})?);
let cpu_fallback = PooledRenderer::new();
let compute_manager =
gpu_backend
.create_compute_manager()
.map_err(|e| PlottingError::GpuInitError {
backend: "compute".to_string(),
error: format!("{}", e),
})?;
let gpu_memory_pool = GpuMemoryPool::new(
gpu_backend.device().device().clone(),
gpu_backend.device().queue().clone(),
gpu_backend.capabilities(),
)?;
let gpu_threshold = Self::calculate_gpu_threshold(gpu_backend.capabilities());
Ok(Self {
gpu_backend,
cpu_fallback,
compute_manager,
gpu_memory_pool,
gpu_threshold,
stats: GpuRendererStats::default(),
})
}
fn calculate_gpu_threshold(capabilities: &super::GpuCapabilities) -> usize {
let base_threshold = if capabilities.supports_compute {
5_000 } else {
50_000 };
if let Some(memory) = capabilities.memory_size {
let memory_gb = memory / (1024 * 1024 * 1024);
match memory_gb {
0..=2 => base_threshold * 2, 3..=8 => base_threshold, _ => base_threshold / 2, }
} else {
base_threshold
}
}
pub fn transform_coordinates_optimal<T>(
&mut self,
x_data: &T,
y_data: &T,
x_range: (f64, f64),
y_range: (f64, f64),
viewport: (f32, f32, f32, f32), ) -> Result<(PooledVec<f32>, PooledVec<f32>)>
where
T: Data1D<f64>,
{
let point_count = x_data.len().min(y_data.len());
if point_count >= self.gpu_threshold && self.gpu_backend.is_available() {
self.transform_coordinates_gpu(x_data, y_data, x_range, y_range, viewport)
.or_else(|_| {
log::warn!("GPU coordinate transformation failed, falling back to CPU");
self.stats.cpu_operations += 1;
self.transform_coordinates_cpu(x_data, y_data, x_range, y_range, viewport)
})
} else {
self.transform_coordinates_cpu(x_data, y_data, x_range, y_range, viewport)
}
}
fn transform_coordinates_gpu<T>(
&mut self,
x_data: &T,
y_data: &T,
x_range: (f64, f64),
y_range: (f64, f64),
viewport: (f32, f32, f32, f32),
) -> Result<(PooledVec<f32>, PooledVec<f32>)>
where
T: Data1D<f64>,
{
let start_time = std::time::Instant::now();
let point_count = x_data.len().min(y_data.len());
let mut input_points = Vec::with_capacity(point_count);
for i in 0..point_count {
let x = x_data.get(i).copied().unwrap_or(0.0) as f32;
let y = y_data.get(i).copied().unwrap_or(0.0) as f32;
input_points.push([x, y]);
}
let input_buffer = self.gpu_memory_pool.create_buffer(
&input_points,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
)?;
let output_buffer = self.gpu_memory_pool.create_buffer_empty::<[f32; 2]>(
point_count,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
)?;
let (left, top, right, bottom) = viewport;
let params = TransformParams {
scale_x: (right - left) / (x_range.1 - x_range.0) as f32,
scale_y: (bottom - top) / (y_range.1 - y_range.0) as f32,
offset_x: left - (x_range.0 as f32 * (right - left) / (x_range.1 - x_range.0) as f32),
offset_y: top - (y_range.0 as f32 * (bottom - top) / (y_range.1 - y_range.0) as f32),
width: (right - left) as u32,
height: (bottom - top) as u32,
_padding: [0, 0],
};
self.compute_manager.execute_transform(
&input_buffer,
&output_buffer,
¶ms,
point_count as u32,
)?;
let output_data = self
.gpu_memory_pool
.read_buffer::<[f32; 2]>(&output_buffer)?;
let x_pool = SharedMemoryPool::new(point_count);
let y_pool = SharedMemoryPool::new(point_count);
let mut x_result = PooledVec::with_capacity(point_count, x_pool);
let mut y_result = PooledVec::with_capacity(point_count, y_pool);
for point in output_data {
x_result.push(point[0]);
y_result.push(point[1]);
}
let elapsed = start_time.elapsed().as_micros() as f32;
self.stats.gpu_operations += 1;
self.stats.gpu_points_processed += point_count as u64;
self.stats.avg_gpu_time =
(self.stats.avg_gpu_time * (self.stats.gpu_operations - 1) as f32 + elapsed)
/ self.stats.gpu_operations as f32;
Ok((x_result, y_result))
}
fn transform_coordinates_cpu<T>(
&mut self,
x_data: &T,
y_data: &T,
x_range: (f64, f64),
y_range: (f64, f64),
viewport: (f32, f32, f32, f32),
) -> Result<(PooledVec<f32>, PooledVec<f32>)>
where
T: Data1D<f64>,
{
let start_time = std::time::Instant::now();
let point_count = x_data.len().min(y_data.len());
let (left, _top, right, bottom) = viewport;
let x_result = self
.cpu_fallback
.transform_x_coordinates_pooled(x_data, x_range.0, x_range.1, left, right)?;
let y_result = self.cpu_fallback.transform_y_coordinates_pooled(
y_data, y_range.0, y_range.1, bottom, left, )?;
let elapsed = start_time.elapsed().as_micros() as f32;
self.stats.cpu_operations += 1;
self.stats.cpu_points_processed += point_count as u64;
self.stats.avg_cpu_time =
(self.stats.avg_cpu_time * (self.stats.cpu_operations - 1) as f32 + elapsed)
/ self.stats.cpu_operations as f32;
Ok((x_result, y_result))
}
pub fn should_use_gpu(&self, point_count: usize) -> bool {
point_count >= self.gpu_threshold && self.gpu_backend.is_available()
}
pub fn gpu_threshold(&self) -> usize {
self.gpu_threshold
}
pub fn set_gpu_threshold(&mut self, threshold: usize) {
self.gpu_threshold = threshold;
}
pub fn get_stats(&self) -> &GpuRendererStats {
&self.stats
}
pub fn reset_stats(&mut self) {
self.stats = GpuRendererStats::default();
}
pub fn gpu_capabilities(&self) -> &super::GpuCapabilities {
self.gpu_backend.capabilities()
}
pub fn is_gpu_available(&self) -> bool {
self.gpu_backend.is_available()
}
}
impl Renderer for GpuRenderer {
type Error = PlottingError;
fn render(&self) -> std::result::Result<(), Self::Error> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_gpu_renderer_creation() {
let result = GpuRenderer::new().await;
if let Ok(renderer) = result {
assert!(renderer.gpu_threshold() > 0);
println!(
"GPU renderer created successfully with threshold: {}",
renderer.gpu_threshold()
);
} else {
println!("GPU not available, which is expected in CI environments");
}
}
#[tokio::test]
async fn test_coordinate_transformation() {
if let Ok(mut renderer) = GpuRenderer::new().await {
let x_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y_data = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let result = renderer.transform_coordinates_optimal(
&x_data,
&y_data,
(1.0, 5.0), (10.0, 50.0), (0.0, 0.0, 800.0, 600.0), );
assert!(result.is_ok());
let (x_transformed, y_transformed) = result.unwrap();
assert_eq!(x_transformed.len(), 5);
assert_eq!(y_transformed.len(), 5);
}
}
}