use hisab::Vec2;
use serde::{Deserialize, Serialize};
use crate::steer::SteerOutput;
#[cfg(feature = "logging")]
use tracing::instrument;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct WeightedSteer {
pub output: SteerOutput,
pub weight: f32,
}
#[cfg_attr(feature = "logging", instrument)]
#[must_use]
pub fn blend_weighted(inputs: &[WeightedSteer], max_speed: f32) -> SteerOutput {
if inputs.is_empty() {
return SteerOutput::default();
}
let mut total = Vec2::ZERO;
let mut weight_sum = 0.0f32;
for entry in inputs {
total += entry.output.velocity * entry.weight;
weight_sum += entry.weight;
}
if weight_sum < f32::EPSILON {
return SteerOutput::default();
}
let blended = total / weight_sum;
let len = blended.length();
if len > max_speed && len > f32::EPSILON {
SteerOutput::from_vec2(blended / len * max_speed)
} else {
SteerOutput::from_vec2(blended)
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct PrioritizedSteer {
pub output: SteerOutput,
pub priority: u32,
pub weight: f32,
}
#[cfg_attr(feature = "logging", instrument)]
#[must_use]
pub fn blend_priority(inputs: &[PrioritizedSteer], max_speed: f32, threshold: f32) -> SteerOutput {
if inputs.is_empty() {
return SteerOutput::default();
}
let mut sorted: Vec<usize> = (0..inputs.len()).collect();
sorted.sort_by_key(|&i| inputs[i].priority);
let mut current_priority = inputs[sorted[0]].priority;
let mut group_start = 0;
for i in 0..=sorted.len() {
let new_group = i == sorted.len() || inputs[sorted[i]].priority != current_priority;
if new_group {
let group: Vec<WeightedSteer> = sorted[group_start..i]
.iter()
.map(|&idx| WeightedSteer {
output: inputs[idx].output,
weight: inputs[idx].weight,
})
.collect();
let result = blend_weighted(&group, max_speed);
if result.speed() >= threshold || i == sorted.len() {
return result;
}
if i < sorted.len() {
current_priority = inputs[sorted[i]].priority;
group_start = i;
}
}
}
SteerOutput::default()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn blend_empty() {
let result = blend_weighted(&[], 5.0);
assert!(result.speed() < f32::EPSILON);
}
#[test]
fn blend_single() {
let inputs = [WeightedSteer {
output: SteerOutput::new(3.0, 4.0),
weight: 1.0,
}];
let result = blend_weighted(&inputs, 10.0);
assert!((result.speed() - 5.0).abs() < 0.01);
}
#[test]
fn blend_equal_weights() {
let inputs = [
WeightedSteer {
output: SteerOutput::new(4.0, 0.0),
weight: 1.0,
},
WeightedSteer {
output: SteerOutput::new(0.0, 4.0),
weight: 1.0,
},
];
let result = blend_weighted(&inputs, 10.0);
assert!((result.velocity.x - 2.0).abs() < 0.01);
assert!((result.velocity.y - 2.0).abs() < 0.01);
}
#[test]
fn blend_unequal_weights() {
let inputs = [
WeightedSteer {
output: SteerOutput::new(10.0, 0.0),
weight: 3.0,
},
WeightedSteer {
output: SteerOutput::new(0.0, 10.0),
weight: 1.0,
},
];
let result = blend_weighted(&inputs, 100.0);
assert!((result.velocity.x - 7.5).abs() < 0.01);
assert!((result.velocity.y - 2.5).abs() < 0.01);
}
#[test]
fn blend_clamps_to_max_speed() {
let inputs = [WeightedSteer {
output: SteerOutput::new(100.0, 0.0),
weight: 1.0,
}];
let result = blend_weighted(&inputs, 5.0);
assert!((result.speed() - 5.0).abs() < 0.01);
}
#[test]
fn blend_zero_weights() {
let inputs = [WeightedSteer {
output: SteerOutput::new(10.0, 0.0),
weight: 0.0,
}];
let result = blend_weighted(&inputs, 5.0);
assert!(result.speed() < f32::EPSILON);
}
#[test]
fn priority_empty() {
let result = blend_priority(&[], 5.0, 0.1);
assert!(result.speed() < f32::EPSILON);
}
#[test]
fn priority_high_overrides_low() {
let inputs = [
PrioritizedSteer {
output: SteerOutput::new(5.0, 0.0), priority: 0,
weight: 1.0,
},
PrioritizedSteer {
output: SteerOutput::new(0.0, 5.0), priority: 1,
weight: 1.0,
},
];
let result = blend_priority(&inputs, 10.0, 0.1);
assert!((result.velocity.x - 5.0).abs() < 0.01);
assert!(result.velocity.y.abs() < 0.01);
}
#[test]
fn priority_falls_through() {
let inputs = [
PrioritizedSteer {
output: SteerOutput::default(), priority: 0,
weight: 1.0,
},
PrioritizedSteer {
output: SteerOutput::new(0.0, 5.0), priority: 1,
weight: 1.0,
},
];
let result = blend_priority(&inputs, 10.0, 0.1);
assert!((result.velocity.y - 5.0).abs() < 0.01);
}
#[test]
fn priority_same_level_blended() {
let inputs = [
PrioritizedSteer {
output: SteerOutput::new(4.0, 0.0),
priority: 0,
weight: 1.0,
},
PrioritizedSteer {
output: SteerOutput::new(0.0, 4.0),
priority: 0,
weight: 1.0,
},
];
let result = blend_priority(&inputs, 10.0, 0.1);
assert!((result.velocity.x - 2.0).abs() < 0.01);
assert!((result.velocity.y - 2.0).abs() < 0.01);
}
#[test]
fn blend_weighted_serde_roundtrip() {
let ws = WeightedSteer {
output: SteerOutput::new(1.0, 2.0),
weight: 0.5,
};
let json = serde_json::to_string(&ws).unwrap();
let deserialized: WeightedSteer = serde_json::from_str(&json).unwrap();
assert!((deserialized.weight - 0.5).abs() < f32::EPSILON);
}
#[test]
fn blend_priority_serde_roundtrip() {
let ps = PrioritizedSteer {
output: SteerOutput::new(1.0, 2.0),
priority: 3,
weight: 0.5,
};
let json = serde_json::to_string(&ps).unwrap();
let deserialized: PrioritizedSteer = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.priority, 3);
}
#[test]
fn blend_opposing_forces_cancel() {
let inputs = [
WeightedSteer {
output: SteerOutput::new(5.0, 0.0),
weight: 1.0,
},
WeightedSteer {
output: SteerOutput::new(-5.0, 0.0),
weight: 1.0,
},
];
let result = blend_weighted(&inputs, 10.0);
assert!(result.speed() < f32::EPSILON);
}
}