#![warn(clippy::all, clippy::pedantic, clippy::nursery)]
use clap::Parser;
use hyper::service::{make_service_fn, service_fn};
use hyper::{header::CONTENT_TYPE, Body, Request, Response, Server, StatusCode};
use multer::Multipart;
use serde::Serialize;
use serde_json::to_string;
use std::fs;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::{convert::Infallible, net::SocketAddr};
use whisperd::{Language, Model, Size, Whisper};
use crate::utils::write_to;
mod utils;
#[derive(Serialize)]
struct TranscriptionResponse {
text: String,
}
#[derive(Parser)]
struct Opts {
#[clap(subcommand)]
subcmd: SubCommand,
}
#[derive(Parser)]
enum SubCommand {
#[command(about = "Start the transcription server.")]
Serve {
#[clap(short, long, default_value = "8000")]
port: u16,
#[clap(short, long)]
model_path: String,
},
#[command(about = "Transcribe a given audio file.")]
Transcribe(TranscribeArgs),
}
#[derive(Parser)]
struct TranscribeArgs {
#[clap(short, long, default_value = "medium")]
model: Size,
#[clap(short, long)]
lang: Option<Language>,
#[clap(name = "AUDIO")]
audio: String,
#[clap(short, long, default_value = "false")]
translate: bool,
#[clap(short, long, default_value = "false")]
karaoke: bool,
#[clap(short, long, default_value = "false")]
write: bool,
}
#[tokio::main]
async fn main() {
let opts = Opts::parse();
match opts.subcmd {
SubCommand::Serve { port, model_path } => {
let model_path = Path::new(&model_path);
start_server(port, &model_path).await;
}
SubCommand::Transcribe(args) => transcribe_audio(args).await,
}
}
async fn start_server(port: u16, model_path: &Path) {
let whisper = Arc::new(Mutex::new(
Whisper::from_model_path(model_path, Some(Language::Auto)).await,
));
let make_svc = make_service_fn(move |_conn| {
let whisper_clone = whisper.clone();
async move {
Ok::<_, Infallible>(service_fn(move |req| {
handle_transcription(req, whisper_clone.clone())
}))
}
});
let addr = SocketAddr::from(([127, 0, 0, 1], port));
let server = Server::bind(&addr).serve(make_svc);
println!("🏃♀️ Server running at: {}", addr);
if let Err(e) = server.await {
eprintln!("server error: {}", e);
}
}
async fn handle_transcription(
req: Request<Body>,
whisper: Arc<Mutex<Whisper>>,
) -> Result<Response<Body>, Infallible> {
if req.method() == hyper::Method::OPTIONS && req.uri().path() == "/v1/audio/transcriptions" {
let res = Response::builder()
.status(StatusCode::OK)
.header("Access-Control-Allow-Origin", "*")
.header("Access-Control-Allow-Methods", "POST, OPTIONS")
.header("Access-Control-Allow-Headers", "Content-Type")
.body(Body::empty())
.unwrap();
return Ok(res);
}
let boundary = req
.headers()
.get(CONTENT_TYPE)
.and_then(|ct| ct.to_str().ok())
.and_then(|ct| multer::parse_boundary(ct).ok());
if boundary.is_none() {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Access-Control-Allow-Origin", "*") .body(Body::from("BAD REQUEST"))
.unwrap());
}
let transcription_request = process_multipart(req.into_body(), boundary.unwrap()).await;
if let Err(err) = transcription_request {
return Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Access-Control-Allow-Origin", "*") .body(Body::from(format!("INTERNAL SERVER ERROR: {}", err)))
.unwrap());
}
if let Ok(trans_req) = transcription_request {
let audio = Path::new(trans_req.as_str());
let transcript = {
let mut whisper_guard = whisper.lock().unwrap();
whisper_guard.transcribe(audio, false, false).unwrap()
};
println!("time: {:?}", transcript.processing_time);
let transcript_text = transcript.as_text();
let response: TranscriptionResponse = TranscriptionResponse {
text: transcript_text,
};
let json_response = to_string(&response).expect("Failed to serialize to JSON");
let response = Response::builder()
.header("Access-Control-Allow-Origin", "*") .body(Body::from(json_response))
.unwrap();
return Ok(response);
}
Ok(Response::new(Body::from("Success")))
}
async fn process_multipart(body: Body, boundary: String) -> multer::Result<String> {
let mut multipart = Multipart::new(body, boundary);
let mut file_path = String::new();
while let Some(mut field) = multipart.next_field().await? {
if field.name() == Some("file") {
let name = field.name();
let file_name = field.file_name();
let content_type = field.content_type();
println!(
"Name: {:?}, FileName: {:?}, Content-Type: {:?}",
name, file_name, content_type
);
let mut bytes_len = 0;
let mut audio_data = Vec::new();
while let Some(field_chunk) = field.chunk().await? {
audio_data.extend_from_slice(&field_chunk);
bytes_len += field_chunk.len();
}
println!("Bytes Length: {:?}", bytes_len);
let file_name_str: &str = field.file_name().as_ref().unwrap_or(&"audio.wav");
file_path = format!("/tmp/{}", file_name_str); fs::write(&file_path, audio_data).expect("Failed to write to file");
println!("Write the file to {}", file_path);
}
}
Ok(file_path)
}
async fn transcribe_audio(mut args: TranscribeArgs) {
let audio = Path::new(&args.audio);
let file_name = audio.file_name().unwrap().to_str().unwrap();
assert!(audio.exists(), "The provided audio file does not exist.");
if args.model.is_english_only() && (args.lang == Some(Language::Auto) || args.lang.is_none()) {
args.lang = Some(Language::English);
}
assert!(
!args.model.is_english_only() || args.lang == Some(Language::English),
"The selected model only supports English."
);
let mut whisper = Whisper::new(Model::new(args.model), args.lang).await;
let transcript = whisper
.transcribe(audio, args.translate, args.karaoke)
.unwrap();
println!("time: {:?}", transcript.processing_time);
if args.write {
write_to(
audio.with_file_name(format!("{file_name}.txt")),
&transcript.as_text(),
);
write_to(
audio.with_file_name(format!("{file_name}.vtt")),
&transcript.as_vtt(),
);
write_to(
audio.with_file_name(format!("{file_name}.srt")),
&transcript.as_srt(),
);
} else {
println!("");
println!("🔊 {}", transcript.as_text());
}
}