herolib-ai 0.3.13

AI client with multi-provider support (Groq, OpenRouter, SambaNova) and automatic failover
Documentation
//! Voice transcription model definitions and types.
//!
//! This module defines the available transcription models and the types for transcription requests/responses.

use serde::{Deserialize, Serialize};

use crate::provider::Provider;

/// Available transcription models.
///
/// Each model maps to one or more providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum TranscriptionModel {
    /// Whisper Large V3 Turbo - Fast multilingual transcription with 216x real-time speed.
    WhisperLargeV3Turbo,
    /// Whisper Large V3 - High accuracy multilingual transcription and translation.
    WhisperLargeV3,
}

/// Transcription model information including provider mappings.
#[derive(Debug, Clone)]
pub struct TranscriptionModelInfo {
    /// Our internal model name.
    pub model: TranscriptionModel,
    /// Human-readable description.
    pub description: &'static str,
    /// Whether the model supports translation.
    pub supports_translation: bool,
    /// Provider mappings in order of preference.
    pub providers: Vec<TranscriptionProviderMapping>,
}

/// Mapping of a transcription model to a specific provider.
#[derive(Debug, Clone)]
pub struct TranscriptionProviderMapping {
    /// The provider.
    pub provider: Provider,
    /// The model name/ID used by this provider.
    pub model_id: &'static str,
}

impl TranscriptionProviderMapping {
    /// Creates a new provider mapping.
    pub const fn new(provider: Provider, model_id: &'static str) -> Self {
        Self { provider, model_id }
    }
}

impl TranscriptionModel {
    /// Returns the model information.
    pub fn info(&self) -> TranscriptionModelInfo {
        match self {
            TranscriptionModel::WhisperLargeV3Turbo => TranscriptionModelInfo {
                model: *self,
                description: "Whisper Large V3 Turbo - Fast multilingual transcription",
                supports_translation: false,
                providers: vec![TranscriptionProviderMapping::new(
                    Provider::Groq,
                    "whisper-large-v3-turbo",
                )],
            },
            TranscriptionModel::WhisperLargeV3 => TranscriptionModelInfo {
                model: *self,
                description: "Whisper Large V3 - High accuracy transcription and translation",
                supports_translation: true,
                providers: vec![TranscriptionProviderMapping::new(
                    Provider::Groq,
                    "whisper-large-v3",
                )],
            },
        }
    }

    /// Returns the human-readable name.
    pub fn name(&self) -> &'static str {
        match self {
            TranscriptionModel::WhisperLargeV3Turbo => "Whisper Large V3 Turbo",
            TranscriptionModel::WhisperLargeV3 => "Whisper Large V3",
        }
    }

    /// Returns the default transcription model.
    pub fn default() -> Self {
        TranscriptionModel::WhisperLargeV3Turbo
    }

    /// Returns all available transcription models.
    pub fn all() -> &'static [TranscriptionModel] {
        &[
            TranscriptionModel::WhisperLargeV3Turbo,
            TranscriptionModel::WhisperLargeV3,
        ]
    }
}

impl std::fmt::Display for TranscriptionModel {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.name())
    }
}

/// Response format for transcription requests.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TranscriptionResponseFormat {
    /// Simple JSON with just the text.
    Json,
    /// Plain text response.
    Text,
    /// Verbose JSON with timestamps and metadata.
    VerboseJson,
}

impl Default for TranscriptionResponseFormat {
    fn default() -> Self {
        TranscriptionResponseFormat::Json
    }
}

/// Timestamp granularity options.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TimestampGranularity {
    /// Word-level timestamps.
    Word,
    /// Segment-level timestamps.
    Segment,
}

/// Options for transcription requests.
#[derive(Debug, Clone, Default)]
pub struct TranscriptionOptions {
    /// The language of the input audio in ISO-639-1 format.
    pub language: Option<String>,
    /// An optional text to guide the model's style or continue a previous audio segment.
    pub prompt: Option<String>,
    /// The format of the transcript output.
    pub response_format: Option<TranscriptionResponseFormat>,
    /// The sampling temperature (0 to 1).
    pub temperature: Option<f32>,
    /// The timestamp granularities to populate.
    pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}

impl TranscriptionOptions {
    /// Creates new default transcription options.
    pub fn new() -> Self {
        Self::default()
    }

    /// Sets the language of the input audio.
    pub fn with_language(mut self, language: impl Into<String>) -> Self {
        self.language = Some(language.into());
        self
    }

