use std::io::{self, IsTerminal, Write};
use std::path::Path;
use std::sync::{
Arc, Mutex,
atomic::{AtomicBool, Ordering},
};
use std::thread;
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow, bail};
use ct2rs::{ComputeType, Config, Device};
use futures_util::StreamExt;
use indicatif::{ProgressBar, ProgressStyle};
use reqwest::Client;
use tempfile::NamedTempFile;
use tokio::io::AsyncWriteExt;
use tokio::sync::watch;
use tokio::time::{Duration as TokioDuration, sleep};
use url::Url;
use crate::audio::{
PreparedAudio, RawAudioEncoding, RawAudioSpec, RawAudioStreamDecoder, inspect_audio_file,
inspect_raw_audio_file,
};
use crate::micro::MicrophoneCapture;
use crate::model::ModelChoice;
use crate::stream::{StreamConfig, StreamEngine};
use crate::whisper::{Whisper, WhisperChunkResult, WhisperOptions};
#[derive(Clone, Debug)]
pub struct ExecutionMode {
device: Device,
compute_type: ComputeType,
gpu_device: Option<i32>,
warning: Option<String>,
}
impl ExecutionMode {
pub fn from_cli(use_gpu: bool, gpu_device: i32) -> Result<Self> {
if use_gpu {
#[cfg(feature = "gpu")]
{
let visible_devices = ct2rs::sys::get_device_count(Device::CUDA);
if visible_devices <= 0 {
return Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: Some(
"CUDA backend is not available in this environment; falling back to CPU"
.to_string(),
),
});
}
if gpu_device < 0 || gpu_device >= visible_devices {
bail!(
"CUDA device {} is out of range; visible devices: 0..{}",
gpu_device,
visible_devices - 1
);
}
return Ok(Self {
device: Device::CUDA,
compute_type: ComputeType::AUTO,
gpu_device: Some(gpu_device),
warning: None,
});
}
#[cfg(not(feature = "gpu"))]
{
let _ = gpu_device;
bail!("--gpu not supported");
}
}
Ok(Self {
device: Device::CPU,
compute_type: ComputeType::INT8,
gpu_device: None,
warning: None,
})
}
pub fn device_label(&self) -> &'static str {
match self.device {
Device::CPU => "cpu",
Device::CUDA => "cuda",
_ => "unknown",
}
}
pub fn compute_type_label(&self) -> &'static str {
match self.compute_type {
ComputeType::DEFAULT => "default",
ComputeType::AUTO => "auto",
ComputeType::FLOAT32 => "float32",
ComputeType::INT8 => "int8",
ComputeType::INT8_FLOAT32 => "int8_float32",
ComputeType::INT8_FLOAT16 => "int8_float16",
ComputeType::INT8_BFLOAT16 => "int8_bfloat16",
ComputeType::INT16 => "int16",
ComputeType::FLOAT16 => "float16",
ComputeType::BFLOAT16 => "bfloat16",
_ => "unknown",
}
}
pub fn warning(&self) -> Option<&str> {
self.warning.as_deref()
}
pub fn gpu_requested_but_unavailable(&self) -> bool {
self.warning.is_some()
}
}
const STREAM_SPINNER_FRAMES: &[&str] = &["|", "/", "-", "\\"];
const LIVE_WINDOW_SECONDS: usize = 4;
const LIVE_OVERLAP_SECONDS: usize = 1;
const LIVE_INITIAL_PROBE_BYTES: u64 = 16 * 1024;
const LIVE_REPARSE_BYTES: u64 = 64 * 1024;
fn transcription_options() -> WhisperOptions {
WhisperOptions {
beam_size: 5,
..Default::default()
}
}
fn live_transcription_options() -> WhisperOptions {
WhisperOptions {
beam_size: 5,
repetition_penalty: 1.15,
no_repeat_ngram_size: 3,
return_no_speech_prob: true,
..Default::default()
}
}
pub async fn run_transcription(
audio: &PreparedAudio,
model_choice: ModelChoice,
model_dir: &Path,
_models_root: &Path,
execution: &ExecutionMode,
stream: bool,
) -> Result<()> {
let whisper = load_whisper_model(model_choice, model_dir, execution)?;
let options = transcription_options();
println!();
println!("Audio parameters");
println!(
"backend: CTranslate2 / {} ({})",
execution.device_label(),
execution.compute_type_label()
);
if let Some(gpu_device) = execution.gpu_device {
println!("gpu id : {gpu_device}");
}
println!("source : {}", audio.display_name);
if let Some(duration) = audio.metadata.duration {
println!("length : {}", format_duration(duration.as_secs_f64()));
}
if let Some(source_rate) = audio.metadata.source_sample_rate {
println!("input rate : {source_rate} Hz");
}
println!("model rate : {} Hz", audio.metadata.target_sample_rate);
if let Some(channels) = audio.metadata.channels {
println!("channels : {channels} -> mono");
}
println!("codec : {}", audio.metadata.codec);
println!("model path : {}", model_dir.display());
println!();
if whisper.sampling_rate() != audio.metadata.target_sample_rate as usize {
anyhow::bail!(
"audio was resampled to {} Hz but the model expects {} Hz",
audio.metadata.target_sample_rate,
whisper.sampling_rate()
);
}
if stream {
stream_transcription(&whisper, audio, &options)
} else {
println!("{}", "=".repeat(72));
let processing = ProgressBar::new_spinner();
processing.set_style(
ProgressStyle::with_template(" processing {spinner:.green} {msg}")
.context("failed to configure processing spinner")?,
);
processing.enable_steady_tick(Duration::from_millis(80));
processing.set_message(audio.display_name.clone());
let lines = whisper
.generate(&audio.samples, None, false, &options)
.context("CTranslate2 transcription failed")?;
processing.finish_with_message("processing complete");
print_transcript_divider();
for line in lines {
let line = line.trim();
if !line.is_empty() {
println!("{line}");
}
}
Ok(())
}
}
pub async fn run_live_transcription(
input: &str,
model_choice: ModelChoice,
model_dir: &Path,
_models_root: &Path,
execution: &ExecutionMode,
) -> Result<()> {
let url = Url::parse(input).with_context(|| format!("invalid live URL `{input}`"))?;
if !matches!(url.scheme(), "http" | "https") {
bail!("--live currently supports only http/https media URLs");
}
let whisper = load_whisper_model(model_choice, model_dir, execution)?;
let options = live_transcription_options();
let mut engine = StreamEngine::new(live_stream_config(
whisper.sampling_rate(),
whisper.n_samples(),
)?);
let overlap_seconds = engine.config().overlap_seconds();
let (response, raw_audio) = connect_live_stream(&url).await?;
println!();
println!("Audio parameters");
println!(
"backend: CTranslate2 / {} ({})",
execution.device_label(),
execution.compute_type_label()
);
if let Some(gpu_device) = execution.gpu_device {
println!("gpu id : {gpu_device}");
}
println!("source : {input}");
println!("length : live");
println!("model rate : {} Hz", whisper.sampling_rate());
println!("model path : {}", model_dir.display());
if let Some(raw_audio) = raw_audio.as_ref() {
println!("input rate : {} Hz", raw_audio.sample_rate);
println!("channels : {} -> mono", raw_audio.channels);
println!("codec : raw {:?}", raw_audio.encoding);
}
println!();
let mut display = StreamDisplay::new(format!(
"live stream / 0 KiB buffered / {overlap_seconds}s overlap"
))?;
if let Some(raw_audio) = raw_audio {
return run_raw_live_transcription(
response, raw_audio, &whisper, &options, engine, display,
)
.await;
}
let live_source = start_live_buffer(response).await?;
let temp_file = live_source.temp_file;
let mut status_rx = live_source.status_rx;
let _download_task = live_source.download_task;
let mut last_buffered_bytes = 0u64;
let mut last_sample_count = 0usize;
let mut metadata_printed = false;
let mut parse_failures = 0usize;
let mut last_probe_at = Instant::now();
loop {
let status = status_rx.borrow().clone();
if let Some(error) = status.error {
return Err(anyhow!("live stream download failed: {error}"));
}
let buffered_delta = status.bytes_received.saturating_sub(last_buffered_bytes);
let should_poll = status.finished
|| (status.bytes_received > 0
&& ((last_buffered_bytes == 0 && buffered_delta >= LIVE_INITIAL_PROBE_BYTES)
|| buffered_delta >= LIVE_REPARSE_BYTES));
if should_poll && status.bytes_received > 0 {
last_probe_at = Instant::now();
match inspect_live_snapshot(temp_file.path(), whisper.sampling_rate() as u32, None)
.await
{
Ok((metadata, samples)) => {
parse_failures = 0;
if !metadata_printed {
if let Some(source_rate) = metadata.source_sample_rate {
println!("input rate : {source_rate} Hz");
}
if let Some(channels) = metadata.channels {
println!("channels : {channels} -> mono");
}
println!("codec : {}", metadata.codec);
println!();
metadata_printed = true;
}
if samples.len() < last_sample_count {
bail!("live stream sample count moved backwards");
}
for chunk in engine.push_audio(&samples[last_sample_count..]) {
display.set_status(format!(
"live chunk {} / {} KiB buffered / {}s overlap",
chunk.index + 1,
status.bytes_received / 1024,
overlap_seconds
))?;
let raw_text = transcribe_live_chunk_text(
&whisper,
&chunk.samples,
&options,
"live streaming chunk transcription failed",
)?;
if let Some(stable_text) = raw_text
.as_deref()
.and_then(|raw_text| engine.stabilize_text(raw_text))
{
display.push_text(&stable_text)?;
}
}
last_sample_count = samples.len();
last_buffered_bytes = status.bytes_received;
if status.finished {
if let Some(chunk) = engine.finish_audio() {
display.set_status(format!(
"live chunk {} / {} KiB buffered / finalizing",
chunk.index + 1,
status.bytes_received / 1024
))?;
let raw_text = transcribe_live_chunk_text(
&whisper,
&chunk.samples,
&options,
"live streaming final chunk transcription failed",
)?;
if let Some(stable_text) = raw_text
.as_deref()
.and_then(|raw_text| engine.stabilize_text(raw_text))
{
display.push_text(&stable_text)?;
}
}
if let Some(final_text) = engine.finish_text() {
display.push_text(&final_text)?;
}
return display.finish("live stream complete");
}
}
Err(error) => {
if status.finished {
return Err(error).context("failed to decode final live media snapshot");
}
parse_failures += 1;
last_buffered_bytes = status.bytes_received;
if status.bytes_received >= 512 * 1024 && parse_failures >= 3 {
return Err(error).context("failed to decode buffered live media stream");
}
}
}
}
if status.finished && status.bytes_received == 0 {
return Err(anyhow!(
"live stream ended before any media data was received"
));
}
if last_buffered_bytes == status.bytes_received
&& last_probe_at.elapsed() < TokioDuration::from_secs(2)
{
let _ = status_rx.changed().await;
} else {
sleep(TokioDuration::from_millis(250)).await;
}
}
}
pub async fn run_microphone_transcription(
model_choice: ModelChoice,
model_dir: &Path,
_models_root: &Path,
execution: &ExecutionMode,
) -> Result<()> {
let whisper = load_whisper_model(model_choice, model_dir, execution)?;
let options = live_transcription_options();
let mut engine = StreamEngine::new(live_stream_config(
whisper.sampling_rate(),
whisper.n_samples(),
)?);
let overlap_seconds = engine.config().overlap_seconds();
let mut microphone = MicrophoneCapture::open_default(whisper.sampling_rate() as u32, 1)?;
println!();
println!("Audio parameters");
println!(
"backend: CTranslate2 / {} ({})",
execution.device_label(),
execution.compute_type_label()
);
if let Some(gpu_device) = execution.gpu_device {
println!("gpu id : {gpu_device}");
}
println!("source : microphone");
println!("length : live");
println!("model rate : {} Hz", whisper.sampling_rate());
println!("channels : 1 -> mono");
println!("codec : native S16Le");
println!("model path : {}", model_dir.display());
println!();
let mut display = StreamDisplay::new(format!(
"microphone / listening / {overlap_seconds}s overlap"
))?;
let mut captured_samples = 0usize;
loop {
let samples = microphone.read_samples()?;
captured_samples += samples.len();
for chunk in engine.push_audio(&samples) {
display.set_status(format!(
"microphone chunk {} / {:.1}s captured / {}s overlap",
chunk.index + 1,
captured_samples as f64 / whisper.sampling_rate() as f64,
overlap_seconds
))?;
let raw_text = transcribe_live_chunk_text(
&whisper,
&chunk.samples,
&options,
"microphone chunk transcription failed",
)?;
if let Some(stable_text) = raw_text
.as_deref()
.and_then(|raw_text| engine.stabilize_text(raw_text))
{
display.push_text(&stable_text)?;
}
}
}
}
fn load_whisper_model(
model_choice: ModelChoice,
model_dir: &Path,
execution: &ExecutionMode,
) -> Result<Whisper> {
let progress = ProgressBar::new_spinner();
progress.set_style(
ProgressStyle::with_template(" loading model {spinner:.green} {msg}")
.context("failed to configure model loading spinner")?,
);
progress.enable_steady_tick(Duration::from_millis(80));
progress.set_message(model_dir.display().to_string());
let whisper = Whisper::new(model_dir, ctranslate2_config(execution)).with_context(|| {
format!(
"failed to initialize CTranslate2 model `{}` from `{}`",
model_choice.cli_name(),
model_dir.display()
)
})?;
progress.finish_with_message("model loaded");
Ok(whisper)
}
#[derive(Clone, Debug, Default)]
struct LiveBufferStatus {
bytes_received: u64,
finished: bool,
error: Option<String>,
}
#[derive(Debug)]
struct LiveSourceInfo {
temp_file: NamedTempFile,
status_rx: watch::Receiver<LiveBufferStatus>,
download_task: tokio::task::JoinHandle<()>,
}
async fn connect_live_stream(url: &Url) -> Result<(reqwest::Response, Option<RawAudioSpec>)> {
let client = Client::builder()
.user_agent("transcribe-cli/0.1.0")
.build()
.context("failed to build HTTP client for live stream")?;
let response = client
.get(url.clone())
.send()
.await
.with_context(|| format!("failed to connect to live media stream `{url}`"))?
.error_for_status()
.with_context(|| format!("live media stream returned an error for `{url}`"))?;
let raw_audio = raw_audio_spec_from_headers(&response);
Ok((response, raw_audio))
}
async fn start_live_buffer(response: reqwest::Response) -> Result<LiveSourceInfo> {
let temp_file = tempfile::Builder::new()
.prefix("transcribe-cli-live-")
.suffix(".media")
.tempfile()
.context("failed to create temporary live media buffer")?;
let temp_path = temp_file.path().to_path_buf();
let mut stream = response.bytes_stream();
let (status_tx, status_rx) = watch::channel(LiveBufferStatus::default());
let task = tokio::spawn(async move {
let mut status = LiveBufferStatus::default();
let file = tokio::fs::OpenOptions::new()
.append(true)
.open(&temp_path)
.await;
let mut file = match file {
Ok(file) => file,
Err(error) => {
status.error = Some(error.to_string());
let _ = status_tx.send(status);
return;
}
};
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
if let Err(error) = file.write_all(&bytes).await {
status.error = Some(error.to_string());
let _ = status_tx.send(status);
return;
}
status.bytes_received += bytes.len() as u64;
let _ = status_tx.send(status.clone());
}
Err(error) => {
status.error = Some(error.to_string());
let _ = status_tx.send(status);
return;
}
}
}
status.finished = true;
let _ = status_tx.send(status);
});
Ok(LiveSourceInfo {
temp_file,
status_rx,
download_task: task,
})
}
async fn run_raw_live_transcription(
response: reqwest::Response,
raw_audio: RawAudioSpec,
whisper: &Whisper,
options: &WhisperOptions,
mut engine: StreamEngine,
mut display: StreamDisplay,
) -> Result<()> {
let overlap_seconds = engine.config().overlap_seconds();
let mut decoder =
RawAudioStreamDecoder::new(raw_audio.clone(), whisper.sampling_rate() as u32)?;
let mut bytes_received = 0u64;
let mut stream = response.bytes_stream();
while let Some(item) = stream.next().await {
let bytes = item.context("failed to read live raw audio bytes")?;
bytes_received += bytes.len() as u64;
let samples = decoder.push_bytes(&bytes)?;
for chunk in engine.push_audio(&samples) {
display.set_status(format!(
"live chunk {} / {} KiB streamed / {}s overlap",
chunk.index + 1,
bytes_received / 1024,
overlap_seconds
))?;
let raw_text = transcribe_live_chunk_text(
whisper,
&chunk.samples,
options,
"live raw chunk transcription failed",
)?;
if let Some(stable_text) = raw_text
.as_deref()
.and_then(|raw_text| engine.stabilize_text(raw_text))
{
display.push_text(&stable_text)?;
}
}
}
let samples = decoder.finish()?;
for chunk in engine
.push_audio(&samples)
.into_iter()
.chain(engine.finish_audio())
{
display.set_status(format!(
"live chunk {} / {} KiB streamed / finalizing",
chunk.index + 1,
bytes_received / 1024
))?;
let raw_text = transcribe_live_chunk_text(
whisper,
&chunk.samples,
options,
"live raw final chunk transcription failed",
)?;
if let Some(stable_text) = raw_text
.as_deref()
.and_then(|raw_text| engine.stabilize_text(raw_text))
{
display.push_text(&stable_text)?;
}
}
if let Some(final_text) = engine.finish_text() {
display.push_text(&final_text)?;
}
display.finish("live stream complete")
}
async fn inspect_live_snapshot(
path: &Path,
target_sample_rate: u32,
raw_audio: Option<RawAudioSpec>,
) -> Result<(crate::audio::AudioMetadata, Vec<f32>)> {
let path = path.to_path_buf();
tokio::task::spawn_blocking(move || match raw_audio {
Some(raw_audio) => inspect_raw_audio_file(&path, target_sample_rate, &raw_audio),
None => inspect_audio_file(&path, target_sample_rate),
})
.await
.context("failed to join live media decode task")?
}
fn raw_audio_spec_from_headers(response: &reqwest::Response) -> Option<RawAudioSpec> {
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)?
.to_str()
.ok()?
.to_ascii_lowercase();
if !content_type.starts_with("application/octet-stream") {
return None;
}
let sample_rate = response
.headers()
.get("x-audio-sample-rate")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u32>().ok())
.unwrap_or(48_000);
let channels = response
.headers()
.get("x-audio-channels")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse::<u16>().ok())
.unwrap_or(1);
let encoding = response
.headers()
.get("x-audio-format")
.and_then(|value| value.to_str().ok())
.map(|value| value.trim().to_ascii_lowercase());
match encoding.as_deref().unwrap_or("f32le") {
"f32" | "f32le" | "float32" | "pcm_f32le" => Some(RawAudioSpec {
sample_rate,
channels,
encoding: RawAudioEncoding::F32Le,
}),
_ => None,
}
}
fn live_stream_config(sample_rate: usize, model_window_samples: usize) -> Result<StreamConfig> {
let mut config = StreamConfig::for_realtime(sample_rate, model_window_samples)?;
config.window_samples = config.window_samples.min(sample_rate * LIVE_WINDOW_SECONDS);
config.overlap_samples = (sample_rate * LIVE_OVERLAP_SECONDS).min(config.window_samples / 2);
Ok(config)
}
fn transcribe_live_chunk_text(
whisper: &Whisper,
samples: &[f32],
options: &WhisperOptions,
error_context: &str,
) -> Result<Option<String>> {
let rms = chunk_rms(samples);
if rms < 0.003 {
return Ok(None);
}
let results = whisper
.generate_detailed(samples, None, false, options)
.with_context(|| error_context.to_string())?;
let no_speech_prob = results
.iter()
.map(|result| result.no_speech_prob)
.fold(0.0f32, f32::max);
let raw_text = join_chunk_text(results);
if raw_text.is_empty() {
return Ok(None);
}
if should_drop_live_text(samples, &raw_text, no_speech_prob) {
return Ok(None);
}
Ok(Some(raw_text))
}
fn join_chunk_text(results: Vec<WhisperChunkResult>) -> String {
results
.into_iter()
.map(|result| result.text.trim().to_string())
.filter(|text| !text.is_empty())
.collect::<Vec<_>>()
.join(" ")
}
fn should_drop_live_text(samples: &[f32], text: &str, no_speech_prob: f32) -> bool {
let rms = chunk_rms(samples);
if no_speech_prob >= 0.65 && rms < 0.02 {
return true;
}
is_suspicious_live_text(text)
}
fn chunk_rms(samples: &[f32]) -> f32 {
if samples.is_empty() {
return 0.0;
}
let mean_square =
samples.iter().map(|sample| sample * sample).sum::<f32>() / samples.len() as f32;
mean_square.sqrt()
}
fn is_suspicious_live_text(text: &str) -> bool {
if has_long_repeated_char_run(text, 12) {
return true;
}
let words = text
.split_whitespace()
.map(normalize_live_word)
.filter(|word| !word.is_empty())
.collect::<Vec<_>>();
if words.len() < 4 {
return false;
}
if max_consecutive_word_run(&words) >= 3 {
return true;
}
let unique_words = words.iter().collect::<std::collections::HashSet<_>>().len();
if unique_words * 2 <= words.len() {
return true;
}
let mut counts = std::collections::HashMap::new();
for word in &words {
*counts.entry(word.as_str()).or_insert(0usize) += 1;
}
counts.values().copied().max().unwrap_or_default() * 2 > words.len()
}
fn has_long_repeated_char_run(text: &str, min_run: usize) -> bool {
let mut previous = None;
let mut run = 0usize;
for character in text.chars().filter(|character| !character.is_whitespace()) {
if Some(character) == previous {
run += 1;
if run >= min_run {
return true;
}
} else {
previous = Some(character);
run = 1;
}
}
false
}
fn max_consecutive_word_run(words: &[String]) -> usize {
if words.is_empty() {
return 0;
}
let mut max_run = 1usize;
let mut current_run = 1usize;
for pair in words.windows(2) {
if pair[0] == pair[1] {
current_run += 1;
max_run = max_run.max(current_run);
} else {
current_run = 1;
}
}
max_run
}
fn normalize_live_word(word: &str) -> String {
word.chars()
.filter(|character| character.is_alphanumeric())
.flat_map(|character| character.to_lowercase())
.collect()
}
fn stream_transcription(
whisper: &Whisper,
audio: &PreparedAudio,
options: &WhisperOptions,
) -> Result<()> {
let mut engine = StreamEngine::new(StreamConfig::for_realtime(
audio.metadata.target_sample_rate as usize,
whisper.n_samples(),
)?);
let window_count = engine.config().window_count(audio.samples.len());
let overlap_seconds = engine.config().overlap_seconds();
let mut display = StreamDisplay::new(format!(
"chunk 1/{window_count} / stream mode / {overlap_seconds}s overlap"
))?;
for chunk in engine
.push_audio(&audio.samples)
.into_iter()
.chain(engine.finish_audio())
{
display.set_status(chunk.status(window_count, overlap_seconds))?;
let raw_text = whisper
.generate(&chunk.samples, None, false, options)
.context("streaming chunk transcription failed")?
.into_iter()
.map(|line| line.trim().to_string())
.filter(|line| !line.is_empty())
.collect::<Vec<_>>()
.join(" ");
if let Some(stable_text) = engine.stabilize_text(&raw_text) {
display.push_text(&stable_text)?;
}
}
if let Some(final_text) = engine.finish_text() {
display.push_text(&final_text)?;
}
display.finish("processing complete")
}
fn format_duration(seconds: f64) -> String {
let total_millis = (seconds * 1000.0).round() as u64;
let minutes = total_millis / 60_000;
let seconds = (total_millis % 60_000) / 1000;
let centis = (total_millis % 1000) / 10;
format!("{minutes:02}:{seconds:02}.{centis:02}")
}
fn print_transcript_divider() {
println!("{}", "=".repeat(72));
println!();
println!();
}
fn ctranslate2_config(execution: &ExecutionMode) -> Config {
let threads = std::thread::available_parallelism()
.map(|parallelism| parallelism.get())
.unwrap_or(4);
Config {
device: execution.device,
compute_type: execution.compute_type,
device_indices: execution
.gpu_device
.map(|index| vec![index])
.unwrap_or_else(|| vec![0]),
num_threads_per_replica: threads,
max_queued_batches: 2,
..Default::default()
}
}
struct StreamDisplay {
inner: Option<InteractiveStreamDisplay>,
wrote_text: bool,
}
impl StreamDisplay {
fn new(status: String) -> Result<Self> {
if io::stdout().is_terminal() {
Ok(Self {
inner: Some(InteractiveStreamDisplay::new(status)?),
wrote_text: false,
})
} else {
print_transcript_divider();
println!(" processing {status}");
println!();
Ok(Self {
inner: None,
wrote_text: false,
})
}
}
fn set_status(&mut self, status: String) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.set_status(status)?;
}
Ok(())
}
fn push_text(&mut self, text: &str) -> Result<()> {
if let Some(inner) = self.inner.as_mut() {
inner.push_text(text)?;
} else {
let mut stdout = io::stdout().lock();
let rendered = render_stream_chunk(text, self.wrote_text);
write!(stdout, "{}", rendered.as_deref().unwrap_or_default())
.context("failed to write stream text")?;
stdout.flush().context("failed to flush stream text")?;
}
if !text.trim().is_empty() {
self.wrote_text = true;
}
Ok(())
}
fn finish(mut self, final_status: &str) -> Result<()> {
if let Some(inner) = self.inner.take() {
inner.finish(final_status)?;
} else {
println!();
}
Ok(())
}
}
struct InteractiveStreamDisplay {
state: Arc<Mutex<StreamDisplayState>>,
running: Arc<AtomicBool>,
spinner_thread: Option<thread::JoinHandle<()>>,
}
impl InteractiveStreamDisplay {
fn new(initial_status: String) -> Result<Self> {
let state = Arc::new(Mutex::new(StreamDisplayState::new(initial_status)));
{
let mut guard = state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.redraw()?;
}
let running = Arc::new(AtomicBool::new(true));
let state_for_thread = Arc::clone(&state);
let running_for_thread = Arc::clone(&running);
let spinner_thread = thread::spawn(move || {
let mut frame_index = 0usize;
while running_for_thread.load(Ordering::Relaxed) {
if let Ok(mut guard) = state_for_thread.lock() {
guard.spinner = STREAM_SPINNER_FRAMES
[frame_index % STREAM_SPINNER_FRAMES.len()]
.to_string();
let _ = guard.redraw();
}
frame_index = frame_index.wrapping_add(1);
thread::sleep(Duration::from_millis(80));
}
});
Ok(Self {
state,
running,
spinner_thread: Some(spinner_thread),
})
}
fn set_status(&mut self, status: String) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.status = status;
guard.redraw()
}
fn push_text(&mut self, text: &str) -> Result<()> {
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
if !guard.transcript.is_empty() {
guard.transcript.push(' ');
}
guard.transcript.push_str(text);
guard.redraw()
}
fn finish(mut self, final_status: &str) -> Result<()> {
self.running.store(false, Ordering::Relaxed);
if let Some(handle) = self.spinner_thread.take() {
let _ = handle.join();
}
let mut guard = self
.state
.lock()
.map_err(|_| anyhow::anyhow!("failed to lock stream display state"))?;
guard.spinner.clear();
guard.status = final_status.to_string();
guard.redraw()?;
println!();
Ok(())
}
}
struct StreamDisplayState {
status: String,
spinner: String,
transcript: String,
width: usize,
rendered_lines: usize,
}
impl StreamDisplayState {
fn new(status: String) -> Self {
Self {
status,
spinner: STREAM_SPINNER_FRAMES[0].to_string(),
transcript: String::new(),
width: std::env::var("COLUMNS")
.ok()
.and_then(|value| value.parse::<usize>().ok())
.filter(|width| *width >= 40)
.unwrap_or(100),
rendered_lines: 0,
}
}
fn redraw(&mut self) -> Result<()> {
let lines = self.render_lines();
let mut stdout = io::stdout().lock();
if self.rendered_lines > 0 {
write!(stdout, "\x1b[{}F\x1b[J", self.rendered_lines)
.context("failed to rewind stream display")?;
}
for line in &lines {
writeln!(stdout, "{line}").context("failed to render stream display")?;
}
stdout.flush().context("failed to flush stream display")?;
self.rendered_lines = lines.len();
Ok(())
}
fn render_lines(&self) -> Vec<String> {
let divider = "=".repeat(72);
let status_line = format!(" processing {} {}", self.spinner, self.status);
let transcript_lines = wrap_text(&self.transcript, self.width.saturating_sub(1).max(40));
let mut lines = vec![
divider.clone(),
status_line,
divider,
String::new(),
String::new(),
];
lines.extend(transcript_lines);
lines
}
}
fn wrap_text(text: &str, width: usize) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
let mut lines = Vec::new();
let mut current = String::new();
for word in text.split_whitespace() {
if current.is_empty() {
current.push_str(word);
} else if current.len() + 1 + word.len() <= width {
current.push(' ');
current.push_str(word);
} else {
lines.push(current);
current = word.to_string();
}
}
if !current.is_empty() {
lines.push(current);
}
lines
}
fn render_stream_chunk(text: &str, needs_leading_space: bool) -> Option<String> {
let text = text.trim();
if text.is_empty() {
return None;
}
let mut rendered = String::new();
if needs_leading_space {
rendered.push(' ');
}
rendered.push_str(text);
Some(rendered)
}
#[cfg(test)]
mod tests {
use super::{is_suspicious_live_text, render_stream_chunk};
#[test]
fn renders_stream_chunks_with_separator_after_first_chunk() {
assert_eq!(
render_stream_chunk("hello world", false),
Some("hello world".to_string())
);
assert_eq!(
render_stream_chunk("again there", true),
Some(" again there".to_string())
);
}
#[test]
fn flags_repetitive_live_hallucinations() {
assert!(is_suspicious_live_text(
"пидарасы пидарасы пидарасы пидарасы"
));
assert!(is_suspicious_live_text(
"аааааааааааа thanks thanks thanks thanks"
));
}
#[test]
fn keeps_regular_live_phrases() {
assert!(!is_suspicious_live_text("привет как дела сегодня"));
assert!(!is_suspicious_live_text("thanks for joining the stream"));
}
}