use std::fmt::{Display, Formatter};
use std::str::FromStr;
use subtp::srt::SubRip;
use subtp::vtt::WebVtt;
use crate::audio::TextFormatError;
use crate::audio::TextFormatResult;
use crate::macros::{
impl_display_for_serialize, impl_enum_string_serialization,
};
pub trait TextResponseFormat {
fn format() -> &'static str;
}
pub trait TextResponseFormatter<T> {
fn format(raw_text: String) -> TextFormatResult<T>;
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct JsonResponse {
pub text: String,
}
impl TextResponseFormat for JsonResponse {
fn format() -> &'static str {
"json"
}
}
impl_display_for_serialize!(JsonResponse);
pub struct JsonResponseFormatter {}
impl TextResponseFormatter<JsonResponse> for JsonResponseFormatter {
fn format(raw_text: String) -> TextFormatResult<JsonResponse> {
serde_json::from_str(raw_text.as_str()).map_err(|error| {
TextFormatError::FormatJsonFailed {
error,
text: raw_text,
}
})
}
}
impl TextResponseFormat for String {
fn format() -> &'static str {
"text"
}
}
pub struct PlainTextResponseFormatter;
impl TextResponseFormatter<String> for PlainTextResponseFormatter {
fn format(raw_text: String) -> TextFormatResult<String> {
Ok(raw_text)
}
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct VerboseJsonResponse {
pub task: String,
pub language: String,
pub duration: f32,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<VerboseJsonResponseSegment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<VerboseJsonResponseWord>>,
}
impl_display_for_serialize!(VerboseJsonResponse);
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct VerboseJsonResponseSegment {
pub id: u32,
pub seek: u32,
pub start: f32,
pub end: f32,
pub text: String,
pub tokens: Vec<u32>,
pub temperature: f32,
pub avg_logprob: f32,
pub compression_ratio: f32,
pub no_speech_prob: f32,
}
impl_display_for_serialize!(VerboseJsonResponseSegment);
#[derive(Debug, Clone, PartialEq, serde::Deserialize, serde::Serialize)]
pub struct VerboseJsonResponseWord {
pub word: String,
pub start: f32,
pub end: f32,
}
impl_display_for_serialize!(VerboseJsonResponseWord);
impl TextResponseFormat for VerboseJsonResponse {
fn format() -> &'static str {
"verbose_json"
}
}
pub struct VerboseJsonResponseFormatter;
impl TextResponseFormatter<VerboseJsonResponse>
for VerboseJsonResponseFormatter
{
fn format(raw_text: String) -> TextFormatResult<VerboseJsonResponse> {
serde_json::from_str(raw_text.as_str()).map_err(|error| {
TextFormatError::FormatJsonFailed {
error,
text: raw_text,
}
})
}
}
impl TextResponseFormat for SubRip {
fn format() -> &'static str {
"srt"
}
}
pub struct SrtResponseFormatter;
impl TextResponseFormatter<SubRip> for SrtResponseFormatter {
fn format(raw_text: String) -> TextFormatResult<SubRip> {
SubRip::parse(raw_text.as_str()).map_err(|error| {
TextFormatError::ParseSrtFailed {
error,
text: raw_text,
}
})
}
}
impl TextResponseFormat for WebVtt {
fn format() -> &'static str {
"vtt"
}
}
pub struct VttResponseFormatter;
impl TextResponseFormatter<WebVtt> for VttResponseFormatter {
fn format(raw_text: String) -> TextFormatResult<WebVtt> {
WebVtt::parse(raw_text.as_str()).map_err(|error| {
TextFormatError::ParseVttFailed {
error,
text: raw_text,
}
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SpeechResponseFormat {
Mp3,
Opus,
Aac,
Flac,
}
impl Default for SpeechResponseFormat {
fn default() -> Self {
Self::Mp3
}
}
impl Display for SpeechResponseFormat {
fn fmt(
&self,
f: &mut Formatter<'_>,
) -> std::fmt::Result {
match self {
| Self::Mp3 => {
write!(f, "mp3")
},
| Self::Opus => {
write!(f, "opus")
},
| Self::Aac => {
write!(f, "aac")
},
| Self::Flac => {
write!(f, "flac")
},
}
}
}
impl FromStr for SpeechResponseFormat {
type Err = crate::ValidationError<String>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
| "mp3" => Ok(Self::Mp3),
| "opus" => Ok(Self::Opus),
| "aac" => Ok(Self::Aac),
| "flac" => Ok(Self::Flac),
| _ => Err(crate::ValidationError {
type_name: "SpeechResponseFormat".to_string(),
reason: "Unknown speech response format".to_string(),
value: s.to_string(),
}),
}
}
}
impl_enum_string_serialization!(
SpeechResponseFormat,
Mp3 => "mp3",
Opus => "opus",
Aac => "aac",
Flac => "flac"
);