use bevy::asset::RenderAssetUsages;
use bevy::mesh::{Indices, PrimitiveTopology};
use bevy::platform::collections::HashMap;
use bevy::prelude::*;
use std::hash::{DefaultHasher, Hash, Hasher};
use symbios_turtle_3d::{Skeleton, SkeletonPoint};
#[derive(Default)]
struct MeshData {
positions: Vec<Vec3>,
normals: Vec<Vec3>,
colors: Vec<[f32; 4]>,
uvs: Vec<[f32; 2]>,
indices: Vec<u32>,
}
impl MeshData {
fn to_mesh(&self) -> Mesh {
let mut mesh = Mesh::new(
PrimitiveTopology::TriangleList,
RenderAssetUsages::default(),
);
mesh.insert_attribute(Mesh::ATTRIBUTE_POSITION, self.positions.clone());
mesh.insert_attribute(Mesh::ATTRIBUTE_NORMAL, self.normals.clone());
mesh.insert_attribute(Mesh::ATTRIBUTE_COLOR, self.colors.clone());
mesh.insert_attribute(Mesh::ATTRIBUTE_UV_0, self.uvs.clone());
mesh.insert_indices(Indices::U32(self.indices.clone()));
let _ = mesh.generate_tangents();
mesh
}
}
const MAX_RESOLUTION: u32 = 128;
pub struct LSystemMeshBuilder {
buckets: HashMap<u16, MeshData>,
resolution: u32,
}
impl Default for LSystemMeshBuilder {
fn default() -> Self {
Self {
buckets: HashMap::new(),
resolution: 8,
}
}
}
impl LSystemMeshBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_resolution(mut self, res: u32) -> Self {
if res > MAX_RESOLUTION {
warn!(
"Mesh resolution {} exceeds maximum of {}; clamping to {}",
res, MAX_RESOLUTION, MAX_RESOLUTION
);
}
self.resolution = res.clamp(3, MAX_RESOLUTION);
self
}
pub fn build(mut self, skeleton: &Skeleton) -> HashMap<u16, Mesh> {
for strand in &skeleton.strands {
if strand.len() < 2 {
continue;
}
self.process_strand(strand);
}
self.buckets
.into_iter()
.map(|(k, v)| (k, v.to_mesh()))
.collect()
}
fn process_strand(&mut self, points: &[SkeletonPoint]) {
let filtered: Vec<&SkeletonPoint> = {
let mut result = vec![&points[0]];
for point in &points[1..] {
let last = result.last().unwrap();
if last.position.distance_squared(point.position) > 0.000001 {
result.push(point);
}
}
result
};
if filtered.len() < 2 {
return;
}
let points = filtered;
let n = points.len();
let rotations = {
let mut rots = Vec::with_capacity(n);
let tangent_0 = (points[1].position - points[0].position).normalize_or_zero();
let mut rot = points[0].rotation;
let turtle_fwd = rot * Vec3::Y;
rot = Self::robust_rotation_arc(turtle_fwd, tangent_0) * rot;
rots.push(rot);
for i in 1..n {
let tangent = if i < n - 1 {
let v_in = (points[i].position - points[i - 1].position).normalize_or_zero();
let v_out = (points[i + 1].position - points[i].position).normalize_or_zero();
let sum = v_in + v_out;
if sum.length_squared() < 0.001 {
v_in
} else {
sum.normalize()
}
} else {
(points[i].position - points[i - 1].position).normalize_or_zero()
};
let fwd = rot * Vec3::Y;
let bend = Self::robust_rotation_arc(fwd, tangent);
rot = bend * rot;
rots.push(rot);
}
rots
};
let v_coords = {
let mut coords = Vec::with_capacity(n);
let mut cumulative_v = 0.0f32;
coords.push(0.0);
for i in 0..n - 1 {
let seg_len = points[i].position.distance(points[i + 1].position);
let avg_radius = (points[i].radius + points[i + 1].radius) * 0.5;
let circumference = avg_radius * std::f32::consts::TAU;
let v_scale = if circumference > 0.0001 {
1.0 / circumference
} else {
1.0
};
cumulative_v += seg_len * v_scale * points[i].uv_scale;
coords.push(cumulative_v);
}
coords
};
let mut ring_cache: Vec<Option<(u16, u32)>> = vec![None; n];
for i in 0..n - 1 {
let curr = points[i];
let next = points[i + 1];
let mat_id = curr.material_id as u16;
let bucket = self.buckets.entry(mat_id).or_default();
let bottom_idx = match ring_cache[i] {
Some((cached_mat, idx)) if cached_mat == mat_id => idx,
_ => Self::add_ring(
bucket,
curr.position,
rotations[i],
curr.radius,
curr.color,
v_coords[i],
self.resolution,
),
};
let top_idx = Self::add_ring(
bucket,
next.position,
rotations[i + 1],
next.radius,
next.color,
v_coords[i + 1],
self.resolution,
);
Self::connect_rings(bucket, bottom_idx, top_idx, self.resolution);
ring_cache[i + 1] = Some((mat_id, top_idx));
}
}
fn robust_rotation_arc(from: Vec3, to: Vec3) -> Quat {
const DOT_THRESHOLD: f32 = 0.9999;
let dot = from.dot(to);
if dot < -DOT_THRESHOLD {
let axis = if from.x.abs() < 0.8 {
Vec3::X.cross(from).normalize()
} else {
Vec3::Y.cross(from).normalize()
};
return Quat::from_axis_angle(axis, std::f32::consts::PI);
} else if dot > DOT_THRESHOLD {
return Quat::IDENTITY;
}
Quat::from_rotation_arc(from, to)
}
fn add_ring(
data: &mut MeshData,
center: Vec3,
rotation: Quat,
radius: f32,
color: Vec4,
v_coord: f32,
res: u32,
) -> u32 {
let start_index = data.positions.len() as u32;
let color_array = color.to_array();
for i in 0..=res {
let u = i as f32 / res as f32;
let theta = u * std::f32::consts::TAU;
let (sin, cos) = theta.sin_cos();
let local_pos = Vec3::new(cos * radius, 0.0, sin * radius);
let local_normal = Vec3::new(cos, 0.0, sin);
data.positions.push(center + (rotation * local_pos));
data.normals.push(rotation * local_normal);
data.colors.push(color_array);
data.uvs.push([u, v_coord]);
}
start_index
}
fn connect_rings(data: &mut MeshData, bottom_start: u32, top_start: u32, res: u32) {
for i in 0..res {
let bottom_curr = bottom_start + i;
let bottom_next = bottom_start + i + 1;
let top_curr = top_start + i;
let top_next = top_start + i + 1;
data.indices.push(bottom_curr);
data.indices.push(top_curr);
data.indices.push(bottom_next);
data.indices.push(bottom_next);
data.indices.push(top_curr);
data.indices.push(top_next);
}
}
}
#[derive(Resource, Default, Debug)]
pub struct MeshCache {
entries: HashMap<u64, HashMap<u16, Handle<Mesh>>>,
hits: u64,
misses: u64,
}
impl MeshCache {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn contains(&self, skeleton: &Skeleton, resolution: u32) -> bool {
self.entries
.contains_key(&compute_fingerprint(skeleton, resolution))
}
pub fn hits(&self) -> u64 {
self.hits
}
pub fn misses(&self) -> u64 {
self.misses
}
pub fn reset_stats(&mut self) {
self.hits = 0;
self.misses = 0;
}
pub fn get_or_insert_with<F>(
&mut self,
fingerprint: u64,
build: F,
) -> HashMap<u16, Handle<Mesh>>
where
F: FnOnce() -> HashMap<u16, Handle<Mesh>>,
{
if let Some(handles) = self.entries.get(&fingerprint) {
self.hits += 1;
return handles.clone();
}
self.misses += 1;
let handles = build();
self.entries.insert(fingerprint, handles.clone());
handles
}
}
pub fn compute_skeleton_fingerprint(skeleton: &Skeleton, resolution: u32) -> u64 {
compute_fingerprint(skeleton, resolution)
}
impl LSystemMeshBuilder {
pub fn build_cached(
self,
skeleton: &Skeleton,
cache: &mut MeshCache,
meshes: &mut Assets<Mesh>,
) -> HashMap<u16, Handle<Mesh>> {
let fingerprint = compute_fingerprint(skeleton, self.resolution);
if let Some(handles) = cache.entries.get(&fingerprint) {
cache.hits += 1;
return handles.clone();
}
cache.misses += 1;
let mesh_buckets = self.build(skeleton);
let handles: HashMap<u16, Handle<Mesh>> = mesh_buckets
.into_iter()
.map(|(id, mesh)| (id, meshes.add(mesh)))
.collect();
cache.entries.insert(fingerprint, handles.clone());
handles
}
}
fn compute_fingerprint(skeleton: &Skeleton, resolution: u32) -> u64 {
let mut hasher = DefaultHasher::new();
resolution.hash(&mut hasher);
skeleton.strands.len().hash(&mut hasher);
for strand in &skeleton.strands {
strand.len().hash(&mut hasher);
for point in strand {
hash_skeleton_point(point, &mut hasher);
}
}
skeleton.strand_parents.hash(&mut hasher);
hasher.finish()
}
fn hash_skeleton_point<H: Hasher>(p: &SkeletonPoint, hasher: &mut H) {
p.position.x.to_bits().hash(hasher);
p.position.y.to_bits().hash(hasher);
p.position.z.to_bits().hash(hasher);
p.rotation.x.to_bits().hash(hasher);
p.rotation.y.to_bits().hash(hasher);
p.rotation.z.to_bits().hash(hasher);
p.rotation.w.to_bits().hash(hasher);
p.radius.to_bits().hash(hasher);
p.color.x.to_bits().hash(hasher);
p.color.y.to_bits().hash(hasher);
p.color.z.to_bits().hash(hasher);
p.color.w.to_bits().hash(hasher);
p.material_id.hash(hasher);
p.uv_scale.to_bits().hash(hasher);
}