use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum InputType {
Audio,
Text,
Image,
Embedding,
}
impl InputType {
pub fn as_str(&self) -> &'static str {
match self {
InputType::Audio => "audio",
InputType::Text => "text",
InputType::Image => "image",
InputType::Embedding => "embedding",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OutputType {
Audio,
Text,
Image,
Embedding,
Structured,
}
impl OutputType {
pub fn as_str(&self) -> &'static str {
match self {
OutputType::Audio => "audio",
OutputType::Text => "text",
OutputType::Image => "image",
OutputType::Embedding => "embedding",
OutputType::Structured => "structured",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum AudioSampleFormat {
Pcm16,
Pcm32,
#[default]
Float32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum ImageFormat {
#[default]
Rgb,
Bgr,
Grayscale,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AudioInputConfig {
pub sample_rate: u32,
pub channels: u32,
#[serde(default)]
pub format: AudioSampleFormat,
#[serde(default)]
pub streaming: bool,
}
impl AudioInputConfig {
pub fn asr_default() -> Self {
Self {
sample_rate: 16000,
channels: 1,
format: AudioSampleFormat::Float32,
streaming: false,
}
}
pub fn cd_quality() -> Self {
Self {
sample_rate: 44100,
channels: 2,
format: AudioSampleFormat::Pcm16,
streaming: false,
}
}
}
impl Default for AudioInputConfig {
fn default() -> Self {
Self::asr_default()
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TextInputConfig {
#[serde(default)]
pub max_length: Option<u32>,
#[serde(default = "default_encoding")]
pub encoding: String,
}
fn default_encoding() -> String {
"utf8".to_string()
}
impl Default for TextInputConfig {
fn default() -> Self {
Self {
max_length: None,
encoding: "utf8".to_string(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ImageInputConfig {
pub width: u32,
pub height: u32,
#[serde(default = "default_image_channels")]
pub channels: u32,
#[serde(default)]
pub format: ImageFormat,
}
fn default_image_channels() -> u32 {
3
}
impl Default for ImageInputConfig {
fn default() -> Self {
Self {
width: 224,
height: 224,
channels: 3,
format: ImageFormat::Rgb,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingInputConfig {
pub dimensions: u32,
}
impl Default for EmbeddingInputConfig {
fn default() -> Self {
Self { dimensions: 384 }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum InputConfig {
Audio {
#[serde(flatten)]
config: AudioInputConfig,
},
Text {
#[serde(flatten)]
config: TextInputConfig,
},
Image {
#[serde(flatten)]
config: ImageInputConfig,
},
Embedding {
#[serde(flatten)]
config: EmbeddingInputConfig,
},
}
impl InputConfig {
pub fn input_type(&self) -> InputType {
match self {
InputConfig::Audio { .. } => InputType::Audio,
InputConfig::Text { .. } => InputType::Text,
InputConfig::Image { .. } => InputType::Image,
InputConfig::Embedding { .. } => InputType::Embedding,
}
}
pub fn audio(config: AudioInputConfig) -> Self {
InputConfig::Audio { config }
}
pub fn text(config: TextInputConfig) -> Self {
InputConfig::Text { config }
}
pub fn image(config: ImageInputConfig) -> Self {
InputConfig::Image { config }
}
pub fn embedding(config: EmbeddingInputConfig) -> Self {
InputConfig::Embedding { config }
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AudioOutputConfig {
#[serde(default = "default_output_sample_rate")]
pub sample_rate: u32,
#[serde(default)]
pub format: AudioSampleFormat,
}
fn default_output_sample_rate() -> u32 {
22050
}
impl Default for AudioOutputConfig {
fn default() -> Self {
Self {
sample_rate: 22050,
format: AudioSampleFormat::Pcm16,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EmbeddingOutputConfig {
pub dimensions: u32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct StructuredOutputConfig {
#[serde(default)]
pub schema: serde_json::Value,
}
impl Default for StructuredOutputConfig {
fn default() -> Self {
Self {
schema: serde_json::Value::Object(serde_json::Map::new()),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
#[derive(Default)]
pub enum OutputConfig {
Audio {
#[serde(flatten)]
config: AudioOutputConfig,
},
#[default]
Text,
Image,
Embedding {
#[serde(flatten)]
config: EmbeddingOutputConfig,
},
Structured {
#[serde(flatten)]
config: StructuredOutputConfig,
},
}
impl OutputConfig {
pub fn output_type(&self) -> OutputType {
match self {
OutputConfig::Audio { .. } => OutputType::Audio,
OutputConfig::Text => OutputType::Text,
OutputConfig::Image => OutputType::Image,
OutputConfig::Embedding { .. } => OutputType::Embedding,
OutputConfig::Structured { .. } => OutputType::Structured,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_input_config_serde() {
let config = InputConfig::Audio {
config: AudioInputConfig {
sample_rate: 16000,
channels: 1,
format: AudioSampleFormat::Float32,
streaming: false,
},
};
let yaml = serde_yaml::to_string(&config).unwrap();
assert!(yaml.contains("type: audio"));
assert!(yaml.contains("sample_rate: 16000"));
let parsed: InputConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(config, parsed);
}
#[test]
fn test_text_input_config_serde() {
let config = InputConfig::Text {
config: TextInputConfig {
max_length: Some(4096),
encoding: "utf8".to_string(),
},
};
let yaml = serde_yaml::to_string(&config).unwrap();
let parsed: InputConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(config, parsed);
}
#[test]
fn test_output_config_text() {
let yaml = "type: text";
let config: OutputConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.output_type(), OutputType::Text);
}
#[test]
fn test_output_config_audio() {
let yaml = r#"
type: audio
sample_rate: 22050
format: pcm16
"#;
let config: OutputConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.output_type(), OutputType::Audio);
}
#[test]
fn test_input_type_as_str() {
assert_eq!(InputType::Audio.as_str(), "audio");
assert_eq!(InputType::Text.as_str(), "text");
assert_eq!(InputType::Image.as_str(), "image");
assert_eq!(InputType::Embedding.as_str(), "embedding");
}
}