use crate::{Actor, ActorBehavior, Message, Port};
use anyhow::{Error, Result};
use reflow_actor::{message::EncodableValue, ActorContext};
use reflow_actor_macro::actor;
use serde_json::{json, Value};
use std::collections::HashMap;
#[actor(
SkinningActor,
inports::<10>(mesh, skinned_mesh, bone_transforms, skin),
outports::<1>(deformed_mesh, metadata),
state(MemoryState),
await_inports(bone_transforms)
)]
pub async fn skinning_actor(ctx: ActorContext) -> Result<HashMap<String, Message>, Error> {
let payload = ctx.get_payload();
let config = ctx.get_config_hashmap();
let stride = config.get("stride").and_then(|v| v.as_u64()).unwrap_or(24) as usize;
if let Some(Message::Bytes(b)) = payload.get("mesh") {
ctx.pool_upsert("_cache", "mesh_b64", json!(b64_encode(&b)));
}
if let Some(Message::Bytes(b)) = payload.get("skinned_mesh") {
ctx.pool_upsert("_cache", "weights_b64", json!(b64_encode(&b)));
}
if let Some(Message::Object(obj)) = payload.get("skin") {
let v: Value = obj.as_ref().clone().into();
ctx.pool_upsert("_cache", "skin", v);
}
let bone_bytes = match payload.get("bone_transforms") {
Some(Message::Bytes(b)) => b.to_vec(),
_ => unreachable!("await_inports guarantees bone_transforms"),
};
let cache: HashMap<String, Value> = ctx.get_pool("_cache").into_iter().collect();
let mesh_bytes = match cache.get("mesh_b64").and_then(|v| v.as_str()) {
Some(s) => b64_decode(s),
None => return Ok(HashMap::new()), };
let skin_bytes = match cache.get("weights_b64").and_then(|v| v.as_str()) {
Some(s) => b64_decode(s),
None => return Ok(HashMap::new()), };
let skin_info = cache
.get("skin")
.cloned()
.unwrap_or(json!({"maxInfluences": 4}));
let vertex_count = mesh_bytes.len() / stride;
let max_influences = skin_info
.get("maxInfluences")
.and_then(|v| v.as_u64())
.unwrap_or(4) as usize;
let bone_count = bone_bytes.len() / 64;
let mut bone_matrices: Vec<[f32; 16]> = Vec::with_capacity(bone_count);
for i in 0..bone_count {
let off = i * 64;
let mut m = [0.0f32; 16];
for j in 0..16 {
m[j] = f32::from_le_bytes(bone_bytes[off + j * 4..off + j * 4 + 4].try_into().unwrap());
}
bone_matrices.push(m);
}
let entry_size = 6; let weights_per_vertex = max_influences * entry_size;
let mut output = vec![0u8; vertex_count * stride];
for i in 0..vertex_count {
let mesh_off = i * stride;
let skin_off = i * weights_per_vertex;
let px = f32::from_le_bytes(mesh_bytes[mesh_off..mesh_off + 4].try_into().unwrap());
let py = f32::from_le_bytes(mesh_bytes[mesh_off + 4..mesh_off + 8].try_into().unwrap());
let pz = f32::from_le_bytes(mesh_bytes[mesh_off + 8..mesh_off + 12].try_into().unwrap());
let nx = f32::from_le_bytes(mesh_bytes[mesh_off + 12..mesh_off + 16].try_into().unwrap());
let ny = f32::from_le_bytes(mesh_bytes[mesh_off + 16..mesh_off + 20].try_into().unwrap());
let nz = f32::from_le_bytes(mesh_bytes[mesh_off + 20..mesh_off + 24].try_into().unwrap());
let mut blended = [0.0f32; 16];
let mut total_weight = 0.0f32;
for j in 0..max_influences {
let w_off = skin_off + j * entry_size;
if w_off + entry_size > skin_bytes.len() {
break;
}
let bone_idx =
u16::from_le_bytes(skin_bytes[w_off..w_off + 2].try_into().unwrap()) as usize;
let weight = f32::from_le_bytes(skin_bytes[w_off + 2..w_off + 6].try_into().unwrap());
if weight < 1e-6 || bone_idx >= bone_count {
continue;
}
let m = &bone_matrices[bone_idx];
for k in 0..16 {
blended[k] += m[k] * weight;
}
total_weight += weight;
}
if total_weight < 1e-6 {
blended = super::math_helpers::MAT4_IDENTITY;
}
let npx = blended[0] * px + blended[4] * py + blended[8] * pz + blended[12];
let npy = blended[1] * px + blended[5] * py + blended[9] * pz + blended[13];
let npz = blended[2] * px + blended[6] * py + blended[10] * pz + blended[14];
let tnx = blended[0] * nx + blended[4] * ny + blended[8] * nz;
let tny = blended[1] * nx + blended[5] * ny + blended[9] * nz;
let tnz = blended[2] * nx + blended[6] * ny + blended[10] * nz;
let nlen = (tnx * tnx + tny * tny + tnz * tnz).sqrt().max(1e-8);
let out_off = i * stride;
output[out_off..out_off + 4].copy_from_slice(&npx.to_le_bytes());
output[out_off + 4..out_off + 8].copy_from_slice(&npy.to_le_bytes());
output[out_off + 8..out_off + 12].copy_from_slice(&npz.to_le_bytes());
output[out_off + 12..out_off + 16].copy_from_slice(&(tnx / nlen).to_le_bytes());
output[out_off + 16..out_off + 20].copy_from_slice(&(tny / nlen).to_le_bytes());
output[out_off + 20..out_off + 24].copy_from_slice(&(tnz / nlen).to_le_bytes());
if stride > 24 && mesh_off + stride <= mesh_bytes.len() {
output[out_off + 24..out_off + stride]
.copy_from_slice(&mesh_bytes[mesh_off + 24..mesh_off + stride]);
}
}
let mut out = HashMap::new();
out.insert("deformed_mesh".to_string(), Message::bytes(output));
out.insert(
"metadata".to_string(),
Message::object(EncodableValue::from(json!({
"vertexCount": vertex_count,
"boneCount": bone_count,
"stride": stride,
"format": if stride > 24 { "pos3_normal3_color3_f32" } else { "pos3_normal3_f32" },
}))),
);
Ok(out)
}
fn b64_encode(data: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::STANDARD.encode(data)
}
fn b64_decode(s: &str) -> Vec<u8> {
use base64::Engine;
base64::engine::general_purpose::STANDARD
.decode(s)
.unwrap_or_default()
}