use encase::{ShaderType, UniformBuffer};
use glam::{Mat4, Vec2, Vec4};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::collections::HashMap;
use crate::picking::LayerPickingResult;
use crate::render_traits::{DrawToRasterGpu, DrawToRasterCpu, DrawToSvg, MarginParams, PickableLayer, PreparedLayer, ViewParams};
use crate::viewport::{DataCoord, ScreenCoord};
use crate::render_types::{CpuContext, CpuRenderPass, PrepareResult, RenderResult};
use crate::render_types::GpuContext;
use crate::wgpu;
use crate::two::shapes::TwoElement;
use crate::two::svg::{update_svg, SvgContext};
use crate::layers::point_layer::PointShapeMode;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Point3dLayerParams {
pub layer_id: String,
pub bounds: Option<MarginParams>,
pub point_radius: f32,
pub point_shape_mode: PointShapeMode,
pub position_x: Arc<Vec<f32>>,
pub position_y: Arc<Vec<f32>>,
pub position_z: Arc<Vec<f32>>,
pub labels_vec: Arc<Vec<i32>>,
}
pub struct Point3dLayer {
view_params: ViewParams,
layer_params: Point3dLayerParams,
}
impl Point3dLayer {
pub fn new(
view_params: ViewParams,
layer_params: Point3dLayerParams,
) -> Self {
Self {
view_params,
layer_params,
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl PreparedLayer for Point3dLayer {
async fn prepare(&mut self, _gpu_context: Option<&GpuContext<'_>>) -> PrepareResult {
return PrepareResult {
bailed_early: false,
};
}
}
#[derive(ShaderType, Debug)]
struct Point3dLayerUniforms {
layer_size: Vec2, camera_view: Mat4, point_radius: f32, point_shape_mode: u32, color: Vec4, }
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToRasterGpu for Point3dLayer {
async fn draw(&self, gpu_context: &GpuContext<'_>, pass: &mut wgpu::RenderPass) {
let GpuContext { device, queue } = gpu_context;
let Self { layer_params, view_params } = self;
let x_bytes = bytemuck::cast_slice(&layer_params.position_x);
let y_bytes = bytemuck::cast_slice(&layer_params.position_y);
let z_bytes = bytemuck::cast_slice(&layer_params.position_z);
let n = layer_params.labels_vec.len();
let labels_bytes: &[u8] = bytemuck::cast_slice(&layer_params.labels_vec);
let x_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("X Coordinates Storage Buffer"),
size: x_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&x_buffer, 0, x_bytes);
let y_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Y Coordinates Storage Buffer"),
size: y_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&y_buffer, 0, y_bytes);
let z_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Z Coordinates Storage Buffer"),
size: z_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&z_buffer, 0, z_bytes);
let labels_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Class labels Storage Buffer"),
size: labels_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&labels_buffer, 0, labels_bytes);
let camera_view = view_params.camera_view.unwrap_or([
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
]);
let bounds = if layer_params.bounds.is_none() {
&view_params.margins
} else {
&layer_params.bounds
};
let margin_top = if let Some(margin_params) = &bounds {
margin_params.margin_top.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_right = if let Some(margin_params) = &bounds {
margin_params.margin_right.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_bottom = if let Some(margin_params) = &bounds {
margin_params.margin_bottom.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_left = if let Some(margin_params) = &bounds {
margin_params.margin_left.unwrap_or(0.0)
} else { 0.0 } as f64;
let viewport_w = view_params.width as f32;
let viewport_h = view_params.height as f32;
let layer_w = viewport_w - (margin_left + margin_right) as f32;
let layer_h = viewport_h - (margin_top + margin_bottom) as f32;
let uniform_struct = Point3dLayerUniforms {
layer_size: Vec2::new(layer_w, layer_h),
camera_view: Mat4::from_cols_array(&camera_view),
point_radius: layer_params.point_radius,
point_shape_mode: match layer_params.point_shape_mode {
PointShapeMode::Square => 0,
PointShapeMode::Circle => 1,
},
color: Vec4::from_array([1.0, 0.0, 0.0, 1.0]),
};
let mut buffer = UniformBuffer::new(Vec::<u8>::new());
buffer.write(&uniform_struct).unwrap();
let uniform_bytes = buffer.into_inner();
let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Uniform Buffer"),
size: uniform_bytes.len() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&uniform_buffer, 0, &uniform_bytes);
let bind_group_layout = device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Point3dLayer BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 4,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let bind_group = device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Point3dLayer BG"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: z_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: labels_buffer.as_entire_binding(),
},
],
});
let shader = device
.create_shader_module(wgpu::include_wgsl!("shaders/point_3d_layer.wgsl"));
let render_pipeline_layout = device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Render Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let render_pipeline = device
.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Render Pipeline"),
layout: Some(&render_pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
compilation_options: Default::default(),
buffers: &[],
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
compilation_options: Default::default(),
targets: &[Some(wgpu::ColorTargetState {
format: wgpu::TextureFormat::Rgba8UnormSrgb,
blend: Some(wgpu::BlendState {
color: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
alpha: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
}),
write_mask: wgpu::ColorWrites::ALL,
})],
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleStrip,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
cache: None,
multiview_mask: None,
});
pass.set_viewport(
margin_left as f32,
margin_top as f32,
viewport_w - (margin_left + margin_right) as f32,
viewport_h - (margin_top + margin_bottom) as f32,
0.0,
1.0,
);
pass.set_scissor_rect(
margin_left as u32,
margin_top as u32,
(viewport_w - (margin_left + margin_right) as f32) as u32,
(viewport_h - (margin_top + margin_bottom) as f32) as u32,
);
pass.set_pipeline(&render_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.draw(0..4, 0..(n as u32));
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToRasterCpu for Point3dLayer {
async fn draw(&self, _cpu_context: &CpuContext<'_>, _pass: &mut CpuRenderPass) {}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToSvg for Point3dLayer {
async fn draw(&self, _ctx: &mut SvgContext) {
}
}
inventory::submit! {
crate::registry::LayerRegistration {
layer_type_name: "Point3dLayer",
create_layer: |value, view_params| {
let params: Point3dLayerParams = serde_json::from_value(value).unwrap();
Box::new(Point3dLayer::new(view_params.clone(), params))
},
}
}
impl PickableLayer for Point3dLayer {
fn pick(&self, _screen_coord: ScreenCoord, data_coord: Option<DataCoord>) -> Option<LayerPickingResult> {
let DataCoord::ThreeD { x: cx, y: cy, z: cz } = data_coord? else {
return None;
};
let n = self.layer_params.labels_vec.len();
if n == 0 {
return None;
}
let mut min_dist_sq = f32::MAX;
let mut closest_idx = 0usize;
for i in 0..n {
let dx = self.layer_params.position_x[i] - cx;
let dy = self.layer_params.position_y[i] - cy;
let dz = self.layer_params.position_z[i] - cz;
let dist_sq = dx * dx + dy * dy + dz * dz;
if dist_sq < min_dist_sq {
min_dist_sq = dist_sq;
closest_idx = i;
}
}
let mut info = HashMap::new();
info.insert("index".to_string(), closest_idx.to_string());
info.insert("label".to_string(), self.layer_params.labels_vec[closest_idx].to_string());
info.insert("x".to_string(), self.layer_params.position_x[closest_idx].to_string());
info.insert("y".to_string(), self.layer_params.position_y[closest_idx].to_string());
info.insert("z".to_string(), self.layer_params.position_z[closest_idx].to_string());
Some(LayerPickingResult {
layer_id: self.layer_params.layer_id.clone(),
info,
})
}
}