llm_sdk/api/
whisper.rs

1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest::multipart::{Form, Part};
4use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
5use serde::Deserialize;
6use strum::{Display, EnumString};
7
8#[derive(Debug, Clone, Builder)]
9#[builder(pattern = "mutable")]
10pub struct WhisperRequest {
11    /// The audio file object (not file name) to transcribe/translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
12    file: Vec<u8>,
13    /// ID of the model to use. Only whisper-1 is currently available.
14    #[builder(default)]
15    model: WhisperModel,
16    /// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. Should not use this for translation
17    #[builder(default, setter(strip_option, into))]
18    language: Option<String>,
19    /// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language for transcription, and should be English only for translation.
20    #[builder(default, setter(strip_option, into))]
21    prompt: Option<String>,
22    /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
23    #[builder(default)]
24    pub(crate) response_format: WhisperResponseFormat,
25    /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
26    #[builder(default, setter(strip_option))]
27    temperature: Option<f32>,
28
29    request_type: WhisperRequestType,
30}
31
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
33pub enum WhisperModel {
34    #[default]
35    #[strum(serialize = "whisper-1")]
36    Whisper1,
37}
38
39#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
40#[strum(serialize_all = "snake_case")]
41pub enum WhisperResponseFormat {
42    #[default]
43    Json,
44    Text,
45    Srt,
46    VerboseJson,
47    Vtt,
48}
49
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
51pub enum WhisperRequestType {
52    #[default]
53    Transcription,
54    Translation,
55}
56
57#[derive(Debug, Clone, Deserialize)]
58pub struct WhisperResponse {
59    pub text: String,
60}
61
62impl WhisperRequest {
63    pub fn transcription(data: Vec<u8>) -> Self {
64        WhisperRequestBuilder::default()
65            .file(data)
66            .request_type(WhisperRequestType::Transcription)
67            .build()
68            .unwrap()
69    }
70
71    pub fn translation(data: Vec<u8>) -> Self {
72        WhisperRequestBuilder::default()
73            .file(data)
74            .request_type(WhisperRequestType::Translation)
75            .build()
76            .unwrap()
77    }
78
79    fn into_form(self) -> Form {
80        let part = Part::bytes(self.file)
81            .file_name("file")
82            .mime_str("audio/mp3")
83            .unwrap();
84        let mut form = Form::new()
85            .part("file", part)
86            .text("model", self.model.to_string())
87            .text("response_format", self.response_format.to_string());
88
89        // translation doesn't need language
90        form = match (self.request_type, self.language) {
91            (WhisperRequestType::Transcription, Some(language)) => form.text("language", language),
92            _ => form,
93        };
94        form = if let Some(prompt) = self.prompt {
95            form.text("prompt", prompt)
96        } else {
97            form
98        };
99        if let Some(temperature) = self.temperature {
100            form.text("temperature", temperature.to_string())
101        } else {
102            form
103        }
104    }
105}
106
107impl IntoRequest for WhisperRequest {
108    fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
109        let url = match self.request_type {
110            WhisperRequestType::Transcription => format!("{}/audio/transcriptions", base_url),
111            WhisperRequestType::Translation => format!("{}/audio/translations", base_url),
112        };
113        client.post(url).multipart(self.into_form())
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120    use crate::SDK;
121    use anyhow::Result;
122    use std::fs;
123
124    #[tokio::test]
125    async fn transcription_should_work() -> Result<()> {
126        let data = fs::read("fixtures/speech.mp3")?;
127        let req = WhisperRequest::transcription(data);
128        let res = SDK.whisper(req).await?;
129        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.");
130        Ok(())
131    }
132
133    #[tokio::test]
134    async fn transcription_with_response_format_should_work() -> Result<()> {
135        let data = fs::read("fixtures/speech.mp3")?;
136        let req = WhisperRequestBuilder::default()
137            .file(data)
138            .response_format(WhisperResponseFormat::Text)
139            .request_type(WhisperRequestType::Transcription)
140            .build()?;
141        let res = SDK.whisper(req).await?;
142        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.\n");
143        Ok(())
144    }
145
146    #[tokio::test]
147    async fn transcription_with_vtt_response_format_should_work() -> Result<()> {
148        let data = fs::read("fixtures/speech.mp3")?;
149        let req = WhisperRequestBuilder::default()
150            .file(data)
151            .response_format(WhisperResponseFormat::Vtt)
152            .request_type(WhisperRequestType::Transcription)
153            .build()?;
154        let res = SDK.whisper(req).await?;
155        assert_eq!(res.text, "WEBVTT\n\n00:00:00.000 --> 00:00:02.800\nThe quick brown fox jumped over the lazy dog.\n\n");
156        Ok(())
157    }
158
159    #[tokio::test]
160    async fn translate_should_work() -> Result<()> {
161        let data = fs::read("fixtures/chinese.mp3")?;
162        let req = WhisperRequestBuilder::default()
163            .file(data)
164            .response_format(WhisperResponseFormat::Srt)
165            .request_type(WhisperRequestType::Translation)
166            .build()?;
167        let res = SDK.whisper(req).await?;
168        assert_eq!(res.text, "1\n00:00:00,000 --> 00:00:03,000\nThe red scarf hangs on the chest, the motherland is always in my heart.\n\n\n");
169        Ok(())
170    }
171}