1use super::cache::PathCache;
6use super::types::{Path, PathFeatures, SearchNode};
7use crate::tonnetz::Tonnetz;
8use crate::traits::PitchMod;
9use crate::{LPR, Triad};
10use rshyper::EdgeId;
11use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
12use strum::IntoEnumIterator; pub struct MotionPlanner<'a> {
16 cache: PathCache,
18 tonnetz: &'a Tonnetz,
20 max_depth: usize,
22 max_paths: usize,
24}
25
26impl<'a> MotionPlanner<'a> {
27 pub fn new(tonnetz: &'a Tonnetz) -> Self {
29 let capacity = 1000; MotionPlanner {
31 cache: PathCache::new(capacity),
32 tonnetz,
33 max_depth: 5, max_paths: 5, }
36 }
37
38 pub fn set_max_depth(&mut self, depth: usize) {
40 self.max_depth = depth;
41 }
42
43 pub fn set_max_paths(&mut self, paths: usize) {
45 self.max_paths = paths;
46 }
47
48 pub fn with_max_depth(self, depth: usize) -> Self {
50 Self {
51 max_depth: depth,
52 ..self
53 }
54 }
55
56 pub fn with_max_paths(self, paths: usize) -> Self {
58 Self {
59 max_paths: paths,
60 ..self
61 }
62 }
63
64 pub const fn cache(&self) -> &PathCache {
66 &self.cache
67 }
68
69 pub fn cache_mut(&mut self) -> &mut PathCache {
71 &mut self.cache
72 }
73 pub fn find_paths_to_pitch(&mut self, start_edge: EdgeId, target_pitch: usize) -> Vec<Path> {
75 fn improved_heuristic(triad: &Triad, target: usize) -> usize {
77 if triad.contains(&target) {
78 return 0;
79 }
80
81 1
84 }
85 if let Some(paths) = self.cache.get(&[start_edge.0, 0, 0], target_pitch).cloned() {
87 return paths;
88 }
89
90 let start_triad = match self.tonnetz.get_triad(start_edge) {
92 Some(triad) => triad.clone(),
93 None => return Vec::new(),
94 };
95
96 if start_triad.contains(&target_pitch) {
98 let features = PathFeatures::default();
99 let path = Path {
100 transforms: Vec::new(),
101 triads: vec![start_triad],
102 edge_ids: vec![Some(start_edge)],
103 cost: 0,
104 features,
105 };
106
107 self.cache
108 .insert([start_edge.0, 0, 0], target_pitch, vec![path.clone()]);
109 return vec![path];
110 }
111
112 let mut result_paths = Vec::new();
113 let mut open_set = BinaryHeap::new();
114 let mut visited = HashMap::<[usize; 3], usize>::new(); open_set.push(SearchNode {
118 priority: 0, cost: 0,
120 triad: start_triad.clone(),
121 transforms: Vec::new(),
122 triads: vec![start_triad.clone()],
123 edge_ids: vec![Some(start_edge)],
124 });
125
126 visited.insert(start_triad.notes(), 0);
127
128 while let Some(node) = open_set.pop() {
129 if let Some(&prev_cost) = visited.get(&node.triad.notes()) {
131 if prev_cost < node.cost {
132 continue;
133 }
134 }
135
136 if node.transforms.len() >= self.max_depth {
138 continue;
139 }
140
141 for transform in LPR::iter() {
143 let next_triad = node.triad.transform(transform);
145 let new_cost = node.cost + 1;
146
147 if let Some(&prev_cost) = visited.get(&next_triad.notes()) {
149 if prev_cost <= new_cost {
150 continue;
151 }
152 }
153
154 visited.insert(next_triad.notes(), new_cost);
156
157 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
159 if facet.notes() == next_triad.notes() {
160 Some(id)
161 } else {
162 None
163 }
164 });
165
166 let mut new_transforms = node.transforms.clone();
168 new_transforms.push(transform);
169
170 let mut new_triads = node.triads.clone();
171 new_triads.push(next_triad);
172
173 let mut new_edge_ids = node.edge_ids.clone();
174 new_edge_ids.push(next_edge_id);
175
176 if next_triad.contains(&target_pitch) {
178 let features = self.analyze_path_features(&new_triads);
180
181 let path = Path {
183 transforms: new_transforms.clone(),
184 triads: new_triads.clone(),
185 edge_ids: new_edge_ids.clone(),
186 cost: new_cost,
187 features,
188 };
189
190 result_paths.push(path);
191
192 if result_paths.len() >= self.max_paths {
194 result_paths.sort_by_key(|p| p.cost);
196
197 self.cache
199 .insert([start_edge.0, 0, 0], target_pitch, result_paths.clone());
200
201 return result_paths;
202 }
203 }
204
205 let h = improved_heuristic(&next_triad, target_pitch);
207 let priority = -((new_cost as i32) + (h as i32)); let next_node = SearchNode {
210 priority,
211 cost: new_cost,
212 triad: next_triad,
213 transforms: new_transforms,
214 triads: new_triads,
215 edge_ids: new_edge_ids,
216 };
217
218 open_set.push(next_node);
219 }
220 }
221
222 result_paths.sort_by_key(|p| p.cost);
224
225 self.cache
227 .insert([start_edge.0, 0, 0], target_pitch, result_paths.clone());
228
229 result_paths
230 }
231 pub fn find_paths_between_edges(&mut self, start_edge: EdgeId, goal_edge: EdgeId) -> Vec<Path> {
233 if let Some(paths) = self.cache.get(&[start_edge.0, goal_edge.0, 1], 0).cloned() {
235 return paths;
236 }
237
238 let start_triad = match self.tonnetz.get_triad(start_edge) {
240 Some(triad) => triad.clone(),
241 None => return Vec::new(),
242 };
243
244 if start_edge == goal_edge {
251 let features = PathFeatures::default();
252 let path = Path {
253 transforms: Vec::new(),
254 triads: vec![start_triad],
255 edge_ids: vec![Some(start_edge)],
256 cost: 0,
257 features,
258 };
259
260 self.cache
262 .insert([start_edge.0, goal_edge.0, 1], 0, vec![path.clone()]);
263
264 return vec![path];
265 }
266
267 let mut all_paths = Vec::new();
269
270 for depth in 1..=self.max_depth {
272 let paths = self.bfs_between_edges(start_edge, goal_edge, depth);
273
274 if !paths.is_empty() {
275 all_paths = paths;
276 break;
277 }
278 }
279
280 all_paths.sort_by_key(|p| p.cost);
282 if all_paths.len() > self.max_paths {
283 all_paths.truncate(self.max_paths);
284 }
285
286 self.cache
288 .insert([start_edge.0, goal_edge.0, 1], 0, all_paths.clone());
289
290 all_paths
291 }
292
293 fn bfs_between_edges(
295 &self,
296 start_edge: EdgeId,
297 goal_edge: EdgeId,
298 max_depth: usize,
299 ) -> Vec<Path> {
300 let mut result_paths = Vec::new();
301
302 let start_triad = match self.tonnetz.get_triad(start_edge) {
304 Some(triad) => triad.clone(),
305 None => return Vec::new(),
306 };
307
308 let mut queue = VecDeque::new();
310 queue.push_back((
311 start_triad.clone(), Vec::<LPR>::new(), vec![start_triad], vec![Some(start_edge)], 0usize, ));
317
318 let mut visited = HashMap::<[usize; 3], HashSet<usize>>::new();
320 visited.entry(start_triad.notes()).or_default().insert(0);
321
322 while let Some((current_triad, transforms, triads, edge_ids, depth)) = queue.pop_front() {
323 if depth >= max_depth {
325 continue;
326 }
327
328 for transform in LPR::iter() {
330 let next_triad = current_triad.transform(transform);
332 let next_depth = depth + 1;
333
334 let depths = visited.entry(next_triad.notes()).or_default();
336 if depths.contains(&next_depth) {
337 continue;
338 }
339
340 depths.insert(next_depth);
342
343 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
345 if facet.notes() == next_triad.notes() {
346 Some(id)
347 } else {
348 None
349 }
350 });
351
352 let mut new_transforms = transforms.clone();
354 new_transforms.push(transform);
355
356 let mut new_triads = triads.clone();
357 new_triads.push(next_triad);
358
359 let mut new_edge_ids = edge_ids.clone();
360 new_edge_ids.push(next_edge_id);
361
362 if next_edge_id == Some(goal_edge) {
364 let features = self.analyze_path_features(&new_triads);
366 let cost = features.distance + new_transforms.len();
367
368 let path = Path {
370 transforms: new_transforms,
371 triads: new_triads,
372 edge_ids: new_edge_ids,
373 cost,
374 features,
375 };
376
377 result_paths.push(path);
378
379 if result_paths.len() >= self.max_paths {
381 return result_paths;
382 }
383 } else if next_depth < max_depth {
384 queue.push_back((
387 next_triad,
388 new_transforms,
389 new_triads,
390 new_edge_ids,
391 next_depth,
392 ));
393 }
394 }
395 }
396
397 result_paths
398 }
399
400 pub fn search_from(
403 &self,
404 start_triad: Triad,
405 target_pitch: usize,
406 transforms: Vec<LPR>,
407 triads: Vec<Triad>,
408 edge_ids: Vec<Option<EdgeId>>,
409 max_paths: usize,
410 remaining_depth: usize,
411 ) -> Vec<Path> {
412 let mut result_paths = Vec::new();
413
414 if start_triad.contains(&target_pitch) {
416 let features = self.analyze_path_features(&triads);
417 let cost = features.distance + transforms.len();
418
419 result_paths.push(Path {
420 transforms,
421 triads,
422 edge_ids,
423 cost,
424 features,
425 });
426
427 return result_paths;
428 }
429
430 if remaining_depth == 0 {
432 return result_paths;
433 }
434
435 let mut queue = VecDeque::new();
437 queue.push_back((
438 start_triad.clone(),
439 transforms.clone(),
440 triads.clone(),
441 edge_ids.clone(),
442 0usize,
443 ));
444
445 let mut visited = HashMap::<[usize; 3], HashSet<usize>>::new();
447 visited.entry(start_triad.notes()).or_default().insert(0);
448
449 while let Some((
450 current_triad,
451 current_transforms,
452 current_triads,
453 current_edge_ids,
454 depth,
455 )) = queue.pop_front()
456 {
457 if depth >= remaining_depth {
459 continue;
460 }
461
462 for transform in LPR::iter() {
464 let next_triad = current_triad.transform(transform);
466 let next_depth = depth + 1;
467
468 let depths = visited.entry(next_triad.notes()).or_default();
470 if depths.contains(&next_depth) {
471 continue;
472 }
473
474 depths.insert(next_depth);
476
477 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
479 if facet.notes() == next_triad.notes() {
480 Some(id)
481 } else {
482 None
483 }
484 });
485
486 let mut new_transforms = current_transforms.clone();
488 new_transforms.push(transform);
489
490 let mut new_triads = current_triads.clone();
491 new_triads.push(next_triad);
492
493 let mut new_edge_ids = current_edge_ids.clone();
494 new_edge_ids.push(next_edge_id);
495
496 if next_triad.contains(&target_pitch) {
498 let features = self.analyze_path_features(&new_triads);
500 let cost = features.distance + new_transforms.len();
501
502 let path = Path {
504 transforms: new_transforms,
505 triads: new_triads,
506 edge_ids: new_edge_ids,
507 cost,
508 features,
509 };
510
511 result_paths.push(path);
512
513 if result_paths.len() >= max_paths {
515 result_paths.sort_by_key(|p| p.cost);
516 return result_paths;
517 }
518 } else if next_depth < remaining_depth {
519 queue.push_back((
522 next_triad,
523 new_transforms,
524 new_triads,
525 new_edge_ids,
526 next_depth,
527 ));
528 }
529 }
530 }
531
532 result_paths.sort_by_key(|p| p.cost);
534 result_paths
535 }
536
537 #[cfg(feature = "rayon")]
539 pub fn find_paths_parallel(&self, start_edge: EdgeId, target_pitch: usize) -> Vec<Path> {
540 use rayon::prelude::*;
541
542 let start_triad = match self.tonnetz.get_triad(start_edge) {
544 Some(triad) => triad.clone(),
545 None => return Vec::new(),
546 };
547
548 if start_triad.contains(&target_pitch) {
550 let features = PathFeatures::default();
551 let path = Path {
552 transforms: Vec::new(),
553 triads: vec![start_triad],
554 edge_ids: vec![Some(start_edge)],
555 cost: 0,
556 features,
557 };
558
559 return vec![path];
560 }
561
562 let results: Vec<Vec<Path>> = LPR::iter()
564 .par_bridge()
565 .filter_map(|transform| {
566 match transform.try_apply(&start_triad) {
568 Ok(next_triad) => {
569 let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
571 if facet.notes() == next_triad.notes() {
572 Some(id)
573 } else {
574 None
575 }
576 });
577
578 let transforms = vec![transform];
579 let triads = vec![start_triad.clone(), next_triad.clone()];
580 let edge_ids = vec![Some(start_edge), next_edge_id];
581
582 Some(self.search_from(
584 next_triad,
585 target_pitch,
586 transforms,
587 triads,
588 edge_ids,
589 self.max_paths,
590 self.max_depth - 1,
591 ))
592 }
593 Err(_) => None,
594 }
595 })
596 .collect();
597
598 let mut all_paths = Vec::new();
600 for paths in results {
601 all_paths.extend(paths);
602 }
603
604 all_paths.sort_by_key(|p| p.cost);
606 if all_paths.len() > self.max_paths {
607 all_paths.truncate(self.max_paths);
608 }
609
610 all_paths
611 }
612
613 fn analyze_path_features(&self, triads: &[Triad]) -> PathFeatures {
615 let mut features = PathFeatures::default();
616
617 let mut transform_counts = HashMap::new();
619 let mut modality_changes = 0;
620 let mut voice_leading_distance = 0;
621
622 for i in 1..triads.len() {
624 let prev = &triads[i - 1];
625 let curr = &triads[i];
626
627 let transform = if prev.is_major() != curr.is_major() {
629 if prev.root() == curr.root() {
631 LPR::Parallel
632 }
633 else if prev.common_tones(curr).len() == 2 {
635 LPR::Relative
636 }
637 else {
639 LPR::Leading
640 }
641 } else {
642 LPR::Leading
644 };
645
646 *transform_counts.entry(transform).or_insert(0) += 1;
647
648 if prev.is_major() != curr.is_major() {
650 modality_changes += 1;
651 }
652
653 for prev_note in prev.notes() {
655 let min_distance = curr
657 .notes()
658 .iter()
659 .map(|&curr_note| {
660 let dist = (curr_note as isize - prev_note as isize).abs().pmod();
661 std::cmp::min(dist, 12 - dist) as usize
662 })
663 .min()
664 .unwrap_or(0);
665
666 voice_leading_distance += min_distance;
667 }
668 }
669
670 features.transform_counts = transform_counts;
671 features.modality_changes = modality_changes;
672 features.distance = voice_leading_distance;
673
674 features
675 }
676}