Skip to main content

tauri_plugin_tts/
models.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3use ts_rs::TS;
4
5/// Maximum text length in bytes (10KB)
6pub const MAX_TEXT_LENGTH: usize = 10_000;
7/// Maximum voice ID length
8pub const MAX_VOICE_ID_LENGTH: usize = 256;
9/// Maximum language code length
10pub const MAX_LANGUAGE_LENGTH: usize = 35;
11
12#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq, TS)]
13#[ts(export, export_to = "../guest-js/bindings/")]
14#[serde(rename_all = "lowercase")]
15pub enum QueueMode {
16    /// Flush any pending speech and start speaking immediately (default)
17    #[default]
18    Flush,
19    /// Add to queue and speak after current speech finishes
20    Add,
21}
22
23#[derive(Debug, Clone, Deserialize, Serialize, TS)]
24#[ts(export, export_to = "../guest-js/bindings/")]
25#[serde(rename_all = "camelCase")]
26pub struct SpeakOptions {
27    /// The text to speak (max 10,000 characters)
28    pub text: String,
29    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub language: Option<String>,
32    /// Specific voice ID to use (from getVoices). Takes priority over language
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub voice_id: Option<String>,
35    /// Speech rate (0.1 to 4.0, where 1.0 = normal)
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub rate: Option<f32>,
38    /// Pitch (0.5 to 2.0, where 1.0 = normal)
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub pitch: Option<f32>,
41    /// Volume (0.0 to 1.0, where 1.0 = full volume)
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub volume: Option<f32>,
44    /// Queue mode: "flush" (default) or "add"
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub queue_mode: Option<QueueMode>,
47}
48
49#[derive(Debug, Clone, Deserialize, Serialize, TS)]
50#[ts(export, export_to = "../guest-js/bindings/")]
51#[serde(rename_all = "camelCase")]
52pub struct PreviewVoiceOptions {
53    /// Voice ID to preview
54    pub voice_id: String,
55    /// Optional custom sample text (uses default if not provided)
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub text: Option<String>,
58}
59
60#[derive(Debug, Deserialize, Serialize)]
61#[serde(rename_all = "camelCase")]
62pub struct SpeakRequest {
63    /// The text to speak
64    pub text: String,
65    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
66    #[serde(default)]
67    pub language: Option<String>,
68    /// Voice ID to use (from getVoices)
69    #[serde(default)]
70    pub voice_id: Option<String>,
71    /// Speech rate (0.1 to 4.0, where 1.0 = normal, 2.0 = double, 0.5 = half)
72    #[serde(default = "default_rate")]
73    pub rate: f32,
74    /// Pitch (0.5 = low, 1.0 = normal, 2.0 = high)
75    #[serde(default = "default_pitch")]
76    pub pitch: f32,
77    /// Volume (0.0 = silent, 1.0 = full volume)
78    #[serde(default = "default_volume")]
79    pub volume: f32,
80    /// Queue mode: "flush" (default) or "add"
81    #[serde(default)]
82    pub queue_mode: QueueMode,
83}
84
85fn default_rate() -> f32 {
86    1.0
87}
88fn default_pitch() -> f32 {
89    1.0
90}
91fn default_volume() -> f32 {
92    1.0
93}
94
95#[derive(Debug, Clone, thiserror::Error)]
96pub enum ValidationError {
97    #[error("Text cannot be empty")]
98    EmptyText,
99    #[error("Text too long: {len} bytes (max: {max})")]
100    TextTooLong { len: usize, max: usize },
101    #[error("Voice ID too long: {len} chars (max: {max})")]
102    VoiceIdTooLong { len: usize, max: usize },
103    #[error("Invalid voice ID format - only alphanumeric, dots, underscores, and hyphens allowed")]
104    InvalidVoiceId,
105    #[error("Language code too long: {len} chars (max: {max})")]
106    LanguageTooLong { len: usize, max: usize },
107}
108
109#[derive(Debug, Clone)]
110pub struct ValidatedSpeakRequest {
111    pub text: String,
112    pub language: Option<String>,
113    pub voice_id: Option<String>,
114    pub rate: f32,
115    pub pitch: f32,
116    pub volume: f32,
117    pub queue_mode: QueueMode,
118}
119
120impl SpeakRequest {
121    pub fn validate(&self) -> Result<ValidatedSpeakRequest, ValidationError> {
122        // Text validation
123        if self.text.is_empty() {
124            return Err(ValidationError::EmptyText);
125        }
126        if self.text.len() > MAX_TEXT_LENGTH {
127            return Err(ValidationError::TextTooLong {
128                len: self.text.len(),
129                max: MAX_TEXT_LENGTH,
130            });
131        }
132
133        // Language validation (if provided)
134        let sanitized_language = self
135            .language
136            .as_ref()
137            .map(|lang| Self::validate_language(lang))
138            .transpose()?;
139
140        // Voice ID validation (if provided)
141        if let Some(ref voice_id) = self.voice_id {
142            validate_voice_id(voice_id)?;
143        }
144
145        Ok(ValidatedSpeakRequest {
146            text: self.text.clone(),
147            language: sanitized_language,
148            voice_id: self.voice_id.clone(),
149            rate: self.rate.clamp(0.1, 4.0),
150            pitch: self.pitch.clamp(0.5, 2.0),
151            volume: self.volume.clamp(0.0, 1.0),
152            queue_mode: self.queue_mode,
153        })
154    }
155
156    fn validate_language(lang: &str) -> Result<String, ValidationError> {
157        if lang.len() > MAX_LANGUAGE_LENGTH {
158            return Err(ValidationError::LanguageTooLong {
159                len: lang.len(),
160                max: MAX_LANGUAGE_LENGTH,
161            });
162        }
163        Ok(lang.to_string())
164    }
165}
166
167/// Shared voice ID validation: only alphanumeric, dots, underscores, and hyphens allowed.
168/// Matches the validation in iOS (CharacterSet) and Android (VOICE_ID_PATTERN).
169fn validate_voice_id(voice_id: &str) -> Result<(), ValidationError> {
170    if voice_id.len() > MAX_VOICE_ID_LENGTH {
171        return Err(ValidationError::VoiceIdTooLong {
172            len: voice_id.len(),
173            max: MAX_VOICE_ID_LENGTH,
174        });
175    }
176    if !voice_id
177        .chars()
178        .all(|c| c.is_alphanumeric() || c == '.' || c == '_' || c == '-')
179    {
180        return Err(ValidationError::InvalidVoiceId);
181    }
182    Ok(())
183}
184
185#[derive(Debug, Clone, Default, Deserialize, Serialize)]
186#[serde(rename_all = "camelCase")]
187pub struct SpeakResponse {
188    /// Whether speech was successfully initiated
189    pub success: bool,
190    /// Optional warning message (e.g., voice not found, using fallback)
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub warning: Option<String>,
193}
194
195#[derive(Debug, Clone, Default, Deserialize, Serialize)]
196#[serde(rename_all = "camelCase")]
197pub struct StopResponse {
198    pub success: bool,
199}
200
201#[derive(Debug, Deserialize, Serialize)]
202#[serde(rename_all = "camelCase")]
203pub struct SetBackgroundBehaviorRequest {
204    /// Whether TTS should continue speaking when the app goes to background / screen locks.
205    /// Defaults to true. When false, speech is paused on background and a `speech:backgroundPause`
206    /// event is emitted (matching the previous behavior). Desktop: ignored (no-op).
207    pub continue_in_background: bool,
208}
209
210#[derive(Debug, Clone, Default, Deserialize, Serialize)]
211#[serde(rename_all = "camelCase")]
212pub struct SetBackgroundBehaviorResponse {
213    pub success: bool,
214}
215
216#[derive(Debug, Clone, Deserialize, Serialize, TS)]
217#[ts(export, export_to = "../guest-js/bindings/")]
218#[serde(rename_all = "camelCase")]
219pub struct Voice {
220    /// Unique identifier for the voice
221    pub id: String,
222    /// Display name of the voice
223    pub name: String,
224    /// Language code (e.g., "en-US")
225    pub language: String,
226}
227
228#[derive(Debug, Deserialize, Serialize)]
229#[serde(rename_all = "camelCase")]
230pub struct GetVoicesRequest {
231    /// Optional language filter
232    #[serde(default)]
233    pub language: Option<String>,
234}
235
236#[derive(Debug, Clone, Default, Deserialize, Serialize)]
237#[serde(rename_all = "camelCase")]
238pub struct GetVoicesResponse {
239    pub voices: Vec<Voice>,
240}
241
242#[derive(Debug, Clone, Default, Deserialize, Serialize)]
243#[serde(rename_all = "camelCase")]
244pub struct IsSpeakingResponse {
245    pub speaking: bool,
246}
247
248#[derive(Debug, Clone, Default, Deserialize, Serialize)]
249#[serde(rename_all = "camelCase")]
250pub struct IsInitializedResponse {
251    /// Whether the TTS engine is initialized and ready
252    pub initialized: bool,
253    /// Number of available voices (0 if not initialized)
254    pub voice_count: u32,
255}
256
257#[derive(Debug, Clone, Default, Deserialize, Serialize, TS)]
258#[ts(export, export_to = "../guest-js/bindings/")]
259#[serde(rename_all = "camelCase")]
260pub struct PauseResumeResponse {
261    pub success: bool,
262    /// Reason for failure (if success is false)
263    #[serde(skip_serializing_if = "Option::is_none")]
264    pub reason: Option<String>,
265}
266
267#[derive(Debug, Deserialize, Serialize)]
268#[serde(rename_all = "camelCase")]
269pub struct PreviewVoiceRequest {
270    /// Voice ID to preview
271    pub voice_id: String,
272    /// Optional custom sample text (uses default if not provided)
273    #[serde(default)]
274    pub text: Option<String>,
275}
276
277impl PreviewVoiceRequest {
278    pub const DEFAULT_SAMPLE_TEXT: &'static str =
279        "Hello! This is a sample of how this voice sounds.";
280
281    pub fn sample_text(&self) -> Cow<'_, str> {
282        match &self.text {
283            Some(text) => Cow::Borrowed(text.as_str()),
284            None => Cow::Borrowed(Self::DEFAULT_SAMPLE_TEXT),
285        }
286    }
287
288    pub fn validate(&self) -> Result<(), ValidationError> {
289        // Validate voice ID
290        validate_voice_id(&self.voice_id)?;
291
292        // Validate custom text if provided
293        if let Some(ref text) = self.text {
294            if text.is_empty() {
295                return Err(ValidationError::EmptyText);
296            }
297            if text.len() > MAX_TEXT_LENGTH {
298                return Err(ValidationError::TextTooLong {
299                    len: text.len(),
300                    max: MAX_TEXT_LENGTH,
301                });
302            }
303        }
304
305        Ok(())
306    }
307}
308
309/// On desktop, emitted directly via `app.emit("tts://<event_type>", payload)`.
310/// On mobile, native plugins send this through a Tauri `Channel`; the Rust relay
311/// deserializes it and re-emits via `app.emit()` so JS `listen("tts://...")` works
312/// uniformly on every platform.
313///
314/// The shape matches the JS `SpeechEvent` interface.
315#[derive(Debug, Clone, Default, Deserialize, Serialize)]
316#[serde(rename_all = "camelCase")]
317pub struct TtsEventPayload {
318    /// The event name, e.g. "speech:finish". Used to build the emit key "tts://<event_type>".
319    pub event_type: String,
320    /// Unique identifier for the utterance (if available)
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub id: Option<String>,
323    /// Error message (for error events)
324    #[serde(skip_serializing_if = "Option::is_none")]
325    pub error: Option<String>,
326    /// Whether speech was interrupted
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub interrupted: Option<bool>,
329    /// Reason for the event (e.g. "audio_focus_lost", "app_paused")
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub reason: Option<String>,
332}
333
334#[cfg(test)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_speak_request_defaults() {
340        let json = r#"{"text": "Hello world"}"#;
341        let request: SpeakRequest = serde_json::from_str(json).unwrap();
342
343        assert_eq!(request.text, "Hello world");
344        assert!(request.language.is_none());
345        assert!(request.voice_id.is_none());
346        assert_eq!(request.rate, 1.0);
347        assert_eq!(request.pitch, 1.0);
348        assert_eq!(request.volume, 1.0);
349    }
350
351    #[test]
352    fn test_speak_request_full() {
353        let json = r#"{
354            "text": "Olá",
355            "language": "pt-BR",
356            "voiceId": "com.apple.voice.enhanced.pt-BR",
357            "rate": 0.8,
358            "pitch": 1.2,
359            "volume": 0.9
360        }"#;
361
362        let request: SpeakRequest = serde_json::from_str(json).unwrap();
363        assert_eq!(request.text, "Olá");
364        assert_eq!(request.language, Some("pt-BR".to_string()));
365        assert_eq!(
366            request.voice_id,
367            Some("com.apple.voice.enhanced.pt-BR".to_string())
368        );
369        assert_eq!(request.rate, 0.8);
370        assert_eq!(request.pitch, 1.2);
371        assert_eq!(request.volume, 0.9);
372    }
373
374    #[test]
375    fn test_voice_serialization() {
376        let voice = Voice {
377            id: "test-voice".to_string(),
378            name: "Test Voice".to_string(),
379            language: "en-US".to_string(),
380        };
381
382        let json = serde_json::to_string(&voice).unwrap();
383        assert!(json.contains("\"id\":\"test-voice\""));
384        assert!(json.contains("\"name\":\"Test Voice\""));
385        assert!(json.contains("\"language\":\"en-US\""));
386    }
387
388    #[test]
389    fn test_get_voices_request_optional_language() {
390        let json1 = r#"{}"#;
391        let request1: GetVoicesRequest = serde_json::from_str(json1).unwrap();
392        assert!(request1.language.is_none());
393
394        let json2 = r#"{"language": "en"}"#;
395        let request2: GetVoicesRequest = serde_json::from_str(json2).unwrap();
396        assert_eq!(request2.language, Some("en".to_string()));
397    }
398
399    #[test]
400    fn test_validation_empty_text() {
401        let request = SpeakRequest {
402            text: "".to_string(),
403            language: None,
404            voice_id: None,
405            rate: 1.0,
406            pitch: 1.0,
407            volume: 1.0,
408            queue_mode: QueueMode::Flush,
409        };
410
411        let result = request.validate();
412        assert!(result.is_err());
413        assert!(matches!(result.unwrap_err(), ValidationError::EmptyText));
414    }
415
416    #[test]
417    fn test_validation_text_too_long() {
418        let long_text = "x".repeat(MAX_TEXT_LENGTH + 1);
419        let request = SpeakRequest {
420            text: long_text,
421            language: None,
422            voice_id: None,
423            rate: 1.0,
424            pitch: 1.0,
425            volume: 1.0,
426            queue_mode: QueueMode::Flush,
427        };
428
429        let result = request.validate();
430        assert!(result.is_err());
431        assert!(matches!(
432            result.unwrap_err(),
433            ValidationError::TextTooLong { .. }
434        ));
435    }
436
437    #[test]
438    fn test_validation_valid_voice_id() {
439        let request = SpeakRequest {
440            text: "Hello".to_string(),
441            language: None,
442            voice_id: Some("com.apple.voice.enhanced.en-US".to_string()),
443            rate: 1.0,
444            pitch: 1.0,
445            volume: 1.0,
446            queue_mode: QueueMode::Flush,
447        };
448
449        let result = request.validate();
450        assert!(result.is_ok());
451        assert_eq!(
452            result.unwrap().voice_id,
453            Some("com.apple.voice.enhanced.en-US".to_string())
454        );
455    }
456
457    #[test]
458    fn test_validation_voice_id_too_long() {
459        let long_voice_id = "x".repeat(MAX_VOICE_ID_LENGTH + 1);
460        let request = SpeakRequest {
461            text: "Hello".to_string(),
462            language: None,
463            voice_id: Some(long_voice_id),
464            rate: 1.0,
465            pitch: 1.0,
466            volume: 1.0,
467            queue_mode: QueueMode::Flush,
468        };
469
470        let result = request.validate();
471        assert!(result.is_err());
472        assert!(matches!(
473            result.unwrap_err(),
474            ValidationError::VoiceIdTooLong { .. }
475        ));
476    }
477
478    #[test]
479    fn test_validation_rate_clamping() {
480        let request = SpeakRequest {
481            text: "Hello".to_string(),
482            language: None,
483            voice_id: None,
484            rate: 999.0,
485            pitch: 1.0,
486            volume: 1.0,
487            queue_mode: QueueMode::Flush,
488        };
489
490        let result = request.validate();
491        assert!(result.is_ok());
492        let validated = result.unwrap();
493        assert_eq!(validated.rate, 4.0); // Clamped to max
494    }
495
496    #[test]
497    fn test_validation_pitch_clamping() {
498        let request = SpeakRequest {
499            text: "Hello".to_string(),
500            language: None,
501            voice_id: None,
502            rate: 1.0,
503            pitch: 0.1,
504            volume: 1.0,
505            queue_mode: QueueMode::Flush,
506        };
507
508        let result = request.validate();
509        assert!(result.is_ok());
510        let validated = result.unwrap();
511        assert_eq!(validated.pitch, 0.5); // Clamped to min
512    }
513
514    #[test]
515    fn test_validation_volume_clamping() {
516        let request = SpeakRequest {
517            text: "Hello".to_string(),
518            language: None,
519            voice_id: None,
520            rate: 1.0,
521            pitch: 1.0,
522            volume: 5.0,
523            queue_mode: QueueMode::Flush,
524        };
525
526        let result = request.validate();
527        assert!(result.is_ok());
528        let validated = result.unwrap();
529        assert_eq!(validated.volume, 1.0); // Clamped to max
530    }
531
532    #[test]
533    fn test_preview_voice_validation() {
534        // Valid preview
535        let valid = PreviewVoiceRequest {
536            voice_id: "valid-voice_123".to_string(),
537            text: None,
538        };
539        assert!(valid.validate().is_ok());
540
541        // Invalid voice_id
542        let invalid = PreviewVoiceRequest {
543            voice_id: "invalid<script>".to_string(),
544            text: None,
545        };
546        assert!(invalid.validate().is_err());
547    }
548
549    #[test]
550    fn test_preview_voice_sample_text() {
551        let without_text = PreviewVoiceRequest {
552            voice_id: "voice".to_string(),
553            text: None,
554        };
555        assert_eq!(
556            without_text.sample_text(),
557            PreviewVoiceRequest::DEFAULT_SAMPLE_TEXT
558        );
559
560        let with_text = PreviewVoiceRequest {
561            voice_id: "voice".to_string(),
562            text: Some("Custom sample".to_string()),
563        };
564        assert_eq!(with_text.sample_text(), "Custom sample");
565    }
566}