active_call/synthesis/
supertonic.rs

1use crate::offline::get_offline_models;
2use crate::synthesis::{SynthesisClient, SynthesisEvent, SynthesisOption, SynthesisType};
3use anyhow::{Result, anyhow};
4use async_trait::async_trait;
5use bytes::Bytes;
6use futures::stream::BoxStream;
7use tokio::sync::mpsc;
8use tokio_stream::wrappers::UnboundedReceiverStream;
9use tokio_util::sync::CancellationToken;
10use tracing::{debug, warn};
11use rubato::{Resampler, SincFixedIn, SincInterpolationParameters, SincInterpolationType, WindowFunction};
12
13pub struct SupertonicTtsClient {
14    voice_style: String,
15    speed: f32,
16    target_rate: i32,
17    tx: Option<mpsc::UnboundedSender<(Option<usize>, Result<SynthesisEvent>)>>,
18    token: CancellationToken,
19}
20
21impl SupertonicTtsClient {
22    pub fn create(_streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>> {
23        let voice_style = option.speaker.clone().unwrap_or_else(|| "M1".to_string());
24        let speed = option.speed.unwrap_or(1.0);
25        let target_rate = option.samplerate.unwrap_or(16000);
26
27        Ok(Box::new(Self {
28            voice_style,
29            speed,
30            target_rate,
31            tx: None,
32            token: CancellationToken::new(),
33        }))
34    }
35
36    fn ensure_models_initialized() -> Result<()> {
37        if get_offline_models().is_none() {
38            anyhow::bail!(
39                "Offline models not initialized. Please call init_offline_models() first."
40            );
41        }
42        Ok(())
43    }
44
45    async fn synthesize_text(&self, text: String, cmd_seq: Option<usize>) -> Result<()> {
46        let Some(tx) = self.tx.as_ref() else {
47            return Ok(());
48        };
49
50        let models =
51            get_offline_models().ok_or_else(|| anyhow!("offline models not initialized"))?;
52
53        let voice_style = self.voice_style.clone();
54        let speed = self.speed;
55        let target_rate = self.target_rate;
56        let tx_clone = tx.clone();
57        // Supertonic: en is hardcoded for now, or detect from text?
58        let language = "en".to_string();
59
60        let tts_arc = models.get_supertonic().await?;
61
62        // Run synthesis in blocking task
63        tokio::task::spawn_blocking(move || {
64            // Need write lock to use synthesize if it takes &mut self
65            // But if synthesize takes &mut self, blocking_write is needed.
66            let mut guard = tts_arc.blocking_write();
67
68            if let Some(tts) = guard.as_mut() {
69                debug!(
70                    text = %text,
71                    voice = %voice_style,
72                    speed = speed,
73                    target_rate = target_rate,
74                    "Calling Supertonic TTS synthesis"
75                );
76
77                match tts.synthesize(&text, &language, Some(&voice_style), Some(speed)) {
78                    Ok(mut samples) => {
79                        if !samples.is_empty() {
80                            // Resample if needed
81                            if tts.sample_rate() != target_rate {
82                                let ratio = target_rate as f64 / tts.sample_rate() as f64;
83                                let chunk_size = 1024;
84                                let params = SincInterpolationParameters {
85                                    sinc_len: 256,
86                                    f_cutoff: 0.95,
87                                    interpolation: SincInterpolationType::Linear,
88                                    window: WindowFunction::BlackmanHarris2,
89                                    oversampling_factor: 128,
90                                };
91                                match SincFixedIn::<f32>::new(
92                                    ratio,
93                                    2.0,
94                                    params,
95                                    chunk_size,
96                                    1,
97                                ) {
98                                    Ok(mut resampler) => {
99                                        let mut output = Vec::with_capacity((samples.len() as f64 * ratio + 100.0) as usize);
100                                        let mut buffer = vec![vec![0.0; chunk_size]; 1]; 
101                                        
102                                        // Pad input
103                                        let padding = if samples.len() % chunk_size != 0 {
104                                            chunk_size - (samples.len() % chunk_size)
105                                        } else {
106                                            0
107                                        };
108                                        for _ in 0..padding {
109                                            samples.push(0.0);
110                                        }
111
112                                        for chunk in samples.chunks(chunk_size) {
113                                            buffer[0].copy_from_slice(chunk);
114                                            if let Ok(out) = resampler.process(&buffer, None) {
115                                               output.extend_from_slice(&out[0]);
116                                            }
117                                        }
118                                        samples = output;
119                                    }
120                                    Err(e) => {
121                                         warn!(error = %e, "Failed to create resampler, using original");
122                                    }
123                                }
124                            }
125
126                            // Convert f32 samples to PCM bytes (i16)
127                            let mut bytes = Vec::with_capacity(samples.len() * 2);
128                            for sample in samples {
129                                // Clip and convert
130                                let s = (sample * 32768.0).max(-32768.0).min(32767.0) as i16;
131                                bytes.extend_from_slice(&s.to_le_bytes());
132                            }
133
134                            // Send AudioChunk
135                            let _ = tx_clone.send((
136                                cmd_seq,
137                                Ok(SynthesisEvent::AudioChunk(Bytes::from(bytes))),
138                            ));
139
140                            // Send Finished
141                            let _ = tx_clone.send((cmd_seq, Ok(SynthesisEvent::Finished)));
142                        } else {
143                            warn!("Supertonic produced empty audio");
144                            let _ = tx_clone.send((cmd_seq, Ok(SynthesisEvent::Finished)));
145                        }
146                    }
147                    Err(e) => {
148                        warn!(error = %e, "Supertonic inference failed");
149                        let _ = tx_clone.send((cmd_seq, Err(anyhow!("Synthesis failed: {}", e))));
150                    }
151                }
152            } else {
153                warn!("Supertonic TTS not initialized");
154                let _ = tx_clone.send((cmd_seq, Err(anyhow!("TTS not initialized"))));
155            }
156        })
157        .await
158        .map_err(|e| anyhow!("task join error: {}", e))?;
159
160        Ok(())
161    }
162}
163
164#[async_trait]
165impl SynthesisClient for SupertonicTtsClient {
166    fn provider(&self) -> SynthesisType {
167        SynthesisType::Supertonic
168    }
169
170    async fn start(
171        &mut self,
172    ) -> Result<BoxStream<'static, (Option<usize>, Result<SynthesisEvent>)>> {
173        Self::ensure_models_initialized()?;
174
175        let (tx, rx) = mpsc::unbounded_channel();
176        self.tx = Some(tx);
177
178        // Initialize TTS if needed
179        let models =
180            get_offline_models().ok_or_else(|| anyhow!("offline models not initialized"))?;
181        models.init_supertonic().await?;
182
183        debug!(
184            "SupertonicTtsClient started with voice: {}",
185            self.voice_style
186        );
187
188        Ok(Box::pin(UnboundedReceiverStream::new(rx)))
189    }
190
191    async fn synthesize(
192        &mut self,
193        text: &str,
194        cmd_seq: Option<usize>,
195        _option: Option<SynthesisOption>,
196    ) -> Result<()> {
197        self.synthesize_text(text.to_string(), cmd_seq).await
198    }
199
200    async fn stop(&mut self) -> Result<()> {
201        self.token.cancel();
202        self.tx = None;
203        Ok(())
204    }
205}