Skip to main content

cognee_llm/
transcriber.rs

1//! Audio transcription trait and types.
2//!
3//! Provides an async trait for converting audio bytes to text, with a default
4//! implementation targeting the OpenAI Whisper API (`POST /v1/audio/transcriptions`).
5
6use async_trait::async_trait;
7
8use crate::error::{LlmError, LlmResult};
9
10/// Audio formats accepted by the OpenAI Whisper API.
11const SUPPORTED_AUDIO_FORMATS: &[&str] = &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"];
12
13/// Validate that `format` is a supported audio format.
14///
15/// Returns `Ok(())` if the format (case-insensitive) is in the whitelist,
16/// or `Err(LlmError::InvalidAudioFormat)` otherwise.
17pub fn validate_audio_format(format: &str) -> LlmResult<()> {
18    let lower = format.to_ascii_lowercase();
19    if SUPPORTED_AUDIO_FORMATS.contains(&lower.as_str()) {
20        Ok(())
21    } else {
22        Err(LlmError::InvalidAudioFormat(format.to_string()))
23    }
24}
25
26/// Output of an audio transcription request.
27#[derive(Debug, Clone)]
28pub struct TranscriptionOutput {
29    /// The transcribed text.
30    pub text: String,
31    /// The detected or specified language (e.g. `"english"`).
32    pub language: Option<String>,
33    /// Audio duration in seconds.
34    pub duration: Option<f32>,
35}
36
37/// Trait for audio transcription backends.
38///
39/// Separate from [`crate::Llm`] because the Whisper endpoint uses a different
40/// request shape (multipart upload), response shape, and error semantics.
41#[async_trait]
42pub trait Transcriber: Send + Sync {
43    /// Transcribe audio bytes to text.
44    ///
45    /// # Arguments
46    /// * `audio` - Raw audio file bytes (must be < 25 MB for OpenAI Whisper).
47    /// * `format` - File extension without the dot: `"mp3"`, `"wav"`, etc.
48    /// * `language_hint` - Optional ISO-639-1 language code (e.g. `"en"`).
49    /// * `prompt_hint` - Optional vocabulary/context hint for the model.
50    async fn transcribe_audio(
51        &self,
52        audio: &[u8],
53        format: &str,
54        language_hint: Option<&str>,
55        prompt_hint: Option<&str>,
56    ) -> LlmResult<TranscriptionOutput>;
57
58    /// Return the name of the transcription model in use.
59    fn transcription_model(&self) -> &str;
60}
61
62#[cfg(test)]
63mod tests {
64    #![allow(
65        clippy::unwrap_used,
66        clippy::expect_used,
67        reason = "test code — panics are acceptable"
68    )]
69    use super::*;
70
71    #[test]
72    fn test_valid_formats() {
73        for fmt in &["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"] {
74            assert!(
75                validate_audio_format(fmt).is_ok(),
76                "Expected {fmt} to be valid"
77            );
78        }
79    }
80
81    #[test]
82    fn test_invalid_formats() {
83        for fmt in &["mid", "aiff", "amr", "ogg", "flac", "aac", "wma"] {
84            let result = validate_audio_format(fmt);
85            assert!(result.is_err(), "Expected {fmt} to be invalid");
86            assert!(
87                matches!(result.unwrap_err(), LlmError::InvalidAudioFormat(_)),
88                "Expected InvalidAudioFormat for {fmt}"
89            );
90        }
91    }
92
93    #[test]
94    fn test_format_case_insensitive() {
95        assert!(validate_audio_format("MP3").is_ok());
96        assert!(validate_audio_format("Mp3").is_ok());
97        assert!(validate_audio_format("WAV").is_ok());
98        assert!(validate_audio_format("WebM").is_ok());
99    }
100}