    /// Sets the prompt to guide the model.
    pub fn with_prompt(mut self, prompt: impl Into<String>) -> Self {
        self.prompt = Some(prompt.into());
        self
    }

    /// Sets the response format.
    pub fn with_response_format(mut self, format: TranscriptionResponseFormat) -> Self {
        self.response_format = Some(format);
        self
    }

    /// Sets the sampling temperature.
    pub fn with_temperature(mut self, temperature: f32) -> Self {
        self.temperature = Some(temperature);
        self
    }

    /// Sets the timestamp granularities.
    pub fn with_timestamp_granularities(
        mut self,
        granularities: Vec<TimestampGranularity>,
    ) -> Self {
        self.timestamp_granularities = Some(granularities);
        self
    }
}

/// Response from transcription API (simple JSON format).
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionResponse {
    /// The transcribed text.
    pub text: String,
}

/// Response from transcription API (verbose JSON format).
#[derive(Debug, Clone, Deserialize)]
pub struct VerboseTranscriptionResponse {
    /// The transcribed text.
    pub text: String,
    /// The detected language.
    #[serde(default)]
    pub language: Option<String>,
    /// Duration of the audio in seconds.
    #[serde(default)]
    pub duration: Option<f64>,
    /// Transcription segments with timing information.
    #[serde(default)]
    pub segments: Option<Vec<TranscriptionSegment>>,
    /// Word-level timestamps (if requested).
    #[serde(default)]
    pub words: Option<Vec<TranscriptionWord>>,
}

/// A segment of transcribed text with timing information.
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionSegment {
    /// Segment ID.
    pub id: i32,
    /// Seek position.
    #[serde(default)]
    pub seek: Option<i32>,
    /// Start time in seconds.
    pub start: f64,
    /// End time in seconds.
    pub end: f64,
    /// Transcribed text for this segment.
    pub text: String,
    /// Token IDs.
    #[serde(default)]
    pub tokens: Option<Vec<i32>>,
    /// Average log probability (closer to 0 = higher confidence).
    #[serde(default)]
    pub avg_logprob: Option<f64>,
    /// Compression ratio.
    #[serde(default)]
    pub compression_ratio: Option<f64>,
    /// Probability of no speech.
    #[serde(default)]
    pub no_speech_prob: Option<f64>,
}

/// A word with timestamp information.
#[derive(Debug, Clone, Deserialize)]
pub struct TranscriptionWord {
    /// The word.
    pub word: String,
    /// Start time in seconds.
    pub start: f64,
    /// End time in seconds.
    pub end: f64,
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_transcription_model_info() {
        let info = TranscriptionModel::WhisperLargeV3Turbo.info();
        assert!(!info.providers.is_empty());
        assert!(!info.supports_translation);

        let info = TranscriptionModel::WhisperLargeV3.info();
        assert!(info.supports_translation);
    }

    #[test]
    fn test_all_transcription_models_have_providers() {
        for model in TranscriptionModel::all() {
            let info = model.info();
            assert!(
                !info.providers.is_empty(),
                "Transcription model {} has no providers",
                model.name()
            );
        }
    }

    #[test]
    fn test_transcription_options() {
        let options = TranscriptionOptions::new()
            .with_language("en")
            .with_temperature(0.2)
            .with_response_format(TranscriptionResponseFormat::VerboseJson);

        assert_eq!(options.language, Some("en".to_string()));
        assert_eq!(options.temperature, Some(0.2));
        assert_eq!(
            options.response_format,
            Some(TranscriptionResponseFormat::VerboseJson)
        );
    }

    #[test]
    fn test_transcription_response_parsing() {
        let json = r#"{"text": "Hello, world!"}"#;
        let response: TranscriptionResponse = serde_json::from_str(json).unwrap();
        assert_eq!(response.text, "Hello, world!");
    }

    #[test]
    fn test_verbose_transcription_response_parsing() {
        let json = r#"{
            "text": "Hello, world!",
            "language": "en",
            "duration": 1.5,
            "segments": [{
                "id": 0,
                "start": 0.0,
                "end": 1.5,
                "text": "Hello, world!"
            }]
        }"#;

        let response: VerboseTranscriptionResponse = serde_json::from_str(json).unwrap();
        assert_eq!(response.text, "Hello, world!");
        assert_eq!(response.language, Some("en".to_string()));
        assert_eq!(response.duration, Some(1.5));
        assert!(response.segments.is_some());
    }
}