1use std::collections::{HashSet, VecDeque};
2use std::pin::Pin;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant, SystemTime};
6
7use async_stream::try_stream;
8use bytes::Bytes;
9use futures_util::{SinkExt, Stream, StreamExt};
10use reqwest::Client;
11use tokio::fs;
12use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard};
13use tokio::time::timeout;
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
16use url::Url;
17
18use crate::constants::{TEXT_CHUNK_LIMIT, voice_list_url, websocket_url};
19use crate::error::{Error, Result};
20use crate::options::SpeakOptions;
21use crate::protocol::{
22 generate_connection_id, generate_muid, generate_sec_ms_gec, offset_from_audio_bytes,
23 parse_binary_headers, parse_headers, parse_metadata, sec_ms_gec_version, speech_config_message,
24 split_text, ssml_message, voice_headers, websocket_headers,
25};
26use crate::subtitles::{filter_boundaries, to_srt};
27use crate::types::{BoundaryEvent, SynthesisEvent, SynthesisResult, Voice};
28
29type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
30
31pub type EventStream = Pin<Box<dyn Stream<Item = Result<SynthesisEvent>> + Send + Sync + 'static>>;
32
33#[derive(Debug, Clone)]
34pub struct EdgeTtsClient {
35 http: Client,
36 connect_timeout: Duration,
37 receive_timeout: Duration,
38 request_chunk_reuse: bool,
39 ws_pool: Arc<WsPool>,
40}
41
42#[derive(Debug, Clone)]
43pub struct EdgeTtsClientBuilder {
44 connect_timeout: Duration,
45 receive_timeout: Duration,
46 ws_pool_size: usize,
47 ws_idle_ttl: Duration,
48 ws_warmup: bool,
49 request_chunk_reuse: bool,
50}
51
52#[derive(Debug)]
53struct WsPool {
54 target_idle: usize,
55 idle_ttl: Duration,
56 warmup: bool,
57 next_id: AtomicU64,
58 state: Mutex<WsPoolState>,
59}
60
61#[derive(Debug, Default)]
62struct WsPoolState {
63 entries: Vec<Arc<PoolEntry>>,
64 warming: usize,
65}
66
67#[derive(Debug)]
68struct PoolEntry {
69 id: u64,
70 stream: Arc<AsyncMutex<WsStream>>,
71 state: Mutex<PoolEntryState>,
72}
73
74#[derive(Debug, Clone, Copy)]
75enum PoolEntryState {
76 Idle { returned_at: Instant },
77 Busy,
78}
79
80#[derive(Debug)]
81struct PooledWebsocket {
82 entry: Option<Arc<PoolEntry>>,
83 stream: Option<OwnedMutexGuard<WsStream>>,
84 reusable: bool,
85 pool: Arc<WsPool>,
86}
87
88#[derive(Debug)]
89struct ChunkFailure {
90 err: Error,
91 retryable_on_fresh_connection: bool,
92}
93
94#[derive(Debug)]
95enum ChunkFrame {
96 Event(SynthesisEvent),
97 Continue,
98 TurnEnd,
99}
100
101impl Default for EdgeTtsClientBuilder {
102 fn default() -> Self {
103 Self {
104 connect_timeout: Duration::from_secs(10),
105 receive_timeout: Duration::from_secs(60),
106 ws_pool_size: 1,
107 ws_idle_ttl: Duration::from_secs(15),
108 ws_warmup: true,
109 request_chunk_reuse: true,
110 }
111 }
112}
113
114impl EdgeTtsClientBuilder {
115 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
116 self.connect_timeout = timeout;
117 self
118 }
119
120 pub fn receive_timeout(mut self, timeout: Duration) -> Self {
121 self.receive_timeout = timeout;
122 self
123 }
124
125 pub fn ws_pool_size(mut self, size: usize) -> Self {
126 self.ws_pool_size = size;
127 self
128 }
129
130 pub fn ws_idle_ttl(mut self, ttl: Duration) -> Self {
131 self.ws_idle_ttl = ttl;
132 self
133 }
134
135 pub fn ws_warmup(mut self, enabled: bool) -> Self {
136 self.ws_warmup = enabled;
137 self
138 }
139
140 pub fn request_chunk_reuse(mut self, enabled: bool) -> Self {
141 self.request_chunk_reuse = enabled;
142 self
143 }
144
145 pub fn build(self) -> Result<EdgeTtsClient> {
146 let http = Client::builder()
147 .connect_timeout(self.connect_timeout)
148 .timeout(self.receive_timeout)
149 .use_rustls_tls()
150 .build()?;
151 let client = EdgeTtsClient {
152 http,
153 connect_timeout: self.connect_timeout,
154 receive_timeout: self.receive_timeout,
155 request_chunk_reuse: self.request_chunk_reuse,
156 ws_pool: Arc::new(WsPool {
157 target_idle: self.ws_pool_size,
158 idle_ttl: self.ws_idle_ttl,
159 warmup: self.ws_warmup,
160 next_id: AtomicU64::new(1),
161 state: Mutex::new(WsPoolState::default()),
162 }),
163 };
164 client.ensure_warm_pool();
165 Ok(client)
166 }
167}
168
169impl EdgeTtsClient {
170 pub fn builder() -> EdgeTtsClientBuilder {
171 EdgeTtsClientBuilder::default()
172 }
173
174 pub fn new() -> Result<Self> {
175 Self::builder().build()
176 }
177
178 pub async fn list_voices(&self) -> Result<Vec<Voice>> {
179 let sec_ms_gec = generate_sec_ms_gec(SystemTime::now());
180 let muid = generate_muid();
181 let mut request = self
182 .http
183 .get(format!(
184 "{}&Sec-MS-GEC={sec_ms_gec}&Sec-MS-GEC-Version={}",
185 voice_list_url(),
186 sec_ms_gec_version()
187 ))
188 .header("Cookie", format!("muid={muid};"))
189 .header("Accept-Encoding", "gzip, deflate, br, zstd")
190 .header("Accept-Language", "en-US,en;q=0.9");
191
192 for (name, value) in voice_headers() {
193 request = request.header(name, value);
194 }
195
196 Ok(request.send().await?.error_for_status()?.json().await?)
197 }
198
199 pub async fn stream(
200 &self,
201 text: impl Into<String>,
202 options: SpeakOptions,
203 ) -> Result<EventStream> {
204 options.validate()?;
205 let text = text.into();
206 let chunks = split_text(&text, TEXT_CHUNK_LIMIT)?;
207 let client = self.clone();
208
209 Ok(Box::pin(try_stream! {
210 macro_rules! stream_chunk_with_fresh_socket {
211 (
212 $chunk:expr,
213 $cumulative_audio_bytes:ident,
214 $audio_received:ident,
215 $pending_error:ident,
216 $buffered_events:expr
217 ) => {{
218 let offset_compensation = offset_from_audio_bytes($cumulative_audio_bytes);
219 let mut socket = match client.acquire_websocket().await {
220 Ok(socket) => socket,
221 Err(err) => {
222 $pending_error = Some(err);
223 break;
224 }
225 };
226 match client.send_chunk_request(socket.stream_mut(), &options, $chunk).await {
227 Ok(()) => {
228 loop {
229 match client
230 .read_chunk_frame(
231 socket.stream_mut(),
232 offset_compensation,
233 $buffered_events,
234 )
235 .await
236 {
237 Ok(ChunkFrame::Event(event)) => {
238 if let SynthesisEvent::Audio(chunk) = &event {
239 $cumulative_audio_bytes += chunk.len();
240 $audio_received = true;
241 }
242 yield event;
243 }
244 Ok(ChunkFrame::Continue) => {}
245 Ok(ChunkFrame::TurnEnd) => break,
246 Err(failure) => {
247 pool_log(&format!(
248 "fresh socket read failure retryable={}",
249 failure.retryable_on_fresh_connection
250 ));
251 socket.mark_dirty();
252 $pending_error = Some(failure.err);
253 break;
254 }
255 }
256 }
257 }
258 Err(failure) => {
259 pool_log(&format!(
260 "fresh socket send failure retryable={}",
261 failure.retryable_on_fresh_connection
262 ));
263 socket.mark_dirty();
264 $pending_error = Some(failure.err);
265 break;
266 }
267 }
268
269 if $pending_error.is_some() {
270 break;
271 }
272 }};
273 }
274
275 let mut cumulative_audio_bytes = 0usize;
276 let mut audio_received = false;
277 let mut pending_error = None;
278 let mut buffered_events = VecDeque::new();
279
280 if client.request_chunk_reuse {
281 let mut shared_socket = match client.acquire_websocket().await {
282 Ok(socket) => Some(socket),
283 Err(err) => {
284 pending_error = Some(err);
285 None
286 }
287 };
288
289 if let Some(socket) = shared_socket.as_mut() {
290 let mut fallback_at = None;
291
292 for (index, chunk) in chunks.iter().enumerate() {
293 let offset_compensation = offset_from_audio_bytes(cumulative_audio_bytes);
294 match client.send_chunk_request(socket.stream_mut(), &options, chunk).await {
295 Ok(()) => {
296 loop {
297 match client
298 .read_chunk_frame(
299 socket.stream_mut(),
300 offset_compensation,
301 &mut buffered_events,
302 )
303 .await
304 {
305 Ok(ChunkFrame::Event(event)) => {
306 if let SynthesisEvent::Audio(chunk) = &event {
307 cumulative_audio_bytes += chunk.len();
308 audio_received = true;
309 }
310 yield event;
311 }
312 Ok(ChunkFrame::Continue) => {}
313 Ok(ChunkFrame::TurnEnd) => break,
314 Err(failure) => {
315 pool_log(&format!(
316 "reused socket read failure at chunk={index} retryable={}",
317 failure.retryable_on_fresh_connection
318 ));
319 socket.mark_dirty();
320 if index > 0 && failure.retryable_on_fresh_connection {
321 pool_log(&format!(
322 "fallback retryable frame failure at chunk={index}"
323 ));
324 fallback_at = Some(index);
325 } else {
326 pending_error = Some(failure.err);
327 }
328 break;
329 }
330 }
331 }
332 }
333 Err(failure) => {
334 pool_log(&format!(
335 "reused socket send failure at chunk={index} retryable={}",
336 failure.retryable_on_fresh_connection
337 ));
338 socket.mark_dirty();
339 if index > 0 && failure.retryable_on_fresh_connection {
340 pool_log(&format!(
341 "fallback retryable send failure at chunk={index}"
342 ));
343 fallback_at = Some(index);
344 break;
345 }
346 pending_error = Some(failure.err);
347 break;
348 }
349 }
350
351 if fallback_at.is_some() || pending_error.is_some() {
352 break;
353 }
354 }
355
356 drop(shared_socket);
357
358 if let Some(start_index) = fallback_at {
359 for chunk in chunks.iter().skip(start_index) {
360 stream_chunk_with_fresh_socket!(
361 chunk,
362 cumulative_audio_bytes,
363 audio_received,
364 pending_error,
365 &mut buffered_events
366 );
367 }
368 }
369 }
370 } else {
371 for chunk in &chunks {
372 stream_chunk_with_fresh_socket!(
373 chunk,
374 cumulative_audio_bytes,
375 audio_received,
376 pending_error,
377 &mut buffered_events
378 );
379 }
380 }
381
382 if let Some(err) = pending_error {
383 Err(err)?;
384 }
385
386 if !audio_received {
387 Err(Error::NoAudioReceived)?;
388 }
389 }))
390 }
391
392 pub async fn synthesize(
393 &self,
394 text: impl Into<String>,
395 options: SpeakOptions,
396 ) -> Result<SynthesisResult> {
397 let mut stream = self.stream(text, options).await?;
398 let mut audio = Vec::new();
399 let mut boundaries = Vec::new();
400
401 while let Some(event) = stream.next().await {
402 match event? {
403 SynthesisEvent::Audio(chunk) => audio.extend_from_slice(&chunk),
404 SynthesisEvent::Boundary(boundary) => boundaries.push(boundary),
405 }
406 }
407
408 Ok(SynthesisResult { audio, boundaries })
409 }
410
411 pub async fn save(
412 &self,
413 text: impl Into<String>,
414 options: SpeakOptions,
415 audio_path: impl AsRef<std::path::Path>,
416 srt_path: Option<impl AsRef<std::path::Path>>,
417 ) -> Result<SynthesisResult> {
418 let result = self.synthesize(text, options.clone()).await?;
419 fs::write(audio_path, &result.audio).await?;
420 if let Some(path) = srt_path {
421 let filtered = filter_boundaries(&result.boundaries, options.boundary);
422 fs::write(path, to_srt(&filtered)).await?;
423 }
424 Ok(result)
425 }
426
427 async fn send_chunk_request(
428 &self,
429 websocket: &mut WsStream,
430 options: &SpeakOptions,
431 chunk: &str,
432 ) -> std::result::Result<(), ChunkFailure> {
433 let config_message = speech_config_message(options.boundary);
434 let ssml_message = ssml_message(options, chunk).map_err(|err| ChunkFailure {
435 err,
436 retryable_on_fresh_connection: false,
437 })?;
438
439 debug_frame("send-config", config_message.as_bytes());
440 websocket
441 .send(tokio_tungstenite::tungstenite::Message::Text(
442 config_message.into(),
443 ))
444 .await
445 .map_err(|err| ChunkFailure {
446 err: err.into(),
447 retryable_on_fresh_connection: true,
448 })?;
449 debug_frame("send-ssml", ssml_message.as_bytes());
450 websocket
451 .send(tokio_tungstenite::tungstenite::Message::Text(
452 ssml_message.into(),
453 ))
454 .await
455 .map_err(|err| ChunkFailure {
456 err: err.into(),
457 retryable_on_fresh_connection: true,
458 })?;
459 Ok(())
460 }
461
462 async fn read_chunk_frame(
463 &self,
464 websocket: &mut WsStream,
465 offset_compensation: u64,
466 buffered_events: &mut VecDeque<SynthesisEvent>,
467 ) -> std::result::Result<ChunkFrame, ChunkFailure> {
468 if let Some(event) = buffered_events.pop_front() {
469 return Ok(ChunkFrame::Event(event));
470 }
471
472 let next = timeout(self.receive_timeout, websocket.next())
473 .await
474 .map_err(|_| ChunkFailure {
475 err: Error::UnexpectedResponse("websocket receive timeout"),
476 retryable_on_fresh_connection: false,
477 })?;
478 let Some(message) = next else {
479 return Err(ChunkFailure {
480 err: Error::UnexpectedResponse("websocket closed before turn end"),
481 retryable_on_fresh_connection: false,
482 });
483 };
484
485 match message {
486 Ok(tokio_tungstenite::tungstenite::Message::Text(text_frame)) => {
487 let data = text_frame.as_bytes();
488 debug_frame("text", data);
489 let header_end = data
490 .windows(4)
491 .position(|window| window == b"\r\n\r\n")
492 .ok_or(ChunkFailure {
493 err: Error::MissingHeaders,
494 retryable_on_fresh_connection: false,
495 })?;
496 let (headers, payload) =
497 parse_headers(data, header_end).map_err(|err| ChunkFailure {
498 err,
499 retryable_on_fresh_connection: false,
500 })?;
501 match headers.get("Path").map(String::as_str) {
502 Some("audio.metadata") => {
503 let events =
504 parse_metadata(payload, offset_compensation).map_err(|err| {
505 ChunkFailure {
506 err,
507 retryable_on_fresh_connection: false,
508 }
509 })?;
510 if events.is_empty() {
511 Ok(ChunkFrame::Continue)
512 } else {
513 buffered_events.extend(events);
514 Ok(ChunkFrame::Event(
515 buffered_events
516 .pop_front()
517 .expect("metadata buffer populated"),
518 ))
519 }
520 }
521 Some("turn.end") => Ok(ChunkFrame::TurnEnd),
522 Some("response") | Some("turn.start") => Ok(ChunkFrame::Continue),
523 Some(other) => Err(ChunkFailure {
524 err: Error::UnknownPath(other.to_owned()),
525 retryable_on_fresh_connection: false,
526 }),
527 None => Err(ChunkFailure {
528 err: Error::MissingHeaders,
529 retryable_on_fresh_connection: false,
530 }),
531 }
532 }
533 Ok(tokio_tungstenite::tungstenite::Message::Binary(frame)) => {
534 debug_frame("binary", &frame);
535 if frame.len() < 2 {
536 return Err(ChunkFailure {
537 err: Error::UnexpectedResponse("binary frame too short"),
538 retryable_on_fresh_connection: false,
539 });
540 }
541 let header_length = u16::from_be_bytes([frame[0], frame[1]]) as usize;
542 let (headers, payload) =
543 parse_binary_headers(&frame, header_length).map_err(|err| ChunkFailure {
544 err,
545 retryable_on_fresh_connection: false,
546 })?;
547 if headers.get("Path").map(String::as_str) != Some("audio") {
548 return Err(ChunkFailure {
549 err: Error::UnexpectedResponse("binary frame path was not audio"),
550 retryable_on_fresh_connection: false,
551 });
552 }
553 match headers.get("Content-Type").map(String::as_str) {
554 Some("audio/mpeg") => {
555 if payload.is_empty() {
556 return Err(ChunkFailure {
557 err: Error::UnexpectedResponse("audio frame missing payload"),
558 retryable_on_fresh_connection: false,
559 });
560 }
561 Ok(ChunkFrame::Event(SynthesisEvent::Audio(
562 Bytes::copy_from_slice(payload),
563 )))
564 }
565 None if payload.is_empty() => Ok(ChunkFrame::Continue),
566 None => Err(ChunkFailure {
567 err: Error::UnexpectedResponse(
568 "binary frame had payload without content type",
569 ),
570 retryable_on_fresh_connection: false,
571 }),
572 Some(_) => Err(ChunkFailure {
573 err: Error::UnexpectedResponse("unexpected binary content type"),
574 retryable_on_fresh_connection: false,
575 }),
576 }
577 }
578 Ok(tokio_tungstenite::tungstenite::Message::Close(frame)) => {
579 if std::env::var_os("EDGE_TTS_DEBUG").is_some() {
580 eprintln!("[edge-tts-debug] close: {frame:?}");
581 }
582 Err(ChunkFailure {
583 err: Error::UnexpectedResponse("websocket closed before turn end"),
584 retryable_on_fresh_connection: false,
585 })
586 }
587 Ok(
588 tokio_tungstenite::tungstenite::Message::Ping(_)
589 | tokio_tungstenite::tungstenite::Message::Pong(_)
590 | tokio_tungstenite::tungstenite::Message::Frame(_),
591 ) => Ok(ChunkFrame::Continue),
592 Err(err) => Err(ChunkFailure {
593 err: err.into(),
594 retryable_on_fresh_connection: false,
595 }),
596 }
597 }
598
599 async fn acquire_websocket(&self) -> Result<PooledWebsocket> {
600 if let Some(entry) = self.take_idle_websocket() {
601 pool_log("ws_pool hit");
602 self.ensure_warm_pool();
603 return Ok(PooledWebsocket {
604 stream: Some(entry.stream.clone().lock_owned().await),
605 entry: Some(entry),
606 reusable: true,
607 pool: Arc::clone(&self.ws_pool),
608 });
609 }
610
611 pool_log("ws_pool miss");
612 let stream = self.connect_websocket_fresh().await?;
613 let entry = self.ws_pool.insert_busy(stream);
614 self.ensure_warm_pool();
615 Ok(PooledWebsocket {
616 stream: Some(entry.stream.clone().lock_owned().await),
617 entry: Some(entry),
618 reusable: true,
619 pool: Arc::clone(&self.ws_pool),
620 })
621 }
622
623 fn take_idle_websocket(&self) -> Option<Arc<PoolEntry>> {
624 if self.ws_pool.target_idle == 0 {
625 pool_log("ws_pool disabled");
626 return None;
627 }
628
629 let mut state = self.ws_pool.state.lock().expect("websocket pool poisoned");
630 self.ws_pool.cleanup_expired_locked(&mut state, Instant::now());
631 for entry in &state.entries {
632 let mut entry_state = entry.state.lock().expect("pool entry poisoned");
633 if matches!(*entry_state, PoolEntryState::Idle { .. }) {
634 *entry_state = PoolEntryState::Busy;
635 pool_log("ws_pool took idle socket candidate");
636 return Some(Arc::clone(entry));
637 }
638 }
639 pool_log("ws_pool empty");
640 None
641 }
642
643 fn ensure_warm_pool(&self) {
644 if !self.ws_pool.warmup || self.ws_pool.target_idle == 0 {
645 return;
646 }
647
648 let to_spawn = {
649 let mut state = self.ws_pool.state.lock().expect("websocket pool poisoned");
650 self.ws_pool.cleanup_expired_locked(&mut state, Instant::now());
651 let missing = self.ws_pool.replenishment_needed(
652 self.ws_pool.idle_count_locked(&state),
653 state.warming,
654 );
655 state.warming += missing;
656 missing
657 };
658
659 for _ in 0..to_spawn {
660 let client = self.clone();
661 tokio::spawn(async move {
662 let stream = client.connect_websocket_fresh().await.ok();
663 {
664 let mut state = client
665 .ws_pool
666 .state
667 .lock()
668 .expect("websocket pool poisoned");
669 state.warming = state.warming.saturating_sub(1);
670 if let Some(stream) = stream {
671 if client.ws_pool.idle_count_locked(&state) < client.ws_pool.target_idle {
672 pool_log("ws_pool warmup added idle socket");
673 state.entries.push(client.ws_pool.new_idle_entry(stream));
674 }
675 }
676 }
677 client.ensure_warm_pool();
678 });
679 }
680 }
681
682 async fn connect_websocket_fresh(&self) -> Result<WsStream> {
683 let sec_ms_gec = generate_sec_ms_gec(SystemTime::now());
684 let muid = generate_muid();
685 let url = Url::parse(&format!(
686 "{}&ConnectionId={}&Sec-MS-GEC={sec_ms_gec}&Sec-MS-GEC-Version={}",
687 websocket_url(),
688 generate_connection_id(),
689 sec_ms_gec_version(),
690 ))
691 .map_err(|_| Error::UnexpectedResponse("invalid websocket url"))?;
692
693 let mut request = url.as_str().into_client_request()?;
694 for (name, value) in websocket_headers(&muid) {
695 request.headers_mut().insert(
696 http::header::HeaderName::from_bytes(name.as_bytes())
697 .map_err(|_| Error::UnexpectedResponse("invalid header name"))?,
698 http::HeaderValue::from_str(&value)
699 .map_err(|_| Error::UnexpectedResponse("invalid header value"))?,
700 );
701 }
702
703 let connect = timeout(self.connect_timeout, connect_async(request))
704 .await
705 .map_err(|_| Error::UnexpectedResponse("websocket connect timeout"))?;
706 let (stream, _) = connect?;
707 Ok(stream)
708 }
709}
710
711impl PooledWebsocket {
712 fn stream_mut(&mut self) -> &mut WsStream {
713 self.stream
714 .as_mut()
715 .expect("pooled websocket missing stream")
716 }
717
718 fn mark_dirty(&mut self) {
719 self.reusable = false;
720 }
721}
722
723impl Drop for PooledWebsocket {
724 fn drop(&mut self) {
725 let Some(_stream) = self.stream.take() else {
726 return;
727 };
728 let Some(entry) = self.entry.take() else {
729 return;
730 };
731 if !self.reusable || self.pool.target_idle == 0 {
732 self.pool.remove_entry(entry.id);
733 return;
734 }
735
736 let returned_at = Instant::now();
737 {
738 let mut entry_state = entry.state.lock().expect("pool entry poisoned");
739 *entry_state = PoolEntryState::Idle { returned_at };
740 }
741
742 let mut state = self.pool.state.lock().expect("websocket pool poisoned");
743 self.pool.cleanup_expired_locked(&mut state, returned_at);
744 let replaced = self.pool.trim_idle_locked(&mut state, entry.id);
745 if replaced {
746 pool_log("ws_pool replace oldest idle socket with recently used socket");
747 } else {
748 pool_log("ws_pool return socket to idle");
749 }
750 }
751}
752
753impl WsPool {
754 fn new_idle_entry(&self, stream: WsStream) -> Arc<PoolEntry> {
755 Arc::new(PoolEntry {
756 id: self.next_id.fetch_add(1, Ordering::Relaxed),
757 stream: Arc::new(AsyncMutex::new(stream)),
758 state: Mutex::new(PoolEntryState::Idle {
759 returned_at: Instant::now(),
760 }),
761 })
762 }
763
764 fn insert_busy(&self, stream: WsStream) -> Arc<PoolEntry> {
765 let entry = Arc::new(PoolEntry {
766 id: self.next_id.fetch_add(1, Ordering::Relaxed),
767 stream: Arc::new(AsyncMutex::new(stream)),
768 state: Mutex::new(PoolEntryState::Busy),
769 });
770 let mut state = self.state.lock().expect("websocket pool poisoned");
771 state.entries.push(Arc::clone(&entry));
772 entry
773 }
774
775 fn remove_entry(&self, entry_id: u64) {
776 let mut state = self.state.lock().expect("websocket pool poisoned");
777 state.entries.retain(|entry| entry.id != entry_id);
778 }
779
780 fn is_expired(&self, returned_at: Instant, now: Instant) -> bool {
781 now.saturating_duration_since(returned_at) >= self.idle_ttl
782 }
783
784 fn idle_count_locked(&self, state: &WsPoolState) -> usize {
785 state
786 .entries
787 .iter()
788 .filter(|entry| {
789 matches!(
790 *entry.state.lock().expect("pool entry poisoned"),
791 PoolEntryState::Idle { .. }
792 )
793 })
794 .count()
795 }
796
797 fn cleanup_expired_locked(&self, state: &mut WsPoolState, now: Instant) {
798 state.entries.retain(|entry| {
799 let entry_state = entry.state.lock().expect("pool entry poisoned");
800 match *entry_state {
801 PoolEntryState::Idle { returned_at } if self.is_expired(returned_at, now) => {
802 pool_log("ws_pool drop expired idle socket");
803 false
804 }
805 _ => true,
806 }
807 });
808 }
809
810 fn trim_idle_locked(&self, state: &mut WsPoolState, keep_entry_id: u64) -> bool {
811 let mut idle_entries = state
812 .entries
813 .iter()
814 .filter_map(|entry| {
815 let entry_state = entry.state.lock().expect("pool entry poisoned");
816 match *entry_state {
817 PoolEntryState::Idle { returned_at } => Some((entry.id, returned_at)),
818 PoolEntryState::Busy => None,
819 }
820 })
821 .collect::<Vec<_>>();
822
823 if idle_entries.len() <= self.target_idle {
824 return false;
825 }
826
827 idle_entries.sort_by_key(|(_, returned_at)| *returned_at);
828 let mut removed = false;
829 let overflow = idle_entries.len().saturating_sub(self.target_idle);
830 let mut to_remove = HashSet::with_capacity(overflow);
831 for (entry_id, _) in idle_entries {
832 if to_remove.len() == overflow {
833 break;
834 }
835 if entry_id == keep_entry_id {
836 continue;
837 }
838 to_remove.insert(entry_id);
839 }
840
841 if !to_remove.is_empty() {
842 state.entries.retain(|entry| {
843 let should_keep = !to_remove.contains(&entry.id);
844 if !should_keep {
845 removed = true;
846 }
847 should_keep
848 });
849 }
850
851 removed
852 }
853
854 fn replenishment_needed(&self, idle_len: usize, warming: usize) -> usize {
855 self.target_idle
856 .saturating_sub(idle_len.saturating_add(warming))
857 }
858}
859
860pub fn subtitles(events: &[BoundaryEvent]) -> String {
861 to_srt(events)
862}
863
864fn debug_frame(kind: &str, payload: &[u8]) {
865 if std::env::var_os("EDGE_TTS_DEBUG").is_some() {
866 eprintln!(
867 "[edge-tts-debug] {kind}: {}",
868 String::from_utf8_lossy(payload)
869 );
870 }
871}
872
873fn pool_log(message: &str) {
874 #[cfg(not(debug_assertions))]
875 let _ = message;
876
877 #[cfg(debug_assertions)]
878 eprintln!("{message}");
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884
885 #[test]
886 fn builder_defaults_enable_pooling_and_chunk_reuse() {
887 let builder = EdgeTtsClientBuilder::default();
888 assert_eq!(builder.ws_pool_size, 1);
889 assert_eq!(builder.ws_idle_ttl, Duration::from_secs(15));
890 assert!(builder.ws_warmup);
891 assert!(builder.request_chunk_reuse);
892 }
893
894 #[test]
895 fn pool_replenishment_respects_idle_and_warming_counts() {
896 let pool = WsPool {
897 target_idle: 2,
898 idle_ttl: Duration::from_secs(15),
899 warmup: true,
900 next_id: AtomicU64::new(1),
901 state: Mutex::new(WsPoolState::default()),
902 };
903
904 assert_eq!(pool.replenishment_needed(0, 0), 2);
905 assert_eq!(pool.replenishment_needed(1, 0), 1);
906 assert_eq!(pool.replenishment_needed(1, 1), 0);
907 assert_eq!(pool.replenishment_needed(2, 0), 0);
908 }
909
910 #[test]
911 fn idle_connection_ttl_only_applies_after_expiration() {
912 let pool = WsPool {
913 target_idle: 1,
914 idle_ttl: Duration::from_secs(15),
915 warmup: true,
916 next_id: AtomicU64::new(1),
917 state: Mutex::new(WsPoolState::default()),
918 };
919 let now = Instant::now();
920
921 assert!(!pool.is_expired(now - Duration::from_secs(14), now));
922 assert!(pool.is_expired(now - Duration::from_secs(15), now));
923 }
924}