use core::panic;
use futures_util::stream::StreamExt;
use kalosm_sound::Whisper;
use rodio::{Decoder, source::Source};
use std::{fs::File, io::BufReader, path::Path, time::Duration};
use crate::{MediaType, TranscriptionData, TranscriptionProgress};
pub fn format_duration(duration: Duration) -> String {
let seconds = duration.as_secs() % 60;
let minutes = (duration.as_secs() / 60) % 60;
let hours = (duration.as_secs() / 60) / 60;
if hours >= 1 {
format!("{:0>2}:{:0>2}:{:0>2}", hours, minutes, seconds)
} else {
format!("{:0>2}:{:0>2}", minutes, seconds)
}
}
pub fn get_media_type(file_path: &str) -> MediaType {
match Path::new(&file_path)
.extension()
.and_then(|ext| ext.to_str())
{
Some(ext) => match ext.to_lowercase().as_str() {
"mp4" | "avi" | "mov" | "mkv" => MediaType::Video,
"mp3" | "wav" | "m4a" | "flac" => MediaType::Audio,
_ => MediaType::Error,
},
None => MediaType::Error,
}
}
pub fn get_total_time(media_type: MediaType, file_path: &str) -> Duration {
match media_type {
MediaType::Audio => {
let file = BufReader::new(File::open(file_path).unwrap());
let mut duration: Duration = match Path::new(&file_path)
.extension()
.and_then(|ext| ext.to_str())
{
Some(ext) => match ext.to_lowercase().as_str() {
"mp3" => mp3_duration::from_path(file_path).unwrap_or(Duration::ZERO),
_ => {
let source = Decoder::new(file).unwrap();
Source::total_duration(&source).unwrap_or(Duration::ZERO)
}
},
None => Duration::ZERO,
};
if duration != Duration::ZERO {
duration += Duration::from_secs(1);
}
duration
}
MediaType::Video => todo!(),
MediaType::Error => panic!("Can not get time because of unsupported format"),
}
}
pub async fn transcribe_audio(
file_path: &str,
is_timestamped: bool,
progress_sender: Option<tokio::sync::mpsc::UnboundedSender<TranscriptionProgress>>,
) -> Vec<TranscriptionData> {
let model = Whisper::new().await.unwrap();
let file = BufReader::new(File::open(file_path).unwrap());
let audio = Decoder::new(file).unwrap();
let mut text_stream;
let mut transcript: Vec<TranscriptionData> = vec![];
text_stream = model.transcribe(audio).timestamped();
let mut segment_counter = 0.0;
while let Some(segment) = text_stream.next().await {
for chunk in segment.chunks() {
if let Some(time_range) = chunk.timestamp() {
let true_start = time_range.start + (30.0 * segment_counter);
let true_end = time_range.end + (30.0 * segment_counter);
let transcription_data = TranscriptionData {
text: {
if is_timestamped {
format!(
"{}-{}: {}\n",
format_duration(Duration::from_secs_f32(true_start)),
format_duration(Duration::from_secs_f32(true_end)),
chunk
)
} else {
format!("{}", chunk)
}
},
time: Duration::from_secs_f32(true_start),
};
if let Some(ref progress) = progress_sender {
let _ = progress.send(TranscriptionProgress::InProgress(
transcription_data.clone(),
));
}
transcript.push(transcription_data);
}
if let Some(ref progress) = progress_sender {
let _ = progress.send(TranscriptionProgress::Reading);
}
}
segment_counter += 1.0;
}
if let Some(progress) = progress_sender {
let _ = progress.send(TranscriptionProgress::Finished);
}
transcript
}