1use bevy::ecs::query::QueryData;
2use enum_map::{Enum, EnumMap};
3use itertools::Itertools;
4use std::f32::consts::PI;
5
6use avian3d::prelude::*;
7use bevy::prelude::*;
8use derivative::Derivative;
9
10use crate::SMALL_THRESHOLD;
11
12const NUM_SLOTS: usize = 16;
13
14#[derive(Debug, Copy, Clone, Enum, Hash, PartialEq, Eq)]
16pub enum BehaviorType {
17 Alignment,
18 Approach,
19 Avoid,
20 Cohere,
21 Evasion,
22 Flee,
23 PathFollowing,
24 Pursuit,
25 Seek,
26 Separation,
27 Wander,
28}
29
30#[derive(Resource, Debug, Default, Copy, Clone)]
36pub struct TemporalSmoothing(f32);
37
38impl TemporalSmoothing {
39 pub fn new(blend: f32) -> Self {
44 Self(blend.clamp(0.0, 1.0))
45 }
46}
47
48#[derive(Component, Debug, Copy, Clone, Deref, DerefMut)]
53pub struct ForwardDir(Dir3);
54
55pub(crate) fn update_forward_dir(
56 mut commands: Commands,
57 query: Query<(Entity, &LinearVelocity, &GlobalTransform)>,
58) {
59 for (entity, velocity, transform) in query.iter() {
60 let forward = if velocity.0.length_squared() > SMALL_THRESHOLD {
61 Dir3::new_unchecked(velocity.normalize())
62 } else {
63 transform.forward()
64 };
65 commands.entity(entity).insert(ForwardDir(forward));
66 }
67}
68
69#[derive(Component, Default, Debug, Copy, Clone, Deref, DerefMut)]
70pub(crate) struct PreviousSteeringOutputs(SteeringOutputs);
71
72#[derive(Component, Default, Debug, Copy, Clone)]
77#[require(PreviousSteeringOutputs)]
78pub struct SteeringOutputs {
79 values: EnumMap<BehaviorType, Option<SteeringTarget>>,
80}
81
82impl SteeringOutputs {
83 #[allow(dead_code)]
84 pub(crate) fn get(&self, behavior: BehaviorType) -> Option<SteeringTarget> {
85 self.values[behavior]
86 }
87
88 pub(crate) fn set(&mut self, behavior: BehaviorType, target: SteeringTarget) {
89 self.values[behavior] = Some(target);
90 }
91
92 pub(crate) fn clear(&mut self, behavior: BehaviorType) {
93 self.values[behavior] = None;
94 }
95
96 fn only_some(&self) -> Vec<(BehaviorType, SteeringTarget)> {
99 self.values
100 .iter()
101 .filter_map(|(behavior, target)| target.map(|t| (behavior, t)))
102 .collect()
103 }
104
105 fn has_some(&self) -> bool {
107 self.values.iter().any(|(_, target)| target.is_some())
108 }
109
110 fn lerp(&self, other: &SteeringOutputs, blend: f32) -> SteeringOutputs {
111 let mut result = SteeringOutputs::default();
112 let keys = self
113 .values
114 .iter()
115 .zip(other.values.iter())
116 .flat_map(|((a, _), (b, _))| [a, b])
117 .unique();
118 for behavior in keys {
119 let a = self.get(behavior).unwrap_or_default();
120 let b = other.get(behavior).unwrap_or_default();
121 result.set(behavior, a.lerp(&b, blend));
122 }
123 result
124 }
125}
126
127#[derive(Debug, Copy, Clone, PartialEq, Reflect, Derivative)]
130#[derivative(Default)]
131pub struct SteeringTarget {
132 pub(crate) interest_map: [f32; NUM_SLOTS],
137 pub(crate) danger_map: [f32; NUM_SLOTS],
142}
143
144impl SteeringTarget {
145 const ZERO: Self = Self {
146 interest_map: [0.0; NUM_SLOTS],
147 danger_map: [0.0; NUM_SLOTS],
148 };
149
150 pub(crate) fn slot_to_dir(slot: usize) -> Vec3 {
151 const ANGLE_PER_SLOT: f32 = 2.0 * PI / NUM_SLOTS as f32;
152 let angle = slot as f32 * ANGLE_PER_SLOT;
153 Vec3::new(angle.cos(), 0.0, angle.sin())
154 }
155
156 fn interpolate_peak(values: &[f32], peak_slot: usize) -> Vec3 {
160 let mut sum_dir = Vec3::ZERO;
161 let mut sum_weight = 0.0;
162
163 for offset in -2i32..=2 {
165 let n = NUM_SLOTS as i32;
166 let slot = ((peak_slot as i32 + offset + n) % n) as usize;
167 let weight = values[slot];
168 sum_dir += weight * Self::slot_to_dir(slot);
169 sum_weight += weight;
170 }
171
172 if sum_weight > f32::EPSILON {
173 sum_dir.normalize_or_zero()
174 } else {
175 Self::slot_to_dir(peak_slot)
176 }
177 }
178
179 pub fn set_interest(&mut self, direction: Vec3) {
185 for i in 0..NUM_SLOTS {
186 let slot_dir = SteeringTarget::slot_to_dir(i);
187 self.interest_map[i] = slot_dir.dot(direction).max(0.0);
188 }
189 }
190
191 pub fn set_danger(&mut self, direction: Vec3) {
197 let direction = direction.normalize_or_zero();
198 for i in 0..NUM_SLOTS {
199 let slot_dir = SteeringTarget::slot_to_dir(i);
200 self.danger_map[i] = slot_dir.dot(direction).max(0.0);
201 }
202 }
203
204 fn lerp(&self, other: &SteeringTarget, blend: f32) -> SteeringTarget {
207 let mut result = SteeringTarget::default();
208
209 for i in 0..NUM_SLOTS {
211 result.interest_map[i] = self.interest_map[i].lerp(other.interest_map[i], blend);
212 result.danger_map[i] = self.danger_map[i].lerp(other.danger_map[i], blend);
213 }
214
215 result
216 }
217}
218
219#[derive(Component, Debug, Copy, Clone)]
220pub(crate) struct CombinedSteeringTarget(SteeringTarget);
221
222impl CombinedSteeringTarget {
223 const ZERO: Self = Self(SteeringTarget::ZERO);
224
225 pub fn new(targets: impl Iterator<Item = SteeringTarget>) -> Self {
226 let mut final_target = SteeringTarget::default();
227 for target in targets {
228 for i in 0..NUM_SLOTS {
229 let target_interest = target.interest_map[i];
230 let target_danger = target.danger_map[i];
231 final_target.interest_map[i] = target_interest.max(final_target.interest_map[i]);
232 final_target.danger_map[i] = target_danger.max(final_target.danger_map[i]);
233 }
234 }
235 CombinedSteeringTarget(final_target)
236 }
237
238 pub fn into_heading(self, danger_sensitivity: f32) -> Vec3 {
245 let inner = self.0;
246 let min_danger = inner
247 .danger_map
248 .iter()
249 .filter(|x| **x > 0.0)
250 .fold(f32::MAX, |a, &b| a.min(b));
251 let danger_threshold = min_danger + danger_sensitivity;
252 let mask = inner
253 .danger_map
254 .into_iter()
255 .map(|x| if x <= danger_threshold { 1.0 } else { 0.0 })
256 .collect::<Vec<_>>();
257 let masked_interest = inner
258 .interest_map
259 .into_iter()
260 .zip(mask.iter())
261 .map(|(x, y)| x * y)
262 .collect::<Vec<_>>();
263 let max_interest = masked_interest.iter().fold(f32::MIN, |a, &b| a.max(b));
264 if max_interest <= f32::EPSILON {
265 return Vec3::ZERO;
267 }
268 let target_slot = masked_interest
269 .iter()
270 .position(|x| *x >= max_interest)
271 .unwrap_or(0);
272 SteeringTarget::interpolate_peak(&masked_interest, target_slot)
273 }
274}
275
276#[derive(QueryData)]
277pub(crate) struct CombinedSteeringTargetQuery {
278 agent: Entity,
279 steering_outputs: &'static SteeringOutputs,
280}
281
282pub(crate) fn combine_steering_targets(
283 mut commands: Commands,
284 query: Query<CombinedSteeringTargetQuery>,
285) {
286 for query_item in query.iter() {
287 let steering_outputs = query_item.steering_outputs;
288 let mut entity = commands.entity(query_item.agent);
289 if steering_outputs.has_some() {
290 let targets = steering_outputs
291 .only_some()
292 .into_iter()
293 .map(|(_, target)| target);
294 entity.insert(CombinedSteeringTarget::new(targets));
295 } else {
296 entity.insert(CombinedSteeringTarget::ZERO);
299 }
300 }
301}
302
303#[derive(QueryData)]
304#[query_data(mutable)]
305pub(crate) struct TemporalSmoothingQuery {
306 previous_outputs: &'static PreviousSteeringOutputs,
307 outputs: &'static mut SteeringOutputs,
308}
309
310pub(crate) fn temporal_smoothing(
311 mut query: Query<TemporalSmoothingQuery>,
312 res_smoothing: Option<Res<TemporalSmoothing>>,
313) {
314 let Some(factor) = res_smoothing else {
315 return;
316 };
317
318 for query_item in query.iter_mut() {
319 let previous_outputs = query_item.previous_outputs;
320 let mut outputs = query_item.outputs;
321 *outputs = outputs.lerp(&previous_outputs.0, factor.0);
322 }
323}
324
325pub(crate) fn update_previous_steering_outputs(
326 mut query: Query<(&SteeringOutputs, &mut PreviousSteeringOutputs)>,
327) {
328 for (outputs, mut previous) in query.iter_mut() {
329 previous.0 = *outputs;
330 }
331}
332
333pub(crate) fn debug_combined_steering(
336 mut gizmos: Gizmos,
337 query: Query<(&GlobalTransform, &CombinedSteeringTarget)>,
338) {
339 const BASE_LINE_LENGTH: f32 = 8.0;
340
341 for (transform, combined_target) in query.iter() {
342 let agent_position = transform.translation();
343 let target = &combined_target.0;
344
345 for i in 0..NUM_SLOTS {
347 let direction = SteeringTarget::slot_to_dir(i);
348
349 let interest_value = target.interest_map[i];
351 if interest_value > 0.01 {
352 let interest_length = BASE_LINE_LENGTH * interest_value;
353 let end_point = agent_position + direction * interest_length;
354 gizmos.line(agent_position, end_point, Color::srgb(0.0, 1.0, 0.0));
355 }
356
357 let danger_value = target.danger_map[i];
359 if danger_value > 0.01 {
360 let danger_length = BASE_LINE_LENGTH * danger_value;
361 let end_point = agent_position + direction * danger_length;
362 let offset = Vec3::new(0.1, 0.0, 0.1);
363 gizmos.line(
364 agent_position + offset,
365 end_point + offset,
366 Color::srgb(1.0, 0.0, 0.0),
367 );
368 }
369 }
370 }
371}
372
373pub(crate) fn debug_forward_dir(mut gizmos: Gizmos, query: Query<(&GlobalTransform, &ForwardDir)>) {
374 const ARROW_LENGTH: f32 = 3.0;
375 const ARROW_COLOR: Color = Color::srgb(0.6, 0.3, 0.0);
376
377 for (transform, forward_dir) in query.iter() {
378 let start = transform.translation();
379 let end = start + forward_dir.as_vec3() * ARROW_LENGTH;
380 gizmos.arrow(start, end, ARROW_COLOR);
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387
388 #[test]
389 fn test_slot_to_dir() {
390 let test_cases = [
391 (0, Vec3::new(1.0, 0.0, 0.0)),
392 (4, Vec3::new(0.0, 0.0, 1.0)),
393 (8, Vec3::new(-1.0, 0.0, 0.0)),
394 (12, Vec3::new(0.0, 0.0, -1.0)),
395 ];
396 for (input, expect) in test_cases {
397 let output = SteeringTarget::slot_to_dir(input);
398 assert!(
399 output.abs_diff_eq(expect, 0.0001),
400 "Failed for input: {}",
401 input
402 );
403 }
404 }
405
406 #[test]
407 fn test_into_heading_surrounded_agent() {
408 let mut target = SteeringTarget::default();
411
412 let danger = [
414 0.8, 0.6, 0.3, 0.4, 0.5, 0.4, 0.3, 0.7, 0.9, 0.8, 0.6, 0.7, 0.8, 0.7, 0.6, 0.9,
415 ];
416 let interest = [
417 0.9, 0.7, 0.4, 0.2, 0.1, 0.2, 0.5, 0.8, 0.3, 0.2, 0.1, 0.1, 0.2, 0.1, 0.2, 0.3,
418 ];
419
420 target.danger_map = danger;
421 target.interest_map = interest;
422
423 let combined = CombinedSteeringTarget(target);
424 let heading = combined.into_heading(0.05);
425
426 let expected_angle = 6.0 * (2.0 * PI / 16.0);
435 let expected_direction = Vec3::new(expected_angle.cos(), 0.0, expected_angle.sin());
436
437 assert!(
440 heading.normalize().dot(expected_direction) > 0.9,
441 "Expected heading to be close to slot 6 direction. Got: {:?}, expected direction: {:?}",
442 heading,
443 expected_direction
444 );
445 }
446
447 #[test]
448 fn test_into_heading_clear_path() {
449 let mut target = SteeringTarget::default();
451
452 target.interest_map[0] = 1.0;
454 target.danger_map[0] = 0.0;
455
456 let combined = CombinedSteeringTarget(target);
457 let heading = combined.into_heading(0.05);
458
459 let expected_direction = Vec3::new(1.0, 0.0, 0.0);
461 assert!(
462 heading.normalize().dot(expected_direction) > 0.95,
463 "Expected heading to point in +X direction. Got: {:?}",
464 heading
465 );
466 }
467}