ms_teams_ws/
lib.rs

1mod messages;
2mod types;
3
4use crate::messages::{ClientMessage, ServerMessage};
5use crate::types::AppIdentifiers;
6use futures_util::SinkExt;
7use futures_util::StreamExt;
8use std::error::Error;
9use tokio_tungstenite::connect_async;
10use url::Url;
11
12/// A struct representing a WebSocket connection to a Microsoft Teams server.
13///
14/// # Fields
15/// - `identifier`: An `AppIdentifiers` struct containing information about the app.
16/// - `socket`: An optional WebSocket stream.
17/// - `token`: An optional authentication token.
18/// - `request_id`: A counter for request IDs.
19/// - `url`: The URL of the WebSocket server.
20///
21/// # Methods
22/// - `new`: Creates a new `TeamsWebsocket` instance.
23/// - `connect`: Connects to the WebSocket server.
24/// - `send`: Sends a `ClientMessage` to the server.
25/// - `receive`: Receives a `ServerMessage` from the server.
26/// - `close`: Closes the WebSocket connection.
27///
28/// # Example
29/// ```rust
30/// let identifier = AppIdentifiers {
31///     protocol_version: "1.0",
32///     manufacturer: "TestManufacturer",
33///     device: "TestDevice",
34///     app: "TestApp",
35///     app_version: "1.0",
36/// };
37/// let mut websocket = TeamsWebsocket::new(identifier, None, None).await;
38/// websocket.connect().await.unwrap();
39/// let client_message = ClientMessage::new(messages::MeetingAction::BlurBackground, None);
40/// websocket.send(client_message).await.unwrap();
41/// let server_message = websocket.receive().await.unwrap();
42/// websocket.close().await.unwrap();
43/// ```
44pub struct TeamsWebsocket {
45    identifier: AppIdentifiers,
46    socket: Option<
47        tokio_tungstenite::WebSocketStream<
48            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49        >,
50    >,
51    token: Option<String>,
52    request_id: u32,
53    url: String,
54}
55
56const SOCKET_NOT_CONNECTED: &str = "socket not connected";
57
58impl TeamsWebsocket {
59    pub async fn new(
60        identifier: AppIdentifiers,
61        token: Option<String>,
62        url: Option<String>,
63    ) -> Self {
64        Self {
65            identifier,
66            socket: None,
67            token,
68            request_id: 0,
69            url: url.unwrap_or_else(|| "ws://127.0.0.1:8124".to_string()),
70        }
71    }
72
73    /// Connects to the WebSocket server using the provided URL and parameters.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if the URL cannot be parsed or if the connection attempt fails.
78    ///
79    /// # Examples
80    ///
81    /// ```
82    /// let mut websocket = TeamsWebsocket::new(identifier, token, url).await;
83    /// match websocket.connect().await {
84    ///     Ok(_) => println!("Connected successfully"),
85    ///     Err(e) => eprintln!("Failed to connect: {}", e),
86    /// }
87    /// ```
88    pub async fn connect(&mut self) -> Result<(), Box<dyn Error>> {
89        let url = Url::parse_with_params(
90            &self.url,
91            &[
92                ("protocol-version", self.identifier.protocol_version),
93                ("manufacturer", self.identifier.manufacturer),
94                ("device", self.identifier.device),
95                ("app", self.identifier.app),
96                ("app-version", self.identifier.app_version),
97                ("token", self.token.as_deref().unwrap_or("")),
98            ],
99        );
100        if let Err(e) = url {
101            log::warn!("Error parsing url: {}", e);
102            return Err(Box::new(e));
103        }
104        let url = url.unwrap();
105
106        let (socket, response) = match connect_async(url.as_str()).await {
107            Ok((socket, response)) => (socket, response),
108            Err(e) => {
109                log::warn!("Error: {}", e);
110                return Err(Box::new(e));
111            }
112        };
113
114        if log::log_enabled!(log::Level::Debug) {
115            log::debug!("Connected to the server");
116            log::debug!("Response HTTP code: {}", response.status());
117            log::debug!("Response contains the following headers:");
118            for (header, _value) in response.headers() {
119                log::trace!("* {header}");
120            }
121        }
122        self.socket = Some(socket);
123        Ok(())
124    }
125    
126    /// Sends a `ClientMessage` to Teams.
127    ///
128    /// # Arguments
129    ///
130    /// * `message` - The `ClientMessage` to be sent.
131    ///
132    /// # Errors
133    ///
134    /// Returns an error if the WebSocket connection is not established, if the message cannot be serialized, or if there is an error sending the message.
135    ///
136    /// # Examples
137    ///
138    /// 
139    pub async fn send(&mut self, message: ClientMessage) -> Result<(), Box<dyn Error>> {
140        if let Some(socket) = &mut self.socket {
141            let mut message = message;
142            message.request_id = Some(self.request_id);
143            self.request_id += 1;
144            let serialized_message = serde_json::to_string(&message);
145            log::debug!("Sending message: {:?}", serialized_message);
146            match serialized_message {
147                Ok(msg) => {
148                    if let Err(e) = socket
149                    .send(tungstenite::Message::Text(msg))
150                    .await
151                    {
152                        log::warn!("Error sending message: {}", e);
153                        return Err(Box::new(e));
154                    }
155                }
156                Err(e) => {
157                    log::warn!("Error serializing message: {}", e);
158                    return Err(Box::new(e));
159                }
160            } 
161            return Ok(());
162        }
163        log::warn!("{}", SOCKET_NOT_CONNECTED);
164        Err(Box::from(SOCKET_NOT_CONNECTED))
165        
166    }
167
168    pub async fn receive(&mut self) -> Result<ServerMessage, Box<dyn Error>> {
169        if let Some(socket) = &mut self.socket {
170            match socket.next().await {
171                Some(Ok(msg)) => {
172                    let server_message =
173                        serde_json::from_str::<ServerMessage>(msg.to_text().unwrap());
174                    match server_message {
175                        Ok(json) => {
176                            Ok(json)
177                        }
178                        Err(e) => {
179                            log::warn!("Error parsing json : {}", e);
180                            Err(Box::new(e))
181                        }
182                    }
183                },
184                Some(Err(e)) => {
185                    log::warn!("Error reading from socket {}", e);
186                    Err(Box::new(e))
187                }
188                None => {
189                    log::info!("Socket closed");
190                    Err(Box::from("socket closed"))
191                }
192            }
193        } else {
194            log::warn!("{}", SOCKET_NOT_CONNECTED);
195            Err(Box::from(SOCKET_NOT_CONNECTED))
196        }
197    }
198
199    pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
200        if let Some(socket) = &mut self.socket {
201            if let Err(e) = socket.close(None).await {
202                log::warn!("Error closing socket: {}", e);
203                return Err(Box::new(e));
204            }
205            log::info!("Connection closed");
206            Ok(())
207        } else {
208            log::warn!("{}", SOCKET_NOT_CONNECTED);
209            Err(Box::from(SOCKET_NOT_CONNECTED))
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use rand::Rng;
218    use std::net::SocketAddr;
219    use tokio::net::TcpListener;
220    use tokio::runtime::Runtime;
221    use tokio_tungstenite::accept_async;
222    use tokio_tungstenite::tungstenite::protocol::Message;
223
224    #[test]
225    fn test_teams_websocket_new() {
226        let rt = Runtime::new().unwrap();
227        rt.block_on(async {
228            let identifier = AppIdentifiers {
229                protocol_version: "1.0",
230                manufacturer: "TestManufacturer",
231                device: "TestDevice",
232                app: "TestApp",
233                app_version: "1.0",
234            };
235            let websocket = TeamsWebsocket::new(identifier.clone(), None, None).await;
236            assert_eq!(websocket.identifier, identifier);
237            assert!(websocket.socket.is_none());
238            assert!(websocket.token.is_none());
239            assert_eq!(websocket.request_id, 0);
240        });
241    }
242    async fn start_test_server() -> SocketAddr {
243        let mut rng = rand::thread_rng();
244        let port: u16 = rng.gen_range(1024..65535);
245        let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
246            .await
247            .unwrap();
248        let addr = listener.local_addr().unwrap();
249        tokio::spawn(async move {
250            while let Ok((stream, _)) = listener.accept().await {
251                let ws_stream = accept_async(stream).await.unwrap();
252                let (mut write, mut read) = ws_stream.split();
253                tokio::spawn(async move {
254                    while let Some(Ok(msg)) = read.next().await {
255                        if let Message::Text(text) = msg {
256                            let client_message: ClientMessage =
257                                serde_json::from_str(&text).unwrap();
258                            let server_message = ServerMessage {
259                                request_id: client_message.request_id,
260                                response: Some(format!("Echo: {}", text)),
261                                error_msg: None,
262                                token_refresh: None,
263                                meeting_update: None,
264                            };
265                            let response = serde_json::to_string(&server_message).unwrap();
266                            write.send(Message::Text(response)).await.unwrap();
267                        }
268                    }
269                });
270            }
271        });
272        addr
273    }
274
275    #[test]
276    fn test_teams_websocket_connect() {
277        let rt = Runtime::new().unwrap();
278        rt.block_on(async {
279            let identifier = AppIdentifiers {
280                protocol_version: "1.0",
281                manufacturer: "TestManufacturer",
282                device: "TestDevice",
283                app: "TestApp",
284                app_version: "1.0",
285            };
286            let addr = start_test_server().await;
287            let url = format!("ws://{}", addr);
288            let mut websocket = TeamsWebsocket::new(identifier.clone(), None, Some(url)).await;
289            let result = websocket.connect().await;
290            assert!(result.is_ok());
291            assert!(websocket.socket.is_some());
292        });
293    }
294
295    #[test]
296    fn test_teams_websocket_send_receive() {
297        let rt = Runtime::new().unwrap();
298        rt.block_on(async {
299            let identifier = AppIdentifiers {
300                protocol_version: "1.0",
301                manufacturer: "TestManufacturer",
302                device: "TestDevice",
303                app: "TestApp",
304                app_version: "1.0",
305            };
306            let addr = start_test_server().await;
307            let url = format!("ws://{}", addr);
308            let mut websocket = TeamsWebsocket::new(identifier.clone(), None, Some(url)).await;
309            websocket.connect().await.unwrap();
310
311            let client_message = ClientMessage::new(messages::MeetingAction::BlurBackground, None);
312            websocket.send(client_message).await.unwrap();
313
314            let server_message = websocket.receive().await.unwrap();
315            assert_eq!(
316                server_message.response,
317                Some(
318                    "Echo: {\"action\":\"blur-background\",\"parameters\":null,\"requestId\":0}"
319                        .to_string()
320                )
321            );
322        });
323    }
324}