1mod messages;
2mod types;
3
4use crate::messages::{ClientMessage, ServerMessage};
5use crate::types::AppIdentifiers;
6use futures_util::SinkExt;
7use futures_util::StreamExt;
8use std::error::Error;
9use tokio_tungstenite::connect_async;
10use url::Url;
11
12pub struct TeamsWebsocket {
45 identifier: AppIdentifiers,
46 socket: Option<
47 tokio_tungstenite::WebSocketStream<
48 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
49 >,
50 >,
51 token: Option<String>,
52 request_id: u32,
53 url: String,
54}
55
56const SOCKET_NOT_CONNECTED: &str = "socket not connected";
57
58impl TeamsWebsocket {
59 pub async fn new(
60 identifier: AppIdentifiers,
61 token: Option<String>,
62 url: Option<String>,
63 ) -> Self {
64 Self {
65 identifier,
66 socket: None,
67 token,
68 request_id: 0,
69 url: url.unwrap_or_else(|| "ws://127.0.0.1:8124".to_string()),
70 }
71 }
72
73 pub async fn connect(&mut self) -> Result<(), Box<dyn Error>> {
89 let url = Url::parse_with_params(
90 &self.url,
91 &[
92 ("protocol-version", self.identifier.protocol_version),
93 ("manufacturer", self.identifier.manufacturer),
94 ("device", self.identifier.device),
95 ("app", self.identifier.app),
96 ("app-version", self.identifier.app_version),
97 ("token", self.token.as_deref().unwrap_or("")),
98 ],
99 );
100 if let Err(e) = url {
101 log::warn!("Error parsing url: {}", e);
102 return Err(Box::new(e));
103 }
104 let url = url.unwrap();
105
106 let (socket, response) = match connect_async(url.as_str()).await {
107 Ok((socket, response)) => (socket, response),
108 Err(e) => {
109 log::warn!("Error: {}", e);
110 return Err(Box::new(e));
111 }
112 };
113
114 if log::log_enabled!(log::Level::Debug) {
115 log::debug!("Connected to the server");
116 log::debug!("Response HTTP code: {}", response.status());
117 log::debug!("Response contains the following headers:");
118 for (header, _value) in response.headers() {
119 log::trace!("* {header}");
120 }
121 }
122 self.socket = Some(socket);
123 Ok(())
124 }
125
126 pub async fn send(&mut self, message: ClientMessage) -> Result<(), Box<dyn Error>> {
140 if let Some(socket) = &mut self.socket {
141 let mut message = message;
142 message.request_id = Some(self.request_id);
143 self.request_id += 1;
144 let serialized_message = serde_json::to_string(&message);
145 log::debug!("Sending message: {:?}", serialized_message);
146 match serialized_message {
147 Ok(msg) => {
148 if let Err(e) = socket
149 .send(tungstenite::Message::Text(msg))
150 .await
151 {
152 log::warn!("Error sending message: {}", e);
153 return Err(Box::new(e));
154 }
155 }
156 Err(e) => {
157 log::warn!("Error serializing message: {}", e);
158 return Err(Box::new(e));
159 }
160 }
161 return Ok(());
162 }
163 log::warn!("{}", SOCKET_NOT_CONNECTED);
164 Err(Box::from(SOCKET_NOT_CONNECTED))
165
166 }
167
168 pub async fn receive(&mut self) -> Result<ServerMessage, Box<dyn Error>> {
169 if let Some(socket) = &mut self.socket {
170 match socket.next().await {
171 Some(Ok(msg)) => {
172 let server_message =
173 serde_json::from_str::<ServerMessage>(msg.to_text().unwrap());
174 match server_message {
175 Ok(json) => {
176 Ok(json)
177 }
178 Err(e) => {
179 log::warn!("Error parsing json : {}", e);
180 Err(Box::new(e))
181 }
182 }
183 },
184 Some(Err(e)) => {
185 log::warn!("Error reading from socket {}", e);
186 Err(Box::new(e))
187 }
188 None => {
189 log::info!("Socket closed");
190 Err(Box::from("socket closed"))
191 }
192 }
193 } else {
194 log::warn!("{}", SOCKET_NOT_CONNECTED);
195 Err(Box::from(SOCKET_NOT_CONNECTED))
196 }
197 }
198
199 pub async fn close(&mut self) -> Result<(), Box<dyn Error>> {
200 if let Some(socket) = &mut self.socket {
201 if let Err(e) = socket.close(None).await {
202 log::warn!("Error closing socket: {}", e);
203 return Err(Box::new(e));
204 }
205 log::info!("Connection closed");
206 Ok(())
207 } else {
208 log::warn!("{}", SOCKET_NOT_CONNECTED);
209 Err(Box::from(SOCKET_NOT_CONNECTED))
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use rand::Rng;
218 use std::net::SocketAddr;
219 use tokio::net::TcpListener;
220 use tokio::runtime::Runtime;
221 use tokio_tungstenite::accept_async;
222 use tokio_tungstenite::tungstenite::protocol::Message;
223
224 #[test]
225 fn test_teams_websocket_new() {
226 let rt = Runtime::new().unwrap();
227 rt.block_on(async {
228 let identifier = AppIdentifiers {
229 protocol_version: "1.0",
230 manufacturer: "TestManufacturer",
231 device: "TestDevice",
232 app: "TestApp",
233 app_version: "1.0",
234 };
235 let websocket = TeamsWebsocket::new(identifier.clone(), None, None).await;
236 assert_eq!(websocket.identifier, identifier);
237 assert!(websocket.socket.is_none());
238 assert!(websocket.token.is_none());
239 assert_eq!(websocket.request_id, 0);
240 });
241 }
242 async fn start_test_server() -> SocketAddr {
243 let mut rng = rand::thread_rng();
244 let port: u16 = rng.gen_range(1024..65535);
245 let listener = TcpListener::bind(format!("127.0.0.1:{}", port))
246 .await
247 .unwrap();
248 let addr = listener.local_addr().unwrap();
249 tokio::spawn(async move {
250 while let Ok((stream, _)) = listener.accept().await {
251 let ws_stream = accept_async(stream).await.unwrap();
252 let (mut write, mut read) = ws_stream.split();
253 tokio::spawn(async move {
254 while let Some(Ok(msg)) = read.next().await {
255 if let Message::Text(text) = msg {
256 let client_message: ClientMessage =
257 serde_json::from_str(&text).unwrap();
258 let server_message = ServerMessage {
259 request_id: client_message.request_id,
260 response: Some(format!("Echo: {}", text)),
261 error_msg: None,
262 token_refresh: None,
263 meeting_update: None,
264 };
265 let response = serde_json::to_string(&server_message).unwrap();
266 write.send(Message::Text(response)).await.unwrap();
267 }
268 }
269 });
270 }
271 });
272 addr
273 }
274
275 #[test]
276 fn test_teams_websocket_connect() {
277 let rt = Runtime::new().unwrap();
278 rt.block_on(async {
279 let identifier = AppIdentifiers {
280 protocol_version: "1.0",
281 manufacturer: "TestManufacturer",
282 device: "TestDevice",
283 app: "TestApp",
284 app_version: "1.0",
285 };
286 let addr = start_test_server().await;
287 let url = format!("ws://{}", addr);
288 let mut websocket = TeamsWebsocket::new(identifier.clone(), None, Some(url)).await;
289 let result = websocket.connect().await;
290 assert!(result.is_ok());
291 assert!(websocket.socket.is_some());
292 });
293 }
294
295 #[test]
296 fn test_teams_websocket_send_receive() {
297 let rt = Runtime::new().unwrap();
298 rt.block_on(async {
299 let identifier = AppIdentifiers {
300 protocol_version: "1.0",
301 manufacturer: "TestManufacturer",
302 device: "TestDevice",
303 app: "TestApp",
304 app_version: "1.0",
305 };
306 let addr = start_test_server().await;
307 let url = format!("ws://{}", addr);
308 let mut websocket = TeamsWebsocket::new(identifier.clone(), None, Some(url)).await;
309 websocket.connect().await.unwrap();
310
311 let client_message = ClientMessage::new(messages::MeetingAction::BlurBackground, None);
312 websocket.send(client_message).await.unwrap();
313
314 let server_message = websocket.receive().await.unwrap();
315 assert_eq!(
316 server_message.response,
317 Some(
318 "Echo: {\"action\":\"blur-background\",\"parameters\":null,\"requestId\":0}"
319 .to_string()
320 )
321 );
322 });
323 }
324}