any_tts/traits.rs
1//! Core TTS trait and request/response types.
2
3use std::path::Path;
4
5use crate::audio::AudioSamples;
6use crate::config::TtsConfig;
7use crate::error::TtsError;
8
9/// Metadata about a loaded model.
10#[derive(Debug, Clone)]
11pub struct ModelInfo {
12 /// Human-readable model name.
13 pub name: String,
14 /// Model version or variant.
15 pub variant: String,
16 /// Approximate parameter count.
17 pub parameters: u64,
18 /// Output audio sample rate in Hz.
19 pub sample_rate: u32,
20 /// Supported languages (ISO 639-1 codes or language names).
21 pub languages: Vec<String>,
22 /// Available voice/speaker names.
23 pub voices: Vec<String>,
24}
25
26/// A request to synthesize speech from text.
27#[derive(Debug, Clone)]
28pub struct SynthesisRequest {
29 /// The text to synthesize.
30 pub text: String,
31 /// Target language (ISO code or name). If `None`, auto-detect.
32 pub language: Option<String>,
33 /// Voice/speaker name. If `None`, use default.
34 pub voice: Option<String>,
35 /// Style instruction (model-dependent, e.g. "angry", "whisper").
36 pub instruct: Option<String>,
37 /// Maximum number of tokens to generate. If `None`, use model default.
38 pub max_tokens: Option<usize>,
39 /// Sampling temperature. If `None`, use model default.
40 pub temperature: Option<f64>,
41 /// Guidance scale for backends that support classifier-free guidance.
42 pub cfg_scale: Option<f64>,
43 /// Reference audio for zero-shot voice cloning.
44 ///
45 /// When provided, the model will attempt to match the voice characteristics
46 /// of this audio instead of using a named voice. Not all models support this;
47 /// models that don't will return [`TtsError::ModelError`].
48 pub reference_audio: Option<ReferenceAudio>,
49 /// Pre-extracted voice embedding for voice cloning.
50 ///
51 /// Use this to re-use an embedding extracted via [`VoiceCloning::extract_voice`]
52 /// without re-processing the reference audio each time.
53 pub voice_embedding: Option<VoiceEmbedding>,
54}
55
56impl SynthesisRequest {
57 /// Create a new synthesis request with the given text.
58 pub fn new(text: impl Into<String>) -> Self {
59 Self {
60 text: text.into(),
61 language: None,
62 voice: None,
63 instruct: None,
64 max_tokens: None,
65 temperature: None,
66 cfg_scale: None,
67 reference_audio: None,
68 voice_embedding: None,
69 }
70 }
71
72 /// Set the target language.
73 pub fn with_language(mut self, lang: impl Into<String>) -> Self {
74 self.language = Some(lang.into());
75 self
76 }
77
78 /// Set the voice/speaker.
79 pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
80 self.voice = Some(voice.into());
81 self
82 }
83
84 /// Set a style instruction.
85 pub fn with_instruct(mut self, instruct: impl Into<String>) -> Self {
86 self.instruct = Some(instruct.into());
87 self
88 }
89
90 /// Set the maximum number of tokens.
91 pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
92 self.max_tokens = Some(max_tokens);
93 self
94 }
95
96 /// Set the sampling temperature.
97 pub fn with_temperature(mut self, temperature: f64) -> Self {
98 self.temperature = Some(temperature);
99 self
100 }
101
102 /// Set the classifier-free guidance scale.
103 pub fn with_cfg_scale(mut self, cfg_scale: f64) -> Self {
104 self.cfg_scale = Some(cfg_scale);
105 self
106 }
107
108 /// Set reference audio for zero-shot voice cloning.
109 ///
110 /// The model will extract speaker characteristics from this audio and
111 /// use them to condition the synthesis. Overrides any named voice.
112 pub fn with_reference_audio(mut self, audio: ReferenceAudio) -> Self {
113 self.reference_audio = Some(audio);
114 self
115 }
116
117 /// Set a pre-extracted voice embedding.
118 ///
119 /// Useful for caching: extract once with [`VoiceCloning::extract_voice`],
120 /// then reuse across multiple synthesis calls.
121 pub fn with_voice_embedding(mut self, embedding: VoiceEmbedding) -> Self {
122 self.voice_embedding = Some(embedding);
123 self
124 }
125}
126
127/// The core trait that all TTS model backends must implement.
128///
129/// This provides a unified interface for text-to-speech synthesis regardless
130/// of the underlying model architecture.
131pub trait TtsModel: Send + Sync {
132 /// Load model weights and initialize the model from configuration.
133 ///
134 /// This may involve reading safetensors files, parsing config JSON,
135 /// and moving weights to the target device.
136 fn load(config: TtsConfig) -> Result<Self, TtsError>
137 where
138 Self: Sized;
139
140 /// Synthesize speech from a text request.
141 ///
142 /// Returns raw f32 PCM audio samples at the model's native sample rate.
143 fn synthesize(&self, request: &SynthesisRequest) -> Result<AudioSamples, TtsError>;
144
145 /// Return the native output sample rate in Hz (e.g. 24000).
146 fn sample_rate(&self) -> u32;
147
148 /// Return the list of supported language identifiers.
149 fn supported_languages(&self) -> Vec<String>;
150
151 /// Return the list of available voice/speaker names.
152 fn supported_voices(&self) -> Vec<String>;
153
154 /// Return metadata about this model.
155 fn model_info(&self) -> ModelInfo;
156}
157
158// ---------------------------------------------------------------------------
159// Voice cloning types
160// ---------------------------------------------------------------------------
161
162/// Raw audio data used as a reference for voice cloning.
163///
164/// The model will extract speaker characteristics from this audio and use
165/// them to condition speech synthesis. For best results:
166///
167/// - Use 3–10 seconds of clean speech (single speaker, no background noise)
168/// - Match the model's native sample rate (e.g. 24 kHz for Kokoro)
169/// or the library will resample automatically
170///
171/// # Example
172///
173/// ```rust
174/// use any_tts::ReferenceAudio;
175///
176/// let audio = ReferenceAudio::new(vec![0.0f32; 24000], 24000);
177/// assert_eq!(audio.duration_secs(), 1.0);
178/// ```
179#[derive(Debug, Clone)]
180pub struct ReferenceAudio {
181 /// Raw f32 PCM audio samples in `[-1.0, 1.0]`.
182 pub samples: Vec<f32>,
183 /// Sample rate of the audio in Hz.
184 pub sample_rate: u32,
185}
186
187impl ReferenceAudio {
188 /// Create a new reference audio from raw samples.
189 pub fn new(samples: Vec<f32>, sample_rate: u32) -> Self {
190 Self {
191 samples,
192 sample_rate,
193 }
194 }
195
196 /// Duration of the reference audio in seconds.
197 pub fn duration_secs(&self) -> f32 {
198 if self.sample_rate == 0 {
199 return 0.0;
200 }
201 self.samples.len() as f32 / self.sample_rate as f32
202 }
203
204 /// Whether the audio is empty.
205 pub fn is_empty(&self) -> bool {
206 self.samples.is_empty()
207 }
208}
209
210/// An extracted voice embedding that can be saved, loaded, and reused.
211///
212/// Voice embeddings are model-specific opaque data. An embedding extracted
213/// from one model type cannot be used with a different model type.
214///
215/// # Persistence
216///
217/// Embeddings can be saved to and loaded from JSON files:
218///
219/// ```rust,ignore
220/// // Extract once
221/// let embedding = model.extract_voice(&reference)?;
222/// embedding.save("my_voice.json")?;
223///
224/// // Reuse later
225/// let embedding = VoiceEmbedding::load("my_voice.json")?;
226/// let request = SynthesisRequest::new("Hello!")
227/// .with_voice_embedding(embedding);
228/// ```
229#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
230pub struct VoiceEmbedding {
231 /// Raw embedding data as flattened f32 values.
232 data: Vec<f32>,
233 /// Shape of the embedding tensor.
234 shape: Vec<usize>,
235 /// Model type identifier (e.g. "kokoro", "qwen3-tts").
236 model_type: String,
237}
238
239impl VoiceEmbedding {
240 /// Create a new voice embedding from raw data.
241 pub fn new(data: Vec<f32>, shape: Vec<usize>, model_type: impl Into<String>) -> Self {
242 Self {
243 data,
244 shape,
245 model_type: model_type.into(),
246 }
247 }
248
249 /// Reconstruct the embedding as a candle [`Tensor`](candle_core::Tensor).
250 pub fn to_tensor(&self, device: &candle_core::Device) -> Result<candle_core::Tensor, TtsError> {
251 candle_core::Tensor::new(self.data.as_slice(), device)?
252 .reshape(self.shape.as_slice())
253 .map_err(TtsError::from)
254 }
255
256 /// The model type this embedding was extracted for.
257 pub fn model_type(&self) -> &str {
258 &self.model_type
259 }
260
261 /// Shape of the embedding tensor.
262 pub fn shape(&self) -> &[usize] {
263 &self.shape
264 }
265
266 /// Save the embedding to a JSON file.
267 pub fn save(&self, path: impl AsRef<Path>) -> Result<(), TtsError> {
268 let json = serde_json::to_string_pretty(self)?;
269 std::fs::write(path, json)?;
270 Ok(())
271 }
272
273 /// Load an embedding from a JSON file.
274 pub fn load(path: impl AsRef<Path>) -> Result<Self, TtsError> {
275 let json = std::fs::read_to_string(path)?;
276 let embedding: Self = serde_json::from_str(&json)?;
277 Ok(embedding)
278 }
279}
280
281/// Trait for TTS models that support voice cloning from reference audio.
282///
283/// Not all models support voice cloning. Check [`VoiceCloning::supports_voice_cloning`]
284/// before calling other methods.
285///
286/// # Example
287///
288/// ```rust,ignore
289/// use any_tts::{VoiceCloning, ReferenceAudio, SynthesisRequest};
290///
291/// if model.supports_voice_cloning() {
292/// let reference = ReferenceAudio::new(samples, 24000);
293/// let embedding = model.extract_voice(&reference)?;
294/// let request = SynthesisRequest::new("Hello in the cloned voice!")
295/// .with_voice_embedding(embedding);
296/// let audio = model.synthesize(&request)?;
297/// }
298/// ```
299pub trait VoiceCloning: TtsModel {
300 /// Whether voice cloning is currently available.
301 ///
302 /// Returns `false` if the model doesn't support voice cloning or if
303 /// the required encoder weights were not found during loading.
304 fn supports_voice_cloning(&self) -> bool;
305
306 /// Extract a reusable voice embedding from reference audio.
307 ///
308 /// The returned embedding encodes the speaker's voice characteristics
309 /// and can be saved to disk for later use.
310 fn extract_voice(&self, audio: &ReferenceAudio) -> Result<VoiceEmbedding, TtsError>;
311
312 /// Synthesize speech conditioned on a pre-extracted voice embedding.
313 fn synthesize_with_voice(
314 &self,
315 request: &SynthesisRequest,
316 voice: &VoiceEmbedding,
317 ) -> Result<AudioSamples, TtsError>;
318}