Skip to main content

any_tts/
traits.rs

1//! Core TTS trait and request/response types.
2
3use std::path::Path;
4
5use crate::audio::AudioSamples;
6use crate::config::TtsConfig;
7use crate::error::TtsError;
8
9/// Metadata about a loaded model.
10#[derive(Debug, Clone)]
11pub struct ModelInfo {
12    /// Human-readable model name.
13    pub name: String,
14    /// Model version or variant.
15    pub variant: String,
16    /// Approximate parameter count.
17    pub parameters: u64,
18    /// Output audio sample rate in Hz.
19    pub sample_rate: u32,
20    /// Supported languages (ISO 639-1 codes or language names).
21    pub languages: Vec<String>,
22    /// Available voice/speaker names.
23    pub voices: Vec<String>,
24}
25
26/// A request to synthesize speech from text.
27#[derive(Debug, Clone)]
28pub struct SynthesisRequest {
29    /// The text to synthesize.
30    pub text: String,
31    /// Target language (ISO code or name). If `None`, auto-detect.
32    pub language: Option<String>,
33    /// Voice/speaker name. If `None`, use default.
34    pub voice: Option<String>,
35    /// Style instruction (model-dependent, e.g. "angry", "whisper").
36    pub instruct: Option<String>,
37    /// Maximum number of tokens to generate. If `None`, use model default.
38    pub max_tokens: Option<usize>,
39    /// Sampling temperature. If `None`, use model default.
40    pub temperature: Option<f64>,
41    /// Guidance scale for backends that support classifier-free guidance.
42    pub cfg_scale: Option<f64>,
43    /// Reference audio for zero-shot voice cloning.
44    ///
45    /// When provided, the model will attempt to match the voice characteristics
46    /// of this audio instead of using a named voice. Not all models support this;
47    /// models that don't will return [`TtsError::ModelError`].
48    pub reference_audio: Option<ReferenceAudio>,
49    /// Pre-extracted voice embedding for voice cloning.
50    ///
51    /// Use this to re-use an embedding extracted via [`VoiceCloning::extract_voice`]
52    /// without re-processing the reference audio each time.
53    pub voice_embedding: Option<VoiceEmbedding>,
54}
55
56impl SynthesisRequest {
57    /// Create a new synthesis request with the given text.
58    pub fn new(text: impl Into<String>) -> Self {
59        Self {
60            text: text.into(),
61            language: None,
62            voice: None,
63            instruct: None,
64            max_tokens: None,
65            temperature: None,
66            cfg_scale: None,
67            reference_audio: None,
68            voice_embedding: None,
69        }
70    }
71
72    /// Set the target language.
73    pub fn with_language(mut self, lang: impl Into<String>) -> Self {
74        self.language = Some(lang.into());
75        self
76    }
77
78    /// Set the voice/speaker.
79    pub fn with_voice(mut self, voice: impl Into<String>) -> Self {
80        self.voice = Some(voice.into());
81        self
82    }
83
84    /// Set a style instruction.
85    pub fn with_instruct(mut self, instruct: impl Into<String>) -> Self {
86        self.instruct = Some(instruct.into());
87        self
88    }
89
90    /// Set the maximum number of tokens.
91    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
92        self.max_tokens = Some(max_tokens);
93        self
94    }
95
96    /// Set the sampling temperature.
97    pub fn with_temperature(mut self, temperature: f64) -> Self {
98        self.temperature = Some(temperature);
99        self
100    }
101
102    /// Set the classifier-free guidance scale.
103    pub fn with_cfg_scale(mut self, cfg_scale: f64) -> Self {
104        self.cfg_scale = Some(cfg_scale);
105        self
106    }
107
108    /// Set reference audio for zero-shot voice cloning.
109    ///
110    /// The model will extract speaker characteristics from this audio and
111    /// use them to condition the synthesis. Overrides any named voice.
112    pub fn with_reference_audio(mut self, audio: ReferenceAudio) -> Self {
113        self.reference_audio = Some(audio);
114        self
115    }
116
117    /// Set a pre-extracted voice embedding.
118    ///
119    /// Useful for caching: extract once with [`VoiceCloning::extract_voice`],
120    /// then reuse across multiple synthesis calls.
121    pub fn with_voice_embedding(mut self, embedding: VoiceEmbedding) -> Self {
122        self.voice_embedding = Some(embedding);
123        self
124    }
125}
126
127/// The core trait that all TTS model backends must implement.
128///
129/// This provides a unified interface for text-to-speech synthesis regardless
130/// of the underlying model architecture.
131pub trait TtsModel: Send + Sync {
132    /// Load model weights and initialize the model from configuration.
133    ///
134    /// This may involve reading safetensors files, parsing config JSON,
135    /// and moving weights to the target device.
136    fn load(config: TtsConfig) -> Result<Self, TtsError>
137    where
138        Self: Sized;
139
140    /// Synthesize speech from a text request.
141    ///
142    /// Returns raw f32 PCM audio samples at the model's native sample rate.
143    fn synthesize(&self, request: &SynthesisRequest) -> Result<AudioSamples, TtsError>;
144
145    /// Return the native output sample rate in Hz (e.g. 24000).
146    fn sample_rate(&self) -> u32;
147
148    /// Return the list of supported language identifiers.
149    fn supported_languages(&self) -> Vec<String>;
150
151    /// Return the list of available voice/speaker names.
152    fn supported_voices(&self) -> Vec<String>;
153
154    /// Return metadata about this model.
155    fn model_info(&self) -> ModelInfo;
156}
157
158// ---------------------------------------------------------------------------
159// Voice cloning types
160// ---------------------------------------------------------------------------
161
162/// Raw audio data used as a reference for voice cloning.
163///
164/// The model will extract speaker characteristics from this audio and use
165/// them to condition speech synthesis. For best results:
166///
167/// - Use 3–10 seconds of clean speech (single speaker, no background noise)
168/// - Match the model's native sample rate (e.g. 24 kHz for Kokoro)
169///   or the library will resample automatically
170///
171/// # Example
172///
173/// ```rust
174/// use any_tts::ReferenceAudio;
175///
176/// let audio = ReferenceAudio::new(vec![0.0f32; 24000], 24000);
177/// assert_eq!(audio.duration_secs(), 1.0);
178/// ```
179#[derive(Debug, Clone)]
180pub struct ReferenceAudio {
181    /// Raw f32 PCM audio samples in `[-1.0, 1.0]`.
182    pub samples: Vec<f32>,
183    /// Sample rate of the audio in Hz.
184    pub sample_rate: u32,
185}
186
187impl ReferenceAudio {
188    /// Create a new reference audio from raw samples.
189    pub fn new(samples: Vec<f32>, sample_rate: u32) -> Self {
190        Self {
191            samples,
192            sample_rate,
193        }
194    }
195
196    /// Duration of the reference audio in seconds.
197    pub fn duration_secs(&self) -> f32 {
198        if self.sample_rate == 0 {
199            return 0.0;
200        }
201        self.samples.len() as f32 / self.sample_rate as f32
202    }
203
204    /// Whether the audio is empty.
205    pub fn is_empty(&self) -> bool {
206        self.samples.is_empty()
207    }
208}
209
210/// An extracted voice embedding that can be saved, loaded, and reused.
211///
212/// Voice embeddings are model-specific opaque data. An embedding extracted
213/// from one model type cannot be used with a different model type.
214///
215/// # Persistence
216///
217/// Embeddings can be saved to and loaded from JSON files:
218///
219/// ```rust,ignore
220/// // Extract once
221/// let embedding = model.extract_voice(&reference)?;
222/// embedding.save("my_voice.json")?;
223///
224/// // Reuse later
225/// let embedding = VoiceEmbedding::load("my_voice.json")?;
226/// let request = SynthesisRequest::new("Hello!")
227///     .with_voice_embedding(embedding);
228/// ```
229#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
230pub struct VoiceEmbedding {
231    /// Raw embedding data as flattened f32 values.
232    data: Vec<f32>,
233    /// Shape of the embedding tensor.
234    shape: Vec<usize>,
235    /// Model type identifier (e.g. "kokoro", "qwen3-tts").
236    model_type: String,
237}
238
239impl VoiceEmbedding {
240    /// Create a new voice embedding from raw data.
241    pub fn new(data: Vec<f32>, shape: Vec<usize>, model_type: impl Into<String>) -> Self {
242        Self {
243            data,
244            shape,
245            model_type: model_type.into(),
246        }
247    }
248
249    /// Reconstruct the embedding as a candle [`Tensor`](candle_core::Tensor).
250    pub fn to_tensor(&self, device: &candle_core::Device) -> Result<candle_core::Tensor, TtsError> {
251        candle_core::Tensor::new(self.data.as_slice(), device)?
252            .reshape(self.shape.as_slice())
253            .map_err(TtsError::from)
254    }
255
256    /// The model type this embedding was extracted for.
257    pub fn model_type(&self) -> &str {
258        &self.model_type
259    }
260
261    /// Shape of the embedding tensor.
262    pub fn shape(&self) -> &[usize] {
263        &self.shape
264    }
265
266    /// Save the embedding to a JSON file.
267    pub fn save(&self, path: impl AsRef<Path>) -> Result<(), TtsError> {
268        let json = serde_json::to_string_pretty(self)?;
269        std::fs::write(path, json)?;
270        Ok(())
271    }
272
273    /// Load an embedding from a JSON file.
274    pub fn load(path: impl AsRef<Path>) -> Result<Self, TtsError> {
275        let json = std::fs::read_to_string(path)?;
276        let embedding: Self = serde_json::from_str(&json)?;
277        Ok(embedding)
278    }
279}
280
281/// Trait for TTS models that support voice cloning from reference audio.
282///
283/// Not all models support voice cloning. Check [`VoiceCloning::supports_voice_cloning`]
284/// before calling other methods.
285///
286/// # Example
287///
288/// ```rust,ignore
289/// use any_tts::{VoiceCloning, ReferenceAudio, SynthesisRequest};
290///
291/// if model.supports_voice_cloning() {
292///     let reference = ReferenceAudio::new(samples, 24000);
293///     let embedding = model.extract_voice(&reference)?;
294///     let request = SynthesisRequest::new("Hello in the cloned voice!")
295///         .with_voice_embedding(embedding);
296///     let audio = model.synthesize(&request)?;
297/// }
298/// ```
299pub trait VoiceCloning: TtsModel {
300    /// Whether voice cloning is currently available.
301    ///
302    /// Returns `false` if the model doesn't support voice cloning or if
303    /// the required encoder weights were not found during loading.
304    fn supports_voice_cloning(&self) -> bool;
305
306    /// Extract a reusable voice embedding from reference audio.
307    ///
308    /// The returned embedding encodes the speaker's voice characteristics
309    /// and can be saved to disk for later use.
310    fn extract_voice(&self, audio: &ReferenceAudio) -> Result<VoiceEmbedding, TtsError>;
311
312    /// Synthesize speech conditioned on a pre-extracted voice embedding.
313    fn synthesize_with_voice(
314        &self,
315        request: &SynthesisRequest,
316        voice: &VoiceEmbedding,
317    ) -> Result<AudioSamples, TtsError>;
318}