Skip to main content

tauri_plugin_tts/
models.rs

1use serde::{Deserialize, Serialize};
2use std::borrow::Cow;
3#[cfg(feature = "ts-bindings")]
4use ts_rs::TS;
5
6/// Maximum text length in bytes (10KB)
7pub const MAX_TEXT_LENGTH: usize = 10_000;
8/// Maximum voice ID length
9pub const MAX_VOICE_ID_LENGTH: usize = 256;
10/// Maximum language code length
11pub const MAX_LANGUAGE_LENGTH: usize = 35;
12
13#[cfg_attr(feature = "ts-bindings", derive(TS))]
14#[cfg_attr(
15    feature = "ts-bindings",
16    ts(export, export_to = "../guest-js/bindings/")
17)]
18#[derive(Debug, Clone, Copy, Default, Deserialize, Serialize, PartialEq, Eq)]
19#[serde(rename_all = "lowercase")]
20pub enum QueueMode {
21    /// Flush any pending speech and start speaking immediately (default)
22    #[default]
23    Flush,
24    /// Add to queue and speak after current speech finishes
25    Add,
26}
27
28#[cfg_attr(feature = "ts-bindings", derive(TS))]
29#[cfg_attr(
30    feature = "ts-bindings",
31    ts(export, export_to = "../guest-js/bindings/")
32)]
33#[derive(Debug, Clone, Deserialize, Serialize)]
34#[serde(rename_all = "camelCase")]
35pub struct SpeakOptions {
36    /// The text to speak (max 10,000 characters)
37    pub text: String,
38    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub language: Option<String>,
41    /// Specific voice ID to use (from getVoices). Takes priority over language
42    #[serde(skip_serializing_if = "Option::is_none")]
43    pub voice_id: Option<String>,
44    /// Speech rate (0.1 to 4.0, where 1.0 = normal)
45    #[serde(skip_serializing_if = "Option::is_none")]
46    pub rate: Option<f32>,
47    /// Pitch (0.5 to 2.0, where 1.0 = normal)
48    #[serde(skip_serializing_if = "Option::is_none")]
49    pub pitch: Option<f32>,
50    /// Volume (0.0 to 1.0, where 1.0 = full volume)
51    #[serde(skip_serializing_if = "Option::is_none")]
52    pub volume: Option<f32>,
53    /// Queue mode: "flush" (default) or "add"
54    #[serde(skip_serializing_if = "Option::is_none")]
55    pub queue_mode: Option<QueueMode>,
56}
57
58#[cfg_attr(feature = "ts-bindings", derive(TS))]
59#[cfg_attr(
60    feature = "ts-bindings",
61    ts(export, export_to = "../guest-js/bindings/")
62)]
63#[derive(Debug, Clone, Deserialize, Serialize)]
64#[serde(rename_all = "camelCase")]
65pub struct PreviewVoiceOptions {
66    /// Voice ID to preview
67    pub voice_id: String,
68    /// Optional custom sample text (uses default if not provided)
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub text: Option<String>,
71}
72
73#[derive(Debug, Deserialize, Serialize)]
74#[serde(rename_all = "camelCase")]
75pub struct SpeakRequest {
76    /// The text to speak
77    pub text: String,
78    /// The language/locale code (e.g., "en-US", "pt-BR", "ja-JP")
79    #[serde(default)]
80    pub language: Option<String>,
81    /// Voice ID to use (from getVoices)
82    #[serde(default)]
83    pub voice_id: Option<String>,
84    /// Speech rate (0.1 to 4.0, where 1.0 = normal, 2.0 = double, 0.5 = half)
85    #[serde(default = "default_rate")]
86    pub rate: f32,
87    /// Pitch (0.5 = low, 1.0 = normal, 2.0 = high)
88    #[serde(default = "default_pitch")]
89    pub pitch: f32,
90    /// Volume (0.0 = silent, 1.0 = full volume)
91    #[serde(default = "default_volume")]
92    pub volume: f32,
93    /// Queue mode: "flush" (default) or "add"
94    #[serde(default)]
95    pub queue_mode: QueueMode,
96}
97
98fn default_rate() -> f32 {
99    1.0
100}
101fn default_pitch() -> f32 {
102    1.0
103}
104fn default_volume() -> f32 {
105    1.0
106}
107
108#[derive(Debug, Clone, thiserror::Error)]
109pub enum ValidationError {
110    #[error("Text cannot be empty")]
111    EmptyText,
112    #[error("Text too long: {len} bytes (max: {max})")]
113    TextTooLong { len: usize, max: usize },
114    #[error("Voice ID too long: {len} chars (max: {max})")]
115    VoiceIdTooLong { len: usize, max: usize },
116    #[error("Invalid voice ID format - only alphanumeric, dots, underscores, and hyphens allowed")]
117    InvalidVoiceId,
118    #[error("Language code too long: {len} chars (max: {max})")]
119    LanguageTooLong { len: usize, max: usize },
120}
121
122#[derive(Debug, Clone)]
123pub struct ValidatedSpeakRequest {
124    pub text: String,
125    pub language: Option<String>,
126    pub voice_id: Option<String>,
127    pub rate: f32,
128    pub pitch: f32,
129    pub volume: f32,
130    pub queue_mode: QueueMode,
131}
132
133impl SpeakRequest {
134    pub fn validate(&self) -> Result<ValidatedSpeakRequest, ValidationError> {
135        // Text validation
136        if self.text.is_empty() {
137            return Err(ValidationError::EmptyText);
138        }
139        if self.text.len() > MAX_TEXT_LENGTH {
140            return Err(ValidationError::TextTooLong {
141                len: self.text.len(),
142                max: MAX_TEXT_LENGTH,
143            });
144        }
145
146        // Language validation (if provided)
147        let sanitized_language = self
148            .language
149            .as_ref()
150            .map(|lang| Self::validate_language(lang))
151            .transpose()?;
152
153        // Voice ID validation (if provided)
154        if let Some(ref voice_id) = self.voice_id {
155            validate_voice_id(voice_id)?;
156        }
157
158        Ok(ValidatedSpeakRequest {
159            text: self.text.clone(),
160            language: sanitized_language,
161            voice_id: self.voice_id.clone(),
162            rate: self.rate.clamp(0.1, 4.0),
163            pitch: self.pitch.clamp(0.5, 2.0),
164            volume: self.volume.clamp(0.0, 1.0),
165            queue_mode: self.queue_mode,
166        })
167    }
168
169    fn validate_language(lang: &str) -> Result<String, ValidationError> {
170        if lang.len() > MAX_LANGUAGE_LENGTH {
171            return Err(ValidationError::LanguageTooLong {
172                len: lang.len(),
173                max: MAX_LANGUAGE_LENGTH,
174            });
175        }
176        Ok(lang.to_string())
177    }
178}
179
180/// Shared voice ID validation: only alphanumeric, dots, underscores, and hyphens allowed.
181/// Matches the validation in iOS (CharacterSet) and Android (VOICE_ID_PATTERN).
182fn validate_voice_id(voice_id: &str) -> Result<(), ValidationError> {
183    if voice_id.len() > MAX_VOICE_ID_LENGTH {
184        return Err(ValidationError::VoiceIdTooLong {
185            len: voice_id.len(),
186            max: MAX_VOICE_ID_LENGTH,
187        });
188    }
189    if !voice_id
190        .chars()
191        .all(|c| c.is_alphanumeric() || c == '.' || c == '_' || c == '-')
192    {
193        return Err(ValidationError::InvalidVoiceId);
194    }
195    Ok(())
196}
197
198#[derive(Debug, Clone, Default, Deserialize, Serialize)]
199#[serde(rename_all = "camelCase")]
200pub struct SpeakResponse {
201    /// Whether speech was successfully initiated
202    pub success: bool,
203    /// Optional warning message (e.g., voice not found, using fallback)
204    #[serde(skip_serializing_if = "Option::is_none")]
205    pub warning: Option<String>,
206}
207
208#[derive(Debug, Clone, Default, Deserialize, Serialize)]
209#[serde(rename_all = "camelCase")]
210pub struct StopResponse {
211    pub success: bool,
212}
213
214#[derive(Debug, Deserialize, Serialize)]
215#[serde(rename_all = "camelCase")]
216pub struct SetBackgroundBehaviorRequest {
217    /// Whether TTS should continue speaking when the app goes to background / screen locks.
218    /// Defaults to true. When false, speech is paused on background and a `speech:backgroundPause`
219    /// event is emitted (matching the previous behavior). Desktop: ignored (no-op).
220    pub continue_in_background: bool,
221}
222
223#[derive(Debug, Clone, Default, Deserialize, Serialize)]
224#[serde(rename_all = "camelCase")]
225pub struct SetBackgroundBehaviorResponse {
226    pub success: bool,
227}
228
229#[cfg_attr(feature = "ts-bindings", derive(TS))]
230#[cfg_attr(
231    feature = "ts-bindings",
232    ts(export, export_to = "../guest-js/bindings/")
233)]
234#[derive(Debug, Clone, Deserialize, Serialize)]
235#[serde(rename_all = "camelCase")]
236pub struct Voice {
237    /// Unique identifier for the voice
238    pub id: String,
239    /// Display name of the voice
240    pub name: String,
241    /// Language code (e.g., "en-US")
242    pub language: String,
243}
244
245#[derive(Debug, Deserialize, Serialize)]
246#[serde(rename_all = "camelCase")]
247pub struct GetVoicesRequest {
248    /// Optional language filter
249    #[serde(default)]
250    pub language: Option<String>,
251}
252
253#[derive(Debug, Clone, Default, Deserialize, Serialize)]
254#[serde(rename_all = "camelCase")]
255pub struct GetVoicesResponse {
256    pub voices: Vec<Voice>,
257}
258
259#[derive(Debug, Clone, Default, Deserialize, Serialize)]
260#[serde(rename_all = "camelCase")]
261pub struct IsSpeakingResponse {
262    pub speaking: bool,
263}
264
265#[derive(Debug, Clone, Default, Deserialize, Serialize)]
266#[serde(rename_all = "camelCase")]
267pub struct IsInitializedResponse {
268    /// Whether the TTS engine is initialized and ready
269    pub initialized: bool,
270    /// Number of available voices (0 if not initialized)
271    pub voice_count: u32,
272}
273
274#[cfg_attr(feature = "ts-bindings", derive(TS))]
275#[cfg_attr(
276    feature = "ts-bindings",
277    ts(export, export_to = "../guest-js/bindings/")
278)]
279#[derive(Debug, Clone, Default, Deserialize, Serialize)]
280#[serde(rename_all = "camelCase")]
281pub struct PauseResumeResponse {
282    pub success: bool,
283    /// Reason for failure (if success is false)
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub reason: Option<String>,
286}
287
288#[derive(Debug, Deserialize, Serialize)]
289#[serde(rename_all = "camelCase")]
290pub struct PreviewVoiceRequest {
291    /// Voice ID to preview
292    pub voice_id: String,
293    /// Optional custom sample text (uses default if not provided)
294    #[serde(default)]
295    pub text: Option<String>,
296}
297
298impl PreviewVoiceRequest {
299    pub const DEFAULT_SAMPLE_TEXT: &'static str =
300        "Hello! This is a sample of how this voice sounds.";
301
302    pub fn sample_text(&self) -> Cow<'_, str> {
303        match &self.text {
304            Some(text) => Cow::Borrowed(text.as_str()),
305            None => Cow::Borrowed(Self::DEFAULT_SAMPLE_TEXT),
306        }
307    }
308
309    pub fn validate(&self) -> Result<(), ValidationError> {
310        // Validate voice ID
311        validate_voice_id(&self.voice_id)?;
312
313        // Validate custom text if provided
314        if let Some(ref text) = self.text {
315            if text.is_empty() {
316                return Err(ValidationError::EmptyText);
317            }
318            if text.len() > MAX_TEXT_LENGTH {
319                return Err(ValidationError::TextTooLong {
320                    len: text.len(),
321                    max: MAX_TEXT_LENGTH,
322                });
323            }
324        }
325
326        Ok(())
327    }
328}
329
330/// On desktop, emitted directly via `app.emit("tts://<event_type>", payload)`.
331/// On mobile, native plugins send this through a Tauri `Channel`; the Rust relay
332/// deserializes it and re-emits via `app.emit()` so JS `listen("tts://...")` works
333/// uniformly on every platform.
334///
335/// The shape matches the JS `SpeechEvent` interface.
336#[derive(Debug, Clone, Default, Deserialize, Serialize)]
337#[serde(rename_all = "camelCase")]
338pub struct TtsEventPayload {
339    /// The event name, e.g. "speech:finish". Used to build the emit key "tts://<event_type>".
340    pub event_type: String,
341    /// Unique identifier for the utterance (if available)
342    #[serde(skip_serializing_if = "Option::is_none")]
343    pub id: Option<String>,
344    /// Error message (for error events)
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub error: Option<String>,
347    /// Whether speech was interrupted
348    #[serde(skip_serializing_if = "Option::is_none")]
349    pub interrupted: Option<bool>,
350    /// Reason for the event (e.g. "audio_focus_lost", "app_paused")
351    #[serde(skip_serializing_if = "Option::is_none")]
352    pub reason: Option<String>,
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_speak_request_defaults() {
361        let json = r#"{"text": "Hello world"}"#;
362        let request: SpeakRequest = serde_json::from_str(json).unwrap();
363
364        assert_eq!(request.text, "Hello world");
365        assert!(request.language.is_none());
366        assert!(request.voice_id.is_none());
367        assert_eq!(request.rate, 1.0);
368        assert_eq!(request.pitch, 1.0);
369        assert_eq!(request.volume, 1.0);
370    }
371
372    #[test]
373    fn test_speak_request_full() {
374        let json = r#"{
375            "text": "Olá",
376            "language": "pt-BR",
377            "voiceId": "com.apple.voice.enhanced.pt-BR",
378            "rate": 0.8,
379            "pitch": 1.2,
380            "volume": 0.9
381        }"#;
382
383        let request: SpeakRequest = serde_json::from_str(json).unwrap();
384        assert_eq!(request.text, "Olá");
385        assert_eq!(request.language, Some("pt-BR".to_string()));
386        assert_eq!(
387            request.voice_id,
388            Some("com.apple.voice.enhanced.pt-BR".to_string())
389        );
390        assert_eq!(request.rate, 0.8);
391        assert_eq!(request.pitch, 1.2);
392        assert_eq!(request.volume, 0.9);
393    }
394
395    #[test]
396    fn test_voice_serialization() {
397        let voice = Voice {
398            id: "test-voice".to_string(),
399            name: "Test Voice".to_string(),
400            language: "en-US".to_string(),
401        };
402
403        let json = serde_json::to_string(&voice).unwrap();
404        assert!(json.contains("\"id\":\"test-voice\""));
405        assert!(json.contains("\"name\":\"Test Voice\""));
406        assert!(json.contains("\"language\":\"en-US\""));
407    }
408
409    #[test]
410    fn test_get_voices_request_optional_language() {
411        let json1 = r#"{}"#;
412        let request1: GetVoicesRequest = serde_json::from_str(json1).unwrap();
413        assert!(request1.language.is_none());
414
415        let json2 = r#"{"language": "en"}"#;
416        let request2: GetVoicesRequest = serde_json::from_str(json2).unwrap();
417        assert_eq!(request2.language, Some("en".to_string()));
418    }
419
420    #[test]
421    fn test_validation_empty_text() {
422        let request = SpeakRequest {
423            text: "".to_string(),
424            language: None,
425            voice_id: None,
426            rate: 1.0,
427            pitch: 1.0,
428            volume: 1.0,
429            queue_mode: QueueMode::Flush,
430        };
431
432        let result = request.validate();
433        assert!(result.is_err());
434        assert!(matches!(result.unwrap_err(), ValidationError::EmptyText));
435    }
436
437    #[test]
438    fn test_validation_text_too_long() {
439        let long_text = "x".repeat(MAX_TEXT_LENGTH + 1);
440        let request = SpeakRequest {
441            text: long_text,
442            language: None,
443            voice_id: None,
444            rate: 1.0,
445            pitch: 1.0,
446            volume: 1.0,
447            queue_mode: QueueMode::Flush,
448        };
449
450        let result = request.validate();
451        assert!(result.is_err());
452        assert!(matches!(
453            result.unwrap_err(),
454            ValidationError::TextTooLong { .. }
455        ));
456    }
457
458    #[test]
459    fn test_validation_valid_voice_id() {
460        let request = SpeakRequest {
461            text: "Hello".to_string(),
462            language: None,
463            voice_id: Some("com.apple.voice.enhanced.en-US".to_string()),
464            rate: 1.0,
465            pitch: 1.0,
466            volume: 1.0,
467            queue_mode: QueueMode::Flush,
468        };
469
470        let result = request.validate();
471        assert!(result.is_ok());
472        assert_eq!(
473            result.unwrap().voice_id,
474            Some("com.apple.voice.enhanced.en-US".to_string())
475        );
476    }
477
478    #[test]
479    fn test_validation_voice_id_too_long() {
480        let long_voice_id = "x".repeat(MAX_VOICE_ID_LENGTH + 1);
481        let request = SpeakRequest {
482            text: "Hello".to_string(),
483            language: None,
484            voice_id: Some(long_voice_id),
485            rate: 1.0,
486            pitch: 1.0,
487            volume: 1.0,
488            queue_mode: QueueMode::Flush,
489        };
490
491        let result = request.validate();
492        assert!(result.is_err());
493        assert!(matches!(
494            result.unwrap_err(),
495            ValidationError::VoiceIdTooLong { .. }
496        ));
497    }
498
499    #[test]
500    fn test_validation_rate_clamping() {
501        let request = SpeakRequest {
502            text: "Hello".to_string(),
503            language: None,
504            voice_id: None,
505            rate: 999.0,
506            pitch: 1.0,
507            volume: 1.0,
508            queue_mode: QueueMode::Flush,
509        };
510
511        let result = request.validate();
512        assert!(result.is_ok());
513        let validated = result.unwrap();
514        assert_eq!(validated.rate, 4.0); // Clamped to max
515    }
516
517    #[test]
518    fn test_validation_pitch_clamping() {
519        let request = SpeakRequest {
520            text: "Hello".to_string(),
521            language: None,
522            voice_id: None,
523            rate: 1.0,
524            pitch: 0.1,
525            volume: 1.0,
526            queue_mode: QueueMode::Flush,
527        };
528
529        let result = request.validate();
530        assert!(result.is_ok());
531        let validated = result.unwrap();
532        assert_eq!(validated.pitch, 0.5); // Clamped to min
533    }
534
535    #[test]
536    fn test_validation_volume_clamping() {
537        let request = SpeakRequest {
538            text: "Hello".to_string(),
539            language: None,
540            voice_id: None,
541            rate: 1.0,
542            pitch: 1.0,
543            volume: 5.0,
544            queue_mode: QueueMode::Flush,
545        };
546
547        let result = request.validate();
548        assert!(result.is_ok());
549        let validated = result.unwrap();
550        assert_eq!(validated.volume, 1.0); // Clamped to max
551    }
552
553    #[test]
554    fn test_preview_voice_validation() {
555        // Valid preview
556        let valid = PreviewVoiceRequest {
557            voice_id: "valid-voice_123".to_string(),
558            text: None,
559        };
560        assert!(valid.validate().is_ok());
561
562        // Invalid voice_id
563        let invalid = PreviewVoiceRequest {
564            voice_id: "invalid<script>".to_string(),
565            text: None,
566        };
567        assert!(invalid.validate().is_err());
568    }
569
570    #[test]
571    fn test_preview_voice_sample_text() {
572        let without_text = PreviewVoiceRequest {
573            voice_id: "voice".to_string(),
574            text: None,
575        };
576        assert_eq!(
577            without_text.sample_text(),
578            PreviewVoiceRequest::DEFAULT_SAMPLE_TEXT
579        );
580
581        let with_text = PreviewVoiceRequest {
582            voice_id: "voice".to_string(),
583            text: Some("Custom sample".to_string()),
584        };
585        assert_eq!(with_text.sample_text(), "Custom sample");
586    }
587}