use alloc::vec::Vec;
use serde::{Deserialize, Serialize};
use crate::formant::VowelTarget;
use crate::phoneme::{Phoneme, phoneme_formants};
use crate::voice::VoiceProfile;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FormantKeypoint {
pub time: usize,
pub target: VowelTarget,
pub resistance: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrajectoryPlanner {
keypoints: Vec<FormantKeypoint>,
total_samples: usize,
}
impl TrajectoryPlanner {
#[must_use]
pub fn plan(
phonemes: &[Phoneme],
durations: &[f32],
voice: &VoiceProfile,
sample_rate: f32,
) -> Self {
assert_eq!(phonemes.len(), durations.len());
if phonemes.is_empty() {
return Self {
keypoints: Vec::new(),
total_samples: 0,
};
}
let mut boundaries = Vec::with_capacity(phonemes.len() + 1);
let mut offset = 0usize;
boundaries.push(0);
for &dur in durations {
offset += (dur * sample_rate) as usize;
boundaries.push(offset);
}
let total_samples = offset;
let mut keypoints = Vec::with_capacity(phonemes.len() + 2);
let first_target = voice.apply_formant_scale(&phoneme_formants(&phonemes[0]));
keypoints.push(FormantKeypoint {
time: 0,
target: first_target,
resistance: phonemes[0].coarticulation_resistance(),
});
for (i, phoneme) in phonemes.iter().enumerate() {
let mid = (boundaries[i] + boundaries[i + 1]) / 2;
let target = voice.apply_formant_scale(&phoneme_formants(phoneme));
keypoints.push(FormantKeypoint {
time: mid,
target,
resistance: phoneme.coarticulation_resistance(),
});
}
let last_target = voice.apply_formant_scale(&phoneme_formants(phonemes.last().unwrap()));
keypoints.push(FormantKeypoint {
time: total_samples,
target: last_target,
resistance: phonemes.last().unwrap().coarticulation_resistance(),
});
Self {
keypoints,
total_samples,
}
}
#[must_use]
pub fn formants_at(&self, sample: usize) -> VowelTarget {
if self.keypoints.is_empty() {
return VowelTarget::new(500.0, 1500.0, 2500.0, 3300.0, 3750.0);
}
if self.keypoints.len() == 1 {
return self.keypoints[0].target.clone();
}
let seg = self.find_segment(sample);
let k0 = &self.keypoints[seg];
let k1 = &self.keypoints[(seg + 1).min(self.keypoints.len() - 1)];
let span = (k1.time.saturating_sub(k0.time)).max(1);
let t = (sample.saturating_sub(k0.time)) as f32 / span as f32;
let t = t.clamp(0.0, 1.0);
if seg > 0 && seg + 2 < self.keypoints.len() {
let km1 = &self.keypoints[seg - 1];
let k2 = &self.keypoints[seg + 2];
let outer_resistance = (km1.resistance + k2.resistance) * 0.5;
let catmull_weight = 1.0 - outer_resistance;
if catmull_weight > 0.05 {
let linear = VowelTarget::interpolate(&k0.target, &k1.target, t);
let catmull = catmull_rom_vowel(&km1.target, &k0.target, &k1.target, &k2.target, t);
return blend_targets(&linear, &catmull, catmull_weight);
}
}
let t_smooth = hisab::calc::ease_in_out_smooth(t);
VowelTarget::interpolate(&k0.target, &k1.target, t_smooth)
}
pub fn apply_speaking_rate(&mut self, rate: f32) {
if (rate - 1.0).abs() < f32::EPSILON {
return;
}
let factor = 1.0 / rate;
for kp in &mut self.keypoints {
kp.resistance = (kp.resistance * factor).clamp(0.0, 1.0);
}
}
#[must_use]
pub fn total_samples(&self) -> usize {
self.total_samples
}
#[must_use]
pub fn num_keypoints(&self) -> usize {
self.keypoints.len()
}
#[must_use]
pub fn keypoints(&self) -> &[FormantKeypoint] {
&self.keypoints
}
fn find_segment(&self, sample: usize) -> usize {
let mut lo = 0;
let mut hi = self.keypoints.len().saturating_sub(1);
while lo < hi {
let mid = (lo + hi).div_ceil(2);
if self.keypoints[mid].time <= sample {
lo = mid;
} else {
hi = mid - 1;
}
}
lo.min(self.keypoints.len().saturating_sub(2))
}
}
fn catmull_rom_vowel(
p0: &VowelTarget,
p1: &VowelTarget,
p2: &VowelTarget,
p3: &VowelTarget,
t: f32,
) -> VowelTarget {
let cr = |a: f32, b: f32, c: f32, d: f32| -> f32 {
let t2 = t * t;
let t3 = t2 * t;
0.5 * ((2.0 * b)
+ (-a + c) * t
+ (2.0 * a - 5.0 * b + 4.0 * c - d) * t2
+ (-a + 3.0 * b - 3.0 * c + d) * t3)
};
VowelTarget::with_bandwidths(
[
cr(p0.f1, p1.f1, p2.f1, p3.f1),
cr(p0.f2, p1.f2, p2.f2, p3.f2),
cr(p0.f3, p1.f3, p2.f3, p3.f3),
cr(p0.f4, p1.f4, p2.f4, p3.f4),
cr(p0.f5, p1.f5, p2.f5, p3.f5),
],
[
cr(p0.b1, p1.b1, p2.b1, p3.b1),
cr(p0.b2, p1.b2, p2.b2, p3.b2),
cr(p0.b3, p1.b3, p2.b3, p3.b3),
cr(p0.b4, p1.b4, p2.b4, p3.b4),
cr(p0.b5, p1.b5, p2.b5, p3.b5),
],
)
}
fn blend_targets(a: &VowelTarget, b: &VowelTarget, weight: f32) -> VowelTarget {
VowelTarget::interpolate(a, b, weight)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::phoneme::Phoneme;
#[test]
fn test_empty_plan() {
let plan = TrajectoryPlanner::plan(&[], &[], &VoiceProfile::new_male(), 44100.0);
assert_eq!(plan.total_samples(), 0);
assert_eq!(plan.num_keypoints(), 0);
}
#[test]
fn test_single_phoneme_plan() {
let voice = VoiceProfile::new_male();
let plan = TrajectoryPlanner::plan(&[Phoneme::VowelA], &[0.1], &voice, 44100.0);
assert_eq!(plan.total_samples(), 4410);
assert_eq!(plan.num_keypoints(), 3);
}
#[test]
fn test_three_phoneme_plan() {
let voice = VoiceProfile::new_male();
let phonemes = [Phoneme::VowelA, Phoneme::NasalN, Phoneme::VowelI];
let durations = [0.1, 0.06, 0.1];
let plan = TrajectoryPlanner::plan(&phonemes, &durations, &voice, 44100.0);
assert_eq!(plan.num_keypoints(), 5);
assert_eq!(plan.total_samples(), (0.26 * 44100.0) as usize);
}
#[test]
fn test_formants_at_endpoints() {
let voice = VoiceProfile::new_male();
let phonemes = [Phoneme::VowelA, Phoneme::VowelI];
let durations = [0.1, 0.1];
let plan = TrajectoryPlanner::plan(&phonemes, &durations, &voice, 44100.0);
let target_a = voice.apply_formant_scale(&phoneme_formants(&Phoneme::VowelA));
let target_i = voice.apply_formant_scale(&phoneme_formants(&Phoneme::VowelI));
let at_start = plan.formants_at(0);
assert!((at_start.f1 - target_a.f1).abs() < 1.0);
let at_end = plan.formants_at(plan.total_samples());
assert!((at_end.f1 - target_i.f1).abs() < 1.0);
}
#[test]
fn test_formants_at_midpoint_blends() {
let voice = VoiceProfile::new_male();
let phonemes = [Phoneme::VowelA, Phoneme::VowelI];
let durations = [0.1, 0.1];
let plan = TrajectoryPlanner::plan(&phonemes, &durations, &voice, 44100.0);
let target_a = voice.apply_formant_scale(&phoneme_formants(&Phoneme::VowelA));
let target_i = voice.apply_formant_scale(&phoneme_formants(&Phoneme::VowelI));
let boundary = plan.total_samples() / 2;
let at_boundary = plan.formants_at(boundary);
let f1_a = target_a.f1;
let f1_i = target_i.f1;
let f1_mid = at_boundary.f1;
assert!(
(f1_mid > f1_i.min(f1_a) - 10.0) && (f1_mid < f1_a.max(f1_i) + 10.0),
"boundary F1 should be between /a/ and /i/: got {f1_mid}, range [{f1_i}, {f1_a}]"
);
}
#[test]
fn test_catmull_rom_influence() {
let voice = VoiceProfile::new_male();
let phonemes = [
Phoneme::VowelSchwa, Phoneme::VowelA,
Phoneme::VowelSchwa, ];
let durations = [0.1, 0.1, 0.1];
let plan = TrajectoryPlanner::plan(&phonemes, &durations, &voice, 44100.0);
for sample in (0..plan.total_samples()).step_by(100) {
let target = plan.formants_at(sample);
assert!(target.f1.is_finite());
assert!(target.f2.is_finite());
assert!(target.f1 > 0.0);
}
}
#[test]
fn test_serde_roundtrip_planner() {
let voice = VoiceProfile::new_male();
let plan = TrajectoryPlanner::plan(
&[Phoneme::VowelA, Phoneme::VowelI],
&[0.1, 0.1],
&voice,
44100.0,
);
let json = serde_json::to_string(&plan).unwrap();
let plan2: TrajectoryPlanner = serde_json::from_str(&json).unwrap();
assert_eq!(plan2.num_keypoints(), plan.num_keypoints());
assert_eq!(plan2.total_samples(), plan.total_samples());
}
}