Skip to main content

cnctd_server/socket/
mod.rs

1pub mod client;
2
3use anyhow::anyhow;
4use client::{ClientFormat, ClientInfo, CnctdClient, QueryParams};
5use cnctd_redis::CnctdRedis;
6use futures_util::{SinkExt, StreamExt};
7use local_ip_address::local_ip;
8use serde::de::DeserializeOwned;
9use serde::Serialize;
10use state::InitCell;
11use warp::filters::ws::Ws;
12use warp::reject::Reject;
13use warp::ws::{Message as WebSocketMessage, WebSocket};
14use warp::Filter;
15use std::collections::HashMap;
16use std::future::Future;
17use std::pin::Pin;
18use tokio::sync::{mpsc, RwLock};
19use std::{sync::Arc, fmt::Debug};
20
21use crate::router::message::Message;
22use crate::router::SocketRouterFunction;
23use crate::server::server_info::ServerInfo;
24
25/// Callback type for handling incoming binary WebSocket frames.
26/// Arguments: (client_id, raw_bytes)
27pub type OnBinaryHandler = Arc<dyn Fn(String, Vec<u8>) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
28
29#[derive(Debug)]
30struct NoClientId;
31
32impl Reject for NoClientId {}
33
34#[derive(Clone)]
35pub struct SocketConfig<R> {
36    pub router: R,
37    pub secret: Option<Vec<u8>>,
38    pub redis_url: Option<String>,
39    pub on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,
40    pub on_binary: Option<OnBinaryHandler>,
41}
42
43impl<R> SocketConfig<R> {
44    pub fn new(router: R, secret: Option<Vec<u8>>, redis_url: Option<String>, on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,) -> Self {
45        Self {
46            router,
47            secret,
48            redis_url,
49            on_disconnect,
50            on_binary: None,
51        }
52    }
53
54    pub fn with_on_binary(mut self, handler: OnBinaryHandler) -> Self {
55        self.on_binary = Some(handler);
56        self
57    }
58}
59
60
61
62pub static CLIENTS: InitCell<Arc<RwLock<HashMap<String, CnctdClient>>>> = InitCell::new();
63
64pub struct CnctdSocket;
65
66impl CnctdSocket {
67    pub fn build_routes<M, Resp, R>(config: SocketConfig<R>) -> warp::filters::BoxedFilter<(impl warp::Reply,)>
68    where
69        M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
70        Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static, 
71        R: SocketRouterFunction<M, Resp> + 'static,
72    {
73        CLIENTS.set(Arc::new(RwLock::new(HashMap::new())));
74
75        let redis;
76
77        match config.redis_url {
78            Some(url) => {
79                match cnctd_redis::CnctdRedis::start(&url) {
80                    Ok(_) => {
81                        println!("Redis started!");
82                        tokio::spawn(async {
83                            ServerInfo::set_redis_active(true).await;
84                        });
85                        redis = true
86                    },
87                    Err(e) => {
88                        println!("Error starting Redis pool: {:?}", e);
89                        redis = false
90                    }
91                }
92            }
93            None => redis = false
94        };
95
96        let websocket_route = warp::path("ws")
97            .and(warp::ws())
98            .and(warp::any().map(move || config.router.clone()))
99            .and(warp::query::<QueryParams>())
100            .and_then(move |ws: Ws, router: R, params: QueryParams| {
101                let on_disconnect = config.on_disconnect.clone();
102                let on_binary = config.on_binary.clone();
103
104                async move {
105                    // Resolve client_id: either from query param or via inline registration
106                    let client_id = match params.client_id {
107                        Some(id) => id,
108                        None => {
109                            // Support inline registration: if subscriptions are provided,
110                            // auto-register a new client (used by ESP32 and other lightweight clients
111                            // that don't have an HTTP client for the REST registration step).
112                            if let Some(ref subs_str) = params.subscriptions {
113                                let subscriptions: Vec<String> = subs_str
114                                    .split(',')
115                                    .map(|s| s.trim().to_string())
116                                    .filter(|s| !s.is_empty())
117                                    .collect();
118                                let format = ClientFormat::from_str_opt(params.format.as_deref());
119                                match CnctdClient::register_client_with_format(
120                                    subscriptions,
121                                    None,
122                                    format,
123                                ).await {
124                                    Ok(id) => {
125                                        println!("Inline-registered client: {}", id);
126                                        id
127                                    }
128                                    Err(e) => {
129                                        eprintln!("Inline registration failed: {:?}", e);
130                                        return Err(warp::reject::custom(NoClientId));
131                                    }
132                                }
133                            } else {
134                                return Err(warp::reject::custom(NoClientId));
135                            }
136                        },
137                    };
138
139                    // Proceed with connection setup
140                    Ok(ws.on_upgrade(move |socket| {
141                        Self::handle_connection(socket, router, client_id, redis, on_disconnect.clone(), on_binary)
142                    }))
143                }
144            });
145    
146              
147        let routes = websocket_route;
148
149        routes.boxed()
150
151    }
152    pub async fn start<M, Resp, R>(port: &str, router: R, secret: Option<Vec<u8>>, redis_url: Option<String>, on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,) -> anyhow::Result<()>
153    where
154        M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
155        Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static, 
156        R: SocketRouterFunction<M, Resp> + 'static,
157    {
158        CLIENTS.set(Arc::new(RwLock::new(HashMap::new())));
159    
160        let my_local_ip = local_ip()?;
161        println!("WebSocket server running at ws://{}:{}", my_local_ip, port);
162        let ip_address: [u8; 4] = [0, 0, 0, 0];
163        let parsed_port = port.parse::<u16>()?;
164        let socket_addr = std::net::SocketAddr::from((ip_address, parsed_port));
165        let config = SocketConfig::new(router, secret, redis_url, on_disconnect);
166        let routes = Self::build_routes(config);
167
168        warp::serve(routes).run(socket_addr).await;
169    
170        Ok(())
171        
172    }
173
174    pub async fn broadcast_message(msg: &Message) -> anyhow::Result<()> {
175        let clients = CLIENTS.try_get().ok_or_else(|| anyhow!("Clients not initialized"))?.read().await;
176        
177        for (client_id, client) in clients.iter() {
178            if client.subscriptions.contains(&msg.channel) {
179                CnctdClient::message_client(&client_id, msg).await?;
180            }
181        }
182    
183        Ok(())
184    }
185
186   
187    
188    async fn handle_connection<M, Resp, R>(
189        websocket: WebSocket,
190        router: R,
191        client_id: String,
192        redis: bool,
193        on_disconnect: Option<Arc<dyn Fn(ClientInfo) + Send + Sync>>,
194        on_binary: Option<OnBinaryHandler>,
195    ) where
196        M: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
197        Resp: Serialize + DeserializeOwned + Send + Sync + Debug + Clone + 'static,
198        R: SocketRouterFunction<M, Resp> + 'static,
199    {
200        let (mut ws_tx, mut ws_rx) = websocket.split();
201        let (resp_tx, mut resp_rx) = mpsc::unbounded_channel::<Result<WebSocketMessage, warp::Error>>();
202    
203        {
204            let clients = CLIENTS.get();
205            let mut clients_lock = clients.write().await;
206    
207            if let Some(client) = clients_lock.get_mut(&client_id.clone()) {
208                // Update the sender for the client
209                client.sender = Some(resp_tx.clone());
210
211                if redis {
212                    match Self::push_client_to_redis(&client_id, &client.clone()).await {
213                        Ok(_) => println!("pajama party"),
214                        Err(e) => eprintln!("Error pushing client to Redis: {:?}", e),
215                    }
216                }
217                println!("Updated client sender: {:?}", client);
218            } else {
219                // Log error or handle case where client_id is not found
220                eprintln!("Client with id {} not found.", client_id);
221                return;
222            }
223        }
224        
225        let client_id_clone = client_id.clone();
226        // Incoming message handling
227        let process_incoming = async move {
228            while let Some(result) = ws_rx.next().await {
229                match result {
230                    Ok(msg) => {
231                        if msg.is_binary() {
232                            if let Some(ref handler) = on_binary {
233                                let bytes = msg.into_bytes();
234                                handler(client_id_clone.clone(), bytes).await;
235                            }
236                        } else if let Ok(message_str) = msg.to_str() {
237                            if let Ok(message) = serde_json::from_str::<M>(message_str) {
238                                match router.route(message, client_id_clone.clone()).await {
239                                    Some(response) => {
240                                        if let Ok(response_str) = serde_json::to_string(&response) {
241                                            let _ = resp_tx.send(Ok(WebSocketMessage::text(response_str)));
242                                        }
243                                    },
244                                    None => {}
245                                }
246                            }
247                        }
248                    },
249                    Err(e) => eprintln!("WebSocket receive error: {:?}", e),
250                }
251            }
252        };
253    
254        // Outgoing message handling
255        let send_responses = async move {
256            while let Some(response) = resp_rx.recv().await {
257                if let Ok(msg) = response {
258                    if ws_tx.send(msg).await.is_err() {
259                        eprintln!("WebSocket send error");
260                        break;
261                    }
262                }
263            }
264        };
265    
266        tokio::select! {
267            _ = process_incoming => {},
268            _ = send_responses => {},
269        };
270
271        if let Some(callback) = on_disconnect {
272            let client_info = CnctdClient::get_client_info(&client_id).await.unwrap();
273            callback(client_info);
274        }
275    
276        // Clean up after disconnection
277        match Self::remove_client(&client_id).await {
278            Ok(_) => {},
279            Err(e) => eprintln!("Error removing client: {:?}", e),
280        };
281
282        if redis {
283            match Self::remove_client_from_redis(&client_id).await {    
284                Ok(_) => {},
285                Err(e) => eprintln!("Error removing client from Redis: {:?}", e),
286            }
287        }
288
289
290        
291    }
292
293    pub async fn remove_client(client_id: &str) -> anyhow::Result<()> {
294        let clients = CLIENTS.try_get().ok_or_else(|| anyhow!("Clients not initialized"))?;
295        let mut clients_lock = clients.write().await;
296    
297        if let Some(client) = clients_lock.get(client_id) {
298            let should_remove = client.sender.as_ref().map_or(true, |sender| sender.is_closed());
299    
300            if should_remove {
301                println!("Removing client: {}", client_id);
302                clients_lock.remove(client_id);
303            } else {
304                println!("Client {} is active; no removal necessary.", client_id);
305            }
306        }
307
308        Ok(())
309    }
310
311
312
313    pub async fn push_client_to_redis(client_id: &str, client: &CnctdClient) -> anyhow::Result<()> {
314        let client_info = client.to_client_info(client_id).await;
315        CnctdRedis::hset("clients", &client_id, client_info)?;
316
317        Ok(())
318    }
319
320    pub async fn remove_client_from_redis(client_id: &str) -> anyhow::Result<()> {
321        CnctdRedis::hset("clients", client_id, ())?;
322
323        Ok(())
324    }
325
326}
327