rstmt_nrt/motion/impls/
impl_motion_planner.rs

1/*
2    Appellation: impl_motion_planner <module>
3    Created At: 2025.12.29:20:34:56
4    Contrib: @FL03
5*/
6#![cfg(feature = "alloc")]
7
8use crate::motion::types::{ChainFeatures, Path, PathCache, SearchNode};
9use crate::motion::{MotionPlanner, MotionPlannerConfig};
10use crate::tonnetz::StdHyperTonnetz;
11use crate::triad::DynTriad;
12use crate::types::LPR;
13
14use alloc::collections::{BinaryHeap, VecDeque};
15use core::hash::Hash;
16use hashbrown::{HashMap, HashSet};
17use num_traits::{FromPrimitive, One, ToPrimitive, Zero};
18use rshyper::EdgeId;
19use rstmt::PitchMod;
20
21/// a binary-like step heuristic for estimating distance to target pitch
22fn binary_heuristic<T>(triad: &DynTriad<T>, target: T) -> T
23where
24    T: PartialEq + One + Zero,
25{
26    // For voice-leading distance, we use max of 1 as estimate
27    // This ensures heuristic is admissible (never overestimates)
28    match triad.contains(&target) {
29        true => return T::zero(),
30        false => T::one(),
31    }
32}
33
34impl<'a, T> MotionPlanner<'a, T>
35where
36    T: Eq + Hash,
37{
38    /// Create a new motion planner for the given tonnetz
39    pub fn new(tonnetz: &'a StdHyperTonnetz<T>) -> Self {
40        let capacity = 1000; // Default cache capacity
41        MotionPlanner {
42            cache: PathCache::new(capacity),
43            tonnetz,
44            config: MotionPlannerConfig::default(),
45        }
46    }
47    /// returns an immutable reference to the cache
48    pub const fn cache(&self) -> &PathCache<T> {
49        &self.cache
50    }
51    /// returns a mutable reference to the cache
52    pub const fn cache_mut(&mut self) -> &mut PathCache<T> {
53        &mut self.cache
54    }
55    /// returns an immutable reference to the configuration of the planner
56    pub const fn config(&self) -> &MotionPlannerConfig {
57        &self.config
58    }
59    /// returns a mutable reference to the configuration of the planner
60    pub const fn config_mut(&mut self) -> &mut MotionPlannerConfig {
61        &mut self.config
62    }
63    /// returns the maximum depth for pathfinding
64    pub const fn max_depth(&self) -> usize {
65        self.config().max_depth()
66    }
67    /// returns the maximum number of paths to find
68    pub const fn max_paths(&self) -> usize {
69        self.config().max_paths()
70    }
71    /// returns an immutable reference to the tonnetz
72    pub const fn tonnetz(&self) -> &StdHyperTonnetz<T> {
73        self.tonnetz
74    }
75    /// updates the current configuration and returns a mutable reference to the instance.
76    pub fn set_config(&mut self, config: MotionPlannerConfig) -> &mut Self {
77        self.config = config;
78        self
79    }
80    /// set the maximum depth for pathfinding
81    pub fn set_max_depth(&mut self, depth: usize) -> &mut Self {
82        self.config_mut().set_max_depth(depth);
83        self
84    }
85    /// set the maximum number of paths to find
86    pub fn set_max_paths(&mut self, paths: usize) -> &mut Self {
87        self.config_mut().set_max_paths(paths);
88        self
89    }
90    /// consumes the current instance to create another with the given maximum depth
91    pub fn with_max_depth(self, depth: usize) -> Self {
92        Self {
93            config: self.config.with_max_depth(depth),
94            ..self
95        }
96    }
97    /// consumes the current instance to create another with the given maximum number of paths
98    pub fn with_max_paths(self, paths: usize) -> Self {
99        Self {
100            config: self.config.with_max_paths(paths),
101            ..self
102        }
103    }
104}
105
106impl<'a, T> MotionPlanner<'a, T>
107where
108    T: Copy
109        + Eq
110        + Hash
111        + Ord
112        + FromPrimitive
113        + ToPrimitive
114        + One
115        + Zero
116        + PitchMod<Output = T>
117        + core::ops::Add<Output = T>
118        + core::ops::Sub<Output = T>
119        + core::ops::AddAssign,
120{
121    /// find a set of paths from one triad to one that contains the target pitch
122    pub fn find_paths_to_pitch(&mut self, start_edge: EdgeId, target_pitch: T) -> Vec<Path<T>> {
123        let start_edge_t = <T>::from_usize(*start_edge).unwrap();
124        // Check cache first
125        if let Some(paths) = self
126            .cache
127            .get(&[start_edge_t, T::zero(), T::zero()], target_pitch)
128            .cloned()
129        {
130            return paths;
131        }
132
133        // Get the starting triad
134        let start_triad = match self.tonnetz().get_triad(&start_edge) {
135            Some(&triad) => triad,
136            None => return Vec::new(),
137        };
138
139        // Check if the starting triad already contains the target pitch
140        if start_triad.contains(&target_pitch) {
141            let features = ChainFeatures::default();
142            let path = Path::<T> {
143                transforms: Vec::new(),
144                triads: vec![start_triad],
145                edges: vec![Some(start_edge)],
146                cost: 0,
147                features,
148            };
149
150            self.cache.insert(
151                [start_edge_t, T::zero(), T::zero()],
152                target_pitch,
153                vec![path.clone()],
154            );
155            return vec![path];
156        }
157
158        let mut result_paths = Vec::new();
159        let mut open_set = BinaryHeap::new();
160        let mut visited = HashMap::<[T; 3], usize>::new(); // Track visited triads with their path length
161
162        // Start with initial node
163        open_set.push(SearchNode {
164            priority: 0, // -heuristic for min-heap
165            cost: 0,
166            triad: start_triad,
167            transforms: Vec::new(),
168            visited: vec![start_triad],
169            edges: vec![Some(start_edge)],
170        });
171
172        visited.insert(*start_triad.chord(), 0);
173
174        while let Some(node) = open_set.pop() {
175            // Skip if we've found a shorter path to this triad
176            if let Some(&prev_cost) = visited.get(node.triad.chord()) {
177                if prev_cost < node.cost {
178                    continue;
179                }
180            }
181
182            // Check depth limit
183            if node.transforms.len() >= self.max_depth() {
184                continue;
185            }
186
187            // Try each transformation
188            for transform in LPR::iter() {
189                // Apply transformation
190                let next_triad = node.triad.transform(transform);
191                let new_cost = node.cost + 1;
192
193                // Skip if we've found a shorter path to this triad
194                if let Some(&prev_cost) = visited.get(next_triad.chord()) {
195                    if prev_cost <= new_cost {
196                        continue;
197                    }
198                }
199
200                // Update visited with this triad's path length
201                visited.insert(*next_triad.chord(), new_cost);
202
203                // Find edge ID if this exists in the tonnetz
204                let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
205                    if facet.chord() == next_triad.chord() {
206                        Some(id)
207                    } else {
208                        None
209                    }
210                });
211
212                // Build new path
213                let mut new_transforms = node.transforms.clone();
214                new_transforms.push(transform);
215
216                let mut new_triads = node.visited.clone();
217                new_triads.push(next_triad);
218
219                let mut new_edge_ids = node.edges.clone();
220                new_edge_ids.push(next_edge_id);
221
222                // Check if this triad contains our target pitch
223                if next_triad.contains(&target_pitch) {
224                    // Calculate path features
225                    let features = self.analyze_path_features(&new_triads);
226
227                    // Found a path
228                    let path = Path {
229                        transforms: new_transforms.clone(),
230                        triads: new_triads.clone(),
231                        edges: new_edge_ids.clone(),
232                        cost: new_cost,
233                        features,
234                    };
235
236                    result_paths.push(path);
237
238                    // Check if we've found enough paths
239                    if result_paths.len() >= self.max_paths() {
240                        // Sort paths by cost
241                        result_paths.sort_by_key(|p| p.cost);
242
243                        // Cache the result
244                        self.cache.insert(
245                            [start_edge_t, T::zero(), T::zero()],
246                            target_pitch,
247                            result_paths.clone(),
248                        );
249
250                        return result_paths;
251                    }
252                }
253
254                // Continue search - use admissible heuristic
255                let h = binary_heuristic(&next_triad, target_pitch);
256                let priority = -((new_cost as i32) + h.to_i32().unwrap()); // Negative for min-heap
257
258                let next_node = SearchNode {
259                    priority,
260                    cost: new_cost,
261                    triad: next_triad,
262                    transforms: new_transforms,
263                    visited: new_triads,
264                    edges: new_edge_ids,
265                };
266
267                open_set.push(next_node);
268            }
269        }
270
271        // Sort by cost
272        result_paths.sort_by_key(|p| p.cost);
273
274        // Cache results
275        self.cache.insert(
276            [start_edge_t, T::zero(), T::zero()],
277            target_pitch,
278            result_paths.clone(),
279        );
280
281        result_paths
282    }
283    /// Search for paths between two specific edges in the tonnetz
284    pub fn find_paths_between_edges(
285        &mut self,
286        start_edge: EdgeId,
287        goal_edge: EdgeId,
288    ) -> Vec<Path<T>> {
289        let start_edge_t = <T>::from_usize(*start_edge).unwrap();
290        let goal_edge_t = <T>::from_usize(*goal_edge).unwrap();
291        // Check cache first
292        if let Some(paths) = self
293            .cache_mut()
294            .get(&[start_edge_t, goal_edge_t, T::one()], T::zero())
295            .cloned()
296        {
297            return paths;
298        }
299
300        // Get the starting and goal triads
301        let start_triad = match self.tonnetz().get_triad(&start_edge) {
302            Some(triad) => *triad,
303            None => return Vec::new(),
304        };
305
306        // let goal_triad = match self.tonnetz().get_triad(goal_edge) {
307        //     Some(triad) => triad.clone(),
308        //     None => return Vec::new(),
309        // };
310
311        // Check if start and goal are the same
312        if start_edge == goal_edge {
313            let features = ChainFeatures::default();
314            let path = Path::<T> {
315                transforms: Vec::new(),
316                triads: vec![start_triad],
317                edges: vec![Some(start_edge)],
318                cost: 0,
319                features,
320            };
321
322            // Cache the result
323            self.cache_mut().insert(
324                [start_edge_t, goal_edge_t, T::one()],
325                T::zero(),
326                vec![path.clone()],
327            );
328
329            return vec![path];
330        }
331
332        // Use BFS with iterative deepening to find paths to goal edge
333        let mut all_paths = Vec::new();
334
335        // Iterative deepening
336        for depth in 1..=self.max_depth() {
337            let paths = self.bfs_between_edges(start_edge, goal_edge, depth);
338
339            if !paths.is_empty() {
340                all_paths = paths;
341                break;
342            }
343        }
344
345        // Sort by cost and limit to max_paths
346        all_paths.sort_by_key(|p| p.cost);
347        if all_paths.len() > self.max_paths() {
348            all_paths.truncate(self.max_paths());
349        }
350
351        // Cache the result
352        self.cache_mut().insert(
353            [start_edge_t, goal_edge_t, T::one()],
354            T::zero(),
355            all_paths.clone(),
356        );
357
358        all_paths
359    }
360
361    /// BFS to find paths between two edges with a depth limit
362    fn bfs_between_edges(
363        &self,
364        start_edge: EdgeId,
365        goal_edge: EdgeId,
366        max_depth: usize,
367    ) -> Vec<Path<T>> {
368        let mut result_paths = Vec::new();
369
370        // Get the starting triad
371        let start_triad = match self.tonnetz().get_triad(&start_edge) {
372            Some(&triad) => triad,
373            None => return Vec::new(),
374        };
375
376        // Initialize BFS queue
377        let mut queue = VecDeque::new();
378        queue.push_back((
379            start_triad,            // Current triad
380            Vec::<LPR>::new(),      // Transformation path
381            vec![start_triad],      // Triad history
382            vec![Some(start_edge)], // Edge IDs
383            0usize,                 // Current depth
384        ));
385
386        // Track visited triads at each depth
387        let mut visited = HashMap::<[T; 3], HashSet<T>>::new();
388        visited
389            .entry(*start_triad.chord())
390            .or_default()
391            .insert(T::zero());
392
393        while let Some((current_triad, transforms, triads, edge_ids, depth)) = queue.pop_front() {
394            // If we've reached max depth, skip this path
395            if depth >= max_depth {
396                continue;
397            }
398
399            // Try each transformation
400            for transform in LPR::iter() {
401                // Apply transformation
402                let next_triad = current_triad.transform(transform);
403                let next_depth = depth + 1;
404                let next_depth_t = T::from_usize(next_depth).unwrap();
405
406                // Check if this triad+depth combination has been visited before
407                let depths = visited.entry(*next_triad.chord()).or_default();
408                if depths.contains(&next_depth_t) {
409                    continue;
410                }
411
412                // Mark as visited at this depth
413                depths.insert(next_depth_t);
414
415                // Find edge ID if this triad exists in the tonnetz
416                let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
417                    if facet.chord() == next_triad.chord() {
418                        Some(id)
419                    } else {
420                        None
421                    }
422                });
423
424                // Create new path components
425                let mut new_transforms = transforms.clone();
426                new_transforms.push(transform);
427
428                let mut new_triads = triads.clone();
429                new_triads.push(next_triad);
430
431                let mut new_edge_ids = edge_ids.clone();
432                new_edge_ids.push(next_edge_id);
433
434                // Check if we've reached the goal edge
435                if next_edge_id == Some(goal_edge) {
436                    // Calculate path features
437                    let features = self.analyze_path_features(&new_triads);
438                    let cost = features.distance + new_transforms.len();
439
440                    // Create path
441                    let path = Path::<T> {
442                        transforms: new_transforms,
443                        triads: new_triads,
444                        edges: new_edge_ids,
445                        cost,
446                        features,
447                    };
448
449                    result_paths.push(path);
450
451                    // If we've found max_paths, return early
452                    if result_paths.len() >= self.max_paths() {
453                        return result_paths;
454                    }
455                } else if next_depth < max_depth {
456                    // If we haven't reached the goal and haven't reached max depth,
457                    // add to queue for further exploration
458                    queue.push_back((
459                        next_triad,
460                        new_transforms,
461                        new_triads,
462                        new_edge_ids,
463                        next_depth,
464                    ));
465                }
466            }
467        }
468
469        result_paths
470    }
471    #[allow(clippy::too_many_arguments)]
472    /// Find paths from a specified edge by continuing search from a given state
473    /// Used for parallel search implementations
474    pub fn search_from(
475        &self,
476        start_triad: DynTriad<T>,
477        target_pitch: T,
478        transforms: Vec<LPR>,
479        triads: Vec<DynTriad<T>>,
480        edge_ids: Vec<Option<EdgeId>>,
481        max_paths: usize,
482        remaining_depth: usize,
483    ) -> Vec<Path<T>> {
484        let mut result_paths = Vec::new();
485
486        // Check if we're already at a triad containing the target pitch
487        if start_triad.contains(&target_pitch) {
488            let features = self.analyze_path_features(&triads);
489            let cost = features.distance + transforms.len();
490
491            result_paths.push(Path {
492                transforms,
493                triads,
494                edges: edge_ids,
495                cost,
496                features,
497            });
498
499            return result_paths;
500        }
501
502        // If we've reached max depth, return empty results
503        if remaining_depth == 0 {
504            return result_paths;
505        }
506
507        // Use BFS for continuation of search
508        let mut queue = VecDeque::new();
509        queue.push_back((
510            start_triad,
511            transforms.clone(),
512            triads.clone(),
513            edge_ids.clone(),
514            0usize,
515        ));
516
517        // Track visited triads at each depth
518        let mut visited = HashMap::<[T; 3], HashSet<T>>::new();
519        visited
520            .entry(*start_triad.chord())
521            .or_default()
522            .insert(T::zero());
523
524        while let Some((
525            current_triad,
526            current_transforms,
527            current_triads,
528            current_edge_ids,
529            depth,
530        )) = queue.pop_front()
531        {
532            // If we've reached max depth, skip this path
533            if depth >= remaining_depth {
534                continue;
535            }
536
537            // Try each transformation
538            for transform in LPR::iter() {
539                // Apply transformation
540                let next_triad = current_triad.transform(transform);
541                let next_depth = depth + 1;
542                let next_depth_t = T::from_usize(next_depth).unwrap();
543
544                // Check if this triad+depth combination has been visited before
545                let depths = visited.entry(*next_triad.chord()).or_default();
546                if depths.contains(&next_depth_t) {
547                    continue;
548                }
549
550                // Mark as visited at this depth
551                depths.insert(next_depth_t);
552
553                // Find edge ID if this triad exists in the tonnetz
554                let next_edge_id = self.tonnetz().triads().iter().find_map(|(&id, facet)| {
555                    if facet.chord() == next_triad.chord() {
556                        Some(id)
557                    } else {
558                        None
559                    }
560                });
561
562                // Create new path components
563                let mut new_transforms = current_transforms.clone();
564                new_transforms.push(transform);
565
566                let mut new_triads = current_triads.clone();
567                new_triads.push(next_triad);
568
569                let mut new_edge_ids = current_edge_ids.clone();
570                new_edge_ids.push(next_edge_id);
571
572                // Check if this triad contains the target pitch
573                if next_triad.contains(&target_pitch) {
574                    // Calculate path features
575                    let features = self.analyze_path_features(&new_triads);
576                    let cost = features.distance() + new_transforms.len();
577
578                    // Create path
579                    let path = Path::<T> {
580                        transforms: new_transforms,
581                        triads: new_triads,
582                        edges: new_edge_ids,
583                        cost,
584                        features,
585                    };
586
587                    result_paths.push(path);
588
589                    // If we've found max_paths, return early
590                    if result_paths.len() >= max_paths {
591                        result_paths.sort_by_key(|p| p.cost);
592                        return result_paths;
593                    }
594                } else if next_depth < remaining_depth {
595                    // If we haven't reached the target and haven't reached max depth,
596                    // add to queue for further exploration
597                    queue.push_back((
598                        next_triad,
599                        new_transforms,
600                        new_triads,
601                        new_edge_ids,
602                        next_depth,
603                    ));
604                }
605            }
606        }
607
608        // Sort and return paths
609        result_paths.sort_by_key(|p| p.cost());
610        result_paths
611    }
612
613    /// Run searches in parallel from initial transformations
614    #[cfg(feature = "rayon")]
615    pub fn find_paths_parallel(&self, start_edge: EdgeId, target_pitch: T) -> Vec<Path<T>> {
616        use rayon::iter::{ParallelBridge, ParallelIterator};
617
618        // Get the starting triad
619        let start_triad = match self.tonnetz().get_triad(&start_edge) {
620            Some(triad) => *triad,
621            None => return Vec::new(),
622        };
623
624        // Check if starting triad already contains the target pitch
625        if start_triad.contains(&target_pitch) {
626            let features = ChainFeatures::default();
627            let path = Path {
628                transforms: Vec::new(),
629                triads: vec![start_triad],
630                edges: vec![Some(start_edge)],
631                cost: 0,
632                features,
633            };
634
635            return vec![path];
636        }
637
638        // Initial transformations for parallel searches
639        let results: Vec<Vec<Path<T>>> = LPR::iter()
640            .par_bridge()
641            .filter_map(|transform| {
642                // Try applying the transformation
643                match transform.apply(&start_triad) {
644                    Ok(next_triad) => {
645                        // Find edge ID if it exists
646                        let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
647                            if facet.chord() == next_triad.chord() {
648                                Some(id)
649                            } else {
650                                None
651                            }
652                        });
653
654                        let transforms = vec![transform];
655                        let triads = vec![start_triad, next_triad];
656                        let edge_ids = vec![Some(start_edge), next_edge_id];
657
658                        // Start search from this branch
659                        Some(self.search_from(
660                            next_triad,
661                            target_pitch,
662                            transforms,
663                            triads,
664                            edge_ids,
665                            self.max_paths(),
666                            self.max_depth() - 1,
667                        ))
668                    }
669                    Err(_) => None,
670                }
671            })
672            .collect();
673
674        // Combine and sort results
675        let mut all_paths = Vec::new();
676        for paths in results {
677            all_paths.extend(paths);
678        }
679
680        // Sort by cost and take max_paths
681        all_paths.sort_by_key(|p| p.cost);
682        if all_paths.len() > self.max_paths() {
683            all_paths.truncate(self.max_paths());
684        }
685
686        all_paths
687    }
688
689    /// Analyze musical features of a transformation path
690    fn analyze_path_features(&self, triads: &[DynTriad<T>]) -> ChainFeatures {
691        let mut features = ChainFeatures::default();
692
693        // Count transforms (infer from triad progression)
694        let mut transform_counts = HashMap::new();
695        let mut modality_changes = 0;
696        let mut voice_leading_distance = 0;
697
698        // Analyze modality changes and voice leading
699        for i in 1..triads.len() {
700            let prev = &triads[i - 1];
701            let curr = &triads[i];
702
703            // Determine which transform was applied (approximate)
704            let transform = if prev.is_major() != curr.is_major() {
705                // Parallel transform changes mode while preserving root
706                if prev.root() == curr.root() {
707                    LPR::Parallel
708                }
709                // Relative transform preserves two notes()
710                else if prev.common_tones(curr).len() == 2 {
711                    LPR::Relative
712                }
713                // Leading transform if no better match
714                else {
715                    LPR::Leading
716                }
717            } else {
718                // If mode is preserved, likely Leading transform
719                LPR::Leading
720            };
721
722            *transform_counts.entry(transform).or_insert(0) += 1;
723
724            // Check for modality change
725            if prev.is_major() != curr.is_major() {
726                modality_changes += 1;
727            }
728
729            // Calculate voice leading distance (semitone movement between triads)
730            for &prev_note in prev.chord() {
731                // Find the minimum distance to move from prev_note to any note in curr
732                let min_distance = curr
733                    .chord()
734                    .iter()
735                    .map(|&curr_note| {
736                        let dist = (curr_note - prev_note).pmod();
737                        core::cmp::min(dist, T::from_usize(12).unwrap() - dist)
738                    })
739                    .min()
740                    .unwrap_or(T::zero());
741
742                voice_leading_distance += min_distance.to_usize().unwrap();
743            }
744        }
745
746        features
747            .set_transform_counts(transform_counts)
748            .set_modality_changes(modality_changes)
749            .set_distance(voice_leading_distance);
750        features
751    }
752}