openllm/api/
whisper.rs

1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest::multipart::{Form, Part};
4use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
5use serde::Deserialize;
6use strum::{Display, EnumString};
7
8#[derive(Debug, Clone, Builder)]
9#[builder(pattern = "mutable")]
10pub struct WhisperRequest {
11    /// The audio file object (not file name) to transcribe/translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
12    file: Vec<u8>,
13    /// ID of the model to use. Only whisper-1 is currently available.
14    #[builder(default)]
15    model: WhisperModel,
16    /// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. Should not use this for translation
17    #[builder(default, setter(strip_option, into))]
18    language: Option<String>,
19    /// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language for transcription, and should be English only for translation.
20    #[builder(default, setter(strip_option, into))]
21    prompt: Option<String>,
22    /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
23    #[builder(default)]
24    pub(crate) response_format: WhisperResponseFormat,
25    /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
26    #[builder(default, setter(strip_option))]
27    temperature: Option<f32>,
28
29    request_type: WhisperRequestType,
30}
31
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
33pub enum WhisperModel {
34    #[default]
35    #[strum(serialize = "whisper-1")]
36    Whisper1,
37}
38
39#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
40#[strum(serialize_all = "snake_case")]
41pub enum WhisperResponseFormat {
42    #[default]
43    Json,
44    Text,
45    Srt,
46    VerboseJson,
47    Vtt,
48}
49
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
51pub enum WhisperRequestType {
52    #[default]
53    Transcription,
54    Translation,
55}
56
57#[derive(Debug, Clone, Deserialize)]
58pub struct WhisperResponse {
59    pub text: String,
60}
61
62impl WhisperRequest {
63    pub fn transcription(data: Vec<u8>) -> Self {
64        WhisperRequestBuilder::default()
65            .file(data)
66            .request_type(WhisperRequestType::Transcription)
67            .build()
68            .unwrap()
69    }
70
71    pub fn translation(data: Vec<u8>) -> Self {
72        WhisperRequestBuilder::default()
73            .file(data)
74            .request_type(WhisperRequestType::Translation)
75            .build()
76            .unwrap()
77    }
78
79    fn into_form(self) -> Form {
80        let part = Part::bytes(self.file)
81            .file_name("file")
82            .mime_str("audio/mp3")
83            .unwrap();
84        let mut form = Form::new()
85            .part("file", part)
86            .text("model", self.model.to_string())
87            .text("response_format", self.response_format.to_string());
88
89        // translation doesn't need language
90        form = match (self.request_type, self.language) {
91            (WhisperRequestType::Transcription, Some(language)) => form.text("language", language),
92            _ => form,
93        };
94        form = if let Some(prompt) = self.prompt {
95            form.text("prompt", prompt)
96        } else {
97            form
98        };
99        if let Some(temperature) = self.temperature {
100            form.text("temperature", temperature.to_string())
101        } else {
102            form
103        }
104    }
105}
106
107impl IntoRequest for WhisperRequest {
108    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
109        let url = match self.request_type {
110            WhisperRequestType::Transcription => format!("{}/audio/transcriptions", base_url),
111            WhisperRequestType::Translation => format!("{}/audio/translations", base_url),
112        };
113        client.post(url).multipart(self.into_form())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::SDK;
121    use anyhow::Result;
122    use std::fs;
123
124    #[tokio::test]
125    #[ignore]
126    async fn transcription_should_work() -> Result<()> {
127        let data = fs::read("fixtures/speech.mp3")?;
128        let req = WhisperRequest::transcription(data);
129        let res = SDK.whisper(req).await?;
130        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.");
131        Ok(())
132    }
133
134    #[tokio::test]
135    #[ignore]
136    async fn transcription_with_response_format_should_work() -> Result<()> {
137        let data = fs::read("fixtures/speech.mp3")?;
138        let req = WhisperRequestBuilder::default()
139            .file(data)
140            .response_format(WhisperResponseFormat::Text)
141            .request_type(WhisperRequestType::Transcription)
142            .build()?;
143        let res = SDK.whisper(req).await?;
144        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.\n");
145        Ok(())
146    }
147
148    #[tokio::test]
149    #[ignore]
150    async fn transcription_with_vtt_response_format_should_work() -> Result<()> {
151        let data = fs::read("fixtures/speech.mp3")?;
152        let req = WhisperRequestBuilder::default()
153            .file(data)
154            .response_format(WhisperResponseFormat::Vtt)
155            .request_type(WhisperRequestType::Transcription)
156            .build()?;
157        let res = SDK.whisper(req).await?;
158        assert_eq!(
159            res.text,
160            "WEBVTT\n\n00:00:00.000 --> 00:00:02.800\nThe quick brown fox jumped over the lazy dog.\n\n"
161        );
162        Ok(())
163    }
164
165    #[tokio::test]
166    #[ignore]
167    async fn translate_should_work() -> Result<()> {
168        let data = fs::read("fixtures/chinese.mp3")?;
169        let req = WhisperRequestBuilder::default()
170            .file(data)
171            .response_format(WhisperResponseFormat::Srt)
172            .request_type(WhisperRequestType::Translation)
173            .build()?;
174        let res = SDK.whisper(req).await?;
175        assert_eq!(
176            res.text,
177            "1\n00:00:00,000 --> 00:00:03,000\nThe red scarf hangs on the chest, the motherland is always in my heart.\n\n\n"
178        );
179        Ok(())
180    }
181}