use serde::{Deserialize, Serialize};
use std::pin::Pin;
use std::str::FromStr;
use tokio_stream::Stream;
use crate::error::DashScopeError;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSocketEventHeader {
pub task_id: String,
pub event: String,
pub attributes: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_message: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WebSocketEventPayload {
#[serde(skip_serializing_if = "Option::is_none")]
pub output: Option<AsrOutput>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<AsrUsage>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum WebSocketEvent {
TaskStarted { header: WebSocketEventHeader },
ResultGenerated {
header: WebSocketEventHeader,
payload: WebSocketEventPayload,
},
TaskFinished {
header: WebSocketEventHeader,
payload: WebSocketEventPayload,
},
TaskFailed { header: WebSocketEventHeader },
}
impl WebSocketEvent {
pub fn event_type(&self) -> EventType {
match self {
WebSocketEvent::TaskStarted { .. } => EventType::TaskStarted,
WebSocketEvent::ResultGenerated { .. } => EventType::ResultGenerated,
WebSocketEvent::TaskFinished { .. } => EventType::TaskFinished,
WebSocketEvent::TaskFailed { .. } => EventType::TaskFailed,
}
}
pub fn is_task_started(&self) -> bool {
matches!(self, WebSocketEvent::TaskStarted { .. })
}
pub fn is_result_generated(&self) -> bool {
matches!(self, WebSocketEvent::ResultGenerated { .. })
}
pub fn is_task_finished(&self) -> bool {
matches!(self, WebSocketEvent::TaskFinished { .. })
}
pub fn is_task_failed(&self) -> bool {
matches!(self, WebSocketEvent::TaskFailed { .. })
}
pub fn task_id(&self) -> &str {
match self {
WebSocketEvent::TaskStarted { header } => &header.task_id,
WebSocketEvent::ResultGenerated { header, .. } => &header.task_id,
WebSocketEvent::TaskFinished { header, .. } => &header.task_id,
WebSocketEvent::TaskFailed { header } => &header.task_id,
}
}
pub fn get_usage(&self) -> Option<&AsrUsage> {
match self {
WebSocketEvent::ResultGenerated { payload, .. } => payload.usage.as_ref(),
WebSocketEvent::TaskFinished { payload, .. } => payload.usage.as_ref(),
_ => None,
}
}
pub fn get_error_info(&self) -> Option<(&str, &str)> {
match self {
WebSocketEvent::TaskFailed { header } => {
if let (Some(code), Some(message)) = (&header.error_code, &header.error_message) {
Some((code, message))
} else {
None
}
}
_ => None,
}
}
}
impl TryFrom<String> for WebSocketEvent {
type Error = DashScopeError;
fn try_from(value: String) -> Result<Self, Self::Error> {
let json_value: serde_json::Value =
serde_json::from_str(&value).map_err(|e| DashScopeError::JSONDeserialize {
source: e,
raw_response: value.clone(),
})?;
let event_type = json_value
.get("header")
.and_then(|h| h.get("event"))
.and_then(|e| e.as_str())
.ok_or_else(|| DashScopeError::UnknownEventType {
event_type: "unknown".to_string(),
})?;
match event_type {
"task-started" => {
let event: WebSocketEventWithHeaderOnly =
serde_json::from_str(&value).map_err(|e| DashScopeError::JSONDeserialize {
source: e,
raw_response: value,
})?;
Ok(WebSocketEvent::TaskStarted {
header: event.header,
})
}
"result-generated" => {
let event: WebSocketEventWithPayload =
serde_json::from_str(&value).map_err(|e| DashScopeError::JSONDeserialize {
source: e,
raw_response: value,
})?;
Ok(WebSocketEvent::ResultGenerated {
header: event.header,
payload: event.payload,
})
}
"task-finished" => {
let event: WebSocketEventWithPayload =
serde_json::from_str(&value).map_err(|e| DashScopeError::JSONDeserialize {
source: e,
raw_response: value,
})?;
Ok(WebSocketEvent::TaskFinished {
header: event.header,
payload: event.payload,
})
}
"task-failed" => {
let event: WebSocketEventWithHeaderOnly =
serde_json::from_str(&value).map_err(|e| DashScopeError::JSONDeserialize {
source: e,
raw_response: value,
})?;
Ok(WebSocketEvent::TaskFailed {
header: event.header,
})
}
_ => Err(DashScopeError::UnknownEventType {
event_type: event_type.to_string(),
}),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct WebSocketEventWithHeaderOnly {
pub header: WebSocketEventHeader,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
struct WebSocketEventWithPayload {
pub header: WebSocketEventHeader,
pub payload: WebSocketEventPayload,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Asrtranscription {
pub sentence_id: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub begin_time: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_time: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lang: Option<String>,
pub words: Vec<AsrWord>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(deserialize_with = "deserialize_bool_from_any")]
pub sentence_end: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AsrTranslation {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AsrOutput {
#[serde(skip_serializing_if = "Option::is_none")]
pub sentence: Option<AsrSentence>,
#[serde(skip_serializing_if = "Option::is_none")]
pub translations: Option<Vec<AsrTranslation>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub transcription: Option<Asrtranscription>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AsrSentence {
#[serde(skip_serializing_if = "Option::is_none")]
pub index: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub begin_time: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub end_time: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
pub words: Vec<AsrWord>,
pub heartbeat: Option<bool>,
#[serde(deserialize_with = "deserialize_bool_from_any", default)]
pub sentence_end: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub emo_tag: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub emo_confidence: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AsrWord {
pub begin_time: u32,
pub end_time: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
pub punctuation: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AsrUsage {
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AutomaticSpeechRecognitionOutput {
pub request_id: String,
pub output: AsrOutput,
pub usage: AsrUsage,
}
pub type AutomaticSpeechRecognitionOutputStream =
Pin<Box<dyn Stream<Item = Result<WebSocketEvent, DashScopeError>> + Send>>;
#[derive(Debug, Clone, PartialEq)]
pub enum EventType {
TaskStarted,
ResultGenerated,
TaskFinished,
TaskFailed,
}
impl EventType {
pub fn as_str(&self) -> &'static str {
match self {
EventType::TaskStarted => "task-started",
EventType::ResultGenerated => "result-generated",
EventType::TaskFinished => "task-finished",
EventType::TaskFailed => "task-failed",
}
}
}
impl FromStr for EventType {
type Err = ();
fn from_str(event: &str) -> Result<Self, Self::Err> {
match event {
"task-started" => Ok(EventType::TaskStarted),
"result-generated" => Ok(EventType::ResultGenerated),
"task-finished" => Ok(EventType::TaskFinished),
"task-failed" => Ok(EventType::TaskFailed),
_ => Err(()),
}
}
}
impl AsrSentence {
pub fn is_intermediate(&self) -> bool {
self.end_time.is_none()
}
pub fn is_final(&self) -> bool {
self.end_time.is_some()
}
}
use serde::de::{self, Deserializer};
fn deserialize_bool_from_any<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
where
D: Deserializer<'de>,
{
let value = serde_json::Value::deserialize(deserializer)?;
match value {
serde_json::Value::Bool(b) => Ok(Some(b)),
serde_json::Value::Number(n) => {
if n.is_u64() || n.is_i64() {
Ok(Some(n.as_i64().unwrap_or(0) != 0))
} else {
Ok(Some(n.as_f64().unwrap_or(0.0) != 0.0))
}
}
serde_json::Value::String(s) => {
if s == "true" || s == "1" || s == "yes" {
Ok(Some(true))
} else if s == "false" || s == "0" || s == "no" {
Ok(Some(false))
} else {
match s.parse::<i64>() {
Ok(n) => Ok(Some(n != 0)),
Err(_) => Err(de::Error::custom(format!(
"unable to parse '{}' as bool",
s
))),
}
}
}
serde_json::Value::Null => Ok(Some(false)),
_ => Err(de::Error::custom(format!(
"unexpected value type for bool: {:?}",
value
))),
}
}