agents 0.3.0

Facade crate for building typed Rust agents
Documentation
use serde::{Deserialize, Serialize};
use std::path::PathBuf;

use crate::llm::completion::{ModelSelector, ProviderType};

#[derive(Debug, Clone)]
pub enum AudioSource {
    Data(Vec<u8>),
    Url(String),
    Path(PathBuf),
}

impl Serialize for AudioSource {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        match self {
            AudioSource::Data(data) => {
                #[derive(Serialize)]
                struct DataVariant {
                    data: Vec<u8>,
                }
                DataVariant { data: data.clone() }.serialize(serializer)
            }
            AudioSource::Url(url) => {
                #[derive(Serialize)]
                struct UrlVariant {
                    url: String,
                }
                UrlVariant { url: url.clone() }.serialize(serializer)
            }
            AudioSource::Path(path) => {
                #[derive(Serialize)]
                struct PathVariant {
                    path: String,
                }
                PathVariant {
                    path: path.to_string_lossy().to_string(),
                }
                .serialize(serializer)
            }
        }
    }
}

impl<'de> Deserialize<'de> for AudioSource {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        #[derive(Deserialize)]
        #[serde(untagged)]
        enum RawAudioSource {
            Data { data: Vec<u8> },
            Url { url: String },
            Path { path: String },
        }

        let raw = RawAudioSource::deserialize(deserializer)?;
        match raw {
            RawAudioSource::Data { data } => Ok(AudioSource::Data(data)),
            RawAudioSource::Url { url } => Ok(AudioSource::Url(url)),
            RawAudioSource::Path { path } => Ok(AudioSource::Path(PathBuf::from(path))),
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AudioTranscriptionRequest {
    pub audio: AudioSource,
    pub language: TranscriptionLanguage,
    pub prompt: TranscriptionPrompt,
    pub model: ModelSelector,
    pub response_format: TranscriptionFormat,
}

impl AudioTranscriptionRequest {
    pub fn new(audio: AudioSource) -> Self {
        Self {
            audio,
            language: TranscriptionLanguage::AutoDetect,
            prompt: TranscriptionPrompt::None,
            model: ModelSelector::Any,
            response_format: TranscriptionFormat::ProviderDefault,
        }
    }

    pub fn builder() -> AudioTranscriptionRequestBuilder {
        AudioTranscriptionRequestBuilder::default()
    }

    pub fn with_model(mut self, model: ModelSelector) -> Self {
        self.model = model;
        self
    }

    pub fn with_language(mut self, language: impl Into<String>) -> Self {
        self.language = TranscriptionLanguage::explicit(language);
        self
    }

    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.prompt = TranscriptionPrompt::hint(prompt);
        self
    }

    pub fn with_response_format(mut self, response_format: TranscriptionFormat) -> Self {
        self.response_format = response_format;
        self
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionRequestBuilder {
    pub audio: Option<AudioSource>,
    pub language: TranscriptionLanguage,
    pub prompt: TranscriptionPrompt,
    pub model: ModelSelector,
    pub response_format: TranscriptionFormat,
}

impl Default for AudioTranscriptionRequestBuilder {
    fn default() -> Self {
        Self {
            audio: None,
            language: TranscriptionLanguage::AutoDetect,
            prompt: TranscriptionPrompt::None,
            model: ModelSelector::Any,
            response_format: TranscriptionFormat::ProviderDefault,
        }
    }
}

impl AudioTranscriptionRequestBuilder {
    pub fn with_audio(mut self, audio: AudioSource) -> Self {
        self.audio = Some(audio);
        self
    }

    pub fn with_model(mut self, model: ModelSelector) -> Self {
        self.model = model;
        self
    }

    pub fn with_language(mut self, language: impl Into<String>) -> Self {
        self.language = TranscriptionLanguage::explicit(language);
        self
    }

    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.prompt = TranscriptionPrompt::hint(prompt);
        self
    }

    pub fn with_response_format(mut self, response_format: TranscriptionFormat) -> Self {
        self.response_format = response_format;
        self
    }

    pub fn build(self) -> crate::llm::error::LlmResult<AudioTranscriptionRequest> {
        let audio = self
            .audio
            .ok_or_else(|| crate::llm::error::Error::InvalidRequest {
                reason: "Audio transcription request requires an audio source".to_string(),
            })?;

        Ok(AudioTranscriptionRequest {
            audio,
            language: self.language,
            prompt: self.prompt,
            model: self.model,
            response_format: self.response_format,
        })
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum TranscriptionLanguage {
    AutoDetect,
    Explicit { language: String },
}

impl TranscriptionLanguage {
    pub fn explicit(language: impl Into<String>) -> Self {
        Self::Explicit {
            language: language.into(),
        }
    }
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase", tag = "type")]
pub enum TranscriptionPrompt {
    None,
    Hint { text: String },
}

impl TranscriptionPrompt {
    pub fn hint(text: impl Into<String>) -> Self {
        Self::Hint { text: text.into() }
    }
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TranscriptionFormat {
    ProviderDefault,
    Text,
    Json,
    VerboseJson,
    DiarizedJson,
    Srt,
    Vtt,
}

impl TranscriptionFormat {
    pub fn as_openai_str(&self) -> Option<&'static str> {
        match self {
            TranscriptionFormat::ProviderDefault => None,
            TranscriptionFormat::Text => Some("text"),
            TranscriptionFormat::Json => Some("json"),
            TranscriptionFormat::VerboseJson => Some("verbose_json"),
            TranscriptionFormat::DiarizedJson => Some("diarized_json"),
            TranscriptionFormat::Srt => Some("srt"),
            TranscriptionFormat::Vtt => Some("vtt"),
        }
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AudioTranscriptionResponse {
    pub provider: ProviderType,
    pub model: String,
    pub text: String,
}

#[cfg(test)]
mod tests {
    use super::{
        AudioSource, AudioTranscriptionRequest, TranscriptionFormat, TranscriptionLanguage,
        TranscriptionPrompt,
    };
    use crate::llm::completion::ModelSelector;
    use crate::llm::error::Error;

    #[test]
    fn transcription_request_defaults_capture_intent() {
        let request = AudioTranscriptionRequest::new(AudioSource::Data(vec![1, 2, 3]));

        assert!(matches!(request.model, ModelSelector::Any));
        assert!(matches!(
            request.language,
            TranscriptionLanguage::AutoDetect
        ));
        assert!(matches!(request.prompt, TranscriptionPrompt::None));
        assert_eq!(
            request.response_format,
            TranscriptionFormat::ProviderDefault
        );
    }

    #[test]
    fn transcription_request_builder_requires_audio() {
        let error = AudioTranscriptionRequest::builder()
            .build()
            .expect_err("missing audio should fail");

        assert!(matches!(error, Error::InvalidRequest { .. }));
    }
}