use crate::spatial::impl_generic::convex_hull::convex_hull_impl;
use crate::spatial::traits::halfspace_intersection::HalfspaceIntersection;
use crate::spatial::validate_points_dtype;
use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{
CompareOps, IndexingOps, LinalgOps, ReduceOps, ScalarOps, SortingOps, TensorOps,
TypeConversionOps,
};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use std::collections::HashSet;
pub fn halfspace_intersection_impl<R, C>(
client: &C,
halfspaces: &Tensor<R>,
interior_point: &Tensor<R>,
) -> Result<HalfspaceIntersection<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R>
+ ScalarOps<R>
+ ReduceOps<R>
+ CompareOps<R>
+ IndexingOps<R>
+ SortingOps<R>
+ TypeConversionOps<R>
+ LinalgOps<R>
+ RuntimeClient<R>,
{
validate_points_dtype(halfspaces.dtype(), "halfspace_intersection")?;
let shape = halfspaces.shape();
if shape.len() != 2 {
return Err(Error::InvalidArgument {
arg: "halfspaces",
reason: "Expected 2D tensor [m, d+1]".to_string(),
});
}
let m = shape[0];
let d_plus_1 = shape[1];
if d_plus_1 < 2 {
return Err(Error::InvalidArgument {
arg: "halfspaces",
reason: "Need at least 2 columns (d >= 1)".to_string(),
});
}
let d = d_plus_1 - 1;
let ip_shape = interior_point.shape();
if ip_shape.len() != 1 || ip_shape[0] != d {
return Err(Error::InvalidArgument {
arg: "interior_point",
reason: format!("Expected shape [{}], got {:?}", d, ip_shape),
});
}
let device = halfspaces.device();
let dtype = halfspaces.dtype();
let normals = halfspaces.narrow(1, 0, d)?.contiguous()?; let offsets = halfspaces.narrow(1, d, 1)?.contiguous()?.reshape(&[m])?;
let ip_col = interior_point.reshape(&[d, 1])?; let n_dot_ip = client.matmul(&normals, &ip_col)?.reshape(&[m])?; let vals = client.add(&n_dot_ip, &offsets)?;
let max_val: f64 = client.max(&vals, &[0], false)?.item::<f64>()?;
if max_val >= 0.0 {
let zero = Tensor::<R>::zeros(&[], dtype, device);
let violated_raw = client.ge(&vals, &zero)?;
let violated = client.cast(&violated_raw, DType::U8)?;
let violated_data: Vec<u8> = violated.to_vec();
let idx = violated_data.iter().position(|&v| v > 0).unwrap_or(0);
let val_at_idx: f64 = vals.narrow(0, idx, 1)?.item::<f64>()?;
return Err(Error::InvalidArgument {
arg: "interior_point",
reason: format!(
"Interior point violates halfspace {} (n·x + b = {:.6} >= 0)",
idx, val_at_idx
),
});
}
let shifted_b = vals; let shifted_b_col = shifted_b.reshape(&[m, 1])?; let shifted_b_broadcast = shifted_b_col.broadcast_to(&[m, d])?.contiguous()?; let neg_normals = client.mul_scalar(&normals, -1.0)?; let dual_points = client.div(&neg_normals, &shifted_b_broadcast)?;
let hull = convex_hull_impl(client, &dual_points)?;
let simplices_data: Vec<i64> = hull.simplices.to_vec();
let hs_data: Vec<f64> = halfspaces.to_vec();
let n_simplices = hull.simplices.shape()[0];
let simplex_dim = hull.simplices.shape()[1];
let mut primal_vertices: Vec<f64> = Vec::new();
let mut seen: HashSet<Vec<i64>> = HashSet::new();
for s in 0..n_simplices {
let mut hs_indices: Vec<i64> = (0..simplex_dim)
.map(|k| simplices_data[s * simplex_dim + k])
.collect();
hs_indices.sort();
if seen.contains(&hs_indices) {
continue;
}
seen.insert(hs_indices.clone());
let mut a_data: Vec<f64> = Vec::with_capacity(d * d);
let mut b_data: Vec<f64> = Vec::with_capacity(d);
for &hi in &hs_indices {
let hi = hi as usize;
for j in 0..d {
a_data.push(hs_data[hi * d_plus_1 + j]);
}
b_data.push(-hs_data[hi * d_plus_1 + d]);
}
let a_tensor = Tensor::<R>::from_slice(&a_data, &[d, d], device);
let b_tensor = Tensor::<R>::from_slice(&b_data, &[d, 1], device);
match LinalgOps::solve(client, &a_tensor, &b_tensor) {
Ok(x) => {
let x_data: Vec<f64> = x.to_vec::<f64>();
primal_vertices.extend_from_slice(&x_data[..d]);
}
Err(_) => {
continue;
}
}
}
let n_vertices = primal_vertices.len() / d;
if n_vertices == 0 {
return Err(Error::InvalidArgument {
arg: "halfspaces",
reason: "No valid intersection vertices found".to_string(),
});
}
let intersections = Tensor::<R>::from_slice(&primal_vertices, &[n_vertices, d], device);
Ok(HalfspaceIntersection {
halfspaces: halfspaces.clone(),
intersections,
dual_points,
interior_point: interior_point.clone(),
})
}