1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
use crate::IntoRequest;
use derive_builder::Builder;
use reqwest::{
    multipart::{Form, Part},
    Client, RequestBuilder,
};
use serde::Deserialize;
use strum::{Display, EnumString};

#[derive(Debug, Clone, Builder)]
#[builder(pattern = "mutable")]
pub struct WhisperRequest {
    /// The audio file object (not file name) to transcribe/translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.
    file: Vec<u8>,
    /// ID of the model to use. Only whisper-1 is currently available.
    #[builder(default)]
    model: WhisperModel,
    /// The language of the input audio. Supplying the input language in ISO-639-1 format will improve accuracy and latency. Should not use this for translation
    #[builder(default, setter(strip_option, into))]
    language: Option<String>,
    /// An optional text to guide the model's style or continue a previous audio segment. The prompt should match the audio language for transcription, and should be English only for translation.
    #[builder(default, setter(strip_option, into))]
    prompt: Option<String>,
    /// The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
    #[builder(default)]
    pub(crate) response_format: WhisperResponseFormat,
    /// The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit.
    #[builder(default, setter(strip_option))]
    temperature: Option<f32>,

    request_type: WhisperRequestType,
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
pub enum WhisperModel {
    #[default]
    #[strum(serialize = "whisper-1")]
    Whisper1,
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
#[strum(serialize_all = "snake_case")]
pub enum WhisperResponseFormat {
    #[default]
    Json,
    Text,
    Srt,
    VerboseJson,
    Vtt,
}

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
pub enum WhisperRequestType {
    #[default]
    Transcription,
    Translation,
}

#[derive(Debug, Clone, Deserialize)]
pub struct WhisperResponse {
    pub text: String,
}

impl WhisperRequest {
    pub fn transcription(data: Vec<u8>) -> Self {
        WhisperRequestBuilder::default()
            .file(data)
            .request_type(WhisperRequestType::Transcription)
            .build()
            .unwrap()
    }

    pub fn translation(data: Vec<u8>) -> Self {
        WhisperRequestBuilder::default()
            .file(data)
            .request_type(WhisperRequestType::Translation)
            .build()
            .unwrap()
    }

    fn into_form(self) -> Form {
        let part = Part::bytes(self.file)
            .file_name("file")
            .mime_str("audio/mp3")
            .unwrap();
        let mut form = Form::new()
            .part("file", part)
            .text("model", self.model.to_string())
            .text("response_format", self.response_format.to_string());

        // translation doesn't need language
        form = match (self.request_type, self.language) {
            (WhisperRequestType::Transcription, Some(language)) => form.text("language", language),
            _ => form,
        };
        form = if let Some(prompt) = self.prompt {
            form.text("prompt", prompt)
        } else {
            form
        };
        if let Some(temperature) = self.temperature {
            form.text("temperature", temperature.to_string())
        } else {
            form
        }
    }
}

impl IntoRequest for WhisperRequest {
    fn into_request(self, base_url: &str, client: Client) -> RequestBuilder {
        let url = match self.request_type {
            WhisperRequestType::Transcription => format!("{}/audio/transcriptions", base_url),
            WhisperRequestType::Translation => format!("{}/audio/translations", base_url),
        };
        client.post(url).multipart(self.into_form())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::SDK;
    use anyhow::Result;
    use std::fs;

    #[tokio::test]
    async fn transcription_should_work() -> Result<()> {
        let data = fs::read("fixtures/speech.mp3")?;
        let req = WhisperRequest::transcription(data);
        let res = SDK.whisper(req).await?;
        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.");
        Ok(())
    }

    #[tokio::test]
    async fn transcription_with_response_format_should_work() -> Result<()> {
        let data = fs::read("fixtures/speech.mp3")?;
        let req = WhisperRequestBuilder::default()
            .file(data)
            .response_format(WhisperResponseFormat::Text)
            .request_type(WhisperRequestType::Transcription)
            .build()?;
        let res = SDK.whisper(req).await?;
        assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.\n");
        Ok(())
    }

    #[tokio::test]
    async fn transcription_with_vtt_response_format_should_work() -> Result<()> {
        let data = fs::read("fixtures/speech.mp3")?;
        let req = WhisperRequestBuilder::default()
            .file(data)
            .response_format(WhisperResponseFormat::Vtt)
            .request_type(WhisperRequestType::Transcription)
            .build()?;
        let res = SDK.whisper(req).await?;
        assert_eq!(res.text, "WEBVTT\n\n00:00:00.000 --> 00:00:02.800\nThe quick brown fox jumped over the lazy dog.\n\n");
        Ok(())
    }

    #[tokio::test]
    async fn translate_should_work() -> Result<()> {
        let data = fs::read("fixtures/chinese.mp3")?;
        let req = WhisperRequestBuilder::default()
            .file(data)
            .response_format(WhisperResponseFormat::Srt)
            .request_type(WhisperRequestType::Translation)
            .build()?;
        let res = SDK.whisper(req).await?;
        assert_eq!(res.text, "1\n00:00:00,000 --> 00:00:03,000\nThe red scarf hangs on the chest, the motherland is always in my heart.\n\n\n");
        Ok(())
    }
}