1#![cfg(feature = "alloc")]
7
8use crate::motion::types::{ChainFeatures, Path, PathCache, SearchNode};
9use crate::motion::{MotionPlanner, MotionPlannerConfig};
10use crate::tonnetz::StdHyperTonnetz;
11use crate::triad::DynTriad;
12use crate::types::LPR;
13
14use alloc::collections::{BinaryHeap, VecDeque};
15use core::hash::Hash;
16use hashbrown::{HashMap, HashSet};
17use num_traits::{FromPrimitive, One, ToPrimitive, Zero};
18use rshyper::EdgeId;
19use rstmt::PitchMod;
20
21fn binary_heuristic<T>(triad: &DynTriad<T>, target: T) -> T
23where
24 T: PartialEq + One + Zero,
25{
26 match triad.contains(&target) {
29 true => return T::zero(),
30 false => T::one(),
31 }
32}
33
34impl<'a, T> MotionPlanner<'a, T>
35where
36 T: Eq + Hash,
37{
38 pub fn new(tonnetz: &'a StdHyperTonnetz<T>) -> Self {
40 let capacity = 1000; MotionPlanner {
42 cache: PathCache::new(capacity),
43 tonnetz,
44 config: MotionPlannerConfig::default(),
45 }
46 }
47 pub const fn cache(&self) -> &PathCache<T> {
49 &self.cache
50 }
51 pub const fn cache_mut(&mut self) -> &mut PathCache<T> {
53 &mut self.cache
54 }
55 pub const fn config(&self) -> &MotionPlannerConfig {
57 &self.config
58 }
59 pub const fn config_mut(&mut self) -> &mut MotionPlannerConfig {
61 &mut self.config
62 }
63 pub const fn max_depth(&self) -> usize {
65 self.config().max_depth()
66 }
67 pub const fn max_paths(&self) -> usize {
69 self.config().max_paths()
70 }
71 pub const fn tonnetz(&self) -> &StdHyperTonnetz<T> {
73 self.tonnetz
74 }
75 pub fn set_config(&mut self, config: MotionPlannerConfig) -> &mut Self {
77 self.config = config;
78 self
79 }
80 pub fn set_max_depth(&mut self, depth: usize) -> &mut Self {
82 self.config_mut().set_max_depth(depth);
83 self
84 }
85 pub fn set_max_paths(&mut self, paths: usize) -> &mut Self {
87 self.config_mut().set_max_paths(paths);
88 self
89 }
90 pub fn with_max_depth(self, depth: usize) -> Self {
92 Self {
93 config: self.config.with_max_depth(depth),
94 ..self
95 }
96 }
97 pub fn with_max_paths(self, paths: usize) -> Self {
99 Self {
100 config: self.config.with_max_paths(paths),
101 ..self
102 }
103 }
104}
105
106impl<'a, T> MotionPlanner<'a, T>
107where
108 T: Copy
109 + Eq
110 + Hash
111 + Ord
112 + FromPrimitive
113 + ToPrimitive
114 + One
115 + Zero
116 + PitchMod<Output = T>
117 + core::ops::Add<Output = T>
118 + core::ops::Sub<Output = T>
119 + core::ops::AddAssign,
120{
121 pub fn find_paths_to_pitch(&mut self, start_edge: EdgeId, target_pitch: T) -> Vec<Path<T>> {
123 let start_edge_t = <T>::from_usize(*start_edge).unwrap();
124 if let Some(paths) = self
126 .cache
127 .get(&[start_edge_t, T::zero(), T::zero()], target_pitch)
128 .cloned()
129 {
130 return paths;
131 }
132
133 let start_triad = match self.tonnetz().get_triad(&start_edge) {
135 Some(&triad) => triad,
136 None => return Vec::new(),
137 };
138
139 if start_triad.contains(&target_pitch) {
141 let features = ChainFeatures::default();
142 let path = Path::<T> {
143 transforms: Vec::new(),
144 triads: vec![start_triad],
145 edges: vec![Some(start_edge)],
146 cost: 0,
147 features,
148 };
149
150 self.cache.insert(
151 [start_edge_t, T::zero(), T::zero()],
152 target_pitch,
153 vec![path.clone()],
154 );
155 return vec![path];
156 }
157
158 let mut result_paths = Vec::new();
159 let mut open_set = BinaryHeap::new();
160 let mut visited = HashMap::<[T; 3], usize>::new(); open_set.push(SearchNode {
164 priority: 0, cost: 0,
166 triad: start_triad,
167 transforms: Vec::new(),
168 visited: vec![start_triad],
169 edges: vec![Some(start_edge)],
170 });
171
172 visited.insert(*start_triad.chord(), 0);
173
174 while let Some(node) = open_set.pop() {
175 if let Some(&prev_cost) = visited.get(node.triad.chord()) {
177 if prev_cost < node.cost {
178 continue;
179 }
180 }
181
182 if node.transforms.len() >= self.max_depth() {
184 continue;
185 }
186
187 for transform in LPR::iter() {
189 let next_triad = node.triad.transform(transform);
191 let new_cost = node.cost + 1;
192
193 if let Some(&prev_cost) = visited.get(next_triad.chord()) {
195 if prev_cost <= new_cost {
196 continue;
197 }
198 }
199
200 visited.insert(*next_triad.chord(), new_cost);
202
203 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
205 if facet.chord() == next_triad.chord() {
206 Some(id)
207 } else {
208 None
209 }
210 });
211
212 let mut new_transforms = node.transforms.clone();
214 new_transforms.push(transform);
215
216 let mut new_triads = node.visited.clone();
217 new_triads.push(next_triad);
218
219 let mut new_edge_ids = node.edges.clone();
220 new_edge_ids.push(next_edge_id);
221
222 if next_triad.contains(&target_pitch) {
224 let features = self.analyze_path_features(&new_triads);
226
227 let path = Path {
229 transforms: new_transforms.clone(),
230 triads: new_triads.clone(),
231 edges: new_edge_ids.clone(),
232 cost: new_cost,
233 features,
234 };
235
236 result_paths.push(path);
237
238 if result_paths.len() >= self.max_paths() {
240 result_paths.sort_by_key(|p| p.cost);
242
243 self.cache.insert(
245 [start_edge_t, T::zero(), T::zero()],
246 target_pitch,
247 result_paths.clone(),
248 );
249
250 return result_paths;
251 }
252 }
253
254 let h = binary_heuristic(&next_triad, target_pitch);
256 let priority = -((new_cost as i32) + h.to_i32().unwrap()); let next_node = SearchNode {
259 priority,
260 cost: new_cost,
261 triad: next_triad,
262 transforms: new_transforms,
263 visited: new_triads,
264 edges: new_edge_ids,
265 };
266
267 open_set.push(next_node);
268 }
269 }
270
271 result_paths.sort_by_key(|p| p.cost);
273
274 self.cache.insert(
276 [start_edge_t, T::zero(), T::zero()],
277 target_pitch,
278 result_paths.clone(),
279 );
280
281 result_paths
282 }
283 pub fn find_paths_between_edges(
285 &mut self,
286 start_edge: EdgeId,
287 goal_edge: EdgeId,
288 ) -> Vec<Path<T>> {
289 let start_edge_t = <T>::from_usize(*start_edge).unwrap();
290 let goal_edge_t = <T>::from_usize(*goal_edge).unwrap();
291 if let Some(paths) = self
293 .cache_mut()
294 .get(&[start_edge_t, goal_edge_t, T::one()], T::zero())
295 .cloned()
296 {
297 return paths;
298 }
299
300 let start_triad = match self.tonnetz().get_triad(&start_edge) {
302 Some(triad) => *triad,
303 None => return Vec::new(),
304 };
305
306 if start_edge == goal_edge {
313 let features = ChainFeatures::default();
314 let path = Path::<T> {
315 transforms: Vec::new(),
316 triads: vec![start_triad],
317 edges: vec![Some(start_edge)],
318 cost: 0,
319 features,
320 };
321
322 self.cache_mut().insert(
324 [start_edge_t, goal_edge_t, T::one()],
325 T::zero(),
326 vec![path.clone()],
327 );
328
329 return vec![path];
330 }
331
332 let mut all_paths = Vec::new();
334
335 for depth in 1..=self.max_depth() {
337 let paths = self.bfs_between_edges(start_edge, goal_edge, depth);
338
339 if !paths.is_empty() {
340 all_paths = paths;
341 break;
342 }
343 }
344
345 all_paths.sort_by_key(|p| p.cost);
347 if all_paths.len() > self.max_paths() {
348 all_paths.truncate(self.max_paths());
349 }
350
351 self.cache_mut().insert(
353 [start_edge_t, goal_edge_t, T::one()],
354 T::zero(),
355 all_paths.clone(),
356 );
357
358 all_paths
359 }
360
361 fn bfs_between_edges(
363 &self,
364 start_edge: EdgeId,
365 goal_edge: EdgeId,
366 max_depth: usize,
367 ) -> Vec<Path<T>> {
368 let mut result_paths = Vec::new();
369
370 let start_triad = match self.tonnetz().get_triad(&start_edge) {
372 Some(&triad) => triad,
373 None => return Vec::new(),
374 };
375
376 let mut queue = VecDeque::new();
378 queue.push_back((
379 start_triad, Vec::<LPR>::new(), vec![start_triad], vec![Some(start_edge)], 0usize, ));
385
386 let mut visited = HashMap::<[T; 3], HashSet<T>>::new();
388 visited
389 .entry(*start_triad.chord())
390 .or_default()
391 .insert(T::zero());
392
393 while let Some((current_triad, transforms, triads, edge_ids, depth)) = queue.pop_front() {
394 if depth >= max_depth {
396 continue;
397 }
398
399 for transform in LPR::iter() {
401 let next_triad = current_triad.transform(transform);
403 let next_depth = depth + 1;
404 let next_depth_t = T::from_usize(next_depth).unwrap();
405
406 let depths = visited.entry(*next_triad.chord()).or_default();
408 if depths.contains(&next_depth_t) {
409 continue;
410 }
411
412 depths.insert(next_depth_t);
414
415 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
417 if facet.chord() == next_triad.chord() {
418 Some(id)
419 } else {
420 None
421 }
422 });
423
424 let mut new_transforms = transforms.clone();
426 new_transforms.push(transform);
427
428 let mut new_triads = triads.clone();
429 new_triads.push(next_triad);
430
431 let mut new_edge_ids = edge_ids.clone();
432 new_edge_ids.push(next_edge_id);
433
434 if next_edge_id == Some(goal_edge) {
436 let features = self.analyze_path_features(&new_triads);
438 let cost = features.distance + new_transforms.len();
439
440 let path = Path::<T> {
442 transforms: new_transforms,
443 triads: new_triads,
444 edges: new_edge_ids,
445 cost,
446 features,
447 };
448
449 result_paths.push(path);
450
451 if result_paths.len() >= self.max_paths() {
453 return result_paths;
454 }
455 } else if next_depth < max_depth {
456 queue.push_back((
459 next_triad,
460 new_transforms,
461 new_triads,
462 new_edge_ids,
463 next_depth,
464 ));
465 }
466 }
467 }
468
469 result_paths
470 }
471 #[allow(clippy::too_many_arguments)]
472 pub fn search_from(
475 &self,
476 start_triad: DynTriad<T>,
477 target_pitch: T,
478 transforms: Vec<LPR>,
479 triads: Vec<DynTriad<T>>,
480 edge_ids: Vec<Option<EdgeId>>,
481 max_paths: usize,
482 remaining_depth: usize,
483 ) -> Vec<Path<T>> {
484 let mut result_paths = Vec::new();
485
486 if start_triad.contains(&target_pitch) {
488 let features = self.analyze_path_features(&triads);
489 let cost = features.distance + transforms.len();
490
491 result_paths.push(Path {
492 transforms,
493 triads,
494 edges: edge_ids,
495 cost,
496 features,
497 });
498
499 return result_paths;
500 }
501
502 if remaining_depth == 0 {
504 return result_paths;
505 }
506
507 let mut queue = VecDeque::new();
509 queue.push_back((
510 start_triad,
511 transforms.clone(),
512 triads.clone(),
513 edge_ids.clone(),
514 0usize,
515 ));
516
517 let mut visited = HashMap::<[T; 3], HashSet<T>>::new();
519 visited
520 .entry(*start_triad.chord())
521 .or_default()
522 .insert(T::zero());
523
524 while let Some((
525 current_triad,
526 current_transforms,
527 current_triads,
528 current_edge_ids,
529 depth,
530 )) = queue.pop_front()
531 {
532 if depth >= remaining_depth {
534 continue;
535 }
536
537 for transform in LPR::iter() {
539 let next_triad = current_triad.transform(transform);
541 let next_depth = depth + 1;
542 let next_depth_t = T::from_usize(next_depth).unwrap();
543
544 let depths = visited.entry(*next_triad.chord()).or_default();
546 if depths.contains(&next_depth_t) {
547 continue;
548 }
549
550 depths.insert(next_depth_t);
552
553 let next_edge_id = self.tonnetz().triads().iter().find_map(|(&id, facet)| {
555 if facet.chord() == next_triad.chord() {
556 Some(id)
557 } else {
558 None
559 }
560 });
561
562 let mut new_transforms = current_transforms.clone();
564 new_transforms.push(transform);
565
566 let mut new_triads = current_triads.clone();
567 new_triads.push(next_triad);
568
569 let mut new_edge_ids = current_edge_ids.clone();
570 new_edge_ids.push(next_edge_id);
571
572 if next_triad.contains(&target_pitch) {
574 let features = self.analyze_path_features(&new_triads);
576 let cost = features.distance() + new_transforms.len();
577
578 let path = Path::<T> {
580 transforms: new_transforms,
581 triads: new_triads,
582 edges: new_edge_ids,
583 cost,
584 features,
585 };
586
587 result_paths.push(path);
588
589 if result_paths.len() >= max_paths {
591 result_paths.sort_by_key(|p| p.cost);
592 return result_paths;
593 }
594 } else if next_depth < remaining_depth {
595 queue.push_back((
598 next_triad,
599 new_transforms,
600 new_triads,
601 new_edge_ids,
602 next_depth,
603 ));
604 }
605 }
606 }
607
608 result_paths.sort_by_key(|p| p.cost());
610 result_paths
611 }
612
613 #[cfg(feature = "rayon")]
615 pub fn find_paths_parallel(&self, start_edge: EdgeId, target_pitch: T) -> Vec<Path<T>> {
616 use rayon::iter::{ParallelBridge, ParallelIterator};
617
618 let start_triad = match self.tonnetz().get_triad(&start_edge) {
620 Some(triad) => *triad,
621 None => return Vec::new(),
622 };
623
624 if start_triad.contains(&target_pitch) {
626 let features = ChainFeatures::default();
627 let path = Path {
628 transforms: Vec::new(),
629 triads: vec![start_triad],
630 edges: vec![Some(start_edge)],
631 cost: 0,
632 features,
633 };
634
635 return vec![path];
636 }
637
638 let results: Vec<Vec<Path<T>>> = LPR::iter()
640 .par_bridge()
641 .filter_map(|transform| {
642 match transform.apply(&start_triad) {
644 Ok(next_triad) => {
645 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
647 if facet.chord() == next_triad.chord() {
648 Some(id)
649 } else {
650 None
651 }
652 });
653
654 let transforms = vec![transform];
655 let triads = vec![start_triad, next_triad];
656 let edge_ids = vec![Some(start_edge), next_edge_id];
657
658 Some(self.search_from(
660 next_triad,
661 target_pitch,
662 transforms,
663 triads,
664 edge_ids,
665 self.max_paths(),
666 self.max_depth() - 1,
667 ))
668 }
669 Err(_) => None,
670 }
671 })
672 .collect();
673
674 let mut all_paths = Vec::new();
676 for paths in results {
677 all_paths.extend(paths);
678 }
679
680 all_paths.sort_by_key(|p| p.cost);
682 if all_paths.len() > self.max_paths() {
683 all_paths.truncate(self.max_paths());
684 }
685
686 all_paths
687 }
688
689 fn analyze_path_features(&self, triads: &[DynTriad<T>]) -> ChainFeatures {
691 let mut features = ChainFeatures::default();
692
693 let mut transform_counts = HashMap::new();
695 let mut modality_changes = 0;
696 let mut voice_leading_distance = 0;
697
698 for i in 1..triads.len() {
700 let prev = &triads[i - 1];
701 let curr = &triads[i];
702
703 let transform = if prev.is_major() != curr.is_major() {
705 if prev.root() == curr.root() {
707 LPR::Parallel
708 }
709 else if prev.common_tones(curr).len() == 2 {
711 LPR::Relative
712 }
713 else {
715 LPR::Leading
716 }
717 } else {
718 LPR::Leading
720 };
721
722 *transform_counts.entry(transform).or_insert(0) += 1;
723
724 if prev.is_major() != curr.is_major() {
726 modality_changes += 1;
727 }
728
729 for &prev_note in prev.chord() {
731 let min_distance = curr
733 .chord()
734 .iter()
735 .map(|&curr_note| {
736 let dist = (curr_note - prev_note).pmod();
737 core::cmp::min(dist, T::from_usize(12).unwrap() - dist)
738 })
739 .min()
740 .unwrap_or(T::zero());
741
742 voice_leading_distance += min_distance.to_usize().unwrap();
743 }
744 }
745
746 features
747 .set_transform_counts(transform_counts)
748 .set_modality_changes(modality_changes)
749 .set_distance(voice_leading_distance);
750 features
751 }
752}