asterisk_ari/ws/
client.rs

1use crate::config::Config;
2use crate::errors::AriError;
3use crate::ws::{models, params};
4use futures_util::{SinkExt, StreamExt as _};
5use rand::random;
6use std::time::Duration;
7use tokio::time::interval;
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::Stream;
10use tokio_tungstenite::connect_async;
11use tokio_tungstenite::tungstenite::Message;
12use tokio_util::sync::CancellationToken;
13use tracing::{debug, trace, warn};
14use url::Url;
15
16#[derive(Debug)]
17pub struct Client {
18    config: Config,
19    stop_signal: CancellationToken,
20    _ws_join_handle: Option<tokio::task::JoinHandle<Result<(), AriError>>>,
21}
22
23impl Drop for Client {
24    fn drop(&mut self) {
25        self.stop_signal.cancel();
26    }
27}
28
29impl Client {
30    pub fn with_config(config: Config) -> Self {
31        Self {
32            config,
33            stop_signal: CancellationToken::new(),
34            _ws_join_handle: None,
35        }
36    }
37
38    /// Disconnects the WebSocket client and waits for the join handler to finish.
39    pub async fn disconnect(&mut self) -> Result<(), AriError> {
40        self.stop_signal.cancel();
41
42        // Acquire the lock and take the join handle out of the Option.
43        if let Some(handle) = self._ws_join_handle.take() {
44            return handle.await.unwrap_or_else(|e| {
45                warn!("error when waiting for ws join handle: {:#?}", e);
46                Err(AriError::Internal(e.to_string()))
47            });
48        };
49
50        Ok(())
51    }
52
53    /// Connects to the ARI WebSocket and starts listening for events.
54    pub async fn connect(
55        &mut self,
56        request: params::ListenRequest,
57    ) -> Result<impl Stream<Item = models::Event>, AriError> {
58        let mut url = Url::parse(self.config.api_base.as_str())?;
59
60        url.set_scheme(if url.scheme().starts_with("https") {
61            "wss"
62        } else {
63            "ws"
64        })
65        .unwrap();
66
67        url.set_path("/ari/events");
68
69        url.query_pairs_mut()
70            .append_pair(
71                "api_key",
72                &format!("{}:{}", self.config.username, self.config.password),
73            )
74            .append_pair("app", request.app.as_str())
75            .append_pair(
76                "subscribeAll",
77                request.subscribe_all.unwrap_or(true).to_string().as_str(),
78            );
79
80        debug!("connecting to ws_url: {}", url);
81
82        let ws_stream = match connect_async(url.to_string()).await {
83            Ok((ws_stream, _)) => ws_stream,
84            Err(e) => {
85                warn!("error when connecting to the websocket: {:#?}", e);
86                return Err(e.into());
87            }
88        };
89        debug!("websocket connected");
90
91        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
92        let mut refresh_interval = interval(Duration::from_millis(5000));
93        let cancel_token = self.stop_signal.child_token();
94        let (tx, rx) = tokio::sync::mpsc::channel(100);
95
96        // Store the join handle.
97        self._ws_join_handle = Some(tokio::spawn(async move {
98            loop {
99                tokio::select! {
100                    _ = cancel_token.cancelled() => {
101                        match ws_sender.close().await {
102                                Ok(_) => {
103                                    debug!("WS connection closed");
104                                    break;
105                                },
106                                Err(e) => return Err(AriError::from(e)),
107                            }
108                    },
109                    msg = ws_receiver.next() => {
110                        match msg {
111                            Some(msg) => {
112                                match msg {
113                                    Ok(Message::Close(close_frame)) => {
114                                        debug!("Close message received, leaving the loop! {:#?}", close_frame);
115                                        break;
116                                    }
117                                    Ok(Message::Pong(_)) => {},
118                                    Ok(Message::Ping(data)) => {
119                                        let _ = ws_sender.send(Message::Pong(data)).await;
120                                    }
121                                    Ok(Message::Text(string_msg)) => {
122                                        trace!("WS Ari Event: {:#?}", string_msg);
123                                        match serde_json::from_str::<models::Event>(&string_msg) {
124                                            Ok(event) => {
125                                                if tx.send(event).await.is_err() {
126                                                    warn!("error when sending ARI event to the channel");
127                                                    break;
128                                                }
129                                            }
130                                            Err(e) => warn!("error when deserializing ARI event: {:#?}. Event: {:#?}", e, string_msg),
131                                        }
132                                    }
133                                    Err(e) => {
134                                        warn!("Error when receiving websocket message: {:#?}", e);
135                                        break;
136                                    }
137                                    _ => {
138                                        warn!("Unknown websocket message received: {:#?}", msg);
139                                    }
140                                }
141                            }
142                            None => break,
143                        }
144                    },
145                    _ = refresh_interval.tick() => {
146                        let _ = ws_sender.send(Message::Ping(random::<[u8; 32]>().to_vec().into())).await;
147                        debug!("ari connection ping sent");
148                    }
149                }
150            }
151
152            Ok(())
153        }));
154
155        Ok(ReceiverStream::new(rx))
156    }
157}