use anyhow::{Context, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::rate_limiter::RateLimiter;
pub struct AzureSpeechClient {
subscription_key: String,
region: String,
http_client: Client,
rate_limiter: Option<std::sync::Arc<RateLimiter>>,
}
impl AzureSpeechClient {
pub fn new(subscription_key: impl Into<String>, region: impl Into<String>) -> Self {
Self {
subscription_key: subscription_key.into(),
region: region.into(),
http_client: Client::new(),
rate_limiter: None,
}
}
pub fn with_rate_limit(mut self, requests_per_minute: u32) -> Self {
self.rate_limiter = Some(std::sync::Arc::new(RateLimiter::new(requests_per_minute)));
self
}
async fn acquire_rate_limit(&self) {
if let Some(ref limiter) = self.rate_limiter {
limiter.acquire().await;
}
}
fn tts_endpoint(&self) -> String {
format!(
"https://{}.tts.speech.microsoft.com/cognitiveservices/v1",
self.region
)
}
fn stt_endpoint(&self) -> String {
format!(
"https://{}.stt.speech.microsoft.com/speech/recognition/conversation/cognitiveservices/v1",
self.region
)
}
fn voices_endpoint(&self) -> String {
format!(
"https://{}.tts.speech.microsoft.com/cognitiveservices/voices/list",
self.region
)
}
pub async fn synthesize(&self, ssml: &str, output_format: &str) -> Result<Vec<u8>> {
self.acquire_rate_limit().await;
let response = self
.http_client
.post(self.tts_endpoint())
.header("Ocp-Apim-Subscription-Key", &self.subscription_key)
.header("Content-Type", "application/ssml+xml")
.header("X-Microsoft-OutputFormat", output_format)
.body(ssml.to_string())
.send()
.await
.context("Failed to send Azure TTS request")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Azure TTS API error ({}): {}", status, body);
}
let bytes = response
.bytes()
.await
.context("Failed to read Azure TTS response")?;
Ok(bytes.to_vec())
}
pub async fn synthesize_text(
&self,
text: &str,
voice_name: &str,
output_format: &str,
) -> Result<Vec<u8>> {
let ssml = format!(
r#"<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="en-US">
<voice name="{voice_name}">{text}</voice>
</speak>"#,
voice_name = voice_name,
text = text,
);
self.synthesize(&ssml, output_format).await
}
pub async fn recognize(
&self,
audio_data: Vec<u8>,
req: &AzureSttRequest,
) -> Result<AzureSttResponse> {
self.acquire_rate_limit().await;
let mut url = self.stt_endpoint();
let lang = req.language.as_deref().unwrap_or("en-US");
url = format!("{}?language={}", url, lang);
let content_type = req
.content_type
.as_deref()
.unwrap_or("audio/wav; codecs=audio/pcm; samplerate=16000");
let response = self
.http_client
.post(&url)
.header("Ocp-Apim-Subscription-Key", &self.subscription_key)
.header("Content-Type", content_type)
.body(audio_data)
.send()
.await
.context("Failed to send Azure STT request")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Azure STT API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse Azure STT response")
}
pub async fn list_voices(&self) -> Result<Vec<AzureVoice>> {
self.acquire_rate_limit().await;
let response = self
.http_client
.get(self.voices_endpoint())
.header("Ocp-Apim-Subscription-Key", &self.subscription_key)
.send()
.await
.context("Failed to list Azure voices")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Azure voices API error ({}): {}", status, body);
}
response
.json()
.await
.context("Failed to parse Azure voices response")
}
}
#[derive(Debug, Clone, Default)]
pub struct AzureSttRequest {
pub language: Option<String>,
pub content_type: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct AzureSttResponse {
#[serde(rename = "RecognitionStatus")]
pub recognition_status: String,
#[serde(rename = "DisplayText")]
pub display_text: Option<String>,
#[serde(rename = "Offset")]
pub offset: Option<u64>,
#[serde(rename = "Duration")]
pub duration: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureVoice {
#[serde(rename = "Name")]
pub name: String,
#[serde(rename = "DisplayName")]
pub display_name: String,
#[serde(rename = "ShortName")]
pub short_name: String,
#[serde(rename = "Gender")]
pub gender: String,
#[serde(rename = "Locale")]
pub locale: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_creation() {
let client = AzureSpeechClient::new("test-key", "eastus");
assert!(client.tts_endpoint().contains("eastus"));
}
#[test]
fn test_stt_response_deserialization() {
let json = r#"{
"RecognitionStatus": "Success",
"DisplayText": "Hello world.",
"Offset": 0,
"Duration": 10000000
}"#;
let resp: AzureSttResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.recognition_status, "Success");
assert_eq!(resp.display_text, Some("Hello world.".to_string()));
}
}