use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
use super::config::DeepgramConfig;
use super::error::DeepgramErrorMapper;
use super::stt::{self, DeepgramResponse, OpenAITranscriptionResponse, TranscriptionRequest};
use crate::core::providers::base::{GlobalPoolManager, HttpErrorMapper};
use crate::core::providers::unified_provider::ProviderError;
use crate::core::traits::provider::ProviderConfig as _;
use crate::core::types::health::HealthStatus;
use crate::core::types::{model::ModelInfo, model::ProviderCapability};
const PROVIDER_NAME: &str = "deepgram";
const DEEPGRAM_CAPABILITIES: &[ProviderCapability] = &[ProviderCapability::AudioTranscription];
#[derive(Debug, Clone)]
pub struct DeepgramProvider {
config: DeepgramConfig,
models: Vec<ModelInfo>,
}
impl DeepgramProvider {
pub async fn new(config: DeepgramConfig) -> Result<Self, ProviderError> {
config
.validate()
.map_err(|e| ProviderError::configuration(PROVIDER_NAME, e))?;
let _pool_manager = Arc::new(GlobalPoolManager::new().map_err(|e| {
ProviderError::configuration(
PROVIDER_NAME,
format!("Failed to create pool manager: {}", e),
)
})?);
let models = Self::build_model_list();
Ok(Self { config, models })
}
pub async fn with_api_key(api_key: impl Into<String>) -> Result<Self, ProviderError> {
let config = DeepgramConfig::from_env().with_api_key(api_key);
Self::new(config).await
}
pub fn build_model_list() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "nova-2".to_string(),
name: "Nova 2".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "nova-2-general".to_string(),
name: "Nova 2 General".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "nova-2-meeting".to_string(),
name: "Nova 2 Meeting".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "nova-2-phonecall".to_string(),
name: "Nova 2 Phone Call".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "nova-2-medical".to_string(),
name: "Nova 2 Medical".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "enhanced".to_string(),
name: "Enhanced".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
ModelInfo {
id: "base".to_string(),
name: "Base".to_string(),
provider: "deepgram".to_string(),
max_context_length: 0,
max_output_length: None,
supports_streaming: true,
supports_tools: false,
supports_multimodal: true,
input_cost_per_1k_tokens: None,
output_cost_per_1k_tokens: None,
currency: "USD".to_string(),
capabilities: vec![ProviderCapability::AudioTranscription],
created_at: None,
updated_at: None,
metadata: HashMap::new(),
},
]
}
pub fn name(&self) -> &'static str {
PROVIDER_NAME
}
pub fn capabilities(&self) -> &'static [ProviderCapability] {
DEEPGRAM_CAPABILITIES
}
pub fn models(&self) -> &[ModelInfo] {
&self.models
}
pub fn get_error_mapper(&self) -> DeepgramErrorMapper {
DeepgramErrorMapper
}
pub async fn transcribe_audio(
&self,
request: TranscriptionRequest,
) -> Result<OpenAITranscriptionResponse, ProviderError> {
debug!("Deepgram STT request: model={}", request.model);
let url = stt::build_stt_url(&self.config.get_api_base(), &request);
let api_key = self
.config
.get_api_key()
.ok_or_else(|| ProviderError::authentication(PROVIDER_NAME, "API key is required"))?;
let content_type = request
.filename
.as_ref()
.map(|f| stt::detect_audio_mime_type(f))
.unwrap_or("audio/mpeg");
let client = reqwest::Client::new();
let response = client
.post(&url)
.header("Authorization", format!("Token {}", api_key))
.header("Content-Type", content_type)
.body(request.file)
.send()
.await
.map_err(|e| ProviderError::network(PROVIDER_NAME, e.to_string()))?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.ok();
return Err(Self::map_http_error(status, body.as_deref()));
}
let response_text = response.text().await.map_err(|e| {
ProviderError::response_parsing(
PROVIDER_NAME,
format!("Failed to read response: {}", e),
)
})?;
let deepgram_response: DeepgramResponse =
serde_json::from_str(&response_text).map_err(|e| {
ProviderError::response_parsing(
PROVIDER_NAME,
format!(
"Failed to parse response: {}\nResponse: {}",
e, response_text
),
)
})?;
deepgram_response.try_into()
}
pub async fn transcribe_simple(
&self,
file: Vec<u8>,
model: Option<String>,
language: Option<String>,
diarize: Option<bool>,
punctuate: Option<bool>,
filename: Option<String>,
) -> Result<OpenAITranscriptionResponse, ProviderError> {
let request = TranscriptionRequest {
file,
model: model.unwrap_or_else(|| "nova-2".to_string()),
language,
smart_format: Some(true),
punctuate,
diarize,
paragraphs: diarize, words: Some(true),
filename,
..Default::default()
};
self.transcribe_audio(request).await
}
pub fn map_http_error(status: u16, body: Option<&str>) -> ProviderError {
let message = body.unwrap_or("Unknown error").to_string();
match status {
400 => ProviderError::invalid_request(PROVIDER_NAME, message),
401 => ProviderError::authentication(PROVIDER_NAME, "Invalid API key"),
402 => ProviderError::quota_exceeded(PROVIDER_NAME, "Usage quota exceeded"),
403 => ProviderError::authentication(PROVIDER_NAME, "Access forbidden"),
404 => ProviderError::model_not_found(PROVIDER_NAME, "Model not found"),
429 => ProviderError::rate_limit(PROVIDER_NAME, Some(60)),
500 => ProviderError::api_error(PROVIDER_NAME, 500, "Internal server error"),
502 | 503 => ProviderError::api_error(PROVIDER_NAME, status, "Service unavailable"),
_ => HttpErrorMapper::map_status_code(
PROVIDER_NAME,
status,
&format!("HTTP error {}: {}", status, message),
),
}
}
pub async fn health_check(&self) -> HealthStatus {
let url = format!("{}/projects", self.config.get_api_base().replace("/v1", ""));
let api_key = match self.config.get_api_key() {
Some(key) => key,
None => return HealthStatus::Unhealthy,
};
let client = reqwest::Client::new();
match client
.get(&url)
.header("Authorization", format!("Token {}", api_key))
.send()
.await
{
Ok(response) if response.status().is_success() => HealthStatus::Healthy,
_ => HealthStatus::Unhealthy,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_model_list() {
let models = DeepgramProvider::build_model_list();
assert!(!models.is_empty());
let has_nova2 = models.iter().any(|m| m.id == "nova-2");
assert!(has_nova2);
let has_enhanced = models.iter().any(|m| m.id == "enhanced");
assert!(has_enhanced);
for model in &models {
assert_eq!(model.provider, "deepgram");
assert!(
model
.capabilities
.contains(&ProviderCapability::AudioTranscription)
);
}
}
#[test]
fn test_map_http_error() {
let err = DeepgramProvider::map_http_error(400, Some("Bad request"));
assert!(matches!(err, ProviderError::InvalidRequest { .. }));
let err = DeepgramProvider::map_http_error(401, None);
assert!(matches!(err, ProviderError::Authentication { .. }));
let err = DeepgramProvider::map_http_error(402, Some("Quota"));
assert!(matches!(err, ProviderError::QuotaExceeded { .. }));
let err = DeepgramProvider::map_http_error(403, None);
assert!(matches!(err, ProviderError::Authentication { .. }));
let err = DeepgramProvider::map_http_error(404, None);
assert!(matches!(err, ProviderError::ModelNotFound { .. }));
let err = DeepgramProvider::map_http_error(429, None);
assert!(matches!(err, ProviderError::RateLimit { .. }));
let err = DeepgramProvider::map_http_error(500, None);
assert!(matches!(err, ProviderError::ApiError { .. }));
let err = DeepgramProvider::map_http_error(503, None);
assert!(matches!(err, ProviderError::ApiError { .. }));
}
#[test]
fn test_capabilities() {
assert!(DEEPGRAM_CAPABILITIES.contains(&ProviderCapability::AudioTranscription));
assert!(!DEEPGRAM_CAPABILITIES.contains(&ProviderCapability::ChatCompletion));
assert!(!DEEPGRAM_CAPABILITIES.contains(&ProviderCapability::TextToSpeech));
}
}