use std::collections::HashSet;
use crate::{
cartesian::ToCartesian,
distance::Distance,
edge::{EdgeAx, FEdgeAx},
hex::{FHexAx, FHexCb, HexAx, HexCb},
FNodeAx, Hex, NodeAx,
};
pub trait Roundable<T> {
fn round(&self) -> T;
}
impl Roundable<HexAx> for FHexAx {
fn round(&self) -> HexAx {
let cube = FHexCb::from(self);
let rounded_cube = cube.round();
HexAx::from(rounded_cube)
}
}
impl Roundable<HexCb> for FHexCb {
fn round(&self) -> HexCb {
let mut q = self.q.round();
let mut r = self.r.round();
let mut s = self.s.round();
let q_diff = (q - self.q).abs();
let r_diff = (r - self.r).abs();
let s_diff = (s - self.s).abs();
if q_diff > r_diff && q_diff > s_diff {
q = -r - s;
} else if r_diff > s_diff {
r = -q - s;
} else {
s = -q - r;
}
HexCb::new(q as isize, r as isize, s as isize)
}
}
impl Roundable<EdgeAx> for FEdgeAx {
fn round(&self) -> EdgeAx {
let q_part = self.q;
let r_part = self.r;
let s_part = -self.q - self.r;
let mut q_part_rounded = q_part.round();
let mut r_part_rounded = r_part.round();
let s_part_rounded = s_part.round();
let q_part_diff = (q_part_rounded - q_part).abs();
let r_part_diff = (r_part_rounded - r_part).abs();
let s_part_diff = (s_part_rounded - s_part).abs();
if q_part_diff > r_part_diff && q_part_diff > s_part_diff {
q_part_rounded = -r_part_rounded - s_part_rounded;
} else if r_part_diff > s_part_diff {
r_part_rounded = -q_part_rounded - s_part_rounded;
}
if (q_part_rounded as isize) % 2 != 0 || (r_part_rounded as isize) % 2 != 0 {
return EdgeAx::new(q_part_rounded as isize, r_part_rounded as isize);
}
let hex = HexAx::new(
(q_part_rounded / 2.0) as isize,
(r_part_rounded / 2.0) as isize,
);
let candidates: HashSet<EdgeAx> = hex.edges();
candidates
.into_iter()
.reduce(|acc, c| {
if self.to_cartesian().dist(&c.to_cartesian())
< self.to_cartesian().dist(&acc.to_cartesian())
{
c
} else {
acc
}
})
.expect("expected at least 1 valid rounded edge coordinate while rounding FEdgeAx")
}
}
impl Roundable<NodeAx> for FNodeAx {
fn round(&self) -> NodeAx {
let (q_rem, r_rem) = (self.q.rem_euclid(3.0), self.r.rem_euclid(3.0));
if q_rem + r_rem > 3.0 {
NodeAx::new(
(self.q - q_rem).round() as isize + 2,
(self.r - r_rem).round() as isize + 2,
)
} else {
NodeAx::new(
(self.q - q_rem).round() as isize + 1,
(self.r - r_rem).round() as isize + 1,
)
}
}
}