use crate::error::{SpatialError, SpatialResult};
use std::cmp::Ordering;
const EPSILON: f64 = 1e-10;
#[derive(Debug, Clone, Copy)]
pub struct Point2D {
pub x: f64,
pub y: f64,
}
impl Point2D {
pub fn new(x: f64, y: f64) -> Self {
Self { x, y }
}
}
impl PartialEq for Point2D {
fn eq(&self, other: &Self) -> bool {
(self.x - other.x).abs() < EPSILON && (self.y - other.y).abs() < EPSILON
}
}
impl Eq for Point2D {}
impl PartialOrd for Point2D {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Point2D {
fn cmp(&self, other: &Self) -> Ordering {
match float_cmp(self.x, other.x) {
Ordering::Equal => float_cmp(self.y, other.y),
ord => ord,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Segment2D {
pub start: Point2D,
pub end: Point2D,
id: usize,
}
impl Segment2D {
pub fn new(x1: f64, y1: f64, x2: f64, y2: f64) -> Self {
let p1 = Point2D::new(x1, y1);
let p2 = Point2D::new(x2, y2);
let (start, end) = if p1 <= p2 { (p1, p2) } else { (p2, p1) };
Self { start, end, id: 0 }
}
pub fn from_points(p1: Point2D, p2: Point2D) -> Self {
let (start, end) = if p1 <= p2 { (p1, p2) } else { (p2, p1) };
Self { start, end, id: 0 }
}
fn y_at_x(&self, x: f64) -> Option<f64> {
let dx = self.end.x - self.start.x;
if dx.abs() < EPSILON {
Some((self.start.y + self.end.y) / 2.0)
} else {
let t = (x - self.start.x) / dx;
if !(-EPSILON..=1.0 + EPSILON).contains(&t) {
return None;
}
let t_clamped = t.clamp(0.0, 1.0);
Some(self.start.y + t_clamped * (self.end.y - self.start.y))
}
}
fn is_vertical(&self) -> bool {
(self.end.x - self.start.x).abs() < EPSILON
}
fn slope(&self) -> Option<f64> {
let dx = self.end.x - self.start.x;
if dx.abs() < EPSILON {
None
} else {
Some((self.end.y - self.start.y) / dx)
}
}
}
impl PartialEq for Segment2D {
fn eq(&self, other: &Self) -> bool {
self.id == other.id
}
}
impl Eq for Segment2D {}
#[derive(Debug, Clone)]
pub struct Intersection {
pub point: Point2D,
pub segment_a: usize,
pub segment_b: usize,
}
#[derive(Debug, Clone)]
enum EventType {
LeftEndpoint(usize),
RightEndpoint(usize),
IntersectionEvent(usize, usize),
}
#[derive(Debug, Clone)]
struct SweepEvent {
point: Point2D,
event_type: EventType,
}
impl PartialEq for SweepEvent {
fn eq(&self, other: &Self) -> bool {
self.point == other.point
}
}
impl Eq for SweepEvent {}
impl PartialOrd for SweepEvent {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SweepEvent {
fn cmp(&self, other: &Self) -> Ordering {
self.point.cmp(&other.point)
}
}
#[derive(Debug, Clone)]
struct StatusEntry {
segment_id: usize,
current_y: f64,
slope: f64,
}
impl PartialEq for StatusEntry {
fn eq(&self, other: &Self) -> bool {
self.segment_id == other.segment_id
}
}
impl Eq for StatusEntry {}
impl PartialOrd for StatusEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for StatusEntry {
fn cmp(&self, other: &Self) -> Ordering {
match float_cmp(self.current_y, other.current_y) {
Ordering::Equal => {
match float_cmp(self.slope, other.slope) {
Ordering::Equal => self.segment_id.cmp(&other.segment_id),
ord => ord,
}
}
ord => ord,
}
}
}
fn float_cmp(a: f64, b: f64) -> Ordering {
if (a - b).abs() < EPSILON {
Ordering::Equal
} else if a < b {
Ordering::Less
} else {
Ordering::Greater
}
}
pub fn segment_intersection(seg_a: &Segment2D, seg_b: &Segment2D) -> Option<(Point2D, f64, f64)> {
let x1 = seg_a.start.x;
let y1 = seg_a.start.y;
let x2 = seg_a.end.x;
let y2 = seg_a.end.y;
let x3 = seg_b.start.x;
let y3 = seg_b.start.y;
let x4 = seg_b.end.x;
let y4 = seg_b.end.y;
let denom = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4);
if denom.abs() < EPSILON {
return None;
}
let t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / denom;
let u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / denom;
let tol = EPSILON;
if t >= -tol && t <= 1.0 + tol && u >= -tol && u <= 1.0 + tol {
let ix = x1 + t * (x2 - x1);
let iy = y1 + t * (y2 - y1);
Some((Point2D::new(ix, iy), t, u))
} else {
None
}
}
pub fn find_all_intersections(segments: &[Segment2D]) -> SpatialResult<Vec<Intersection>> {
if segments.is_empty() || segments.len() < 2 {
return Ok(Vec::new());
}
let mut segs: Vec<Segment2D> = segments.to_vec();
for (i, seg) in segs.iter_mut().enumerate() {
seg.id = i;
}
let mut events: Vec<(f64, bool, usize)> = Vec::with_capacity(segs.len() * 2);
for (i, seg) in segs.iter().enumerate() {
events.push((seg.start.x, true, i)); events.push((seg.end.x, false, i)); }
events.sort_by(|a, b| {
float_cmp(a.0, b.0).then_with(|| {
match (a.1, b.1) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => a.2.cmp(&b.2),
}
})
});
let mut active: Vec<usize> = Vec::new();
let mut intersections: Vec<Intersection> = Vec::new();
let mut found_pairs: std::collections::HashSet<(usize, usize)> =
std::collections::HashSet::new();
for (_, is_start, seg_idx) in &events {
if *is_start {
for &other_idx in &active {
let pair = if *seg_idx < other_idx {
(*seg_idx, other_idx)
} else {
(other_idx, *seg_idx)
};
if !found_pairs.contains(&pair) {
if let Some((pt, _, _)) = segment_intersection(&segs[pair.0], &segs[pair.1]) {
found_pairs.insert(pair);
intersections.push(Intersection {
point: pt,
segment_a: pair.0,
segment_b: pair.1,
});
}
}
}
active.push(*seg_idx);
} else {
active.retain(|&id| id != *seg_idx);
}
}
Ok(intersections)
}
pub fn find_all_intersections_brute_force(segments: &[Segment2D]) -> Vec<Intersection> {
let mut intersections = Vec::new();
for i in 0..segments.len() {
for j in (i + 1)..segments.len() {
if let Some((pt, _, _)) = segment_intersection(&segments[i], &segments[j]) {
intersections.push(Intersection {
point: pt,
segment_a: i,
segment_b: j,
});
}
}
}
intersections
}
pub fn count_intersections(segments: &[Segment2D]) -> SpatialResult<usize> {
let intersections = find_all_intersections(segments)?;
Ok(intersections.len())
}
pub fn has_any_intersection(segments: &[Segment2D]) -> SpatialResult<bool> {
if segments.len() <= 10 {
for i in 0..segments.len() {
for j in (i + 1)..segments.len() {
if segment_intersection(&segments[i], &segments[j]).is_some() {
return Ok(true);
}
}
}
return Ok(false);
}
let intersections = find_all_intersections(segments)?;
Ok(!intersections.is_empty())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_two_crossing_segments() {
let segments = vec![
Segment2D::new(0.0, 0.0, 2.0, 2.0),
Segment2D::new(0.0, 2.0, 2.0, 0.0),
];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 1);
assert!((intersections[0].point.x - 1.0).abs() < 1e-6);
assert!((intersections[0].point.y - 1.0).abs() < 1e-6);
}
#[test]
fn test_no_intersections() {
let segments = vec![
Segment2D::new(0.0, 0.0, 1.0, 0.0),
Segment2D::new(0.0, 1.0, 1.0, 1.0),
];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 0);
}
#[test]
fn test_multiple_intersections() {
let segments = vec![
Segment2D::new(0.0, 0.0, 4.0, 4.0), Segment2D::new(0.0, 4.0, 4.0, 0.0), Segment2D::new(0.0, 2.0, 4.0, 2.0), ];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 3);
}
#[test]
fn test_parallel_segments() {
let segments = vec![
Segment2D::new(0.0, 0.0, 2.0, 0.0),
Segment2D::new(0.0, 1.0, 2.0, 1.0),
Segment2D::new(0.0, 2.0, 2.0, 2.0),
];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 0);
}
#[test]
fn test_endpoint_intersection() {
let segments = vec![
Segment2D::new(0.0, 0.0, 1.0, 1.0),
Segment2D::new(1.0, 1.0, 2.0, 0.0),
];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 1);
assert!((intersections[0].point.x - 1.0).abs() < 1e-6);
assert!((intersections[0].point.y - 1.0).abs() < 1e-6);
}
#[test]
fn test_brute_force_matches_sweep() {
let segments = vec![
Segment2D::new(0.0, 0.0, 3.0, 3.0),
Segment2D::new(0.0, 3.0, 3.0, 0.0),
Segment2D::new(1.0, 0.0, 1.0, 4.0),
];
let sweep_result = find_all_intersections(&segments).expect("Operation failed");
let brute_result = find_all_intersections_brute_force(&segments);
assert_eq!(sweep_result.len(), brute_result.len());
}
#[test]
fn test_empty_input() {
let segments: Vec<Segment2D> = vec![];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 0);
}
#[test]
fn test_single_segment() {
let segments = vec![Segment2D::new(0.0, 0.0, 1.0, 1.0)];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 0);
}
#[test]
fn test_segment_intersection_function() {
let seg1 = Segment2D::new(0.0, 0.0, 2.0, 2.0);
let seg2 = Segment2D::new(0.0, 2.0, 2.0, 0.0);
let result = segment_intersection(&seg1, &seg2);
assert!(result.is_some());
let (pt, t, u) = result.expect("Operation failed");
assert!((pt.x - 1.0).abs() < 1e-9);
assert!((pt.y - 1.0).abs() < 1e-9);
assert!((t - 0.5).abs() < 1e-9);
assert!((u - 0.5).abs() < 1e-9);
}
#[test]
fn test_has_any_intersection() {
let crossing = vec![
Segment2D::new(0.0, 0.0, 2.0, 2.0),
Segment2D::new(0.0, 2.0, 2.0, 0.0),
];
assert!(has_any_intersection(&crossing).expect("Operation failed"));
let parallel = vec![
Segment2D::new(0.0, 0.0, 2.0, 0.0),
Segment2D::new(0.0, 1.0, 2.0, 1.0),
];
assert!(!has_any_intersection(¶llel).expect("Operation failed"));
}
#[test]
fn test_count_intersections() {
let segments = vec![
Segment2D::new(0.0, 0.0, 4.0, 4.0),
Segment2D::new(0.0, 4.0, 4.0, 0.0),
Segment2D::new(0.0, 2.0, 4.0, 2.0),
];
let count = count_intersections(&segments).expect("Operation failed");
assert_eq!(count, 3);
}
#[test]
fn test_star_pattern() {
let segments = vec![
Segment2D::new(0.0, 2.0, 4.0, 2.0), Segment2D::new(2.0, 0.0, 2.0, 4.0), Segment2D::new(0.0, 0.0, 4.0, 4.0), Segment2D::new(0.0, 4.0, 4.0, 0.0), ];
let intersections = find_all_intersections(&segments).expect("Operation failed");
let brute = find_all_intersections_brute_force(&segments);
assert_eq!(intersections.len(), brute.len());
}
#[test]
fn test_vertical_segment() {
let segments = vec![
Segment2D::new(1.0, 0.0, 1.0, 2.0), Segment2D::new(0.0, 1.0, 2.0, 1.0), ];
let intersections = find_all_intersections(&segments).expect("Operation failed");
assert_eq!(intersections.len(), 1);
assert!((intersections[0].point.x - 1.0).abs() < 1e-6);
assert!((intersections[0].point.y - 1.0).abs() < 1e-6);
}
}