eryon_nrt/transform/
planner.rs

1/*
2    Appellation: motion <planner>
3    Contrib: @FL03
4*/
5use super::cache::PathCache;
6use super::types::{Path, PathFeatures, SearchNode};
7use crate::tonnetz::Tonnetz;
8use crate::traits::PitchMod;
9use crate::{LPR, Triad};
10use rshyper::EdgeId;
11use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
12use strum::IntoEnumIterator; // For LPR::iter()
13
14/// Motion planning algorithm for finding optimal paths in the Tonnetz
15pub struct MotionPlanner<'a> {
16    /// Cache for storing computed paths
17    cache: PathCache,
18    /// Reference to the tonnetz graph
19    tonnetz: &'a Tonnetz,
20    /// Maximum search depth for pathfinding
21    max_depth: usize,
22    /// Maximum number of paths to find
23    max_paths: usize,
24}
25
26impl<'a> MotionPlanner<'a> {
27    /// Create a new motion planner for the given tonnetz
28    pub fn new(tonnetz: &'a Tonnetz) -> Self {
29        let capacity = 1000; // Default cache capacity
30        MotionPlanner {
31            cache: PathCache::new(capacity),
32            tonnetz,
33            max_depth: 5, // Default search depth
34            max_paths: 5, // Default number of paths
35        }
36    }
37
38    /// Set the maximum search depth
39    pub fn set_max_depth(&mut self, depth: usize) {
40        self.max_depth = depth;
41    }
42
43    /// Set the maximum number of paths to find
44    pub fn set_max_paths(&mut self, paths: usize) {
45        self.max_paths = paths;
46    }
47
48    /// consumes the current instance to create another with the given maximum depth
49    pub fn with_max_depth(self, depth: usize) -> Self {
50        Self {
51            max_depth: depth,
52            ..self
53        }
54    }
55
56    /// consumes the current instance to create another with the given maximum number of paths
57    pub fn with_max_paths(self, paths: usize) -> Self {
58        Self {
59            max_paths: paths,
60            ..self
61        }
62    }
63
64    /// returns an immutable reference to the cache
65    pub const fn cache(&self) -> &PathCache {
66        &self.cache
67    }
68
69    /// returns a mutable reference to the cache
70    pub fn cache_mut(&mut self) -> &mut PathCache {
71        &mut self.cache
72    }
73    /// find a set of paths from one triad to one that contains the target pitch
74    pub fn find_paths_to_pitch(&mut self, start_edge: EdgeId, target_pitch: usize) -> Vec<Path> {
75        // Better heuristic function that is admissible for A*
76        fn improved_heuristic(triad: &Triad, target: usize) -> usize {
77            if triad.contains(&target) {
78                return 0;
79            }
80
81            // For voice-leading distance, we use max of 1 as estimate
82            // This ensures heuristic is admissible (never overestimates)
83            1
84        }
85        // Check cache first
86        if let Some(paths) = self.cache.get(&[start_edge.0, 0, 0], target_pitch).cloned() {
87            return paths;
88        }
89
90        // Get the starting triad
91        let start_triad = match self.tonnetz.get_triad(start_edge) {
92            Some(triad) => triad.clone(),
93            None => return Vec::new(),
94        };
95
96        // Check if the starting triad already contains the target pitch
97        if start_triad.contains(&target_pitch) {
98            let features = PathFeatures::default();
99            let path = Path {
100                transforms: Vec::new(),
101                triads: vec![start_triad],
102                edge_ids: vec![Some(start_edge)],
103                cost: 0,
104                features,
105            };
106
107            self.cache
108                .insert([start_edge.0, 0, 0], target_pitch, vec![path.clone()]);
109            return vec![path];
110        }
111
112        let mut result_paths = Vec::new();
113        let mut open_set = BinaryHeap::new();
114        let mut visited = HashMap::<[usize; 3], usize>::new(); // Track visited triads with their path length
115
116        // Start with initial node
117        open_set.push(SearchNode {
118            priority: 0, // -heuristic for min-heap
119            cost: 0,
120            triad: start_triad.clone(),
121            transforms: Vec::new(),
122            triads: vec![start_triad.clone()],
123            edge_ids: vec![Some(start_edge)],
124        });
125
126        visited.insert(start_triad.notes(), 0);
127
128        while let Some(node) = open_set.pop() {
129            // Skip if we've found a shorter path to this triad
130            if let Some(&prev_cost) = visited.get(&node.triad.notes()) {
131                if prev_cost < node.cost {
132                    continue;
133                }
134            }
135
136            // Check depth limit
137            if node.transforms.len() >= self.max_depth {
138                continue;
139            }
140
141            // Try each transformation
142            for transform in LPR::iter() {
143                // Apply transformation
144                let next_triad = node.triad.transform(transform);
145                let new_cost = node.cost + 1;
146
147                // Skip if we've found a shorter path to this triad
148                if let Some(&prev_cost) = visited.get(&next_triad.notes()) {
149                    if prev_cost <= new_cost {
150                        continue;
151                    }
152                }
153
154                // Update visited with this triad's path length
155                visited.insert(next_triad.notes(), new_cost);
156
157                // Find edge ID if this exists in the tonnetz
158                let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
159                    if facet.notes() == next_triad.notes() {
160                        Some(id)
161                    } else {
162                        None
163                    }
164                });
165
166                // Build new path
167                let mut new_transforms = node.transforms.clone();
168                new_transforms.push(transform);
169
170                let mut new_triads = node.triads.clone();
171                new_triads.push(next_triad);
172
173                let mut new_edge_ids = node.edge_ids.clone();
174                new_edge_ids.push(next_edge_id);
175
176                // Check if this triad contains our target pitch
177                if next_triad.contains(&target_pitch) {
178                    // Calculate path features
179                    let features = self.analyze_path_features(&new_triads);
180
181                    // Found a path
182                    let path = Path {
183                        transforms: new_transforms.clone(),
184                        triads: new_triads.clone(),
185                        edge_ids: new_edge_ids.clone(),
186                        cost: new_cost,
187                        features,
188                    };
189
190                    result_paths.push(path);
191
192                    // Check if we've found enough paths
193                    if result_paths.len() >= self.max_paths {
194                        // Sort paths by cost
195                        result_paths.sort_by_key(|p| p.cost);
196
197                        // Cache the result
198                        self.cache
199                            .insert([start_edge.0, 0, 0], target_pitch, result_paths.clone());
200
201                        return result_paths;
202                    }
203                }
204
205                // Continue search - use admissible heuristic
206                let h = improved_heuristic(&next_triad, target_pitch);
207                let priority = -((new_cost as i32) + (h as i32)); // Negative for min-heap
208
209                let next_node = SearchNode {
210                    priority,
211                    cost: new_cost,
212                    triad: next_triad,
213                    transforms: new_transforms,
214                    triads: new_triads,
215                    edge_ids: new_edge_ids,
216                };
217
218                open_set.push(next_node);
219            }
220        }
221
222        // Sort by cost
223        result_paths.sort_by_key(|p| p.cost);
224
225        // Cache results
226        self.cache
227            .insert([start_edge.0, 0, 0], target_pitch, result_paths.clone());
228
229        result_paths
230    }
231    /// Search for paths between two specific edges in the tonnetz
232    pub fn find_paths_between_edges(&mut self, start_edge: EdgeId, goal_edge: EdgeId) -> Vec<Path> {
233        // Check cache first
234        if let Some(paths) = self.cache.get(&[start_edge.0, goal_edge.0, 1], 0).cloned() {
235            return paths;
236        }
237
238        // Get the starting and goal triads
239        let start_triad = match self.tonnetz.get_triad(start_edge) {
240            Some(triad) => triad.clone(),
241            None => return Vec::new(),
242        };
243
244        // let goal_triad = match self.tonnetz.get_triad(goal_edge) {
245        //     Some(triad) => triad.clone(),
246        //     None => return Vec::new(),
247        // };
248
249        // Check if start and goal are the same
250        if start_edge == goal_edge {
251            let features = PathFeatures::default();
252            let path = Path {
253                transforms: Vec::new(),
254                triads: vec![start_triad],
255                edge_ids: vec![Some(start_edge)],
256                cost: 0,
257                features,
258            };
259
260            // Cache the result
261            self.cache
262                .insert([start_edge.0, goal_edge.0, 1], 0, vec![path.clone()]);
263
264            return vec![path];
265        }
266
267        // Use BFS with iterative deepening to find paths to goal edge
268        let mut all_paths = Vec::new();
269
270        // Iterative deepening
271        for depth in 1..=self.max_depth {
272            let paths = self.bfs_between_edges(start_edge, goal_edge, depth);
273
274            if !paths.is_empty() {
275                all_paths = paths;
276                break;
277            }
278        }
279
280        // Sort by cost and limit to max_paths
281        all_paths.sort_by_key(|p| p.cost);
282        if all_paths.len() > self.max_paths {
283            all_paths.truncate(self.max_paths);
284        }
285
286        // Cache the result
287        self.cache
288            .insert([start_edge.0, goal_edge.0, 1], 0, all_paths.clone());
289
290        all_paths
291    }
292
293    /// BFS to find paths between two edges with a depth limit
294    fn bfs_between_edges(
295        &self,
296        start_edge: EdgeId,
297        goal_edge: EdgeId,
298        max_depth: usize,
299    ) -> Vec<Path> {
300        let mut result_paths = Vec::new();
301
302        // Get the starting triad
303        let start_triad = match self.tonnetz.get_triad(start_edge) {
304            Some(triad) => triad.clone(),
305            None => return Vec::new(),
306        };
307
308        // Initialize BFS queue
309        let mut queue = VecDeque::new();
310        queue.push_back((
311            start_triad.clone(),    // Current triad
312            Vec::<LPR>::new(),      // Transformation path
313            vec![start_triad],      // Triad history
314            vec![Some(start_edge)], // Edge IDs
315            0usize,                 // Current depth
316        ));
317
318        // Track visited triads at each depth
319        let mut visited = HashMap::<[usize; 3], HashSet<usize>>::new();
320        visited.entry(start_triad.notes()).or_default().insert(0);
321
322        while let Some((current_triad, transforms, triads, edge_ids, depth)) = queue.pop_front() {
323            // If we've reached max depth, skip this path
324            if depth >= max_depth {
325                continue;
326            }
327
328            // Try each transformation
329            for transform in LPR::iter() {
330                // Apply transformation
331                let next_triad = current_triad.transform(transform);
332                let next_depth = depth + 1;
333
334                // Check if this triad+depth combination has been visited before
335                let depths = visited.entry(next_triad.notes()).or_default();
336                if depths.contains(&next_depth) {
337                    continue;
338                }
339
340                // Mark as visited at this depth
341                depths.insert(next_depth);
342
343                // Find edge ID if this triad exists in the tonnetz
344                let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
345                    if facet.notes() == next_triad.notes() {
346                        Some(id)
347                    } else {
348                        None
349                    }
350                });
351
352                // Create new path components
353                let mut new_transforms = transforms.clone();
354                new_transforms.push(transform);
355
356                let mut new_triads = triads.clone();
357                new_triads.push(next_triad);
358
359                let mut new_edge_ids = edge_ids.clone();
360                new_edge_ids.push(next_edge_id);
361
362                // Check if we've reached the goal edge
363                if next_edge_id == Some(goal_edge) {
364                    // Calculate path features
365                    let features = self.analyze_path_features(&new_triads);
366                    let cost = features.distance + new_transforms.len();
367
368                    // Create path
369                    let path = Path {
370                        transforms: new_transforms,
371                        triads: new_triads,
372                        edge_ids: new_edge_ids,
373                        cost,
374                        features,
375                    };
376
377                    result_paths.push(path);
378
379                    // If we've found max_paths, return early
380                    if result_paths.len() >= self.max_paths {
381                        return result_paths;
382                    }
383                } else if next_depth < max_depth {
384                    // If we haven't reached the goal and haven't reached max depth,
385                    // add to queue for further exploration
386                    queue.push_back((
387                        next_triad,
388                        new_transforms,
389                        new_triads,
390                        new_edge_ids,
391                        next_depth,
392                    ));
393                }
394            }
395        }
396
397        result_paths
398    }
399
400    /// Find paths from a specified edge by continuing search from a given state
401    /// Used for parallel search implementations
402    pub fn search_from(
403        &self,
404        start_triad: Triad,
405        target_pitch: usize,
406        transforms: Vec<LPR>,
407        triads: Vec<Triad>,
408        edge_ids: Vec<Option<EdgeId>>,
409        max_paths: usize,
410        remaining_depth: usize,
411    ) -> Vec<Path> {
412        let mut result_paths = Vec::new();
413
414        // Check if we're already at a triad containing the target pitch
415        if start_triad.contains(&target_pitch) {
416            let features = self.analyze_path_features(&triads);
417            let cost = features.distance + transforms.len();
418
419            result_paths.push(Path {
420                transforms,
421                triads,
422                edge_ids,
423                cost,
424                features,
425            });
426
427            return result_paths;
428        }
429
430        // If we've reached max depth, return empty results
431        if remaining_depth == 0 {
432            return result_paths;
433        }
434
435        // Use BFS for continuation of search
436        let mut queue = VecDeque::new();
437        queue.push_back((
438            start_triad.clone(),
439            transforms.clone(),
440            triads.clone(),
441            edge_ids.clone(),
442            0usize,
443        ));
444
445        // Track visited triads at each depth
446        let mut visited = HashMap::<[usize; 3], HashSet<usize>>::new();
447        visited.entry(start_triad.notes()).or_default().insert(0);
448
449        while let Some((
450            current_triad,
451            current_transforms,
452            current_triads,
453            current_edge_ids,
454            depth,
455        )) = queue.pop_front()
456        {
457            // If we've reached max depth, skip this path
458            if depth >= remaining_depth {
459                continue;
460            }
461
462            // Try each transformation
463            for transform in LPR::iter() {
464                // Apply transformation
465                let next_triad = current_triad.transform(transform);
466                let next_depth = depth + 1;
467
468                // Check if this triad+depth combination has been visited before
469                let depths = visited.entry(next_triad.notes()).or_default();
470                if depths.contains(&next_depth) {
471                    continue;
472                }
473
474                // Mark as visited at this depth
475                depths.insert(next_depth);
476
477                // Find edge ID if this triad exists in the tonnetz
478                let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
479                    if facet.notes() == next_triad.notes() {
480                        Some(id)
481                    } else {
482                        None
483                    }
484                });
485
486                // Create new path components
487                let mut new_transforms = current_transforms.clone();
488                new_transforms.push(transform);
489
490                let mut new_triads = current_triads.clone();
491                new_triads.push(next_triad);
492
493                let mut new_edge_ids = current_edge_ids.clone();
494                new_edge_ids.push(next_edge_id);
495
496                // Check if this triad contains the target pitch
497                if next_triad.contains(&target_pitch) {
498                    // Calculate path features
499                    let features = self.analyze_path_features(&new_triads);
500                    let cost = features.distance + new_transforms.len();
501
502                    // Create path
503                    let path = Path {
504                        transforms: new_transforms,
505                        triads: new_triads,
506                        edge_ids: new_edge_ids,
507                        cost,
508                        features,
509                    };
510
511                    result_paths.push(path);
512
513                    // If we've found max_paths, return early
514                    if result_paths.len() >= max_paths {
515                        result_paths.sort_by_key(|p| p.cost);
516                        return result_paths;
517                    }
518                } else if next_depth < remaining_depth {
519                    // If we haven't reached the target and haven't reached max depth,
520                    // add to queue for further exploration
521                    queue.push_back((
522                        next_triad,
523                        new_transforms,
524                        new_triads,
525                        new_edge_ids,
526                        next_depth,
527                    ));
528                }
529            }
530        }
531
532        // Sort and return paths
533        result_paths.sort_by_key(|p| p.cost);
534        result_paths
535    }
536
537    /// Run searches in parallel from initial transformations
538    #[cfg(feature = "rayon")]
539    pub fn find_paths_parallel(&self, start_edge: EdgeId, target_pitch: usize) -> Vec<Path> {
540        use rayon::prelude::*;
541
542        // Get the starting triad
543        let start_triad = match self.tonnetz.get_triad(start_edge) {
544            Some(triad) => triad.clone(),
545            None => return Vec::new(),
546        };
547
548        // Check if starting triad already contains the target pitch
549        if start_triad.contains(&target_pitch) {
550            let features = PathFeatures::default();
551            let path = Path {
552                transforms: Vec::new(),
553                triads: vec![start_triad],
554                edge_ids: vec![Some(start_edge)],
555                cost: 0,
556                features,
557            };
558
559            return vec![path];
560        }
561
562        // Initial transformations for parallel searches
563        let results: Vec<Vec<Path>> = LPR::iter()
564            .par_bridge()
565            .filter_map(|transform| {
566                // Try applying the transformation
567                match transform.try_apply(&start_triad) {
568                    Ok(next_triad) => {
569                        // Find edge ID if it exists
570                        let next_edge_id = self.tonnetz.triads.iter().find_map(|(&id, facet)| {
571                            if facet.notes() == next_triad.notes() {
572                                Some(id)
573                            } else {
574                                None
575                            }
576                        });
577
578                        let transforms = vec![transform];
579                        let triads = vec![start_triad.clone(), next_triad.clone()];
580                        let edge_ids = vec![Some(start_edge), next_edge_id];
581
582                        // Start search from this branch
583                        Some(self.search_from(
584                            next_triad,
585                            target_pitch,
586                            transforms,
587                            triads,
588                            edge_ids,
589                            self.max_paths,
590                            self.max_depth - 1,
591                        ))
592                    }
593                    Err(_) => None,
594                }
595            })
596            .collect();
597
598        // Combine and sort results
599        let mut all_paths = Vec::new();
600        for paths in results {
601            all_paths.extend(paths);
602        }
603
604        // Sort by cost and take max_paths
605        all_paths.sort_by_key(|p| p.cost);
606        if all_paths.len() > self.max_paths {
607            all_paths.truncate(self.max_paths);
608        }
609
610        all_paths
611    }
612
613    /// Analyze musical features of a transformation path
614    fn analyze_path_features(&self, triads: &[Triad]) -> PathFeatures {
615        let mut features = PathFeatures::default();
616
617        // Count transforms (infer from triad progression)
618        let mut transform_counts = HashMap::new();
619        let mut modality_changes = 0;
620        let mut voice_leading_distance = 0;
621
622        // Analyze modality changes and voice leading
623        for i in 1..triads.len() {
624            let prev = &triads[i - 1];
625            let curr = &triads[i];
626
627            // Determine which transform was applied (approximate)
628            let transform = if prev.is_major() != curr.is_major() {
629                // Parallel transform changes mode while preserving root
630                if prev.root() == curr.root() {
631                    LPR::Parallel
632                }
633                // Relative transform preserves two notes()
634                else if prev.common_tones(curr).len() == 2 {
635                    LPR::Relative
636                }
637                // Leading transform if no better match
638                else {
639                    LPR::Leading
640                }
641            } else {
642                // If mode is preserved, likely Leading transform
643                LPR::Leading
644            };
645
646            *transform_counts.entry(transform).or_insert(0) += 1;
647
648            // Check for modality change
649            if prev.is_major() != curr.is_major() {
650                modality_changes += 1;
651            }
652
653            // Calculate voice leading distance (semitone movement between triads)
654            for prev_note in prev.notes() {
655                // Find the minimum distance to move from prev_note to any note in curr
656                let min_distance = curr
657                    .notes()
658                    .iter()
659                    .map(|&curr_note| {
660                        let dist = (curr_note as isize - prev_note as isize).abs().pmod();
661                        std::cmp::min(dist, 12 - dist) as usize
662                    })
663                    .min()
664                    .unwrap_or(0);
665
666                voice_leading_distance += min_distance;
667            }
668        }
669
670        features.transform_counts = transform_counts;
671        features.modality_changes = modality_changes;
672        features.distance = voice_leading_distance;
673
674        features
675    }
676}