asterisk_ari/ws/
client.rs

1use crate::config::Config;
2use crate::errors::AriError;
3use crate::ws::{models, params};
4use futures_util::{SinkExt, StreamExt as _};
5use rand::random;
6use std::time::Duration;
7use tokio_stream::wrappers::ReceiverStream;
8use tokio_stream::Stream;
9use tokio_tungstenite::connect_async;
10use tokio_tungstenite::tungstenite::Message;
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, error, info, trace, warn};
13use url::Url;
14#[derive(Debug)]
15pub struct Client {
16    config: Config,
17    stop_signal: CancellationToken,
18    _ws_join_handle: Option<tokio::task::JoinHandle<Result<(), AriError>>>,
19}
20
21impl Drop for Client {
22    fn drop(&mut self) {
23        self.stop_signal.cancel();
24    }
25}
26
27impl Client {
28    pub fn with_config(config: Config) -> Self {
29        Self {
30            config,
31            stop_signal: CancellationToken::new(),
32            _ws_join_handle: None,
33        }
34    }
35
36    /// Disconnects the WebSocket client and waits for the join handler to finish.
37    pub async fn disconnect(&mut self) -> Result<(), AriError> {
38        self.stop_signal.cancel();
39
40        if let Some(handle) = self._ws_join_handle.take() {
41            return handle.await.unwrap_or_else(|e| {
42                warn!("error when waiting for ws join handle: {:#?}", e);
43                Err(AriError::Internal(e.to_string()))
44            });
45        }
46
47        Ok(())
48    }
49
50    /// Connects to the ARI WebSocket and starts listening for events.
51    pub async fn connect(
52        &mut self,
53        request: params::ListenRequest,
54    ) -> Result<impl Stream<Item = models::Event>, AriError> {
55        let mut url = Url::parse(self.config.api_base.as_str())?;
56
57        url.set_scheme(if url.scheme().starts_with("https") {
58            "wss"
59        } else {
60            "ws"
61        })
62        .unwrap();
63
64        url.set_path("/ari/events");
65
66        url.query_pairs_mut()
67            .append_pair(
68                "api_key",
69                &format!("{}:{}", self.config.username, self.config.password),
70            )
71            .append_pair("app", request.app.as_str())
72            .append_pair(
73                "subscribeAll",
74                request.subscribe_all.unwrap_or(true).to_string().as_str(),
75            );
76
77        debug!("connecting to ws_url: {}", url);
78
79        let ws_stream = match connect_async(url.to_string()).await {
80            Ok((ws_stream, _)) => ws_stream,
81            Err(e) => {
82                warn!("error when connecting to the websocket: {:#?}", e);
83                return Err(e.into());
84            }
85        };
86        debug!("websocket connected");
87
88        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
89        let mut refresh_interval = tokio::time::interval(Duration::from_millis(5000));
90        let cancel_token = self.stop_signal.child_token();
91        let (tx, rx) = tokio::sync::mpsc::channel(100);
92
93        self._ws_join_handle = Some(tokio::spawn(async move {
94            let mut connected = true;
95
96            'outer: loop {
97                while connected {
98                    tokio::select! {
99                        _ = cancel_token.cancelled() => {
100                                if let Err(e) = ws_sender.close().await {
101                                    return Err(AriError::from(e));
102                                }
103                                debug!("WS connection closed due to cancellation");
104                                break 'outer;
105
106                        },
107                        msg = ws_receiver.next() =>  {
108
109                                let Some(msg) = msg else {
110                                    // If the receiver returns None, mark connection as lost.
111                                    connected = false;
112                                    continue;
113                                };
114
115                                match msg {
116                                    Ok(Message::Close(close_frame)) => {
117                                        warn!("Close message received: {:#?}", close_frame);
118                                        connected = false;
119                                        continue;
120                                    }
121                                    Ok(Message::Pong(_)) => {},
122                                    Ok(Message::Ping(data)) => {
123                                        let _ = ws_sender.send(Message::Pong(data)).await;
124                                    }
125                                    Ok(Message::Text(string_msg)) => {
126                                        trace!("WS Ari Event: {:#?}", string_msg);
127                                        match serde_json::from_str::<models::Event>(&string_msg) {
128                                            Ok(event) => {
129                                                if tx.send(event).await.is_err() {
130                                                    debug!("Receiver closed the connection. Stopping WS client");
131                                                    break;
132                                                }
133                                            }
134                                            Err(e) => warn!("error when deserializing ARI event: {:#?}. Event: {:#?}", e, string_msg),
135                                        }
136                                    }
137                                    Err(e) => {
138                                        warn!("Error when receiving websocket message: {:#?}", e);
139                                        connected = false;
140                                        continue;
141                                    }
142                                    _ => {}
143                                }
144
145                        },
146                        _ = refresh_interval.tick() => {
147
148                                let _ = ws_sender.send(Message::Ping(random::<[u8; 32]>().to_vec().into())).await;
149                                debug!("ARI connection ping sent");
150
151                        }
152                    }
153                }
154
155                let mut i = 0;
156                loop {
157                    i += 1;
158                    if cancel_token.is_cancelled() {
159                        debug!("Cancellation detected during reconnection attempts");
160                        break 'outer;
161                    }
162                    info!("Attempting to reconnect ({i})");
163
164                    match connect_async(url.to_string()).await {
165                        Ok((ws_stream, _)) => {
166                            info!("Reconnected successfully");
167                            connected = true;
168                            let (new_ws_sender, new_ws_receiver) = ws_stream.split();
169                            ws_sender = new_ws_sender;
170                            ws_receiver = new_ws_receiver;
171                            continue 'outer;
172                        }
173                        Err(e) => {
174                            error!("Failed to reconnect ({i}): {e}");
175                        }
176                    }
177                    tokio::time::sleep(std::cmp::min(
178                        Duration::from_millis(500 * i),
179                        Duration::from_secs(90),
180                    ))
181                    .await;
182                }
183            }
184
185            Ok(())
186        }));
187
188        Ok(ReceiverStream::new(rx))
189    }
190}