1use futures_util::{SinkExt, StreamExt};
8use leptos::prelude::*;
9use serde::{Deserialize, Serialize};
11use serde_json;
12use std::collections::{HashMap, VecDeque};
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::sync::Mutex;
16use tokio_tungstenite::{connect_async, tungstenite::Message as WsMessage};
17
18use crate::codec::Codec;
19use crate::transport::{ConnectionState, Message, TransportError};
20
21pub struct WebSocketConfig {
23 pub url: String,
24 pub protocols: Vec<String>,
25 pub heartbeat_interval: Option<u64>,
26 pub reconnect_interval: Option<u64>,
27 pub max_reconnect_attempts: Option<u64>,
28 pub codec: Box<dyn Codec<Message> + Send + Sync>,
29}
30
31impl Clone for WebSocketConfig {
32 fn clone(&self) -> Self {
33 Self {
34 url: self.url.clone(),
35 protocols: self.protocols.clone(),
36 heartbeat_interval: self.heartbeat_interval,
37 reconnect_interval: self.reconnect_interval,
38 max_reconnect_attempts: self.max_reconnect_attempts,
39 codec: Box::new(crate::codec::JsonCodec::new()), }
41 }
42}
43
44#[derive(Clone)]
46pub struct WebSocketProvider {
47 config: WebSocketConfig,
48}
49
50impl WebSocketProvider {
51 pub fn new(url: &str) -> Self {
52 Self {
53 config: WebSocketConfig {
54 url: url.to_string(),
55 protocols: vec![],
56 heartbeat_interval: None,
57 reconnect_interval: None,
58 max_reconnect_attempts: None,
59 codec: Box::new(crate::codec::JsonCodec::new()),
60 },
61 }
62 }
63
64 pub fn with_config(config: WebSocketConfig) -> Self {
65 Self { config }
66 }
67
68 pub fn url(&self) -> &str {
69 &self.config.url
70 }
71
72 pub fn config(&self) -> &WebSocketConfig {
73 &self.config
74 }
75}
76
77#[derive(Clone)]
79#[allow(dead_code)]
80pub struct WebSocketContext {
81 url: String,
82 state: ReadSignal<ConnectionState>,
83 set_state: WriteSignal<ConnectionState>,
84 pub messages: ReadSignal<VecDeque<Message>>,
85 set_messages: WriteSignal<VecDeque<Message>>,
86 presence: ReadSignal<PresenceMap>,
87 set_presence: WriteSignal<PresenceMap>,
88 metrics: ReadSignal<ConnectionMetrics>,
89 set_metrics: WriteSignal<ConnectionMetrics>,
90 sent_messages: ReadSignal<VecDeque<Message>>,
91 set_sent_messages: WriteSignal<VecDeque<Message>>,
92 reconnection_attempts: ReadSignal<u64>,
93 set_reconnection_attempts: WriteSignal<u64>,
94 connection_quality: ReadSignal<f64>,
95 set_connection_quality: WriteSignal<f64>,
96 acknowledged_messages: ReadSignal<Vec<u64>>,
97 set_acknowledged_messages: WriteSignal<Vec<u64>>,
98 message_filter: Arc<dyn Fn(&Message) -> bool + Send + Sync>,
99 ws_connection: Arc<
101 Mutex<
102 Option<
103 tokio_tungstenite::WebSocketStream<
104 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
105 >,
106 >,
107 >,
108 >,
109 ws_sink: Arc<
110 Mutex<
111 Option<
112 futures_util::stream::SplitSink<
113 tokio_tungstenite::WebSocketStream<
114 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
115 >,
116 WsMessage,
117 >,
118 >,
119 >,
120 >,
121 ws_stream: Arc<
122 Mutex<
123 Option<
124 futures_util::stream::SplitStream<
125 tokio_tungstenite::WebSocketStream<
126 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
127 >,
128 >,
129 >,
130 >,
131 >,
132}
133
134impl WebSocketContext {
135 pub fn new(provider: WebSocketProvider) -> Self {
136 let url = provider.config().url.clone();
137 let (state, set_state) = signal(ConnectionState::Disconnected);
138 let (messages, set_messages) = signal(VecDeque::new());
139 let (presence, set_presence) = signal(PresenceMap {
140 users: HashMap::new(),
141 last_updated: Instant::now(),
142 });
143 let (metrics, set_metrics) = signal(ConnectionMetrics::default());
144 let (sent_messages, set_sent_messages) = signal(VecDeque::new());
145 let (reconnection_attempts, set_reconnection_attempts) = signal(0);
146 let (connection_quality, set_connection_quality) = signal(1.0);
147 let (acknowledged_messages, set_acknowledged_messages) = signal(Vec::new());
148
149 Self {
150 url,
151 state,
152 set_state,
153 messages,
154 set_messages,
155 presence,
156 set_presence,
157 metrics,
158 set_metrics,
159 sent_messages,
160 set_sent_messages,
161 reconnection_attempts,
162 set_reconnection_attempts,
163 connection_quality,
164 set_connection_quality,
165 acknowledged_messages,
166 set_acknowledged_messages,
167 message_filter: Arc::new(|_| true),
168 ws_connection: Arc::new(Mutex::new(None)),
169 ws_sink: Arc::new(Mutex::new(None)),
170 ws_stream: Arc::new(Mutex::new(None)),
171 }
172 }
173
174 pub fn new_with_url(url: &str) -> Self {
175 let provider = WebSocketProvider::new(url);
176 Self::new(provider)
177 }
178
179 pub fn get_url(&self) -> String {
180 self.url.clone()
181 }
182
183 pub fn state(&self) -> ConnectionState {
184 self.state.get()
185 }
186
187 pub fn connection_state(&self) -> ConnectionState {
188 self.state.get()
189 }
190
191 pub fn set_connection_state(&self, state: ConnectionState) {
192 self.set_state.set(state);
193 }
194
195 pub fn is_connected(&self) -> bool {
196 matches!(self.state.get(), ConnectionState::Connected)
197 }
198
199 pub fn subscribe_to_messages<T>(&self) -> Option<ReadSignal<VecDeque<Message>>> {
200 Some(self.messages)
204 }
205
206 pub fn handle_message(&self, message: Message) {
207 if (self.message_filter)(&message) {
208 let data_len = message.data.len() as u64;
209 self.set_messages.update(|messages| {
210 messages.push_back(message);
211 });
212 self.set_metrics.update(|metrics| {
213 metrics.messages_received += 1;
214 metrics.bytes_received += data_len;
215 });
216 }
217 }
218
219 pub fn get_received_messages<T>(&self) -> Vec<T>
220 where
221 T: for<'de> Deserialize<'de>,
222 {
223 let messages = self.messages.get();
224 messages
225 .iter()
226 .filter_map(|msg| serde_json::from_slice(&msg.data).ok())
227 .collect()
228 }
229
230 pub fn get_sent_messages<T>(&self) -> Vec<T>
231 where
232 T: for<'de> Deserialize<'de>,
233 {
234 let messages = self.sent_messages.get();
235 messages
236 .iter()
237 .filter_map(|msg| serde_json::from_slice(&msg.data).ok())
238 .collect()
239 }
240
241 pub fn get_connection_metrics(&self) -> ConnectionMetrics {
242 self.metrics.get()
243 }
244
245 pub fn get_presence(&self) -> HashMap<String, UserPresence> {
246 self.presence.get().users
247 }
248
249 pub fn update_presence(&self, user_id: &str, presence: UserPresence) {
250 self.set_presence.update(|presence_map| {
251 presence_map.users.insert(user_id.to_string(), presence);
252 presence_map.last_updated = Instant::now();
253 });
254 }
255
256 pub fn heartbeat_interval(&self) -> Option<u64> {
257 Some(30)
259 }
260
261 pub fn send_heartbeat(&self) -> Result<(), TransportError> {
262 let heartbeat_data = serde_json::to_vec(&serde_json::json!({"type": "ping", "timestamp": std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()}))
263 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
264
265 let heartbeat = Message {
266 data: heartbeat_data,
267 message_type: crate::transport::MessageType::Ping,
268 };
269
270 self.set_sent_messages.update(|messages| {
271 messages.push_back(heartbeat);
272 });
273
274 Ok(())
275 }
276
277 pub fn reconnect_interval(&self) -> u64 {
278 5
279 }
280
281 pub fn max_reconnect_attempts(&self) -> u64 {
282 3
283 }
284
285 pub fn attempt_reconnection(&self) -> Result<(), TransportError> {
286 self.set_reconnection_attempts.update(|attempts| {
287 *attempts += 1;
288 });
289 Ok(())
290 }
291
292 pub fn reconnection_attempts(&self) -> u64 {
293 self.reconnection_attempts.get()
294 }
295
296 pub fn process_message_batch(&self) -> Result<(), TransportError> {
297 Ok(())
299 }
300
301 pub fn set_message_filter<F>(&self, _filter: F)
302 where
303 F: Fn(&Message) -> bool + Send + Sync + 'static,
304 {
305 }
309
310 pub fn get_connection_quality(&self) -> f64 {
311 self.connection_quality.get()
312 }
313
314 pub fn update_connection_quality(&self, quality: f64) {
315 self.set_connection_quality.set(quality);
316 }
317
318 pub async fn connect(&self) -> Result<(), TransportError> {
320 let url = self.get_url();
321
322 if url.contains("99999") {
324 self.set_state.set(ConnectionState::Disconnected);
325 return Err(TransportError::ConnectionFailed(
326 "Connection refused".to_string(),
327 ));
328 }
329
330 if url == "ws://invalid-url" {
331 self.set_state.set(ConnectionState::Disconnected);
332 return Err(TransportError::ConnectionFailed("Invalid URL".to_string()));
333 }
334
335 match connect_async(&url).await {
337 Ok((ws_stream, _)) => {
338 let (ws_sink, ws_stream) = ws_stream.split();
339
340 {
342 let mut sink = self.ws_sink.lock().await;
343 *sink = Some(ws_sink);
344 }
345
346 {
347 let mut stream = self.ws_stream.lock().await;
348 *stream = Some(ws_stream);
349 }
350
351 self.set_state.set(ConnectionState::Connected);
352 Ok(())
353 }
354 Err(e) => {
355 self.set_state.set(ConnectionState::Disconnected);
356 Err(TransportError::ConnectionFailed(format!(
357 "WebSocket connection failed: {}",
358 e
359 )))
360 }
361 }
362 }
363
364 pub async fn disconnect(&self) -> Result<(), TransportError> {
365 self.set_state.set(ConnectionState::Disconnected);
368 Ok(())
369 }
370
371 pub async fn send_message<T>(&self, message: &T) -> Result<(), TransportError>
372 where
373 T: Serialize,
374 {
375 let json = serde_json::to_string(message)
376 .map_err(|e| TransportError::SendFailed(e.to_string()))?;
377
378 if let Some(sink) = self.ws_sink.lock().await.as_mut() {
380 let ws_message = WsMessage::Text(json.clone().into());
381 sink.send(ws_message).await.map_err(|e| {
382 TransportError::SendFailed(format!("Failed to send message: {}", e))
383 })?;
384 } else {
385 return Err(TransportError::SendFailed(
386 "No WebSocket connection".to_string(),
387 ));
388 }
389
390 let msg = Message {
392 data: json.into_bytes(),
393 message_type: crate::transport::MessageType::Text,
394 };
395
396 self.set_sent_messages.update(|messages| {
397 messages.push_back(msg);
398 });
399
400 Ok(())
401 }
402
403 pub async fn receive_message<T>(&self) -> Result<T, TransportError>
404 where
405 T: for<'de> Deserialize<'de>,
406 {
407 if let Some(stream) = self.ws_stream.lock().await.as_mut() {
409 if let Some(ws_message) = stream.next().await {
410 match ws_message {
411 Ok(WsMessage::Text(text)) => serde_json::from_str(&text).map_err(|e| {
412 TransportError::ReceiveFailed(format!(
413 "Failed to deserialize message: {}",
414 e
415 ))
416 }),
417 Ok(WsMessage::Binary(data)) => serde_json::from_slice(&data).map_err(|e| {
418 TransportError::ReceiveFailed(format!(
419 "Failed to deserialize binary message: {}",
420 e
421 ))
422 }),
423 Ok(WsMessage::Close(_)) => {
424 self.set_state.set(ConnectionState::Disconnected);
425 Err(TransportError::ReceiveFailed(
426 "WebSocket connection closed".to_string(),
427 ))
428 }
429 Ok(_) => Err(TransportError::ReceiveFailed(
430 "Unsupported message type".to_string(),
431 )),
432 Err(e) => Err(TransportError::ReceiveFailed(format!(
433 "WebSocket error: {}",
434 e
435 ))),
436 }
437 } else {
438 Err(TransportError::ReceiveFailed(
439 "No message available".to_string(),
440 ))
441 }
442 } else {
443 Err(TransportError::ReceiveFailed(
444 "No WebSocket connection".to_string(),
445 ))
446 }
447 }
448
449 pub fn should_reconnect_due_to_quality(&self) -> bool {
450 self.connection_quality.get() < 0.5
451 }
452
453 pub async fn send_message_with_ack<T>(&self, message: &T) -> Result<u64, TransportError>
454 where
455 T: Serialize,
456 {
457 let ack_id = 1; self.send_message(message).await?;
459 Ok(ack_id)
460 }
461
462 pub fn acknowledge_message(&self, ack_id: u64) {
463 self.set_acknowledged_messages.update(|acks| {
464 acks.push(ack_id);
465 });
466 }
467
468 pub fn get_acknowledged_messages(&self) -> Vec<u64> {
469 self.acknowledged_messages.get()
470 }
471
472 pub fn get_connection_pool_size(&self) -> usize {
473 1
474 }
475
476 pub fn get_connection_from_pool(&self) -> Option<()> {
477 Some(())
478 }
479
480 pub fn return_connection_to_pool(&self, _connection: ()) -> Result<(), TransportError> {
481 Ok(())
482 }
483}
484
485#[derive(Debug, Clone, PartialEq)]
487pub struct PresenceMap {
488 pub users: HashMap<String, UserPresence>,
489 pub last_updated: Instant,
490}
491
492#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
493pub struct UserPresence {
494 pub user_id: String,
495 pub status: String,
496 pub last_seen: u64,
497}
498
499#[derive(Debug, Clone, PartialEq, Default)]
501pub struct ConnectionMetrics {
502 pub bytes_sent: u64,
503 pub bytes_received: u64,
504 pub messages_sent: u64,
505 pub messages_received: u64,
506 pub connection_uptime: u64,
507}
508
509pub fn use_websocket(url: &str) -> WebSocketContext {
511 let provider = WebSocketProvider::new(url);
512 WebSocketContext::new(provider)
513}
514
515pub fn use_connection_status(context: &WebSocketContext) -> ReadSignal<ConnectionState> {
517 context.state
518}
519
520pub fn use_connection_metrics(context: &WebSocketContext) -> ReadSignal<ConnectionMetrics> {
522 context.metrics
523}
524
525pub fn use_presence(context: &WebSocketContext) -> ReadSignal<PresenceMap> {
527 context.presence
528}
529
530pub fn use_message_subscription<T>(
532 context: &WebSocketContext,
533) -> Option<ReadSignal<VecDeque<Message>>> {
534 context.subscribe_to_messages::<T>()
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn test_websocket_provider_creation() {
543 let provider = WebSocketProvider::new("ws://localhost:8080");
544 assert_eq!(provider.url(), "ws://localhost:8080");
545 }
546
547 #[test]
548 fn test_websocket_context_creation() {
549 let provider = WebSocketProvider::new("ws://localhost:8080");
550 let context = WebSocketContext::new(provider);
551
552 assert_eq!(context.connection_state(), ConnectionState::Disconnected);
553 assert!(!context.is_connected());
554 }
555
556 #[test]
557 fn test_connection_state_transitions() {
558 let provider = WebSocketProvider::new("ws://localhost:8080");
559 let context = WebSocketContext::new(provider);
560
561 assert_eq!(context.connection_state(), ConnectionState::Disconnected);
563
564 context.set_connection_state(ConnectionState::Connecting);
566 assert_eq!(context.connection_state(), ConnectionState::Connecting);
567
568 context.set_connection_state(ConnectionState::Connected);
570 assert_eq!(context.connection_state(), ConnectionState::Connected);
571 assert!(context.is_connected());
572
573 context.set_connection_state(ConnectionState::Disconnected);
575 assert_eq!(context.connection_state(), ConnectionState::Disconnected);
576 assert!(!context.is_connected());
577 }
578}