owb_core/utils/connection/
server.rs

1//! WebSocket Server Module
2//!
3//! This module defines the WebSocket server implementation using the
4//! `picoserve` framework. It manages incoming WebSocket connections, processes
5//! I2C commands, and communicates with the embedded control system through a
6//! channel interface.
7
8extern crate alloc;
9
10use alloc::{string::String, vec::Vec};
11
12use embassy_net::Stack;
13use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
14use embassy_time::Duration;
15use embedded_io_async::Read;
16use hashbrown::HashMap;
17use lazy_static::lazy_static;
18use picoserve::{
19    extract::FromRequest,
20    io::embedded_io_async as embedded_aio,
21    request::{RequestBody, RequestParts},
22    response::{
23        ws::{Message, ReadMessageError, SocketRx, SocketTx, WebSocketCallback, WebSocketUpgrade},
24        StatusCode,
25    },
26    url_encoded::deserialize_form,
27    Router,
28};
29use serde::Deserialize;
30
31use crate::utils::{
32    controllers::{SystemCommand, I2C_CHANNEL, LED_CHANNEL},
33    frontend::{CSS, HTML, JAVA},
34};
35
36pub struct ServerTimer;
37pub struct WebSocket;
38#[derive(Clone, Debug)]
39pub struct SessionState {
40    pub last_seen: u64,
41}
42pub struct SessionManager;
43
44lazy_static! {
45    pub static ref SESSION_STORE: Mutex<CriticalSectionRawMutex, HashMap<String, SessionState>> =
46        Mutex::new(HashMap::new());
47}
48
49/// Manages timeouts for the WebSocket server.
50#[allow(unused_qualifications)]
51impl picoserve::Timer for ServerTimer {
52    type Duration = embassy_time::Duration;
53    type TimeoutError = embassy_time::TimeoutError;
54
55    //noinspection ALL
56    /// Runs a future with a timeout.
57    async fn run_with_timeout<F: core::future::Future>(
58        &mut self,
59        duration: Self::Duration,
60        future: F,
61    ) -> Result<F::Output, Self::TimeoutError> {
62        embassy_time::with_timeout(duration, future).await
63    }
64}
65
66/// Handles incoming WebSocket connections.
67impl WebSocketCallback for WebSocket {
68    async fn run<Reader, Writer>(
69        self,
70        mut rx: SocketRx<Reader>,
71        mut tx: SocketTx<Writer>,
72    ) -> Result<(), Writer::Error>
73    where
74        Reader: embedded_aio::Read,
75        Writer: embedded_aio::Write<Error = Reader::Error>,
76    {
77        let mut buffer = [0; 1024];
78
79        tx.send_text("Connected").await?;
80
81        let close_reason = loop {
82            match rx.next_message(&mut buffer).await {
83                Ok(Message::Pong(_)) => continue,
84                Ok(Message::Ping(data)) => tx.send_pong(data).await?,
85                Ok(Message::Close(reason)) => {
86                    tracing::info!(?reason, "websocket closed");
87                    break None;
88                }
89                Ok(Message::Text(data)) => match serde_json::from_str::<SystemCommand>(data) {
90                    Ok(SystemCommand::I(i2c_cmd)) => {
91                        I2C_CHANNEL.send(i2c_cmd).await;
92                        tx.send_text("I2C command received and forwarded").await?;
93                    }
94                    Ok(SystemCommand::L(led_cmd)) => {
95                        LED_CHANNEL.send(led_cmd).await;
96                        tx.send_text("LED command received and forwarded").await?;
97                    }
98                    Err(error) => {
99                        tracing::error!(?error, "error deserializing SystemCommand");
100                        tx.send_text("Invalid command format").await?
101                    }
102                },
103                Ok(Message::Binary(data)) => match serde_json::from_slice::<SystemCommand>(data) {
104                    Ok(SystemCommand::I(i2c_cmd)) => {
105                        I2C_CHANNEL.send(i2c_cmd).await;
106                        tx.send_binary(b"I2C command received and forwarded")
107                            .await?
108                    }
109                    Ok(SystemCommand::L(led_cmd)) => {
110                        LED_CHANNEL.send(led_cmd).await;
111                        tx.send_binary(b"LED command received and forwarded")
112                            .await?
113                    }
114                    Err(error) => {
115                        tracing::error!(?error, "error deserializing incoming message");
116                        tx.send_binary(b"Invalid command format").await?
117                    }
118                },
119                Err(error) => {
120                    tracing::error!(?error, "websocket error");
121                    let code = match error {
122                        ReadMessageError::TextIsNotUtf8 => 1007,
123                        ReadMessageError::ReservedOpcode(_) => 1003,
124                        ReadMessageError::ReadFrameError(_)
125                        | ReadMessageError::UnexpectedMessageStart
126                        | ReadMessageError::MessageStartsWithContinuation => 1002,
127                        ReadMessageError::Io(err) => return Err(err),
128                    };
129                    break Some((code, "Websocket Error"));
130                }
131            };
132        };
133
134        tx.close(close_reason).await
135    }
136}
137
138#[allow(dead_code)]
139impl SessionManager {
140    /// Creates a new session with the given session ID and timestamp.
141    pub async fn create_session(
142        session_id: String,
143        timestamp: u64,
144    ) {
145        SESSION_STORE.lock().await.insert(
146            session_id,
147            SessionState {
148                last_seen: timestamp,
149            },
150        );
151    }
152
153    /// Retrieves a copy of the session state for the given session ID.
154    /// Returns None if the session does not exist.
155    pub async fn get_session(session_id: &str) -> Option<SessionState> {
156        SESSION_STORE.lock().await.get(session_id).cloned()
157    }
158
159    //noinspection ALL
160    //noinspection ALL
161    /// Updates the last seen timestamp of the session identified by session_id.
162    /// Returns true if the session was found and updated.
163    pub async fn update_session(
164        session_id: &str,
165        timestamp: u64,
166    ) -> bool {
167        if let Some(session) = SESSION_STORE.lock().await.get_mut(session_id) {
168            session.last_seen = timestamp;
169            true
170        } else {
171            false
172        }
173    }
174
175    /// Removes the session identified by session_id.
176    /// Returns true if a session was removed.
177    pub async fn remove_session(session_id: &str) -> bool {
178        SESSION_STORE.lock().await.remove(session_id).is_some()
179    }
180
181    //noinspection ALL
182    //noinspection ALL
183    /// Purges sessions that have not been updated since the provided threshold.
184    /// For example, pass in a timestamp and any session with last_seen less
185    /// than that value will be removed.
186    pub async fn purge_stale_sessions(threshold: u64) {
187        // Retain sessions that have a last_seen timestamp >= threshold.
188        SESSION_STORE
189            .lock()
190            .await
191            .retain(|_id, session| session.last_seen >= threshold);
192    }
193
194    /// Returns a list of active session IDs.
195    pub async fn list_sessions() -> Vec<String> {
196        SESSION_STORE.lock().await.keys().cloned().collect()
197    }
198}
199
200//noinspection ALL
201//noinspection ALL
202//noinspection ALL
203//noinspection ALL
204//noinspection ALL
205/// Creates WS Server
206pub async fn run(
207    id: usize,
208    port: u16,
209    stack: Stack<'static>,
210    config: Option<&'static picoserve::Config<Duration>>,
211) -> ! {
212    let default_config = picoserve::Config::new(picoserve::Timeouts {
213        start_read_request: Some(Duration::from_secs(5)),
214        persistent_start_read_request: None,
215        read_request: Some(Duration::from_secs(1)),
216        write: Some(Duration::from_secs(5)),
217    });
218
219    let config = config.unwrap_or(&default_config);
220
221    let router = Router::new()
222        // Serve the HTML file at "/"
223        .route(
224            "/",
225            picoserve::routing::get(|| async {
226                // Serve HTML content
227                picoserve::response::Response::new(
228                    StatusCode::OK,
229                    HTML, // Static HTML content
230                )
231                .with_headers([
232                    ("Content-Type", "text/html; charset=utf-8"),
233                    ("Content-Encoding", "gzip"),
234                ])
235            }),
236        )
237        // Serve the CSS file at "/style.css"
238        .route(
239            "/style.css",
240            picoserve::routing::get(|| async {
241                // Serve CSS content
242                picoserve::response::Response::new(
243                    StatusCode::OK,
244                    CSS, // Static CSS content
245                )
246                .with_headers([
247                    ("Content-Type", "text/css; charset=utf-8"),
248                    ("Content-Encoding", "gzip"),
249                ])
250            }),
251        )
252        // Serve the JS file at "/script.js"
253        .route(
254            "/script.js",
255            picoserve::routing::get(|| async {
256                // Serve JS content
257                picoserve::response::Response::new(
258                    StatusCode::OK,
259                    JAVA, // Static JS content
260                )
261                .with_headers([
262                    ("Content-Type", "application/javascript; charset=utf-8"),
263                    ("Content-Encoding", "gzip"),
264                ])
265            }),
266        )
267        // WebSocket communication on "/ws"
268        .route(
269            "/ws",
270            picoserve::routing::get(|params: WsConnectionParams| async move {
271                let session_id = params.query.session;
272                tracing::info!("New WebSocket connection with session id: {}", session_id);
273                let now = embassy_time::Instant::now().as_secs();
274                SessionManager::create_session(session_id.clone(), now).await;
275                params
276                    .upgrade
277                    .on_upgrade(WebSocket)
278                    .with_protocol("messages")
279            }),
280        );
281
282    // Print out the IP and port before starting the server.
283    if let Some(ip_cfg) = stack.config_v4() {
284        tracing::info!("Starting server at {}:{}", ip_cfg.address, port);
285    } else {
286        tracing::warn!(
287            "Starting WebSocket server on port {port}, but no IPv4 address is assigned yet!"
288        );
289    }
290
291    let (mut rx_buffer, mut tx_buffer, mut http_buffer) = ([0; 1024], [0; 1024], [0; 4096]);
292
293    picoserve::listen_and_serve_with_state(
294        id,
295        &router,
296        config,
297        stack,
298        port,
299        &mut rx_buffer,
300        &mut tx_buffer,
301        &mut http_buffer,
302        &(),
303    )
304    .await
305}
306
307#[derive(Debug, Deserialize)]
308pub struct QueryParams {
309    session: String,
310}
311
312pub struct WsConnectionParams {
313    pub upgrade: WebSocketUpgrade,
314    pub query: QueryParams,
315}
316
317impl<'r, S> FromRequest<'r, S> for WsConnectionParams {
318    type Rejection = &'static str; // Or a custom error type
319
320    async fn from_request<R: Read>(
321        state: &'r S,
322        parts: RequestParts<'r>,
323        body: RequestBody<'r, R>,
324    ) -> Result<Self, Self::Rejection> {
325        // First extract the WebSocketUpgrade as usual.
326        let upgrade = WebSocketUpgrade::from_request(state, parts.clone(), body)
327            .await
328            .map_err(|_| "Failed to extract WebSocketUpgrade")?;
329
330        // Then extract the query string for QueryParams.
331        let query_str = parts.query().ok_or("Missing query parameters")?;
332        let query =
333            deserialize_form::<QueryParams>(query_str).map_err(|_| "Invalid query parameters")?;
334
335        if query.session.is_empty() {
336            return Err("Session ID is required");
337        }
338
339        Ok(WsConnectionParams { upgrade, query })
340    }
341}