1use crate::IntoRequest;
2use derive_builder::Builder;
3use reqwest::multipart::{Form, Part};
4use reqwest_middleware::{ClientWithMiddleware, RequestBuilder};
5use serde::Deserialize;
6use strum::{Display, EnumString};
7
8#[derive(Debug, Clone, Builder)]
9#[builder(pattern = "mutable")]
10pub struct WhisperRequest {
11 file: Vec<u8>,
13 #[builder(default)]
15 model: WhisperModel,
16 #[builder(default, setter(strip_option, into))]
18 language: Option<String>,
19 #[builder(default, setter(strip_option, into))]
21 prompt: Option<String>,
22 #[builder(default)]
24 pub(crate) response_format: WhisperResponseFormat,
25 #[builder(default, setter(strip_option))]
27 temperature: Option<f32>,
28
29 request_type: WhisperRequestType,
30}
31
32#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
33pub enum WhisperModel {
34 #[default]
35 #[strum(serialize = "whisper-1")]
36 Whisper1,
37}
38
39#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
40#[strum(serialize_all = "snake_case")]
41pub enum WhisperResponseFormat {
42 #[default]
43 Json,
44 Text,
45 Srt,
46 VerboseJson,
47 Vtt,
48}
49
50#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, EnumString, Display)]
51pub enum WhisperRequestType {
52 #[default]
53 Transcription,
54 Translation,
55}
56
57#[derive(Debug, Clone, Deserialize)]
58pub struct WhisperResponse {
59 pub text: String,
60}
61
62impl WhisperRequest {
63 pub fn transcription(data: Vec<u8>) -> Self {
64 WhisperRequestBuilder::default()
65 .file(data)
66 .request_type(WhisperRequestType::Transcription)
67 .build()
68 .unwrap()
69 }
70
71 pub fn translation(data: Vec<u8>) -> Self {
72 WhisperRequestBuilder::default()
73 .file(data)
74 .request_type(WhisperRequestType::Translation)
75 .build()
76 .unwrap()
77 }
78
79 fn into_form(self) -> Form {
80 let part = Part::bytes(self.file)
81 .file_name("file")
82 .mime_str("audio/mp3")
83 .unwrap();
84 let mut form = Form::new()
85 .part("file", part)
86 .text("model", self.model.to_string())
87 .text("response_format", self.response_format.to_string());
88
89 form = match (self.request_type, self.language) {
91 (WhisperRequestType::Transcription, Some(language)) => form.text("language", language),
92 _ => form,
93 };
94 form = if let Some(prompt) = self.prompt {
95 form.text("prompt", prompt)
96 } else {
97 form
98 };
99 if let Some(temperature) = self.temperature {
100 form.text("temperature", temperature.to_string())
101 } else {
102 form
103 }
104 }
105}
106
107impl IntoRequest for WhisperRequest {
108 fn into_request(self, base_url: &str, client: ClientWithMiddleware) -> RequestBuilder {
109 let url = match self.request_type {
110 WhisperRequestType::Transcription => format!("{}/audio/transcriptions", base_url),
111 WhisperRequestType::Translation => format!("{}/audio/translations", base_url),
112 };
113 client.post(url).multipart(self.into_form())
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120 use crate::SDK;
121 use anyhow::Result;
122 use std::fs;
123
124 #[tokio::test]
125 async fn transcription_should_work() -> Result<()> {
126 let data = fs::read("fixtures/speech.mp3")?;
127 let req = WhisperRequest::transcription(data);
128 let res = SDK.whisper(req).await?;
129 assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.");
130 Ok(())
131 }
132
133 #[tokio::test]
134 async fn transcription_with_response_format_should_work() -> Result<()> {
135 let data = fs::read("fixtures/speech.mp3")?;
136 let req = WhisperRequestBuilder::default()
137 .file(data)
138 .response_format(WhisperResponseFormat::Text)
139 .request_type(WhisperRequestType::Transcription)
140 .build()?;
141 let res = SDK.whisper(req).await?;
142 assert_eq!(res.text, "The quick brown fox jumped over the lazy dog.\n");
143 Ok(())
144 }
145
146 #[tokio::test]
147 async fn transcription_with_vtt_response_format_should_work() -> Result<()> {
148 let data = fs::read("fixtures/speech.mp3")?;
149 let req = WhisperRequestBuilder::default()
150 .file(data)
151 .response_format(WhisperResponseFormat::Vtt)
152 .request_type(WhisperRequestType::Transcription)
153 .build()?;
154 let res = SDK.whisper(req).await?;
155 assert_eq!(res.text, "WEBVTT\n\n00:00:00.000 --> 00:00:02.800\nThe quick brown fox jumped over the lazy dog.\n\n");
156 Ok(())
157 }
158
159 #[tokio::test]
160 async fn translate_should_work() -> Result<()> {
161 let data = fs::read("fixtures/chinese.mp3")?;
162 let req = WhisperRequestBuilder::default()
163 .file(data)
164 .response_format(WhisperResponseFormat::Srt)
165 .request_type(WhisperRequestType::Translation)
166 .build()?;
167 let res = SDK.whisper(req).await?;
168 assert_eq!(res.text, "1\n00:00:00,000 --> 00:00:03,000\nThe red scarf hangs on the chest, the motherland is always in my heart.\n\n\n");
169 Ok(())
170 }
171}