use std::sync::{Arc, RwLock};
use crate::coord::CoordSystem;
use crate::store::ZoneStore;
use crate::zone::{ZoneEntry, ZoneShape};
use crate::octree::OctreeNode;
#[derive(Debug, Clone)]
pub struct VehiclePose {
pub position: [f64; 3],
pub velocity: [f64; 3],
pub heading_rad: f64,
pub ts_ns: u64,
}
impl VehiclePose {
pub fn stationary(pos: [f64; 3], ts_ns: u64) -> Self {
Self { position: pos, velocity: [0.0; 3], heading_rad: 0.0, ts_ns }
}
}
pub struct ImuFusion {
state: [f64; 6],
cov: [[f64; 6]; 6],
last_ts: u64,
}
#[allow(clippy::needless_range_loop)]
impl ImuFusion {
pub fn new(initial_pos: [f64; 3]) -> Self {
let mut state = [0.0; 6];
state[0] = initial_pos[0];
state[1] = initial_pos[1];
state[2] = initial_pos[2];
let mut cov = [[0.0; 6]; 6];
for i in 0..6 {
cov[i][i] = 1.0;
}
Self { state, cov, last_ts: 0 }
}
pub fn predict(&mut self, accel_enu: [f64; 3], ts_ns: u64) {
if self.last_ts == 0 {
self.last_ts = ts_ns;
return;
}
let dt = (ts_ns - self.last_ts) as f64 / 1e9;
self.last_ts = ts_ns;
for i in 0..3 {
self.state[i] += self.state[i + 3] * dt + 0.5 * accel_enu[i] * dt * dt;
self.state[i + 3] += accel_enu[i] * dt;
}
let q = 0.1 * dt;
for i in 0..6 {
self.cov[i][i] += q;
}
}
pub fn update_rtk(&mut self, pos: [f64; 3], h_acc: f64) {
let r = h_acc * h_acc;
for i in 0..3 {
let k = self.cov[i][i] / (self.cov[i][i] + r);
let innovation = pos[i] - self.state[i];
self.state[i] += k * innovation;
self.cov[i][i] *= 1.0 - k;
}
}
pub fn position(&self) -> [f64; 3] {
[self.state[0], self.state[1], self.state[2]]
}
pub fn velocity(&self) -> [f64; 3] {
[self.state[3], self.state[4], self.state[5]]
}
pub fn pose(&self, heading_rad: f64, ts_ns: u64) -> VehiclePose {
VehiclePose {
position: self.position(),
velocity: self.velocity(),
heading_rad,
ts_ns,
}
}
pub fn uncertainty_m(&self) -> f64 {
(self.cov[0][0] + self.cov[1][1] + self.cov[2][2]).sqrt()
}
}
pub struct LidarFovZone {
pub pose: Arc<RwLock<VehiclePose>>,
pub min_range: f64,
pub max_range: f64,
pub z_min: f64,
pub z_max: f64,
}
impl ZoneShape for LidarFovZone {
fn contains_enu(&self, p: [f64; 3]) -> bool {
let pose = self.pose.read().unwrap();
let dx = p[0] - pose.position[0];
let dy = p[1] - pose.position[1];
let dz = p[2] - pose.position[2];
let r2 = dx * dx + dy * dy;
r2 >= self.min_range * self.min_range
&& r2 <= self.max_range * self.max_range
&& dz >= self.z_min
&& dz <= self.z_max
}
fn aabb_enu(&self) -> [f64; 6] {
let pose = self.pose.read().unwrap();
let r = self.max_range;
let p = pose.position;
[p[0] - r, p[1] - r, p[2] + self.z_min, p[0] + r, p[1] + r, p[2] + self.z_max]
}
}
pub struct CameraFrustumZone {
pub pose: Arc<RwLock<VehiclePose>>,
pub near_m: f64,
pub far_m: f64,
pub hfov_rad: f64,
pub vfov_rad: f64,
}
impl ZoneShape for CameraFrustumZone {
fn contains_enu(&self, p: [f64; 3]) -> bool {
let pose = self.pose.read().unwrap();
let dx = p[0] - pose.position[0];
let dy = p[1] - pose.position[1];
let dz = p[2] - pose.position[2];
let (sin_h, cos_h) = pose.heading_rad.sin_cos();
let body_fwd = dx * cos_h + dy * sin_h; let body_left = -dx * sin_h + dy * cos_h; let body_up = dz;
if body_fwd < self.near_m || body_fwd > self.far_m {
return false;
}
let h_angle = (body_left / body_fwd).atan().abs();
let v_angle = (body_up / body_fwd).atan().abs();
h_angle <= self.hfov_rad / 2.0 && v_angle <= self.vfov_rad / 2.0
}
fn aabb_enu(&self) -> [f64; 6] {
let pose = self.pose.read().unwrap();
let r = self.far_m;
let p = pose.position;
[p[0] - r, p[1] - r, p[2] - r, p[0] + r, p[1] + r, p[2] + r]
}
}
pub struct RadarSectorZone {
pub pose: Arc<RwLock<VehiclePose>>,
pub range_m: f64,
pub half_angle: f64,
}
impl ZoneShape for RadarSectorZone {
fn contains_enu(&self, p: [f64; 3]) -> bool {
let pose = self.pose.read().unwrap();
let dx = p[0] - pose.position[0];
let dy = p[1] - pose.position[1];
let dist = (dx * dx + dy * dy).sqrt();
if dist > self.range_m {
return false;
}
let angle = dy.atan2(dx) - pose.heading_rad;
let norm = angle.rem_euclid(2.0 * std::f64::consts::PI);
norm <= self.half_angle || (2.0 * std::f64::consts::PI - norm) <= self.half_angle
}
fn aabb_enu(&self) -> [f64; 6] {
let pose = self.pose.read().unwrap();
let r = self.range_m;
let p = pose.position;
[p[0] - r, p[1] - r, p[2] - 2.0, p[0] + r, p[1] + r, p[2] + 2.0]
}
}
pub struct SafetyEnvelope {
pub pose: Arc<RwLock<VehiclePose>>,
pub width_m: f64,
pub base_front_m: f64,
pub base_rear_m: f64,
pub speed_factor: f64,
}
impl SafetyEnvelope {
fn dims(&self, pose: &VehiclePose) -> (f64, f64, f64) {
let speed =
(pose.velocity[0].powi(2) + pose.velocity[1].powi(2) + pose.velocity[2].powi(2))
.sqrt();
let front = self.base_front_m + speed * self.speed_factor;
(front, self.base_rear_m, self.width_m)
}
}
impl ZoneShape for SafetyEnvelope {
fn contains_enu(&self, p: [f64; 3]) -> bool {
let pose = self.pose.read().unwrap();
let (front, rear, side) = self.dims(&pose);
let dx = p[0] - pose.position[0];
let dy = p[1] - pose.position[1];
let dz = p[2] - pose.position[2];
let fwd = pose.heading_rad.cos() * dx + pose.heading_rad.sin() * dy;
let lat = -pose.heading_rad.sin() * dx + pose.heading_rad.cos() * dy;
let norm_fwd = if fwd >= 0.0 { fwd / front } else { fwd.abs() / rear };
let norm_lat = lat.abs() / side;
let norm_up = dz.abs() / (side * 0.5);
norm_fwd * norm_fwd + norm_lat * norm_lat + norm_up * norm_up <= 1.0
}
fn aabb_enu(&self) -> [f64; 6] {
let pose = self.pose.read().unwrap();
let (front, _, side) = self.dims(&pose);
let r = front.max(side);
let p = pose.position;
[p[0] - r, p[1] - r, p[2] - side, p[0] + r, p[1] + r, p[2] + side]
}
}
pub fn safety_check(
envelope: &SafetyEnvelope,
octree: &OctreeNode,
pose: &VehiclePose,
range_m: f64,
) -> bool {
let p = pose.position;
let pts = octree.range_query(
[p[0] - range_m, p[1] - range_m, p[2] - range_m],
[p[0] + range_m, p[1] + range_m, p[2] + range_m],
);
pts.iter().any(|&&pt| envelope.contains_enu(pt))
}
pub struct PredictionZone {
pub origin: [f64; 3],
pub velocity: [f64; 3],
pub horizon_s: f64,
pub sigma_lat: f64,
pub sigma_fwd: f64,
}
impl PredictionZone {
pub fn risk_at(&self, p: [f64; 3], t_s: f64) -> f64 {
if t_s > self.horizon_s {
return 0.0;
}
let center: [f64; 3] = std::array::from_fn(|i| self.origin[i] + self.velocity[i] * t_s);
let speed = (0..3).map(|i| self.velocity[i].powi(2)).sum::<f64>().sqrt();
let fwd: [f64; 3] = if speed > 0.01 {
std::array::from_fn(|i| self.velocity[i] / speed)
} else {
[1.0, 0.0, 0.0]
};
let dp = std::array::from_fn::<f64, 3, _>(|i| p[i] - center[i]);
let dp_fwd: f64 = (0..3).map(|i| dp[i] * fwd[i]).sum();
let dp_lat = ((0..3).map(|i| dp[i].powi(2)).sum::<f64>() - dp_fwd * dp_fwd).sqrt();
let sigma_fwd_t = self.sigma_fwd * t_s + 1.0;
let exponent = (dp_fwd / sigma_fwd_t).powi(2) + (dp_lat / self.sigma_lat).powi(2);
(-0.5 * exponent).exp()
}
pub fn max_risk(&self, p: [f64; 3], dt_s: f64) -> f64 {
let steps = (self.horizon_s / dt_s).ceil() as usize;
(0..=steps).map(|i| self.risk_at(p, i as f64 * dt_s)).fold(0.0f64, f64::max)
}
}
impl ZoneShape for PredictionZone {
fn contains_enu(&self, p: [f64; 3]) -> bool {
self.max_risk(p, 0.1) > 0.05
}
fn aabb_enu(&self) -> [f64; 6] {
let max_fwd: f64 = self.velocity.iter().map(|v| v.abs() * self.horizon_s).sum();
let r = max_fwd + self.sigma_lat * 3.0;
let p = self.origin;
[p[0] - r, p[1] - r, p[2] - 2.0, p[0] + r, p[1] + r, p[2] + 2.0]
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum TrafficBehavior {
SpeedLimit { max_mps: f32 },
SchoolZone { max_mps: f32, active_hours: (u8, u8) },
YieldZone,
StopZone,
NoPassing,
Roundabout { clockwise: bool },
OddBoundary,
ConstructionZone { max_mps: f32 },
PedestrianPriority,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct BehaviorZone {
pub entry: ZoneEntry,
pub behavior: TrafficBehavior,
}
pub struct BehaviorZoneStore {
pub zones: Vec<BehaviorZone>,
pub store: ZoneStore,
}
impl BehaviorZoneStore {
pub fn build(zones: Vec<BehaviorZone>, conv: &dyn CoordSystem) -> Self {
let entries: Vec<ZoneEntry> = zones.iter().map(|b| b.entry.clone()).collect();
Self { store: ZoneStore::from_entries(&entries, conv), zones }
}
pub fn active_behaviors(&self, pos: [f64; 3], conv: &dyn CoordSystem) -> Vec<&TrafficBehavior> {
let p = conv.to_internal(pos);
let hits = self.store.query_enu(p);
hits.iter()
.filter_map(|&id| self.zones.iter().find(|z| z.entry.id == id))
.map(|z| &z.behavior)
.collect()
}
pub fn speed_limit_mps(&self, pos: [f64; 3], conv: &dyn CoordSystem) -> f32 {
self.active_behaviors(pos, conv)
.iter()
.filter_map(|b| match b {
TrafficBehavior::SpeedLimit { max_mps }
| TrafficBehavior::SchoolZone { max_mps, .. }
| TrafficBehavior::ConstructionZone { max_mps } => Some(*max_mps),
_ => None,
})
.fold(f32::MAX, f32::min)
}
pub fn odd_violated(&self, pos: [f64; 3], conv: &dyn CoordSystem) -> bool {
let has_odd = self
.zones
.iter()
.any(|z| matches!(z.behavior, TrafficBehavior::OddBoundary));
if !has_odd {
return false;
}
let p = conv.to_internal(pos);
let hits = self.store.query_enu(p);
let in_odd = hits.iter().any(|&id| {
self.zones
.iter()
.find(|z| z.entry.id == id)
.map(|z| matches!(z.behavior, TrafficBehavior::OddBoundary))
.unwrap_or(false)
});
!in_odd
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LaneNode {
pub id: u32,
pub pos: [f64; 3],
pub lane_id: u32,
pub speed_limit: f32,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LaneEdge {
pub from: u32,
pub to: u32,
pub kind: EdgeKind,
pub cost: f64,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub enum EdgeKind {
Forward,
LaneChange,
Intersection,
UTurn,
}
impl EdgeKind {
pub fn penalty(&self) -> f64 {
match self {
Self::Forward => 1.0,
Self::LaneChange => 2.5,
Self::Intersection => 1.5,
Self::UTurn => 5.0,
}
}
}
pub struct HdMap {
pub nodes: Vec<LaneNode>,
pub edges: Vec<LaneEdge>,
}
impl HdMap {
pub fn build(nodes: Vec<LaneNode>, edges: Vec<LaneEdge>) -> Self {
Self { nodes, edges }
}
pub fn nearest_node(&self, pos: [f64; 3]) -> Option<&LaneNode> {
self.nodes
.iter()
.min_by(|a, b| {
let da = (0..3).map(|i| (a.pos[i] - pos[i]).powi(2)).sum::<f64>();
let db = (0..3).map(|i| (b.pos[i] - pos[i]).powi(2)).sum::<f64>();
da.partial_cmp(&db).unwrap()
})
}
pub fn find_route(&self, start_pos: [f64; 3], goal_pos: [f64; 3]) -> Option<Vec<&LaneNode>> {
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Clone)]
struct State {
f: f64,
g: f64,
id: u32,
}
impl PartialEq for State {
fn eq(&self, o: &Self) -> bool {
self.f == o.f
}
}
impl Eq for State {}
impl PartialOrd for State {
fn partial_cmp(&self, o: &Self) -> Option<Ordering> {
Some(self.cmp(o))
}
}
impl Ord for State {
fn cmp(&self, o: &Self) -> Ordering {
o.f.partial_cmp(&self.f).unwrap_or(Ordering::Equal)
}
}
let start = self.nearest_node(start_pos)?.id;
let goal = self.nearest_node(goal_pos)?.id;
let goal_pos = self.nodes.iter().find(|n| n.id == goal)?.pos;
let h = |id: u32| -> f64 {
self.nodes
.iter()
.find(|n| n.id == id)
.map(|n| (0..3).map(|i| (n.pos[i] - goal_pos[i]).powi(2)).sum::<f64>().sqrt())
.unwrap_or(0.0)
};
let mut open = BinaryHeap::new();
let mut came = HashMap::<u32, u32>::new();
let mut g_cost = HashMap::<u32, f64>::new();
open.push(State { f: h(start), g: 0.0, id: start });
g_cost.insert(start, 0.0);
while let Some(State { g, id, .. }) = open.pop() {
if id == goal {
let mut path = vec![id];
let mut cur = id;
while let Some(&prev) = came.get(&cur) {
path.push(prev);
cur = prev;
}
path.reverse();
return Some(
path.iter()
.filter_map(|&i| self.nodes.iter().find(|n| n.id == i))
.collect(),
);
}
for edge in self.edges.iter().filter(|e| e.from == id) {
let ng = g + edge.cost * edge.kind.penalty();
if ng < *g_cost.get(&edge.to).unwrap_or(&f64::MAX) {
g_cost.insert(edge.to, ng);
came.insert(edge.to, id);
open.push(State { f: ng + h(edge.to), g: ng, id: edge.to });
}
}
}
None
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub struct V2xBeacon {
pub vehicle_id: u32,
pub pos: [f32; 3],
pub vel: [f32; 3],
pub heading_rad: f32,
pub speed_mps: f32,
pub ts_ms: u32,
pub safety_alert: Option<SafetyAlert>,
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub enum SafetyAlert {
EmergencyBrake,
ObstacleAhead { dist_m: f32 },
OddViolation,
SensorDegraded { sensor: SensorKind },
}
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)]
pub enum SensorKind {
Lidar,
Camera,
Radar,
Gnss,
Imu,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::zone::Zone;
fn still_pose() -> Arc<RwLock<VehiclePose>> {
Arc::new(RwLock::new(VehiclePose::stationary([0.0, 0.0, 0.0], 0)))
}
#[test]
fn ekf_predict_and_rtk_update() {
let mut ekf = ImuFusion::new([0.0, 0.0, 0.0]);
ekf.predict([1.0, 0.0, 0.0], 1);
ekf.predict([1.0, 0.0, 0.0], 1_000_000_001);
let pos = ekf.position();
assert!(pos[0] > 0.4, "after 1 s accel, x should be ~0.5 m: {pos:?}");
ekf.update_rtk([0.6, 0.0, 0.0], 0.02);
let pos2 = ekf.position();
assert!((pos2[0] - 0.6).abs() < 0.1, "RTK should pull toward 0.6: {pos2:?}");
}
#[test]
fn lidar_fov_contains_nearby_point() {
let pose = still_pose();
let fov = LidarFovZone {
pose: pose.clone(),
min_range: 1.0,
max_range: 50.0,
z_min: -2.0,
z_max: 2.0,
};
assert!(fov.contains_enu([10.0, 0.0, 0.0]));
assert!(!fov.contains_enu([0.5, 0.0, 0.0]), "inside dead zone");
assert!(!fov.contains_enu([60.0, 0.0, 0.0]), "beyond range");
}
#[test]
fn camera_frustum_forward_fov() {
let pose = Arc::new(RwLock::new(VehiclePose {
position: [0.0, 0.0, 1.5],
velocity: [0.0; 3],
heading_rad: 0.0,
ts_ns: 0,
}));
let fov = CameraFrustumZone {
pose: pose.clone(),
near_m: 1.0,
far_m: 50.0,
hfov_rad: std::f64::consts::FRAC_PI_2, vfov_rad: std::f64::consts::FRAC_PI_4, };
assert!(fov.contains_enu([10.0, 0.0, 1.5]), "ahead should be inside FOV");
assert!(!fov.contains_enu([-5.0, 0.0, 1.5]), "behind should be outside");
assert!(!fov.contains_enu([5.0, 10.0, 1.5]), "outside horizontal FOV");
assert!(!fov.contains_enu([0.5, 0.0, 1.5]), "inside dead zone");
}
#[test]
fn safety_envelope_stretches_with_speed() {
let pose = Arc::new(RwLock::new(VehiclePose {
position: [0.0, 0.0, 0.0],
velocity: [10.0, 0.0, 0.0],
heading_rad: 0.0,
ts_ns: 0,
}));
let env = SafetyEnvelope {
pose: pose.clone(),
width_m: 2.0,
base_front_m: 3.0,
base_rear_m: 2.0,
speed_factor: 1.0,
};
assert!(env.contains_enu([10.0, 0.0, 0.0]), "10 m ahead at 10 m/s → in envelope");
assert!(!env.contains_enu([20.0, 0.0, 0.0]), "20 m ahead → too far");
}
#[test]
fn prediction_zone_risk_decays_with_distance() {
let pz = PredictionZone {
origin: [0.0, 0.0, 0.0],
velocity: [10.0, 0.0, 0.0],
horizon_s: 5.0,
sigma_lat: 1.0,
sigma_fwd: 1.0,
};
let risk_center = pz.risk_at([10.0, 0.0, 0.0], 1.0);
let risk_far = pz.risk_at([50.0, 0.0, 0.0], 1.0);
assert!(risk_center > risk_far, "risk at predicted center should exceed far point");
}
#[test]
fn hd_map_finds_route() {
let nodes = vec![
LaneNode { id: 1, pos: [0.0, 0.0, 0.0], lane_id: 1, speed_limit: 13.9 },
LaneNode { id: 2, pos: [10.0, 0.0, 0.0], lane_id: 1, speed_limit: 13.9 },
LaneNode { id: 3, pos: [20.0, 0.0, 0.0], lane_id: 1, speed_limit: 13.9 },
];
let edges = vec![
LaneEdge { from: 1, to: 2, kind: EdgeKind::Forward, cost: 10.0 },
LaneEdge { from: 2, to: 3, kind: EdgeKind::Forward, cost: 10.0 },
];
let map = HdMap::build(nodes, edges);
let route = map.find_route([0.0, 0.0, 0.0], [20.0, 0.0, 0.0]).unwrap();
assert_eq!(route.len(), 3);
assert_eq!(route[0].id, 1);
assert_eq!(route[2].id, 3);
}
#[test]
fn behavior_zone_speed_limit() {
let conv = crate::coord::EnuConverter::new(0.0, 0.0, 0.0);
let bz = BehaviorZoneStore::build(
vec![BehaviorZone {
entry: ZoneEntry::new(
1,
Zone::Cylinder { center: [0.0, 0.0], radius_m: 50.0, z_min: 0.0, z_max: 10.0 },
),
behavior: TrafficBehavior::SpeedLimit { max_mps: 8.3 },
}],
&conv,
);
let limit = bz.speed_limit_mps([0.0, 0.0, 5.0], &conv);
assert!((limit - 8.3).abs() < 0.01);
assert!(!bz.odd_violated([0.0, 0.0, 5.0], &conv));
}
}