pocket_tts_cli/commands/
generate.rs

1//! Generate command implementation
2//!
3//! Provides `pocket-tts generate` for text-to-speech synthesis.
4
5use anyhow::Result;
6use clap::Parser;
7use indicatif::{ProgressBar, ProgressStyle};
8use owo_colors::OwoColorize;
9use pocket_tts::TTSModel;
10use std::path::PathBuf;
11
12use crate::voice::{PREDEFINED_VOICES, resolve_voice};
13
14/// Default text shown when user runs without --text
15pub const DEFAULT_TEXT: &str =
16    "Hello world! I am Pocket TTS, running blazingly fast in Rust. I hope you'll like me.";
17
18#[derive(Parser, Debug)]
19pub struct GenerateArgs {
20    /// Text to synthesize (defaults to a greeting if not specified)
21    #[arg(short, long, default_value = DEFAULT_TEXT)]
22    pub text: String,
23
24    /// Voice for synthesis. Can be:
25    /// - Predefined name: alba, marius, javert, jean, fantine, cosette, eponine, azelma
26    /// - Path to .wav file for voice cloning
27    /// - Path to .safetensors embeddings file
28    /// - HuggingFace URL: hf://owner/repo/file.wav
29    #[arg(short, long)]
30    pub voice: Option<String>,
31
32    /// Output audio file path
33    #[arg(short, long, default_value = "output.wav")]
34    pub output: PathBuf,
35
36    /// Model variant (default: b6369a24)
37    #[arg(long, default_value = "b6369a24")]
38    pub variant: String,
39
40    /// Sampling temperature (higher = more variation)
41    #[arg(long, default_value = "0.7")]
42    pub temperature: f32,
43
44    /// LSD decode steps (more steps = better quality, slower)
45    #[arg(long, default_value = "1")]
46    pub lsd_decode_steps: usize,
47
48    /// EOS threshold (more negative = longer audio)
49    #[arg(long, default_value = "-4.0")]
50    pub eos_threshold: f32,
51
52    /// Noise clamp value (optional)
53    #[arg(long)]
54    pub noise_clamp: Option<f32>,
55
56    /// Frames to generate after EOS detection (optional, auto-estimated if not set)
57    #[arg(long)]
58    pub frames_after_eos: Option<usize>,
59
60    /// Stream raw PCM audio to stdout (for piping to audio players)
61    #[arg(long)]
62    pub stream: bool,
63
64    /// Use simulated int8 quantization for inference
65    #[arg(long)]
66    pub quantized: bool,
67
68    /// Use Metal acceleration (macOS only)
69    #[arg(long)]
70    pub use_metal: bool,
71
72    /// Suppress all output except errors
73    #[arg(short, long)]
74    pub quiet: bool,
75}
76
77/// Print styled message (respects quiet mode)
78macro_rules! info {
79    ($quiet:expr, $($arg:tt)*) => {
80        if !$quiet {
81            println!($($arg)*);
82        }
83    };
84}
85
86pub fn run(args: GenerateArgs) -> Result<()> {
87    let quiet = args.quiet || args.stream;
88
89    // Print banner
90    if !quiet {
91        print_banner();
92    }
93
94    // Set up device
95    let device = if args.use_metal {
96        #[cfg(feature = "metal")]
97        {
98            candle_core::Device::new_metal(0)?
99        }
100        #[cfg(not(feature = "metal"))]
101        {
102            anyhow::bail!("Metal feature not enabled. Rebuild with --features metal");
103        }
104    } else {
105        candle_core::Device::Cpu
106    };
107
108    if !quiet {
109        println!("  {} Using device: {:?}", "▶".cyan(), device);
110    }
111
112    // Load model
113    info!(quiet, "{} Loading model...", "▶".cyan());
114
115    let quantized = args.quantized;
116
117    let model = if quantized {
118        #[cfg(feature = "quantized")]
119        {
120            TTSModel::load_quantized_with_params_device(
121                &args.variant,
122                args.temperature,
123                args.lsd_decode_steps,
124                args.eos_threshold,
125                args.noise_clamp,
126                &device,
127            )?
128        }
129        #[cfg(not(feature = "quantized"))]
130        {
131            anyhow::bail!("Quantization feature not enabled. Rebuild with --features quantized");
132        }
133    } else {
134        TTSModel::load_with_params_device(
135            &args.variant,
136            args.temperature,
137            args.lsd_decode_steps,
138            args.eos_threshold,
139            args.noise_clamp,
140            &device,
141        )?
142    };
143
144    info!(
145        quiet,
146        "  {} Model loaded (sample rate: {}Hz)",
147        "✓".green(),
148        model.sample_rate
149    );
150
151    // Resolve voice
152    let voice_display = args.voice.as_deref().unwrap_or("alba (default)");
153    info!(
154        quiet,
155        "{} Using voice: {}",
156        "▶".cyan(),
157        voice_display.yellow()
158    );
159
160    let voice_state = resolve_voice(&model, args.voice.as_deref())?;
161
162    info!(quiet, "  {} Voice ready", "✓".green());
163
164    // Generate
165    if args.stream {
166        run_streaming(&model, &args.text, &voice_state)
167    } else {
168        run_to_file(&model, &args, &voice_state, quiet)
169    }
170}
171
172/// Run streaming generation to stdout
173fn run_streaming(model: &TTSModel, text: &str, voice_state: &pocket_tts::ModelState) -> Result<()> {
174    use std::io::Write;
175    let mut stdout = std::io::stdout();
176
177    for chunk_res in model.generate_stream_long(text, voice_state) {
178        let chunk = chunk_res?;
179        // Convert tensor to 16-bit PCM
180        let chunk = chunk.squeeze(0)?;
181        let data = chunk.to_vec2::<f32>()?;
182
183        for (i, _) in data[0].iter().enumerate() {
184            for channel_data in &data {
185                // Hard clamp to [-1, 1] to match Python's behavior
186                let val = channel_data[i].clamp(-1.0, 1.0);
187                let val = (val * 32767.0) as i16;
188                stdout.write_all(&val.to_le_bytes())?;
189            }
190        }
191        stdout.flush()?;
192    }
193
194    Ok(())
195}
196
197/// Run generation to file with progress bar
198fn run_to_file(
199    model: &TTSModel,
200    args: &GenerateArgs,
201    voice_state: &pocket_tts::ModelState,
202    quiet: bool,
203) -> Result<()> {
204    use candle_core::Tensor;
205
206    info!(
207        quiet,
208        "{} Generating: \"{}\"",
209        "▶".cyan(),
210        truncate_text(&args.text, 60).italic()
211    );
212
213    let total_steps = model.estimate_generation_steps(&args.text) as u64;
214
215    let pb = if quiet {
216        ProgressBar::hidden()
217    } else {
218        let pb = ProgressBar::new(total_steps);
219        pb.set_style(
220            ProgressStyle::default_bar()
221                .template(
222                    "{spinner:.cyan} [{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}",
223                )
224                .unwrap()
225                .progress_chars("█▓░"),
226        );
227        pb.set_message("generating...");
228        pb
229    };
230
231    let mut audio_chunks = Vec::new();
232    let mut total_samples = 0;
233
234    for chunk_res in model.generate_stream_long(&args.text, voice_state) {
235        let chunk = chunk_res?;
236        let dims = chunk.dims();
237        let samples = if dims.len() == 2 { dims[1] } else { dims[0] };
238        total_samples += samples;
239
240        audio_chunks.push(chunk);
241        pb.inc(1);
242        pb.set_message(format!(
243            "{:.2}s generated",
244            total_samples as f32 / model.sample_rate as f32
245        ));
246    }
247
248    pb.finish_and_clear();
249
250    // Concatenate all audio chunks
251    if audio_chunks.is_empty() {
252        anyhow::bail!("No audio generated - text may be too short or invalid");
253    }
254    let audio = Tensor::cat(&audio_chunks, 2)?;
255    let audio = audio.squeeze(0)?; // Remove batch dimension
256
257    let dims = audio.dims();
258    let num_samples = if dims.len() == 2 { dims[1] } else { dims[0] };
259    let duration_sec = num_samples as f32 / model.sample_rate as f32;
260
261    // Save to file
262    info!(
263        quiet,
264        "{} Saving to: {}",
265        "▶".cyan(),
266        args.output.display().yellow()
267    );
268    pocket_tts::audio::write_wav(&args.output, &audio, model.sample_rate as u32)?;
269
270    // Success message
271    if !quiet {
272        println!();
273        println!(
274            "  {} {}",
275            "✓".green().bold(),
276            "Audio generated successfully!".green().bold()
277        );
278        println!(
279            "    Duration: {:.2}s ({} samples @ {}Hz)",
280            duration_sec, num_samples, model.sample_rate
281        );
282        println!("    Output:   {}", args.output.display().cyan());
283        println!();
284        println!(
285            "  {} {}",
286            "💡".dimmed(),
287            format!("Play with: ffplay -autoexit {:?}", args.output).dimmed()
288        );
289    }
290
291    Ok(())
292}
293
294/// Print startup banner
295fn print_banner() {
296    println!();
297    println!("  {}  {}", "🗣️".bold(), "Pocket TTS".bold().cyan());
298    println!(
299        "      {} {}",
300        "Rust/Candle port".dimmed(),
301        format!("v{}", env!("CARGO_PKG_VERSION")).dimmed()
302    );
303    println!();
304}
305
306/// Truncate text for display
307fn truncate_text(text: &str, max_len: usize) -> String {
308    if text.len() <= max_len {
309        text.to_string()
310    } else {
311        format!("{}...", &text[..max_len - 3])
312    }
313}
314
315/// Print available voices (for help text)
316pub fn available_voices_help() -> String {
317    format!("Predefined voices: {}", PREDEFINED_VOICES.join(", "))
318}