Skip to main content

edge_tts_rust/
client.rs

1use std::collections::{HashSet, VecDeque};
2use std::pin::Pin;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::{Arc, Mutex};
5use std::time::{Duration, Instant, SystemTime};
6
7use async_stream::try_stream;
8use bytes::Bytes;
9use futures_util::{SinkExt, Stream, StreamExt};
10use reqwest::Client;
11use tokio::fs;
12use tokio::sync::{Mutex as AsyncMutex, OwnedMutexGuard};
13use tokio::time::timeout;
14use tokio_tungstenite::tungstenite::client::IntoClientRequest;
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
16use url::Url;
17
18use crate::constants::{TEXT_CHUNK_LIMIT, voice_list_url, websocket_url};
19use crate::error::{Error, Result};
20use crate::options::SpeakOptions;
21use crate::protocol::{
22    generate_connection_id, generate_muid, generate_sec_ms_gec, offset_from_audio_bytes,
23    parse_binary_headers, parse_headers, parse_metadata, sec_ms_gec_version, speech_config_message,
24    split_text, ssml_message, voice_headers, websocket_headers,
25};
26use crate::subtitles::{filter_boundaries, to_srt};
27use crate::types::{BoundaryEvent, SynthesisEvent, SynthesisResult, Voice};
28
29type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
30
31pub type EventStream = Pin<Box<dyn Stream<Item = Result<SynthesisEvent>> + Send + Sync + 'static>>;
32
33#[derive(Debug, Clone)]
34pub struct EdgeTtsClient {
35    http: Client,
36    connect_timeout: Duration,
37    receive_timeout: Duration,
38    request_chunk_reuse: bool,
39    ws_pool: Arc<WsPool>,
40}
41
42#[derive(Debug, Clone)]
43pub struct EdgeTtsClientBuilder {
44    connect_timeout: Duration,
45    receive_timeout: Duration,
46    ws_pool_size: usize,
47    ws_idle_ttl: Duration,
48    ws_warmup: bool,
49    request_chunk_reuse: bool,
50}
51
52#[derive(Debug)]
53struct WsPool {
54    target_idle: usize,
55    idle_ttl: Duration,
56    warmup: bool,
57    next_id: AtomicU64,
58    state: Mutex<WsPoolState>,
59}
60
61#[derive(Debug, Default)]
62struct WsPoolState {
63    entries: Vec<Arc<PoolEntry>>,
64    warming: usize,
65}
66
67#[derive(Debug)]
68struct PoolEntry {
69    id: u64,
70    stream: Arc<AsyncMutex<WsStream>>,
71    state: Mutex<PoolEntryState>,
72}
73
74#[derive(Debug, Clone, Copy)]
75enum PoolEntryState {
76    Idle { returned_at: Instant },
77    Busy,
78}
79
80#[derive(Debug)]
81struct PooledWebsocket {
82    entry: Option<Arc<PoolEntry>>,
83    stream: Option<OwnedMutexGuard<WsStream>>,
84    reusable: bool,
85    pool: Arc<WsPool>,
86}
87
88#[derive(Debug)]
89struct ChunkFailure {
90    err: Error,
91    retryable_on_fresh_connection: bool,
92}
93
94#[derive(Debug)]
95enum ChunkFrame {
96    Event(SynthesisEvent),
97    Continue,
98    TurnEnd,
99}
100
101impl Default for EdgeTtsClientBuilder {
102    fn default() -> Self {
103        Self {
104            connect_timeout: Duration::from_secs(10),
105            receive_timeout: Duration::from_secs(60),
106            ws_pool_size: 1,
107            ws_idle_ttl: Duration::from_secs(15),
108            ws_warmup: true,
109            request_chunk_reuse: true,
110        }
111    }
112}
113
114impl EdgeTtsClientBuilder {
115    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
116        self.connect_timeout = timeout;
117        self
118    }
119
120    pub fn receive_timeout(mut self, timeout: Duration) -> Self {
121        self.receive_timeout = timeout;
122        self
123    }
124
125    pub fn ws_pool_size(mut self, size: usize) -> Self {
126        self.ws_pool_size = size;
127        self
128    }
129
130    pub fn ws_idle_ttl(mut self, ttl: Duration) -> Self {
131        self.ws_idle_ttl = ttl;
132        self
133    }
134
135    pub fn ws_warmup(mut self, enabled: bool) -> Self {
136        self.ws_warmup = enabled;
137        self
138    }
139
140    pub fn request_chunk_reuse(mut self, enabled: bool) -> Self {
141        self.request_chunk_reuse = enabled;
142        self
143    }
144
145    pub fn build(self) -> Result<EdgeTtsClient> {
146        let http = Client::builder()
147            .connect_timeout(self.connect_timeout)
148            .timeout(self.receive_timeout)
149            .use_rustls_tls()
150            .build()?;
151        let client = EdgeTtsClient {
152            http,
153            connect_timeout: self.connect_timeout,
154            receive_timeout: self.receive_timeout,
155            request_chunk_reuse: self.request_chunk_reuse,
156            ws_pool: Arc::new(WsPool {
157                target_idle: self.ws_pool_size,
158                idle_ttl: self.ws_idle_ttl,
159                warmup: self.ws_warmup,
160                next_id: AtomicU64::new(1),
161                state: Mutex::new(WsPoolState::default()),
162            }),
163        };
164        client.ensure_warm_pool();
165        Ok(client)
166    }
167}
168
169impl EdgeTtsClient {
170    pub fn builder() -> EdgeTtsClientBuilder {
171        EdgeTtsClientBuilder::default()
172    }
173
174    pub fn new() -> Result<Self> {
175        Self::builder().build()
176    }
177
178    pub async fn list_voices(&self) -> Result<Vec<Voice>> {
179        let sec_ms_gec = generate_sec_ms_gec(SystemTime::now());
180        let muid = generate_muid();
181        let mut request = self
182            .http
183            .get(format!(
184                "{}&Sec-MS-GEC={sec_ms_gec}&Sec-MS-GEC-Version={}",
185                voice_list_url(),
186                sec_ms_gec_version()
187            ))
188            .header("Cookie", format!("muid={muid};"))
189            .header("Accept-Encoding", "gzip, deflate, br, zstd")
190            .header("Accept-Language", "en-US,en;q=0.9");
191
192        for (name, value) in voice_headers() {
193            request = request.header(name, value);
194        }
195
196        Ok(request.send().await?.error_for_status()?.json().await?)
197    }
198
199    pub async fn stream(
200        &self,
201        text: impl Into<String>,
202        options: SpeakOptions,
203    ) -> Result<EventStream> {
204        options.validate()?;
205        let text = text.into();
206        let chunks = split_text(&text, TEXT_CHUNK_LIMIT)?;
207        let client = self.clone();
208
209        Ok(Box::pin(try_stream! {
210            macro_rules! stream_chunk_with_fresh_socket {
211                (
212                    $chunk:expr,
213                    $cumulative_audio_bytes:ident,
214                    $audio_received:ident,
215                    $pending_error:ident,
216                    $buffered_events:expr
217                ) => {{
218                    let offset_compensation = offset_from_audio_bytes($cumulative_audio_bytes);
219                    let mut socket = match client.acquire_websocket().await {
220                        Ok(socket) => socket,
221                        Err(err) => {
222                            $pending_error = Some(err);
223                            break;
224                        }
225                    };
226                    match client.send_chunk_request(socket.stream_mut(), &options, $chunk).await {
227                        Ok(()) => {
228                            loop {
229                                match client
230                                    .read_chunk_frame(
231                                        socket.stream_mut(),
232                                        offset_compensation,
233                                        $buffered_events,
234                                    )
235                                    .await
236                                {
237                                    Ok(ChunkFrame::Event(event)) => {
238                                        if let SynthesisEvent::Audio(chunk) = &event {
239                                            $cumulative_audio_bytes += chunk.len();
240                                            $audio_received = true;
241                                        }
242                                        yield event;
243                                    }
244                                    Ok(ChunkFrame::Continue) => {}
245                                    Ok(ChunkFrame::TurnEnd) => break,
246                                    Err(failure) => {
247                                        pool_log(&format!(
248                                            "fresh socket read failure retryable={}",
249                                            failure.retryable_on_fresh_connection
250                                        ));
251                                        socket.mark_dirty();
252                                        $pending_error = Some(failure.err);
253                                        break;
254                                    }
255                                }
256                            }
257                        }
258                        Err(failure) => {
259                            pool_log(&format!(
260                                "fresh socket send failure retryable={}",
261                                failure.retryable_on_fresh_connection
262                            ));
263                            socket.mark_dirty();
264                            $pending_error = Some(failure.err);
265                            break;
266                        }
267                    }
268
269                    if $pending_error.is_some() {
270                        break;
271                    }
272                }};
273            }
274
275            let mut cumulative_audio_bytes = 0usize;
276            let mut audio_received = false;
277            let mut pending_error = None;
278            let mut buffered_events = VecDeque::new();
279
280            if client.request_chunk_reuse {
281                let mut shared_socket = match client.acquire_websocket().await {
282                    Ok(socket) => Some(socket),
283                    Err(err) => {
284                        pending_error = Some(err);
285                        None
286                    }
287                };
288
289                if let Some(socket) = shared_socket.as_mut() {
290                    let mut fallback_at = None;
291
292                    for (index, chunk) in chunks.iter().enumerate() {
293                        let offset_compensation = offset_from_audio_bytes(cumulative_audio_bytes);
294                        match client.send_chunk_request(socket.stream_mut(), &options, chunk).await {
295                            Ok(()) => {
296                                loop {
297                                    match client
298                                        .read_chunk_frame(
299                                            socket.stream_mut(),
300                                            offset_compensation,
301                                            &mut buffered_events,
302                                        )
303                                        .await
304                                    {
305                                        Ok(ChunkFrame::Event(event)) => {
306                                            if let SynthesisEvent::Audio(chunk) = &event {
307                                                cumulative_audio_bytes += chunk.len();
308                                                audio_received = true;
309                                            }
310                                            yield event;
311                                        }
312                                        Ok(ChunkFrame::Continue) => {}
313                                        Ok(ChunkFrame::TurnEnd) => break,
314                                        Err(failure) => {
315                                            pool_log(&format!(
316                                                "reused socket read failure at chunk={index} retryable={}",
317                                                failure.retryable_on_fresh_connection
318                                            ));
319                                            socket.mark_dirty();
320                                            if index > 0 && failure.retryable_on_fresh_connection {
321                                                pool_log(&format!(
322                                                    "fallback retryable frame failure at chunk={index}"
323                                                ));
324                                                fallback_at = Some(index);
325                                            } else {
326                                                pending_error = Some(failure.err);
327                                            }
328                                            break;
329                                        }
330                                    }
331                                }
332                            }
333                            Err(failure) => {
334                                pool_log(&format!(
335                                    "reused socket send failure at chunk={index} retryable={}",
336                                    failure.retryable_on_fresh_connection
337                                ));
338                                socket.mark_dirty();
339                                if index > 0 && failure.retryable_on_fresh_connection {
340                                    pool_log(&format!(
341                                        "fallback retryable send failure at chunk={index}"
342                                    ));
343                                    fallback_at = Some(index);
344                                    break;
345                                }
346                                pending_error = Some(failure.err);
347                                break;
348                            }
349                        }
350
351                        if fallback_at.is_some() || pending_error.is_some() {
352                            break;
353                        }
354                    }
355
356                    drop(shared_socket);
357
358                    if let Some(start_index) = fallback_at {
359                        for chunk in chunks.iter().skip(start_index) {
360                            stream_chunk_with_fresh_socket!(
361                                chunk,
362                                cumulative_audio_bytes,
363                                audio_received,
364                                pending_error,
365                                &mut buffered_events
366                            );
367                        }
368                    }
369                }
370            } else {
371                for chunk in &chunks {
372                    stream_chunk_with_fresh_socket!(
373                        chunk,
374                        cumulative_audio_bytes,
375                        audio_received,
376                        pending_error,
377                        &mut buffered_events
378                    );
379                }
380            }
381
382            if let Some(err) = pending_error {
383                Err(err)?;
384            }
385
386            if !audio_received {
387                Err(Error::NoAudioReceived)?;
388            }
389        }))
390    }
391
392    pub async fn synthesize(
393        &self,
394        text: impl Into<String>,
395        options: SpeakOptions,
396    ) -> Result<SynthesisResult> {
397        let mut stream = self.stream(text, options).await?;
398        let mut audio = Vec::new();
399        let mut boundaries = Vec::new();
400
401        while let Some(event) = stream.next().await {
402            match event? {
403                SynthesisEvent::Audio(chunk) => audio.extend_from_slice(&chunk),
404                SynthesisEvent::Boundary(boundary) => boundaries.push(boundary),
405            }
406        }
407
408        Ok(SynthesisResult { audio, boundaries })
409    }
410
411    pub async fn save(
412        &self,
413        text: impl Into<String>,
414        options: SpeakOptions,
415        audio_path: impl AsRef<std::path::Path>,
416        srt_path: Option<impl AsRef<std::path::Path>>,
417    ) -> Result<SynthesisResult> {
418        let result = self.synthesize(text, options.clone()).await?;
419        fs::write(audio_path, &result.audio).await?;
420        if let Some(path) = srt_path {
421            let filtered = filter_boundaries(&result.boundaries, options.boundary);
422            fs::write(path, to_srt(&filtered)).await?;
423        }
424        Ok(result)
425    }
426
427    async fn send_chunk_request(
428        &self,
429        websocket: &mut WsStream,
430        options: &SpeakOptions,
431        chunk: &str,
432    ) -> std::result::Result<(), ChunkFailure> {
433        let config_message = speech_config_message(options.boundary);
434        let ssml_message = ssml_message(options, chunk).map_err(|err| ChunkFailure {
435            err,
436            retryable_on_fresh_connection: false,
437        })?;
438
439        debug_frame("send-config", config_message.as_bytes());
440        websocket
441            .send(tokio_tungstenite::tungstenite::Message::Text(
442                config_message.into(),
443            ))
444            .await
445            .map_err(|err| ChunkFailure {
446                err: err.into(),
447                retryable_on_fresh_connection: true,
448            })?;
449        debug_frame("send-ssml", ssml_message.as_bytes());
450        websocket
451            .send(tokio_tungstenite::tungstenite::Message::Text(
452                ssml_message.into(),
453            ))
454            .await
455            .map_err(|err| ChunkFailure {
456                err: err.into(),
457                retryable_on_fresh_connection: true,
458            })?;
459        Ok(())
460    }
461
462    async fn read_chunk_frame(
463        &self,
464        websocket: &mut WsStream,
465        offset_compensation: u64,
466        buffered_events: &mut VecDeque<SynthesisEvent>,
467    ) -> std::result::Result<ChunkFrame, ChunkFailure> {
468        if let Some(event) = buffered_events.pop_front() {
469            return Ok(ChunkFrame::Event(event));
470        }
471
472        let next = timeout(self.receive_timeout, websocket.next())
473            .await
474            .map_err(|_| ChunkFailure {
475                err: Error::UnexpectedResponse("websocket receive timeout"),
476                retryable_on_fresh_connection: false,
477            })?;
478        let Some(message) = next else {
479            return Err(ChunkFailure {
480                err: Error::UnexpectedResponse("websocket closed before turn end"),
481                retryable_on_fresh_connection: false,
482            });
483        };
484
485        match message {
486            Ok(tokio_tungstenite::tungstenite::Message::Text(text_frame)) => {
487                let data = text_frame.as_bytes();
488                debug_frame("text", data);
489                let header_end = data
490                    .windows(4)
491                    .position(|window| window == b"\r\n\r\n")
492                    .ok_or(ChunkFailure {
493                        err: Error::MissingHeaders,
494                        retryable_on_fresh_connection: false,
495                    })?;
496                let (headers, payload) =
497                    parse_headers(data, header_end).map_err(|err| ChunkFailure {
498                        err,
499                        retryable_on_fresh_connection: false,
500                    })?;
501                match headers.get("Path").map(String::as_str) {
502                    Some("audio.metadata") => {
503                        let events =
504                            parse_metadata(payload, offset_compensation).map_err(|err| {
505                                ChunkFailure {
506                                    err,
507                                    retryable_on_fresh_connection: false,
508                                }
509                            })?;
510                        if events.is_empty() {
511                            Ok(ChunkFrame::Continue)
512                        } else {
513                            buffered_events.extend(events);
514                            Ok(ChunkFrame::Event(
515                                buffered_events
516                                    .pop_front()
517                                    .expect("metadata buffer populated"),
518                            ))
519                        }
520                    }
521                    Some("turn.end") => Ok(ChunkFrame::TurnEnd),
522                    Some("response") | Some("turn.start") => Ok(ChunkFrame::Continue),
523                    Some(other) => Err(ChunkFailure {
524                        err: Error::UnknownPath(other.to_owned()),
525                        retryable_on_fresh_connection: false,
526                    }),
527                    None => Err(ChunkFailure {
528                        err: Error::MissingHeaders,
529                        retryable_on_fresh_connection: false,
530                    }),
531                }
532            }
533            Ok(tokio_tungstenite::tungstenite::Message::Binary(frame)) => {
534                debug_frame("binary", &frame);
535                if frame.len() < 2 {
536                    return Err(ChunkFailure {
537                        err: Error::UnexpectedResponse("binary frame too short"),
538                        retryable_on_fresh_connection: false,
539                    });
540                }
541                let header_length = u16::from_be_bytes([frame[0], frame[1]]) as usize;
542                let (headers, payload) =
543                    parse_binary_headers(&frame, header_length).map_err(|err| ChunkFailure {
544                        err,
545                        retryable_on_fresh_connection: false,
546                    })?;
547                if headers.get("Path").map(String::as_str) != Some("audio") {
548                    return Err(ChunkFailure {
549                        err: Error::UnexpectedResponse("binary frame path was not audio"),
550                        retryable_on_fresh_connection: false,
551                    });
552                }
553                match headers.get("Content-Type").map(String::as_str) {
554                    Some("audio/mpeg") => {
555                        if payload.is_empty() {
556                            return Err(ChunkFailure {
557                                err: Error::UnexpectedResponse("audio frame missing payload"),
558                                retryable_on_fresh_connection: false,
559                            });
560                        }
561                        Ok(ChunkFrame::Event(SynthesisEvent::Audio(
562                            Bytes::copy_from_slice(payload),
563                        )))
564                    }
565                    None if payload.is_empty() => Ok(ChunkFrame::Continue),
566                    None => Err(ChunkFailure {
567                        err: Error::UnexpectedResponse(
568                            "binary frame had payload without content type",
569                        ),
570                        retryable_on_fresh_connection: false,
571                    }),
572                    Some(_) => Err(ChunkFailure {
573                        err: Error::UnexpectedResponse("unexpected binary content type"),
574                        retryable_on_fresh_connection: false,
575                    }),
576                }
577            }
578            Ok(tokio_tungstenite::tungstenite::Message::Close(frame)) => {
579                if std::env::var_os("EDGE_TTS_DEBUG").is_some() {
580                    eprintln!("[edge-tts-debug] close: {frame:?}");
581                }
582                Err(ChunkFailure {
583                    err: Error::UnexpectedResponse("websocket closed before turn end"),
584                    retryable_on_fresh_connection: false,
585                })
586            }
587            Ok(
588                tokio_tungstenite::tungstenite::Message::Ping(_)
589                | tokio_tungstenite::tungstenite::Message::Pong(_)
590                | tokio_tungstenite::tungstenite::Message::Frame(_),
591            ) => Ok(ChunkFrame::Continue),
592            Err(err) => Err(ChunkFailure {
593                err: err.into(),
594                retryable_on_fresh_connection: false,
595            }),
596        }
597    }
598
599    async fn acquire_websocket(&self) -> Result<PooledWebsocket> {
600        if let Some(entry) = self.take_idle_websocket() {
601            pool_log("ws_pool hit");
602            self.ensure_warm_pool();
603            return Ok(PooledWebsocket {
604                stream: Some(entry.stream.clone().lock_owned().await),
605                entry: Some(entry),
606                reusable: true,
607                pool: Arc::clone(&self.ws_pool),
608            });
609        }
610
611        pool_log("ws_pool miss");
612        let stream = self.connect_websocket_fresh().await?;
613        let entry = self.ws_pool.insert_busy(stream);
614        self.ensure_warm_pool();
615        Ok(PooledWebsocket {
616            stream: Some(entry.stream.clone().lock_owned().await),
617            entry: Some(entry),
618            reusable: true,
619            pool: Arc::clone(&self.ws_pool),
620        })
621    }
622
623    fn take_idle_websocket(&self) -> Option<Arc<PoolEntry>> {
624        if self.ws_pool.target_idle == 0 {
625            pool_log("ws_pool disabled");
626            return None;
627        }
628
629        let mut state = self.ws_pool.state.lock().expect("websocket pool poisoned");
630        self.ws_pool.cleanup_expired_locked(&mut state, Instant::now());
631        for entry in &state.entries {
632            let mut entry_state = entry.state.lock().expect("pool entry poisoned");
633            if matches!(*entry_state, PoolEntryState::Idle { .. }) {
634                *entry_state = PoolEntryState::Busy;
635                pool_log("ws_pool took idle socket candidate");
636                return Some(Arc::clone(entry));
637            }
638        }
639        pool_log("ws_pool empty");
640        None
641    }
642
643    fn ensure_warm_pool(&self) {
644        if !self.ws_pool.warmup || self.ws_pool.target_idle == 0 {
645            return;
646        }
647
648        let to_spawn = {
649            let mut state = self.ws_pool.state.lock().expect("websocket pool poisoned");
650            self.ws_pool.cleanup_expired_locked(&mut state, Instant::now());
651            let missing = self.ws_pool.replenishment_needed(
652                self.ws_pool.idle_count_locked(&state),
653                state.warming,
654            );
655            state.warming += missing;
656            missing
657        };
658
659        for _ in 0..to_spawn {
660            let client = self.clone();
661            tokio::spawn(async move {
662                let stream = client.connect_websocket_fresh().await.ok();
663                {
664                    let mut state = client
665                        .ws_pool
666                        .state
667                        .lock()
668                        .expect("websocket pool poisoned");
669                    state.warming = state.warming.saturating_sub(1);
670                    if let Some(stream) = stream {
671                        if client.ws_pool.idle_count_locked(&state) < client.ws_pool.target_idle {
672                            pool_log("ws_pool warmup added idle socket");
673                            state.entries.push(client.ws_pool.new_idle_entry(stream));
674                        }
675                    }
676                }
677                client.ensure_warm_pool();
678            });
679        }
680    }
681
682    async fn connect_websocket_fresh(&self) -> Result<WsStream> {
683        let sec_ms_gec = generate_sec_ms_gec(SystemTime::now());
684        let muid = generate_muid();
685        let url = Url::parse(&format!(
686            "{}&ConnectionId={}&Sec-MS-GEC={sec_ms_gec}&Sec-MS-GEC-Version={}",
687            websocket_url(),
688            generate_connection_id(),
689            sec_ms_gec_version(),
690        ))
691        .map_err(|_| Error::UnexpectedResponse("invalid websocket url"))?;
692
693        let mut request = url.as_str().into_client_request()?;
694        for (name, value) in websocket_headers(&muid) {
695            request.headers_mut().insert(
696                http::header::HeaderName::from_bytes(name.as_bytes())
697                    .map_err(|_| Error::UnexpectedResponse("invalid header name"))?,
698                http::HeaderValue::from_str(&value)
699                    .map_err(|_| Error::UnexpectedResponse("invalid header value"))?,
700            );
701        }
702
703        let connect = timeout(self.connect_timeout, connect_async(request))
704            .await
705            .map_err(|_| Error::UnexpectedResponse("websocket connect timeout"))?;
706        let (stream, _) = connect?;
707        Ok(stream)
708    }
709}
710
711impl PooledWebsocket {
712    fn stream_mut(&mut self) -> &mut WsStream {
713        self.stream
714            .as_mut()
715            .expect("pooled websocket missing stream")
716    }
717
718    fn mark_dirty(&mut self) {
719        self.reusable = false;
720    }
721}
722
723impl Drop for PooledWebsocket {
724    fn drop(&mut self) {
725        let Some(_stream) = self.stream.take() else {
726            return;
727        };
728        let Some(entry) = self.entry.take() else {
729            return;
730        };
731        if !self.reusable || self.pool.target_idle == 0 {
732            self.pool.remove_entry(entry.id);
733            return;
734        }
735
736        let returned_at = Instant::now();
737        {
738            let mut entry_state = entry.state.lock().expect("pool entry poisoned");
739            *entry_state = PoolEntryState::Idle { returned_at };
740        }
741
742        let mut state = self.pool.state.lock().expect("websocket pool poisoned");
743        self.pool.cleanup_expired_locked(&mut state, returned_at);
744        let replaced = self.pool.trim_idle_locked(&mut state, entry.id);
745        if replaced {
746            pool_log("ws_pool replace oldest idle socket with recently used socket");
747        } else {
748            pool_log("ws_pool return socket to idle");
749        }
750    }
751}
752
753impl WsPool {
754    fn new_idle_entry(&self, stream: WsStream) -> Arc<PoolEntry> {
755        Arc::new(PoolEntry {
756            id: self.next_id.fetch_add(1, Ordering::Relaxed),
757            stream: Arc::new(AsyncMutex::new(stream)),
758            state: Mutex::new(PoolEntryState::Idle {
759                returned_at: Instant::now(),
760            }),
761        })
762    }
763
764    fn insert_busy(&self, stream: WsStream) -> Arc<PoolEntry> {
765        let entry = Arc::new(PoolEntry {
766            id: self.next_id.fetch_add(1, Ordering::Relaxed),
767            stream: Arc::new(AsyncMutex::new(stream)),
768            state: Mutex::new(PoolEntryState::Busy),
769        });
770        let mut state = self.state.lock().expect("websocket pool poisoned");
771        state.entries.push(Arc::clone(&entry));
772        entry
773    }
774
775    fn remove_entry(&self, entry_id: u64) {
776        let mut state = self.state.lock().expect("websocket pool poisoned");
777        state.entries.retain(|entry| entry.id != entry_id);
778    }
779
780    fn is_expired(&self, returned_at: Instant, now: Instant) -> bool {
781        now.saturating_duration_since(returned_at) >= self.idle_ttl
782    }
783
784    fn idle_count_locked(&self, state: &WsPoolState) -> usize {
785        state
786            .entries
787            .iter()
788            .filter(|entry| {
789                matches!(
790                    *entry.state.lock().expect("pool entry poisoned"),
791                    PoolEntryState::Idle { .. }
792                )
793            })
794            .count()
795    }
796
797    fn cleanup_expired_locked(&self, state: &mut WsPoolState, now: Instant) {
798        state.entries.retain(|entry| {
799            let entry_state = entry.state.lock().expect("pool entry poisoned");
800            match *entry_state {
801                PoolEntryState::Idle { returned_at } if self.is_expired(returned_at, now) => {
802                    pool_log("ws_pool drop expired idle socket");
803                    false
804                }
805                _ => true,
806            }
807        });
808    }
809
810    fn trim_idle_locked(&self, state: &mut WsPoolState, keep_entry_id: u64) -> bool {
811        let mut idle_entries = state
812            .entries
813            .iter()
814            .filter_map(|entry| {
815                let entry_state = entry.state.lock().expect("pool entry poisoned");
816                match *entry_state {
817                    PoolEntryState::Idle { returned_at } => Some((entry.id, returned_at)),
818                    PoolEntryState::Busy => None,
819                }
820            })
821            .collect::<Vec<_>>();
822
823        if idle_entries.len() <= self.target_idle {
824            return false;
825        }
826
827        idle_entries.sort_by_key(|(_, returned_at)| *returned_at);
828        let mut removed = false;
829        let overflow = idle_entries.len().saturating_sub(self.target_idle);
830        let mut to_remove = HashSet::with_capacity(overflow);
831        for (entry_id, _) in idle_entries {
832            if to_remove.len() == overflow {
833                break;
834            }
835            if entry_id == keep_entry_id {
836                continue;
837            }
838            to_remove.insert(entry_id);
839        }
840
841        if !to_remove.is_empty() {
842            state.entries.retain(|entry| {
843                let should_keep = !to_remove.contains(&entry.id);
844                if !should_keep {
845                    removed = true;
846                }
847                should_keep
848            });
849        }
850
851        removed
852    }
853
854    fn replenishment_needed(&self, idle_len: usize, warming: usize) -> usize {
855        self.target_idle
856            .saturating_sub(idle_len.saturating_add(warming))
857    }
858}
859
860pub fn subtitles(events: &[BoundaryEvent]) -> String {
861    to_srt(events)
862}
863
864fn debug_frame(kind: &str, payload: &[u8]) {
865    if std::env::var_os("EDGE_TTS_DEBUG").is_some() {
866        eprintln!(
867            "[edge-tts-debug] {kind}: {}",
868            String::from_utf8_lossy(payload)
869        );
870    }
871}
872
873fn pool_log(message: &str) {
874    #[cfg(not(debug_assertions))]
875    let _ = message;
876
877    #[cfg(debug_assertions)]
878    eprintln!("{message}");
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884
885    #[test]
886    fn builder_defaults_enable_pooling_and_chunk_reuse() {
887        let builder = EdgeTtsClientBuilder::default();
888        assert_eq!(builder.ws_pool_size, 1);
889        assert_eq!(builder.ws_idle_ttl, Duration::from_secs(15));
890        assert!(builder.ws_warmup);
891        assert!(builder.request_chunk_reuse);
892    }
893
894    #[test]
895    fn pool_replenishment_respects_idle_and_warming_counts() {
896        let pool = WsPool {
897            target_idle: 2,
898            idle_ttl: Duration::from_secs(15),
899            warmup: true,
900            next_id: AtomicU64::new(1),
901            state: Mutex::new(WsPoolState::default()),
902        };
903
904        assert_eq!(pool.replenishment_needed(0, 0), 2);
905        assert_eq!(pool.replenishment_needed(1, 0), 1);
906        assert_eq!(pool.replenishment_needed(1, 1), 0);
907        assert_eq!(pool.replenishment_needed(2, 0), 0);
908    }
909
910    #[test]
911    fn idle_connection_ttl_only_applies_after_expiration() {
912        let pool = WsPool {
913            target_idle: 1,
914            idle_ttl: Duration::from_secs(15),
915            warmup: true,
916            next_id: AtomicU64::new(1),
917            state: Mutex::new(WsPoolState::default()),
918        };
919        let now = Instant::now();
920
921        assert!(!pool.is_expired(now - Duration::from_secs(14), now));
922        assert!(pool.is_expired(now - Duration::from_secs(15), now));
923    }
924}