use encase::{ShaderType, UniformBuffer};
use glam::{Mat4, Vec2, Vec4};
use serde::{Deserialize, Serialize};
use std::sync::{Arc};
use std::collections::HashMap;
use crate::picking::LayerPickingResult;
use crate::render_traits::{AspectRatioMode, AspectRatioAlignmentMode, DrawToRasterGpu, DrawToRasterCpu, DrawToSvg, MarginParams, PickableLayer, PreparedLayer, UnitsMode, ViewParams};
use crate::viewport::{DataCoord, ScreenCoord};
use crate::render_types::{CpuContext, CpuRenderPass, PrepareResult, RenderResult};
use crate::render_types::GpuContext;
use crate::wgpu;
use crate::two::shapes::{TwoCircle, TwoColor, TwoElement, TwoGroup, TwoLine, TwoPath, TwoRectangle, TwoText};
use crate::two::svg::{update_svg, SvgContext};
use crate::positioning::get_point_position;
#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)]
pub enum PointShapeMode {
Square,
Circle,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct PointLayerParams {
pub layer_id: String,
pub bounds: Option<MarginParams>,
pub data_unit_mode_x: UnitsMode,
pub data_unit_mode_y: UnitsMode,
pub point_radius: f32,
pub point_radius_unit_mode_x: UnitsMode,
pub point_radius_unit_mode_y: UnitsMode,
pub point_shape_mode: PointShapeMode,
pub position_x: Arc<Vec<f32>>, pub position_y: Arc<Vec<f32>>,
pub labels_vec: Arc<Vec<i32>>,
}
pub struct PointLayer {
view_params: ViewParams,
layer_params: PointLayerParams,
}
impl PointLayer {
pub fn new(
view_params: ViewParams,
layer_params: PointLayerParams,
) -> Self {
if layer_params.point_radius_unit_mode_x == UnitsMode::Data && layer_params.data_unit_mode_x == UnitsMode::Pixels {
panic!("point_radius_unit_mode cannot be 'data' when data_unit_mode is 'pixels'");
}
if layer_params.point_radius_unit_mode_y == UnitsMode::Data && layer_params.data_unit_mode_y == UnitsMode::Pixels {
panic!("point_radius_unit_mode cannot be 'data' when data_unit_mode is 'pixels'");
}
Self {
view_params,
layer_params,
}
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl PreparedLayer for PointLayer {
async fn prepare(&mut self, _gpu_context: Option<&GpuContext<'_>>) -> PrepareResult {
return PrepareResult {
bailed_early: false,
};
}
}
#[derive(ShaderType, Debug)]
struct PointLayerUniforms {
layer_size: Vec2, camera_view: Mat4, data_unit_mode_x: u32, data_unit_mode_y: u32, point_radius: f32, point_radius_unit_mode_x: u32, point_radius_unit_mode_y: u32, point_shape_mode: u32, aspect_ratio_mode: u32, aspect_ratio_alignment_mode: u32, fill_color_mode: u32,
fill_color: Vec4, }
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToRasterGpu for PointLayer {
async fn draw(&self, gpu_context: &GpuContext<'_>, pass: &mut wgpu::RenderPass) {
let GpuContext { device, queue } = gpu_context;
let Self { layer_params, view_params } = self;
let x_bytes = bytemuck::cast_slice(&layer_params.position_x);
let y_bytes = bytemuck::cast_slice(&layer_params.position_y);
let n = layer_params.labels_vec.len();
let labels_bytes: &[u8] = bytemuck::cast_slice(&layer_params.labels_vec);
let x_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("X Coordinates Storage Buffer"),
size: x_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&x_buffer, 0, x_bytes);
let y_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Y Coordinates Storage Buffer"),
size: y_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&y_buffer, 0, y_bytes);
let labels_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Class labels Storage Buffer"),
size: labels_bytes.len() as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&labels_buffer, 0, labels_bytes);
let camera_view = view_params.camera_view.unwrap_or([
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
]);
let bounds = if layer_params.bounds.is_none() {
&view_params.margins
} else {
&layer_params.bounds
};
let margin_top = if let Some(margin_params) = &bounds {
margin_params.margin_top.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_right = if let Some(margin_params) = &bounds {
margin_params.margin_right.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_bottom = if let Some(margin_params) = &bounds {
margin_params.margin_bottom.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_left = if let Some(margin_params) = &bounds {
margin_params.margin_left.unwrap_or(0.0)
} else { 0.0 } as f64;
let viewport_w = view_params.width as f32;
let viewport_h = view_params.height as f32;
let layer_w = viewport_w - (margin_left + margin_right) as f32;
let layer_h = viewport_h - (margin_top + margin_bottom) as f32;
let uniform_struct = PointLayerUniforms {
layer_size: Vec2::new(layer_w, layer_h),
camera_view: Mat4::from_cols_array(&camera_view),
data_unit_mode_x: match layer_params.data_unit_mode_x {
UnitsMode::Pixels => 0,
UnitsMode::Data => 1,
},
data_unit_mode_y: match layer_params.data_unit_mode_y {
UnitsMode::Pixels => 0,
UnitsMode::Data => 1,
},
point_radius: layer_params.point_radius,
point_radius_unit_mode_x: match layer_params.point_radius_unit_mode_x {
UnitsMode::Pixels => 0,
UnitsMode::Data => 1,
},
point_radius_unit_mode_y: match layer_params.point_radius_unit_mode_y {
UnitsMode::Pixels => 0,
UnitsMode::Data => 1,
},
point_shape_mode: match layer_params.point_shape_mode {
PointShapeMode::Square => 0,
PointShapeMode::Circle => 1,
},
aspect_ratio_mode: match view_params.aspect_ratio_mode {
AspectRatioMode::Ignore => 0,
AspectRatioMode::Contain => 1,
AspectRatioMode::Cover => 2,
},
aspect_ratio_alignment_mode: match view_params.aspect_ratio_alignment_mode {
AspectRatioAlignmentMode::Center => 0,
AspectRatioAlignmentMode::Start => 1,
AspectRatioAlignmentMode::End => 2,
},
fill_color_mode: 2, fill_color: Vec4::from_array([1.0, 0.0, 0.0, 1.0]),
};
let mut buffer = UniformBuffer::new(Vec::<u8>::new());
buffer.write(&uniform_struct).unwrap();
let uniform_bytes = buffer.into_inner();
let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Uniform Buffer"),
size: uniform_bytes.len() as u64,
usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
queue.write_buffer(&uniform_buffer, 0, &uniform_bytes);
let bind_group_layout = device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("PointLayer BGL"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::VERTEX_FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::VERTEX,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 3,
visibility: wgpu::ShaderStages::FRAGMENT,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
],
});
let bind_group = device
.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("PointLayer BG"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: uniform_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: x_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: y_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: labels_buffer.as_entire_binding(),
},
],
});
let shader = device
.create_shader_module(wgpu::include_wgsl!("shaders/point_layer.wgsl"));
let render_pipeline_layout = device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Render Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let render_pipeline = device
.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Render Pipeline"),
layout: Some(&render_pipeline_layout),
vertex: wgpu::VertexState {
module: &shader,
entry_point: Some("vs_main"),
compilation_options: Default::default(),
buffers: &[],
},
fragment: Some(wgpu::FragmentState {
module: &shader,
entry_point: Some("fs_main"),
compilation_options: Default::default(),
targets: &[Some(wgpu::ColorTargetState {
format: wgpu::TextureFormat::Rgba8UnormSrgb,
blend: Some(wgpu::BlendState {
color: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
alpha: wgpu::BlendComponent {
src_factor: wgpu::BlendFactor::SrcAlpha,
dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
operation: wgpu::BlendOperation::Add,
},
}),
write_mask: wgpu::ColorWrites::ALL,
})],
}),
primitive: wgpu::PrimitiveState {
topology: wgpu::PrimitiveTopology::TriangleStrip,
..Default::default()
},
depth_stencil: None,
multisample: wgpu::MultisampleState::default(),
cache: None,
multiview_mask: None,
});
pass.set_viewport(
margin_left as f32,
margin_top as f32,
viewport_w - (margin_left + margin_right) as f32,
viewport_h - (margin_top + margin_bottom) as f32,
0.0, 1.0, );
pass.set_scissor_rect(
margin_left as u32,
margin_top as u32,
(viewport_w - (margin_left + margin_right) as f32) as u32,
(viewport_h - (margin_top + margin_bottom) as f32) as u32,
);
pass.set_pipeline(&render_pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.draw(0..4, 0..(n as u32));
}
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToRasterCpu for PointLayer {
async fn draw(&self, _cpu_context: &CpuContext<'_>, _pass: &mut CpuRenderPass) {}
}
const CATEGORICAL_COLORS: [(u8, u8, u8); 10] = [
(31, 119, 180),
(255, 127, 14),
(44, 160, 44),
(214, 39, 40),
(148, 103, 189),
(227, 119, 194),
(127, 127, 127),
(188, 189, 34),
(23, 190, 207),
(219, 219, 219),
];
fn get_categorical_color(index: i32) -> (u8, u8, u8) {
CATEGORICAL_COLORS[index.rem_euclid(10) as usize]
}
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl DrawToSvg for PointLayer {
async fn draw(&self, ctx: &mut SvgContext) {
let Self { layer_params, view_params } = self;
let n = layer_params.labels_vec.len();
let camera_view = view_params.camera_view.unwrap_or([
1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0,
]);
let bounds = if layer_params.bounds.is_none() {
&view_params.margins
} else {
&layer_params.bounds
};
let margin_top = if let Some(margin_params) = &bounds {
margin_params.margin_top.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_right = if let Some(margin_params) = &bounds {
margin_params.margin_right.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_bottom = if let Some(margin_params) = &bounds {
margin_params.margin_bottom.unwrap_or(0.0)
} else { 0.0 } as f64;
let margin_left = if let Some(margin_params) = &bounds {
margin_params.margin_left.unwrap_or(0.0)
} else { 0.0 } as f64;
let viewport_w = view_params.width as f32;
let viewport_h = view_params.height as f32;
let layer_w = viewport_w - (margin_left + margin_right) as f32;
let layer_h = viewport_h - (margin_top + margin_bottom) as f32;
let mut svg_elements: Vec<TwoElement> = Vec::with_capacity(n);
for i in 0..n {
let x = layer_params.position_x[i];
let y = layer_params.position_y[i];
let (px, py) = get_point_position(
x,
y,
layer_w,
layer_h,
&camera_view,
layer_params.data_unit_mode_x,
layer_params.data_unit_mode_y,
view_params.aspect_ratio_mode,
view_params.aspect_ratio_alignment_mode,
None,
);
let point_radius = layer_params.point_radius;
let label = layer_params.labels_vec[i];
let (r, g, b) = get_categorical_color(label);
let fill = Some(TwoColor::Rgb((r, g, b)));
svg_elements.push(match layer_params.point_shape_mode {
PointShapeMode::Circle => TwoElement::Circle(TwoCircle {
x: px as f64,
y: (layer_h - py) as f64,
radius: point_radius as f64,
fill,
..Default::default()
}),
PointShapeMode::Square => TwoElement::Rectangle(TwoRectangle {
x: (px - point_radius) as f64,
y: ((layer_h - py) - point_radius) as f64,
width: (point_radius * 2.0) as f64,
height: (point_radius * 2.0) as f64,
fill,
..Default::default()
})
});
}
let svg_elements = vec![
TwoElement::Group(TwoGroup {
elements: svg_elements,
translate: Some((margin_left, margin_top)),
layer_id: Some(layer_params.layer_id.clone()),
clip_rect: Some((0.0, 0.0, layer_w as f64, layer_h as f64)),
..Default::default()
})
];
update_svg(ctx, &svg_elements);
}
}
inventory::submit! {
crate::registry::LayerRegistration {
layer_type_name: "PointLayer",
create_layer: |value, view_params| {
let params: PointLayerParams = serde_json::from_value(value).unwrap();
Box::new(PointLayer::new(view_params.clone(), params))
},
}
}
impl PickableLayer for PointLayer {
fn pick(&self, _screen_coord: ScreenCoord, data_coord: Option<DataCoord>) -> Option<LayerPickingResult> {
let DataCoord::TwoD { x: cx, y: cy } = data_coord? else {
return None;
};
let n = self.layer_params.labels_vec.len();
if n == 0 {
return None;
}
let mut min_dist_sq = f32::MAX;
let mut closest_idx = 0usize;
for i in 0..n {
let dx = self.layer_params.position_x[i] - cx;
let dy = self.layer_params.position_y[i] - cy;
let dist_sq = dx * dx + dy * dy;
if dist_sq < min_dist_sq {
min_dist_sq = dist_sq;
closest_idx = i;
}
}
let mut info = HashMap::new();
info.insert("index".to_string(), closest_idx.to_string());
info.insert("label".to_string(), self.layer_params.labels_vec[closest_idx].to_string());
info.insert("x".to_string(), self.layer_params.position_x[closest_idx].to_string());
info.insert("y".to_string(), self.layer_params.position_y[closest_idx].to_string());
Some(LayerPickingResult {
layer_id: self.layer_params.layer_id.clone(),
info,
})
}
}