use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap};
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct RunTaskParameters {
header: TaskHeader,
payload: RunTaskPayload,
}
impl TryFrom<RunTaskParameters> for String {
type Error = crate::error::DashScopeError;
fn try_from(value: RunTaskParameters) -> Result<Self, Self::Error> {
serde_json::to_string(&value)
.map_err(|e| crate::error::DashScopeError::SerializationError(e.to_string()))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum TaskAction {
#[serde(rename = "run-task")]
RunTask,
#[serde(rename = "finish-task")]
FinishTask,
#[serde(rename = "continue-task")]
ContinueTask,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct TaskHeader {
action: TaskAction,
#[builder(default = "uuid::Uuid::new_v4().to_string()")]
task_id: String,
#[serde(default = "default_streaming")]
#[builder(default = "default_streaming()")]
streaming: String,
}
impl TaskHeader {
pub fn run_task() -> Self {
Self {
action: TaskAction::RunTask,
task_id: uuid::Uuid::new_v4().to_string(),
streaming: default_streaming(),
}
}
pub fn finish_task() -> Self {
Self {
action: TaskAction::FinishTask,
task_id: uuid::Uuid::new_v4().to_string(),
streaming: default_streaming(),
}
}
pub fn continue_task() -> Self {
Self {
action: TaskAction::ContinueTask,
task_id: uuid::Uuid::new_v4().to_string(),
streaming: default_streaming(),
}
}
}
fn default_streaming() -> String {
"duplex".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RunTaskType {
#[serde(rename = "asr")]
Asr,
#[serde(rename = "tts")]
Tts,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum RunTaskFunction {
#[serde(rename = "recognition")]
Recognition,
#[serde(rename = "SpeechSynthesizer")]
SpeechSynthesizer,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct RunTaskPayload {
task_group: String,
task: RunTaskType,
function: RunTaskFunction,
#[builder(setter(into))]
model: String,
#[builder(default)]
input: HashMap<String, serde_json::Value>,
#[builder(default)]
parameters: TaskParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub enum TextType {
#[default]
#[serde(rename = "PlainText")]
PlainText,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq, Default)]
pub struct TaskParameters {
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
text_type: Option<TextType>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
voice: Option<String>,
format: String,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
sample_rate: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(default = "default_volume")]
#[serde(skip_serializing_if = "Option::is_none")]
volume: Option<u8>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
rate: Option<f32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
pitch: Option<f32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
enable_ssml: Option<bool>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
bit_rate: Option<u32>,
#[builder(setter(strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
word_timestamp_enabled: Option<bool>,
#[builder(setter(into, strip_option), default)]
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u32>,
#[builder(setter(strip_option), default)]
language_hints: Option<Vec<String>>,
#[builder(setter(into, strip_option), default)]
instruction: Option<String>,
#[builder(setter(strip_option), default)]
enable_aigc_tag: Option<bool>,
#[builder(setter(strip_option), default)]
aigc_propagator: Option<String>,
#[builder(setter(strip_option), default)]
aigc_propagate_id: Option<String>,
#[builder(setter(into, strip_option), default)]
vocabulary_id: Option<String>,
#[builder(setter(strip_option), default)]
semantic_punctuation_enabled: Option<bool>,
#[builder(setter(strip_option), default)]
max_sentence_silence: Option<u32>,
#[builder(setter(strip_option), default)]
multi_threshold_mode_enabled: Option<bool>,
#[builder(setter(strip_option), default)]
heartbeat: Option<bool>,
}
pub fn create_asr_run_task(
task_id: &str,
model: &str,
format: &str,
sample_rate: Option<u32>,
) -> RunTaskParameters {
RunTaskParameters {
header: TaskHeader {
action: TaskAction::RunTask,
task_id: task_id.to_string(),
streaming: "duplex".into(),
},
payload: RunTaskPayload {
task_group: "audio".into(),
task: RunTaskType::Asr,
function: RunTaskFunction::Recognition,
model: model.into(),
input: HashMap::new(),
parameters: TaskParameters {
format: format.into(),
sample_rate,
..Default::default()
},
},
}
}
pub fn create_tts_run_task(
task_id: &str,
model: &str,
voice: Option< &str>,
format: &str,
text: Option<&str>,
) -> RunTaskParameters {
let mut input = HashMap::new();
if let Some(t) = text {
input.insert("text".to_string(), t.into());
}
RunTaskParameters {
header: TaskHeader {
action: TaskAction::RunTask,
task_id: task_id.to_string(),
streaming: "duplex".into(),
},
payload: RunTaskPayload {
task_group: "audio".into(),
task: RunTaskType::Tts,
function: RunTaskFunction::SpeechSynthesizer,
model: model.into(),
input,
parameters: TaskParameters {
text_type: Some(TextType::PlainText),
voice: voice.map(|v| v.into()),
format: format.to_string(),
..Default::default()
},
},
}
}
fn default_volume() -> Option<u8> {
Some(50)
}
pub type FinishTaskHeader = TaskHeader;
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct FinishTaskParameters {
header: FinishTaskHeader,
payload: FinishTaskPayload,
}
impl TryFrom<FinishTaskParameters> for String {
type Error = crate::error::DashScopeError;
fn try_from(value: FinishTaskParameters) -> Result<Self, Self::Error> {
serde_json::to_string(&value)
.map_err(|e| crate::error::DashScopeError::SerializationError(e.to_string()))
}
}
pub fn create_finish_task(task_id: &str) -> FinishTaskParameters {
FinishTaskParameters {
header: TaskHeader {
action: TaskAction::FinishTask,
task_id: task_id.to_string(),
streaming: "duplex".into(),
},
payload: FinishTaskPayload {
input: HashMap::new(),
},
}
}
impl FinishTaskParameters {
pub fn new(task_id: String) -> Self {
Self {
header: TaskHeaderBuilder::default()
.action(TaskAction::FinishTask)
.task_id(task_id)
.build()
.unwrap(),
payload: FinishTaskPayloadBuilder::default().build().unwrap(),
}
}
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct FinishTaskPayload {
#[builder(default)]
input: HashMap<String, serde_json::Value>,
}
type ContinueTaskHeader = TaskHeader;
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct ContinueTaskParameters {
header: ContinueTaskHeader,
payload: ContinueTaskPayload,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct ContinueTaskPayload {
input: ContinueTaskInput,
}
#[derive(Debug, Clone, Builder, Serialize, Deserialize, PartialEq)]
pub struct ContinueTaskInput {
#[builder(setter(into))]
text: String,
}
impl TryFrom<ContinueTaskParameters> for String {
type Error = crate::error::DashScopeError;
fn try_from(value: ContinueTaskParameters) -> Result<Self, Self::Error> {
serde_json::to_string(&value)
.map_err(|e| crate::error::DashScopeError::SerializationError(e.to_string()))
}
}
pub fn create_continue_task<S: ToString>(task_id: String, text: S) -> ContinueTaskParameters {
ContinueTaskParameters {
header: TaskHeader {
action: TaskAction::FinishTask,
task_id,
streaming: "duplex".into(),
},
payload: ContinueTaskPayload {
input: ContinueTaskInput {
text: text.to_string(),
},
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_run_task_parameters_creation() {
let header = TaskHeaderBuilder::default()
.action(TaskAction::RunTask)
.streaming("duplex".to_string())
.build()
.expect("Failed to build header");
let task_params = TaskParametersBuilder::default()
.text_type(TextType::PlainText)
.format("pcm".to_string())
.sample_rate(16000)
.semantic_punctuation_enabled(true)
.build()
.expect("Failed to build recognition parameters");
let payload = RunTaskPayloadBuilder::default()
.task_group("audio".to_string())
.task(RunTaskType::Asr)
.function(RunTaskFunction::Recognition)
.model("fun-asr-realtime".to_string())
.parameters(task_params)
.build()
.expect("Failed to build payload");
let params = RunTaskParametersBuilder::default()
.header(header)
.payload(payload)
.build()
.expect("Failed to build run task parameters");
assert_eq!(params.header.action, TaskAction::RunTask);
assert_eq!(params.payload.task_group, "audio");
assert_eq!(params.payload.task, RunTaskType::Asr);
assert_eq!(params.payload.function, RunTaskFunction::Recognition);
assert_eq!(params.payload.model, "fun-asr-realtime");
assert_eq!(params.payload.parameters.format, "pcm");
assert_eq!(params.payload.parameters.sample_rate, Some(16000));
assert_eq!(
params.payload.parameters.semantic_punctuation_enabled,
Some(true)
);
}
#[test]
fn test_finish_task_parameters_creation() {
let header = TaskHeaderBuilder::default()
.action(TaskAction::FinishTask)
.task_id("test-task-id".to_string())
.streaming("duplex".to_string())
.build()
.expect("Failed to build finish task header");
let payload = FinishTaskPayloadBuilder::default()
.build()
.expect("Failed to build finish task payload");
let params = FinishTaskParametersBuilder::default()
.header(header)
.payload(payload)
.build()
.expect("Failed to build finish task parameters");
assert_eq!(params.header.action, TaskAction::FinishTask);
assert_eq!(params.header.task_id, "test-task-id");
assert_eq!(params.header.streaming, "duplex");
}
#[test]
fn test_default_recognition_parameters() {
let params = TaskParameters::default();
assert_eq!(params.format, "");
assert_eq!(params.sample_rate, None);
assert_eq!(params.semantic_punctuation_enabled, None);
assert_eq!(params.max_sentence_silence, None);
assert_eq!(params.multi_threshold_mode_enabled, None);
assert_eq!(params.heartbeat, None);
assert_eq!(params.language_hints, None);
}
#[test]
fn test_serialization_deserialization() {
let header = TaskHeaderBuilder::default()
.action(TaskAction::RunTask)
.streaming("duplex".to_string())
.build()
.expect("Failed to build header");
let task_params = TaskParametersBuilder::default()
.text_type(TextType::PlainText)
.format("wav".to_string())
.sample_rate(44100)
.volume(50) .semantic_punctuation_enabled(false)
.build()
.expect("Failed to build recognition parameters");
let payload = RunTaskPayloadBuilder::default()
.task_group("audio".to_string())
.task(RunTaskType::Asr)
.function(RunTaskFunction::Recognition)
.model("fun-asr-realtime".to_string())
.parameters(task_params)
.build()
.expect("Failed to build payload");
let original_params = RunTaskParametersBuilder::default()
.header(header)
.payload(payload)
.build()
.expect("Failed to build run task parameters");
let serialized = serde_json::to_string(&original_params).unwrap();
let deserialized: RunTaskParameters = serde_json::from_str(&serialized).unwrap();
assert_eq!(original_params, deserialized);
}
#[test]
fn test_finish_task_serialization_deserialization() {
let header = TaskHeaderBuilder::default()
.action(TaskAction::FinishTask)
.task_id("test-task-id".to_string())
.streaming("duplex".to_string())
.build()
.expect("Failed to build finish task header");
let payload = FinishTaskPayloadBuilder::default()
.build()
.expect("Failed to build finish task payload");
let original_params = FinishTaskParametersBuilder::default()
.header(header)
.payload(payload)
.build()
.expect("Failed to build finish task parameters");
let serialized = serde_json::to_string(&original_params).unwrap();
let deserialized: FinishTaskParameters = serde_json::from_str(&serialized).unwrap();
assert_eq!(original_params, deserialized);
}
}