use crate::interpolate::error::{InterpolateError, InterpolateResult};
use crate::interpolate::impl_generic::bezier_curve::bernstein_basis_matrix;
use crate::interpolate::traits::bezier_surface::BezierSurface;
use numr::algorithm::special::SpecialFunctions;
use numr::ops::{CompareOps, ScalarOps};
use numr::prelude::DType;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn bezier_surface_evaluate_impl<R, C>(
client: &C,
surface: &BezierSurface<R>,
u: &Tensor<R>,
v: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let shape = surface.control_points.shape();
if shape.len() != 3 {
return Err(InterpolateError::InvalidParameter {
parameter: "control_points".to_string(),
message: "expected shape [nu, nv, n_dims]".to_string(),
});
}
let nu = shape[0];
let nv = shape[1];
let n_dims = shape[2];
let m = u.shape()[0];
if u.shape()[0] != v.shape()[0] {
return Err(InterpolateError::ShapeMismatch {
expected: u.shape()[0],
actual: v.shape()[0],
context: "bezier_surface: u and v must have same length".to_string(),
});
}
let basis_u = bernstein_basis_matrix(client, u, surface.degree_u)?;
let basis_v = bernstein_basis_matrix(client, v, surface.degree_v)?;
let cp_flat = surface
.control_points
.reshape(&[nu * nv, n_dims])?
.contiguous()?;
let bu_exp = basis_u
.reshape(&[m, nu, 1])?
.broadcast_to(&[m, nu, nv])?
.contiguous()?;
let bv_exp = basis_v
.reshape(&[m, 1, nv])?
.broadcast_to(&[m, nu, nv])?
.contiguous()?;
let product = client.mul(&bu_exp, &bv_exp)?;
let product_flat = product.reshape(&[m, nu * nv])?;
let result = client.matmul(&product_flat, &cp_flat)?;
Ok(result)
}
pub fn bezier_surface_partial_impl<R, C>(
client: &C,
surface: &BezierSurface<R>,
u: &Tensor<R>,
v: &Tensor<R>,
du: usize,
dv: usize,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
if du == 0 && dv == 0 {
return bezier_surface_evaluate_impl(client, surface, u, v);
}
let shape = surface.control_points.shape();
let nu = shape[0];
let nv = shape[1];
let n_dims = shape[2];
let m = u.shape()[0];
let device = client.device();
if du > surface.degree_u || dv > surface.degree_v {
return Ok(Tensor::zeros(&[m, n_dims], DType::F64, device));
}
let mut diff_cp = surface.control_points.clone();
let mut deg_u = surface.degree_u;
let mut scale_u = 1.0;
let mut cur_nu = nu;
for _ in 0..du {
let hi = diff_cp.narrow(0, 1, cur_nu - 1)?.contiguous()?;
let lo = diff_cp.narrow(0, 0, cur_nu - 1)?.contiguous()?;
diff_cp = client.sub(&hi, &lo)?;
scale_u *= deg_u as f64;
deg_u -= 1;
cur_nu -= 1;
}
let mut deg_v = surface.degree_v;
let mut scale_v = 1.0;
let mut cur_nv = nv;
for _ in 0..dv {
let hi = diff_cp.narrow(1, 1, cur_nv - 1)?.contiguous()?;
let lo = diff_cp.narrow(1, 0, cur_nv - 1)?.contiguous()?;
diff_cp = client.sub(&hi, &lo)?;
scale_v *= deg_v as f64;
deg_v -= 1;
cur_nv -= 1;
}
let deriv_surface = BezierSurface {
control_points: client.mul_scalar(&diff_cp, scale_u * scale_v)?,
degree_u: deg_u,
degree_v: deg_v,
};
bezier_surface_evaluate_impl(client, &deriv_surface, u, v)
}
pub fn bezier_surface_normal_impl<R, C>(
client: &C,
surface: &BezierSurface<R>,
u: &Tensor<R>,
v: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + CompareOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let n_dims = surface.control_points.shape()[2];
if n_dims != 3 {
return Err(InterpolateError::InvalidParameter {
parameter: "control_points".to_string(),
message: "normals require 3D control points".to_string(),
});
}
let du = bezier_surface_partial_impl(client, surface, u, v, 1, 0)?; let dv = bezier_surface_partial_impl(client, surface, u, v, 0, 1)?;
cross_product_3d(client, &du, &dv)
}
pub(crate) fn cross_product_3d<R, C>(
client: &C,
a: &Tensor<R>,
b: &Tensor<R>,
) -> InterpolateResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + RuntimeClient<R>,
{
let a0 = a.narrow(1, 0, 1)?.contiguous()?;
let a1 = a.narrow(1, 1, 1)?.contiguous()?;
let a2 = a.narrow(1, 2, 1)?.contiguous()?;
let b0 = b.narrow(1, 0, 1)?.contiguous()?;
let b1 = b.narrow(1, 1, 1)?.contiguous()?;
let b2 = b.narrow(1, 2, 1)?.contiguous()?;
let c0 = client.sub(&client.mul(&a1, &b2)?, &client.mul(&a2, &b1)?)?;
let c1 = client.sub(&client.mul(&a2, &b0)?, &client.mul(&a0, &b2)?)?;
let c2 = client.sub(&client.mul(&a0, &b1)?, &client.mul(&a1, &b0)?)?;
let result = client.cat(&[&c0, &c1, &c2], 1)?;
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::{CpuClient, CpuDevice, CpuRuntime};
fn setup() -> (CpuDevice, CpuClient) {
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
(device, client)
}
#[test]
fn test_bezier_surface_bilinear() {
let (device, client) = setup();
let cp = Tensor::<CpuRuntime>::from_slice(
&[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0],
&[2, 2, 3],
&device,
);
let surface = BezierSurface {
control_points: cp,
degree_u: 1,
degree_v: 1,
};
let u = Tensor::<CpuRuntime>::from_slice(&[0.0, 0.5, 1.0], &[3], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[0.0, 0.5, 1.0], &[3], &device);
let result = bezier_surface_evaluate_impl(&client, &surface, &u, &v).unwrap();
let vals: Vec<f64> = result.to_vec();
assert!((vals[0]).abs() < 1e-10);
assert!((vals[1]).abs() < 1e-10);
assert!((vals[2]).abs() < 1e-10);
assert!((vals[3] - 0.5).abs() < 1e-10);
assert!((vals[4] - 0.5).abs() < 1e-10);
assert!((vals[5] - 0.25).abs() < 1e-10);
assert!((vals[6] - 1.0).abs() < 1e-10);
assert!((vals[7] - 1.0).abs() < 1e-10);
assert!((vals[8] - 1.0).abs() < 1e-10);
}
#[test]
fn test_bezier_surface_corners() {
let (device, client) = setup();
let cp = Tensor::<CpuRuntime>::from_slice(
&[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0],
&[2, 2, 3],
&device,
);
let surface = BezierSurface {
control_points: cp,
degree_u: 1,
degree_v: 1,
};
let u = Tensor::<CpuRuntime>::from_slice(&[0.0, 1.0, 0.0, 1.0], &[4], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[0.0, 0.0, 1.0, 1.0], &[4], &device);
let result = bezier_surface_evaluate_impl(&client, &surface, &u, &v).unwrap();
let vals: Vec<f64> = result.to_vec();
assert!((vals[0]).abs() < 1e-10);
assert!((vals[4] - 1.0).abs() < 1e-10);
assert!((vals[6] - 1.0).abs() < 1e-10);
}
#[test]
fn test_bezier_surface_normal() {
let (device, client) = setup();
let cp = Tensor::<CpuRuntime>::from_slice(
&[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0],
&[2, 2, 3],
&device,
);
let surface = BezierSurface {
control_points: cp,
degree_u: 1,
degree_v: 1,
};
let u = Tensor::<CpuRuntime>::from_slice(&[0.5], &[1], &device);
let v = Tensor::<CpuRuntime>::from_slice(&[0.5], &[1], &device);
let normal = bezier_surface_normal_impl(&client, &surface, &u, &v).unwrap();
let vals: Vec<f64> = normal.to_vec();
assert!(vals[0].abs() < 1e-10);
assert!(vals[1].abs() < 1e-10);
assert!(vals[2].abs() > 1e-10); }
}