1use anyhow::Result;
6use clap::Parser;
7use indicatif::{ProgressBar, ProgressStyle};
8use owo_colors::OwoColorize;
9use pocket_tts::TTSModel;
10use std::path::PathBuf;
11
12use crate::voice::{PREDEFINED_VOICES, resolve_voice};
13
14pub const DEFAULT_TEXT: &str =
16 "Hello world! I am Pocket TTS, running blazingly fast in Rust. I hope you'll like me.";
17
18#[derive(Parser, Debug)]
19pub struct GenerateArgs {
20 #[arg(short, long, default_value = DEFAULT_TEXT)]
22 pub text: String,
23
24 #[arg(short, long)]
30 pub voice: Option<String>,
31
32 #[arg(short, long, default_value = "output.wav")]
34 pub output: PathBuf,
35
36 #[arg(long, default_value = "b6369a24")]
38 pub variant: String,
39
40 #[arg(long, default_value = "0.7")]
42 pub temperature: f32,
43
44 #[arg(long, default_value = "1")]
46 pub lsd_decode_steps: usize,
47
48 #[arg(long, default_value = "-4.0")]
50 pub eos_threshold: f32,
51
52 #[arg(long)]
54 pub noise_clamp: Option<f32>,
55
56 #[arg(long)]
58 pub frames_after_eos: Option<usize>,
59
60 #[arg(long)]
62 pub stream: bool,
63
64 #[arg(long)]
66 pub quantized: bool,
67
68 #[arg(long)]
70 pub use_metal: bool,
71
72 #[arg(short, long)]
74 pub quiet: bool,
75}
76
77macro_rules! info {
79 ($quiet:expr, $($arg:tt)*) => {
80 if !$quiet {
81 println!($($arg)*);
82 }
83 };
84}
85
86pub fn run(args: GenerateArgs) -> Result<()> {
87 let quiet = args.quiet || args.stream;
88
89 if !quiet {
91 print_banner();
92 }
93
94 let device = if args.use_metal {
96 #[cfg(feature = "metal")]
97 {
98 candle_core::Device::new_metal(0)?
99 }
100 #[cfg(not(feature = "metal"))]
101 {
102 anyhow::bail!("Metal feature not enabled. Rebuild with --features metal");
103 }
104 } else {
105 candle_core::Device::Cpu
106 };
107
108 if !quiet {
109 println!(" {} Using device: {:?}", "▶".cyan(), device);
110 }
111
112 info!(quiet, "{} Loading model...", "▶".cyan());
114
115 let quantized = args.quantized;
116
117 let model = if quantized {
118 #[cfg(feature = "quantized")]
119 {
120 TTSModel::load_quantized_with_params_device(
121 &args.variant,
122 args.temperature,
123 args.lsd_decode_steps,
124 args.eos_threshold,
125 args.noise_clamp,
126 &device,
127 )?
128 }
129 #[cfg(not(feature = "quantized"))]
130 {
131 anyhow::bail!("Quantization feature not enabled. Rebuild with --features quantized");
132 }
133 } else {
134 TTSModel::load_with_params_device(
135 &args.variant,
136 args.temperature,
137 args.lsd_decode_steps,
138 args.eos_threshold,
139 args.noise_clamp,
140 &device,
141 )?
142 };
143
144 info!(
145 quiet,
146 " {} Model loaded (sample rate: {}Hz)",
147 "✓".green(),
148 model.sample_rate
149 );
150
151 let voice_display = args.voice.as_deref().unwrap_or("alba (default)");
153 info!(
154 quiet,
155 "{} Using voice: {}",
156 "▶".cyan(),
157 voice_display.yellow()
158 );
159
160 let voice_state = resolve_voice(&model, args.voice.as_deref())?;
161
162 info!(quiet, " {} Voice ready", "✓".green());
163
164 if args.stream {
166 run_streaming(&model, &args.text, &voice_state)
167 } else {
168 run_to_file(&model, &args, &voice_state, quiet)
169 }
170}
171
172fn run_streaming(model: &TTSModel, text: &str, voice_state: &pocket_tts::ModelState) -> Result<()> {
174 use std::io::Write;
175 let mut stdout = std::io::stdout();
176
177 for chunk_res in model.generate_stream_long(text, voice_state) {
178 let chunk = chunk_res?;
179 let chunk = chunk.squeeze(0)?;
181 let data = chunk.to_vec2::<f32>()?;
182
183 for (i, _) in data[0].iter().enumerate() {
184 for channel_data in &data {
185 let val = channel_data[i].clamp(-1.0, 1.0);
187 let val = (val * 32767.0) as i16;
188 stdout.write_all(&val.to_le_bytes())?;
189 }
190 }
191 stdout.flush()?;
192 }
193
194 Ok(())
195}
196
197fn run_to_file(
199 model: &TTSModel,
200 args: &GenerateArgs,
201 voice_state: &pocket_tts::ModelState,
202 quiet: bool,
203) -> Result<()> {
204 use candle_core::Tensor;
205
206 info!(
207 quiet,
208 "{} Generating: \"{}\"",
209 "▶".cyan(),
210 truncate_text(&args.text, 60).italic()
211 );
212
213 let total_steps = model.estimate_generation_steps(&args.text) as u64;
214
215 let pb = if quiet {
216 ProgressBar::hidden()
217 } else {
218 let pb = ProgressBar::new(total_steps);
219 pb.set_style(
220 ProgressStyle::default_bar()
221 .template(
222 "{spinner:.cyan} [{elapsed_precise}] {bar:40.cyan/blue} {pos}/{len} {msg}",
223 )
224 .unwrap()
225 .progress_chars("█▓░"),
226 );
227 pb.set_message("generating...");
228 pb
229 };
230
231 let mut audio_chunks = Vec::new();
232 let mut total_samples = 0;
233
234 for chunk_res in model.generate_stream_long(&args.text, voice_state) {
235 let chunk = chunk_res?;
236 let dims = chunk.dims();
237 let samples = if dims.len() == 2 { dims[1] } else { dims[0] };
238 total_samples += samples;
239
240 audio_chunks.push(chunk);
241 pb.inc(1);
242 pb.set_message(format!(
243 "{:.2}s generated",
244 total_samples as f32 / model.sample_rate as f32
245 ));
246 }
247
248 pb.finish_and_clear();
249
250 if audio_chunks.is_empty() {
252 anyhow::bail!("No audio generated - text may be too short or invalid");
253 }
254 let audio = Tensor::cat(&audio_chunks, 2)?;
255 let audio = audio.squeeze(0)?; let dims = audio.dims();
258 let num_samples = if dims.len() == 2 { dims[1] } else { dims[0] };
259 let duration_sec = num_samples as f32 / model.sample_rate as f32;
260
261 info!(
263 quiet,
264 "{} Saving to: {}",
265 "▶".cyan(),
266 args.output.display().yellow()
267 );
268 pocket_tts::audio::write_wav(&args.output, &audio, model.sample_rate as u32)?;
269
270 if !quiet {
272 println!();
273 println!(
274 " {} {}",
275 "✓".green().bold(),
276 "Audio generated successfully!".green().bold()
277 );
278 println!(
279 " Duration: {:.2}s ({} samples @ {}Hz)",
280 duration_sec, num_samples, model.sample_rate
281 );
282 println!(" Output: {}", args.output.display().cyan());
283 println!();
284 println!(
285 " {} {}",
286 "💡".dimmed(),
287 format!("Play with: ffplay -autoexit {:?}", args.output).dimmed()
288 );
289 }
290
291 Ok(())
292}
293
294fn print_banner() {
296 println!();
297 println!(" {} {}", "🗣️".bold(), "Pocket TTS".bold().cyan());
298 println!(
299 " {} {}",
300 "Rust/Candle port".dimmed(),
301 format!("v{}", env!("CARGO_PKG_VERSION")).dimmed()
302 );
303 println!();
304}
305
306fn truncate_text(text: &str, max_len: usize) -> String {
308 if text.len() <= max_len {
309 text.to_string()
310 } else {
311 format!("{}...", &text[..max_len - 3])
312 }
313}
314
315pub fn available_voices_help() -> String {
317 format!("Predefined voices: {}", PREDEFINED_VOICES.join(", "))
318}