use glam::{Vec3, Quat, Mat4};
use crate::error::{AnvilKitError, Result};
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "bevy_ecs", derive(bevy_ecs::component::Component))]
pub struct Transform {
pub translation: Vec3,
pub rotation: Quat,
pub scale: Vec3,
}
impl Default for Transform {
fn default() -> Self {
Self::IDENTITY
}
}
impl Transform {
pub const IDENTITY: Self = Self {
translation: Vec3::ZERO,
rotation: Quat::IDENTITY,
scale: Vec3::ONE,
};
pub fn new(translation: Vec3, rotation: Quat, scale: Vec3) -> Self {
Self {
translation,
rotation,
scale,
}
}
pub fn from_translation(translation: Vec3) -> Self {
Self {
translation,
..Self::IDENTITY
}
}
pub fn from_rotation(rotation: Quat) -> Self {
Self {
rotation,
..Self::IDENTITY
}
}
pub fn from_scale(scale: Vec3) -> Self {
Self {
scale,
..Self::IDENTITY
}
}
pub fn from_xyz(x: f32, y: f32, z: f32) -> Self {
Self::from_translation(Vec3::new(x, y, z))
}
pub fn from_xy(x: f32, y: f32) -> Self {
Self::from_translation(Vec3::new(x, y, 0.0))
}
pub fn with_translation(mut self, translation: Vec3) -> Self {
self.translation = translation;
self
}
pub fn with_rotation(mut self, rotation: Quat) -> Self {
self.rotation = rotation;
self
}
pub fn with_scale(mut self, scale: Vec3) -> Self {
self.scale = scale;
self
}
pub fn looking_at(eye: Vec3, target: Vec3, up: Vec3) -> Result<Self> {
let forward = (target - eye).normalize();
if !forward.is_finite() || forward.length_squared() < f32::EPSILON {
return Err(AnvilKitError::generic("无效的朝向向量:目标和眼睛位置相同或无效"));
}
let right = forward.cross(up).normalize();
if !right.is_finite() || right.length_squared() < f32::EPSILON {
return Err(AnvilKitError::generic("无效的上方向向量:与前向向量平行"));
}
let up = right.cross(forward);
if !up.is_finite() {
return Err(AnvilKitError::generic("计算上方向向量时出现数值错误"));
}
let rotation_matrix = glam::Mat3::from_cols(right, up, -forward);
let rotation = Quat::from_mat3(&rotation_matrix);
if !rotation.is_finite() {
return Err(AnvilKitError::generic("计算旋转四元数时出现数值错误"));
}
Ok(Self::new(eye, rotation, Vec3::ONE))
}
pub fn compute_matrix(&self) -> Mat4 {
Mat4::from_scale_rotation_translation(self.scale, self.rotation, self.translation)
}
pub fn transform_point(&self, point: Vec3) -> Vec3 {
self.compute_matrix().transform_point3(point)
}
pub fn transform_vector(&self, vector: Vec3) -> Vec3 {
self.compute_matrix().transform_vector3(vector)
}
pub fn mul_transform(&self, other: &Transform) -> Transform {
let matrix = self.compute_matrix() * other.compute_matrix();
Transform::from_matrix(matrix)
}
pub fn from_matrix(matrix: Mat4) -> Self {
let (scale, rotation, translation) = matrix.to_scale_rotation_translation();
Self {
translation,
rotation,
scale,
}
}
pub fn inverse(&self) -> Result<Self> {
if self.scale.x.abs() < f32::EPSILON ||
self.scale.y.abs() < f32::EPSILON ||
self.scale.z.abs() < f32::EPSILON {
return Err(AnvilKitError::generic("无法计算逆变换:缩放包含零值"));
}
let inv_scale = Vec3::new(1.0 / self.scale.x, 1.0 / self.scale.y, 1.0 / self.scale.z);
let inv_rotation = self.rotation.inverse();
let inv_translation = -(inv_rotation * (self.translation * inv_scale));
Ok(Self {
translation: inv_translation,
rotation: inv_rotation,
scale: inv_scale,
})
}
pub fn is_finite(&self) -> bool {
self.translation.is_finite() &&
self.rotation.is_finite() &&
self.scale.is_finite()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "bevy_ecs", derive(bevy_ecs::component::Component))]
pub struct GlobalTransform(pub Mat4);
impl Default for GlobalTransform {
fn default() -> Self {
Self::IDENTITY
}
}
impl GlobalTransform {
pub const IDENTITY: Self = Self(Mat4::IDENTITY);
pub fn from_matrix(matrix: Mat4) -> Self {
Self(matrix)
}
pub fn from_transform(transform: &Transform) -> Self {
Self(transform.compute_matrix())
}
pub fn matrix(&self) -> Mat4 {
self.0
}
pub fn translation(&self) -> Vec3 {
self.0.w_axis.truncate()
}
pub fn rotation(&self) -> Quat {
let (_, rotation, _) = self.0.to_scale_rotation_translation();
rotation
}
pub fn scale(&self) -> Vec3 {
let (scale, _, _) = self.0.to_scale_rotation_translation();
scale
}
pub fn transform_point(&self, point: Vec3) -> Vec3 {
self.0.transform_point3(point)
}
pub fn transform_vector(&self, vector: Vec3) -> Vec3 {
self.0.transform_vector3(vector)
}
pub fn mul_transform(&self, other: &GlobalTransform) -> GlobalTransform {
GlobalTransform(self.0 * other.0)
}
pub fn inverse(&self) -> Result<Self> {
let inv_matrix = self.0.inverse();
if !inv_matrix.is_finite() {
return Err(AnvilKitError::generic("无法计算全局变换的逆变换"));
}
Ok(Self(inv_matrix))
}
pub fn is_finite(&self) -> bool {
self.0.is_finite()
}
}
impl From<Transform> for GlobalTransform {
fn from(transform: Transform) -> Self {
Self::from_transform(&transform)
}
}
impl From<Mat4> for GlobalTransform {
fn from(matrix: Mat4) -> Self {
Self::from_matrix(matrix)
}
}
#[cfg(test)]
mod tests {
use super::*;
use glam::Vec3;
fn vec3_approx_eq(a: Vec3, b: Vec3, epsilon: f32) -> bool {
(a - b).length() < epsilon
}
fn quat_approx_eq(a: glam::Quat, b: glam::Quat, epsilon: f32) -> bool {
(a.dot(b) - 1.0).abs() < epsilon || (a.dot(b) + 1.0).abs() < epsilon
}
#[test]
fn test_transform_identity() {
let transform = Transform::IDENTITY;
assert_eq!(transform.translation, Vec3::ZERO);
assert_eq!(transform.rotation, Quat::IDENTITY);
assert_eq!(transform.scale, Vec3::ONE);
}
#[test]
fn test_transform_creation() {
let transform = Transform::from_xyz(1.0, 2.0, 3.0);
assert_eq!(transform.translation, Vec3::new(1.0, 2.0, 3.0));
let transform = Transform::from_xy(1.0, 2.0);
assert_eq!(transform.translation, Vec3::new(1.0, 2.0, 0.0));
}
#[test]
fn test_transform_chaining() {
let transform = Transform::IDENTITY
.with_translation(Vec3::new(1.0, 2.0, 3.0))
.with_scale(Vec3::splat(2.0));
assert_eq!(transform.translation, Vec3::new(1.0, 2.0, 3.0));
assert_eq!(transform.scale, Vec3::splat(2.0));
}
#[test]
fn test_transform_point() {
let transform = Transform::from_xyz(1.0, 2.0, 3.0);
let point = Vec3::ZERO;
let transformed = transform.transform_point(point);
assert!(vec3_approx_eq(transformed, Vec3::new(1.0, 2.0, 3.0), 1e-6));
}
#[test]
fn test_transform_vector() {
let transform = Transform::from_scale(Vec3::splat(2.0));
let vector = Vec3::new(1.0, 1.0, 1.0);
let transformed = transform.transform_vector(vector);
assert!(vec3_approx_eq(transformed, Vec3::new(2.0, 2.0, 2.0), 1e-6));
}
#[test]
fn test_transform_composition() {
let parent = Transform::from_xyz(1.0, 0.0, 0.0);
let child = Transform::from_xyz(0.0, 1.0, 0.0);
let combined = parent.mul_transform(&child);
assert!(vec3_approx_eq(combined.translation, Vec3::new(1.0, 1.0, 0.0), 1e-6));
}
#[test]
fn test_transform_matrix_roundtrip() {
let original = Transform::from_xyz(1.0, 2.0, 3.0)
.with_rotation(Quat::from_rotation_y(0.5))
.with_scale(Vec3::new(2.0, 1.5, 0.5));
let matrix = original.compute_matrix();
let reconstructed = Transform::from_matrix(matrix);
assert!(vec3_approx_eq(original.translation, reconstructed.translation, 1e-5));
assert!(quat_approx_eq(original.rotation, reconstructed.rotation, 1e-5));
assert!(vec3_approx_eq(original.scale, reconstructed.scale, 1e-5));
}
#[test]
fn test_transform_inverse() {
let transform = Transform::from_xyz(1.0, 2.0, 3.0)
.with_rotation(Quat::from_rotation_y(0.5))
.with_scale(Vec3::splat(2.0));
let inverse = transform.inverse().unwrap();
let identity = transform.mul_transform(&inverse);
assert!(vec3_approx_eq(identity.translation, Vec3::ZERO, 1e-5));
assert!(quat_approx_eq(identity.rotation, Quat::IDENTITY, 1e-5));
assert!(vec3_approx_eq(identity.scale, Vec3::ONE, 1e-5));
}
#[test]
fn test_transform_inverse_zero_scale() {
let transform = Transform::from_scale(Vec3::new(0.0, 1.0, 1.0));
assert!(transform.inverse().is_err());
}
#[test]
fn test_looking_at() {
let transform = Transform::looking_at(
Vec3::new(0.0, 0.0, 5.0),
Vec3::ZERO,
Vec3::Y
).unwrap();
let forward = transform.transform_vector(-Vec3::Z);
let expected_direction = (Vec3::ZERO - Vec3::new(0.0, 0.0, 5.0)).normalize();
assert!(vec3_approx_eq(forward, expected_direction, 1e-5));
}
#[test]
fn test_looking_at_invalid() {
assert!(Transform::looking_at(Vec3::ZERO, Vec3::ZERO, Vec3::Y).is_err());
let result = Transform::looking_at(Vec3::ZERO, Vec3::new(0.0, 0.0, 1.0), Vec3::new(0.0, 0.0, 1.0));
assert!(result.is_err(), "Expected error for parallel forward and up vectors, but got: {:?}", result);
assert!(Transform::looking_at(Vec3::ZERO, Vec3::Y, Vec3::Y).is_err());
}
#[test]
fn test_global_transform() {
let transform = Transform::from_xyz(1.0, 2.0, 3.0);
let global = GlobalTransform::from_transform(&transform);
assert_eq!(global.translation(), Vec3::new(1.0, 2.0, 3.0));
}
#[test]
fn test_global_transform_composition() {
let global1 = GlobalTransform::from_matrix(Mat4::from_translation(Vec3::X));
let global2 = GlobalTransform::from_matrix(Mat4::from_translation(Vec3::Y));
let combined = global1.mul_transform(&global2);
assert!(vec3_approx_eq(combined.translation(), Vec3::new(1.0, 1.0, 0.0), 1e-6));
}
#[test]
fn test_finite_checks() {
let valid_transform = Transform::from_xyz(1.0, 2.0, 3.0);
assert!(valid_transform.is_finite());
let invalid_transform = Transform::from_xyz(f32::NAN, 2.0, 3.0);
assert!(!invalid_transform.is_finite());
}
#[test]
fn test_transform_nan_input() {
let transform = Transform::from_xyz(f32::NAN, 0.0, 0.0);
assert!(!transform.is_finite());
let transform = Transform::from_xyz(f32::INFINITY, 0.0, 0.0);
assert!(!transform.is_finite());
}
#[test]
fn test_transform_deep_chain_precision() {
let mut composed = Transform::IDENTITY;
let step = Transform::from_xyz(0.1, 0.0, 0.0)
.with_rotation(Quat::from_rotation_z(0.01));
for _ in 0..100 {
composed = composed.mul_transform(&step);
}
assert!(composed.is_finite());
assert!(composed.translation.is_finite());
}
#[test]
fn test_transform_inverse_roundtrip_with_rotation_and_scale() {
let transform = Transform::new(
Vec3::new(3.0, -7.0, 11.0),
Quat::from_euler(glam::EulerRot::YXZ, 0.5, 0.3, 0.7),
Vec3::splat(2.0),
);
let inverse = transform.inverse().unwrap();
let identity = transform.mul_transform(&inverse);
assert!(vec3_approx_eq(identity.translation, Vec3::ZERO, 1e-4));
assert!(quat_approx_eq(identity.rotation, Quat::IDENTITY, 1e-4));
assert!(vec3_approx_eq(identity.scale, Vec3::ONE, 1e-4));
}
#[test]
fn test_transform_from_matrix_identity() {
let transform = Transform::from_matrix(Mat4::IDENTITY);
assert!(vec3_approx_eq(transform.translation, Vec3::ZERO, 1e-6));
assert!(quat_approx_eq(transform.rotation, Quat::IDENTITY, 1e-6));
assert!(vec3_approx_eq(transform.scale, Vec3::ONE, 1e-6));
}
#[test]
fn test_transform_all_axes_inverse() {
assert!(Transform::from_scale(Vec3::new(0.0, 1.0, 1.0)).inverse().is_err());
assert!(Transform::from_scale(Vec3::new(1.0, 0.0, 1.0)).inverse().is_err());
assert!(Transform::from_scale(Vec3::new(1.0, 1.0, 0.0)).inverse().is_err());
assert!(Transform::from_scale(Vec3::new(0.0, 0.0, 0.0)).inverse().is_err());
}
#[test]
fn test_global_transform_from_transform_conversion() {
let transform = Transform::from_xyz(5.0, 10.0, 15.0)
.with_scale(Vec3::splat(2.0));
let global: GlobalTransform = transform.into();
assert!(vec3_approx_eq(global.translation(), Vec3::new(5.0, 10.0, 15.0), 1e-6));
assert!(vec3_approx_eq(global.scale(), Vec3::splat(2.0), 1e-6));
}
#[test]
fn test_global_transform_inverse() {
let global = GlobalTransform::from_matrix(Mat4::from_translation(Vec3::new(1.0, 2.0, 3.0)));
let inverse = global.inverse().unwrap();
let identity = global.mul_transform(&inverse);
assert!(vec3_approx_eq(identity.translation(), Vec3::ZERO, 1e-5));
}
#[test]
fn test_global_transform_inverse_singular() {
let global = GlobalTransform::from_matrix(Mat4::ZERO);
let result = global.inverse();
assert!(result.is_err());
}
#[test]
fn test_transform_default() {
let transform = Transform::default();
assert_eq!(transform, Transform::IDENTITY);
}
#[test]
fn test_global_transform_default() {
let global = GlobalTransform::default();
assert_eq!(global, GlobalTransform::IDENTITY);
}
#[test]
fn test_transform_point_with_scale_and_rotation() {
let transform = Transform::from_scale(Vec3::splat(2.0))
.with_rotation(Quat::from_rotation_z(std::f32::consts::FRAC_PI_2));
let point = Vec3::new(1.0, 0.0, 0.0);
let transformed = transform.transform_point(point);
assert!(vec3_approx_eq(transformed, Vec3::new(0.0, 2.0, 0.0), 1e-5));
}
#[test]
fn test_transform_vector_ignores_translation() {
let transform = Transform::from_xyz(100.0, 200.0, 300.0);
let vector = Vec3::new(1.0, 0.0, 0.0);
let transformed = transform.transform_vector(vector);
assert!(vec3_approx_eq(transformed, Vec3::new(1.0, 0.0, 0.0), 1e-6));
}
#[test]
fn test_global_transform_transform_point() {
let global = GlobalTransform::from_matrix(
Mat4::from_translation(Vec3::new(10.0, 20.0, 30.0))
);
let point = Vec3::ZERO;
let result = global.transform_point(point);
assert!(vec3_approx_eq(result, Vec3::new(10.0, 20.0, 30.0), 1e-6));
}
#[test]
fn test_global_transform_rotation_extraction() {
let rotation = Quat::from_rotation_y(std::f32::consts::FRAC_PI_4);
let transform = Transform::from_rotation(rotation);
let global = GlobalTransform::from_transform(&transform);
assert!(quat_approx_eq(global.rotation(), rotation, 1e-5));
}
}