use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::rate_limiter::RateLimiter;
const ELEVENLABS_API_BASE: &str = "https://api.elevenlabs.io/v1";
pub struct ElevenLabsClient {
api_key: String,
base_url: String,
http_client: Client,
rate_limiter: Option<std::sync::Arc<RateLimiter>>,
}
impl ElevenLabsClient {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
api_key: api_key.into(),
base_url: ELEVENLABS_API_BASE.to_string(),
http_client: Client::new(),
rate_limiter: None,
}
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into();
self
}
pub fn with_rate_limit(mut self, requests_per_minute: u32) -> Self {
self.rate_limiter = Some(std::sync::Arc::new(RateLimiter::new(requests_per_minute)));
self
}
async fn acquire_rate_limit(&self) {
if let Some(ref limiter) = self.rate_limiter {
limiter.acquire().await;
}
}
pub async fn text_to_speech(
&self,
voice_id: &str,
req: &ElevenLabsTtsRequest,
) -> Result<Vec<u8>> {
self.acquire_rate_limit().await;
let url = format!("{}/text-to-speech/{}", self.base_url, voice_id);
let response = self
.http_client
.post(&url)
.header("xi-api-key", &self.api_key)
.header("Content-Type", "application/json")
.json(req)
.send()
.await
.context("Failed to send ElevenLabs TTS request")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("ElevenLabs TTS API error ({}): {}", status, body);
}
let bytes = response
.bytes()
.await
.context("Failed to read ElevenLabs TTS response")?;
Ok(bytes.to_vec())
}
pub async fn speech_to_text(
&self,
audio_data: Vec<u8>,
req: &ElevenLabsSttRequest,
) -> Result<ElevenLabsSttResponse> {
self.acquire_rate_limit().await;
let url = format!("{}/speech-to-text", self.base_url);
let file_part = reqwest::multipart::Part::bytes(audio_data)
.file_name("audio.wav")
.mime_str("audio/wav")
.context("Failed to create multipart")?;
let mut form = reqwest::multipart::Form::new().part("audio", file_part);
if let Some(ref model) = req.model {
form = form.text("model_id", model.clone());
}
if let Some(ref lang) = req.language_code {
form = form.text("language_code", lang.clone());
}
let response = self
.http_client
.post(&url)
.header("xi-api-key", &self.api_key)
.multipart(form)
.send()
.await
.context("Failed to send ElevenLabs STT request")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("ElevenLabs STT API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse ElevenLabs STT response")
}
pub async fn list_voices(&self) -> Result<ElevenLabsVoicesResponse> {
self.acquire_rate_limit().await;
let url = format!("{}/voices", self.base_url);
let response = self
.http_client
.get(&url)
.header("xi-api-key", &self.api_key)
.send()
.await
.context("Failed to list ElevenLabs voices")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("ElevenLabs voices API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse ElevenLabs voices response")
}
}
#[derive(Debug, Clone, Serialize)]
pub struct ElevenLabsTtsRequest {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub voice_settings: Option<ElevenLabsVoiceSettings>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_format: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElevenLabsVoiceSettings {
pub stability: f32,
pub similarity_boost: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub style: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub use_speaker_boost: Option<bool>,
}
#[derive(Debug, Clone, Default)]
pub struct ElevenLabsSttRequest {
pub model: Option<String>,
pub language_code: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ElevenLabsSttResponse {
pub text: String,
#[serde(default)]
pub language_code: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ElevenLabsVoicesResponse {
pub voices: Vec<ElevenLabsVoice>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ElevenLabsVoice {
pub voice_id: String,
pub name: String,
#[serde(default)]
pub labels: std::collections::HashMap<String, String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tts_request_serialization() {
let req = ElevenLabsTtsRequest {
text: "Hello world".to_string(),
model_id: Some("eleven_multilingual_v2".to_string()),
voice_settings: Some(ElevenLabsVoiceSettings {
stability: 0.5,
similarity_boost: 0.75,
style: None,
use_speaker_boost: None,
}),
output_format: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["text"], "Hello world");
assert_eq!(json["model_id"], "eleven_multilingual_v2");
assert!(json.get("output_format").is_none());
}
#[test]
fn test_client_creation() {
let client = ElevenLabsClient::new("test-key");
assert_eq!(client.base_url, ELEVENLABS_API_BASE);
}
#[test]
fn test_voices_response_deserialization() {
let json = r#"{
"voices": [
{"voice_id": "abc123", "name": "Rachel", "labels": {"accent": "american"}}
]
}"#;
let resp: ElevenLabsVoicesResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.voices.len(), 1);
assert_eq!(resp.voices[0].name, "Rachel");
}
}