use std::sync::Arc;
use crate::chat::{ChatMessage, ChatProvider, ChatResponse, Tool};
use crate::completion::{CompletionProvider, CompletionRequest, CompletionResponse};
use crate::embedding::EmbeddingProvider;
#[cfg(feature = "elevenlabs")]
use crate::error::LLMError;
use crate::models::ModelsProvider;
use crate::stt::SpeechToTextProvider;
use crate::tts::TextToSpeechProvider;
use crate::LLMProvider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug)]
pub struct ElevenLabsConfig {
pub api_key: String,
pub model_id: String,
pub base_url: String,
pub timeout_seconds: Option<u64>,
pub voice: Option<String>,
}
#[derive(Debug, Clone)]
pub struct ElevenLabs {
pub config: Arc<ElevenLabsConfig>,
pub client: Client,
}
#[derive(Debug, Deserialize)]
struct ElevenLabsWord {
text: String,
#[serde(default)]
start: f32,
#[serde(default)]
end: f32,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Word {
pub text: String,
pub start: f32,
pub end: f32,
}
#[allow(dead_code)]
#[derive(Debug, Deserialize)]
struct ElevenLabsResponse {
#[serde(skip_serializing_if = "Option::is_none")]
language_code: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
language_probability: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
text: String,
words: Option<Vec<ElevenLabsWord>>,
}
impl ElevenLabs {
pub fn new(
api_key: String,
model_id: String,
base_url: String,
timeout_seconds: Option<u64>,
voice: Option<String>,
) -> Self {
Self::with_client(
Client::new(),
api_key,
model_id,
base_url,
timeout_seconds,
voice,
)
}
pub fn with_client(
client: Client,
api_key: String,
model_id: String,
base_url: String,
timeout_seconds: Option<u64>,
voice: Option<String>,
) -> Self {
Self {
config: Arc::new(ElevenLabsConfig {
api_key,
model_id,
base_url,
timeout_seconds,
voice,
}),
client,
}
}
pub fn api_key(&self) -> &str {
&self.config.api_key
}
pub fn model_id(&self) -> &str {
&self.config.model_id
}
pub fn base_url(&self) -> &str {
&self.config.base_url
}
pub fn timeout_seconds(&self) -> Option<u64> {
self.config.timeout_seconds
}
pub fn voice(&self) -> Option<&str> {
self.config.voice.as_deref()
}
pub fn client(&self) -> &Client {
&self.client
}
}
#[async_trait]
impl SpeechToTextProvider for ElevenLabs {
async fn transcribe(&self, audio: Vec<u8>) -> Result<String, LLMError> {
let url = format!("{}/speech-to-text", self.config.base_url);
let part = reqwest::multipart::Part::bytes(audio).file_name("audio.wav");
let form = reqwest::multipart::Form::new()
.text("model_id", self.config.model_id.clone())
.part("file", part);
let mut req = self
.client
.post(url)
.header("xi-api-key", &self.config.api_key)
.multipart(form);
if let Some(t) = self.config.timeout_seconds {
req = req.timeout(Duration::from_secs(t));
}
let resp = req.send().await?.error_for_status()?;
let text = resp.text().await?;
let raw = text.clone();
let parsed: ElevenLabsResponse =
serde_json::from_str(&text).map_err(|e| LLMError::ResponseFormatError {
message: e.to_string(),
raw_response: raw,
})?;
let words: Option<Vec<Word>> = parsed.words.map(|ws| {
ws.into_iter()
.map(|w| Word {
text: w.text,
start: w.start,
end: w.end,
})
.collect()
});
Ok(words
.unwrap_or_default()
.into_iter()
.map(|w| w.text)
.collect())
}
async fn transcribe_file(&self, file_path: &str) -> Result<String, LLMError> {
let url = format!("{}/speech-to-text", self.config.base_url);
let form = reqwest::multipart::Form::new()
.text("model_id", self.config.model_id.clone())
.file("file", file_path)
.await
.map_err(|e| LLMError::HttpError(e.to_string()))?;
let mut req = self
.client
.post(url)
.header("xi-api-key", &self.config.api_key)
.multipart(form);
if let Some(t) = self.config.timeout_seconds {
req = req.timeout(Duration::from_secs(t));
}
let resp = req.send().await?.error_for_status()?;
let text = resp.text().await?;
let raw = text.clone();
let parsed: ElevenLabsResponse =
serde_json::from_str(&text).map_err(|e| LLMError::ResponseFormatError {
message: e.to_string(),
raw_response: raw,
})?;
let words: Option<Vec<Word>> = parsed.words.map(|ws| {
ws.into_iter()
.map(|w| Word {
text: w.text,
start: w.start,
end: w.end,
})
.collect()
});
Ok(words
.unwrap_or_default()
.into_iter()
.map(|w| w.text)
.collect())
}
}
#[async_trait]
impl CompletionProvider for ElevenLabs {
async fn complete(&self, _req: &CompletionRequest) -> Result<CompletionResponse, LLMError> {
Ok(CompletionResponse {
text: "ElevenLabs completion not implemented.".into(),
})
}
}
#[async_trait]
impl EmbeddingProvider for ElevenLabs {
async fn embed(&self, _text: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
Err(LLMError::ProviderError(
"Embedding not supported".to_string(),
))
}
}
#[async_trait]
impl ChatProvider for ElevenLabs {
async fn chat(&self, _messages: &[ChatMessage]) -> Result<Box<dyn ChatResponse>, LLMError> {
Err(LLMError::ProviderError("Chat not supported".to_string()))
}
async fn chat_with_tools(
&self,
_messages: &[ChatMessage],
_tools: Option<&[Tool]>,
) -> Result<Box<dyn ChatResponse>, LLMError> {
Err(LLMError::ProviderError(
"Chat with tools not supported".to_string(),
))
}
}
#[async_trait]
impl ModelsProvider for ElevenLabs {}
impl LLMProvider for ElevenLabs {
fn tools(&self) -> Option<&[Tool]> {
None
}
}
#[async_trait]
impl TextToSpeechProvider for ElevenLabs {
async fn speech(&self, text: &str) -> Result<Vec<u8>, LLMError> {
let url = format!(
"{}/text-to-speech/{}?output_format=mp3_44100_128",
self.config.base_url,
self.config
.voice
.clone()
.unwrap_or("JBFqnCBsd6RMkjVDRZzb".to_string())
);
let body = serde_json::json!({
"text": text,
"model_id": self.config.model_id
});
let mut req = self
.client
.post(url)
.header("xi-api-key", &self.config.api_key)
.header("Content-Type", "application/json")
.json(&body);
if let Some(t) = self.config.timeout_seconds {
req = req.timeout(Duration::from_secs(t));
}
let resp = req.send().await?.error_for_status()?;
let audio_data = resp.bytes().await?;
Ok(audio_data.to_vec())
}
}