owb_core/utils/connection/
server.rs1extern crate alloc;
9
10use alloc::{string::String, vec::Vec};
11
12use embassy_net::Stack;
13use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
14use embassy_time::Duration;
15use embedded_io_async::Read;
16use hashbrown::HashMap;
17use lazy_static::lazy_static;
18use picoserve::{
19 extract::FromRequest,
20 io::embedded_io_async as embedded_aio,
21 request::{RequestBody, RequestParts},
22 response::{
23 ws::{Message, ReadMessageError, SocketRx, SocketTx, WebSocketCallback, WebSocketUpgrade},
24 StatusCode,
25 },
26 url_encoded::deserialize_form,
27 Router,
28};
29use serde::Deserialize;
30
31use crate::utils::{
32 controllers::{SystemCommand, I2C_CHANNEL, LED_CHANNEL},
33 frontend::{CSS, HTML, JAVA},
34};
35
36pub struct ServerTimer;
37pub struct WebSocket;
38#[derive(Clone, Debug)]
39pub struct SessionState {
40 pub last_seen: u64,
41}
42pub struct SessionManager;
43
44lazy_static! {
45 pub static ref SESSION_STORE: Mutex<CriticalSectionRawMutex, HashMap<String, SessionState>> =
46 Mutex::new(HashMap::new());
47}
48
49#[allow(unused_qualifications)]
51impl picoserve::Timer for ServerTimer {
52 type Duration = embassy_time::Duration;
53 type TimeoutError = embassy_time::TimeoutError;
54
55 async fn run_with_timeout<F: core::future::Future>(
58 &mut self,
59 duration: Self::Duration,
60 future: F,
61 ) -> Result<F::Output, Self::TimeoutError> {
62 embassy_time::with_timeout(duration, future).await
63 }
64}
65
66impl WebSocketCallback for WebSocket {
68 async fn run<Reader, Writer>(
69 self,
70 mut rx: SocketRx<Reader>,
71 mut tx: SocketTx<Writer>,
72 ) -> Result<(), Writer::Error>
73 where
74 Reader: embedded_aio::Read,
75 Writer: embedded_aio::Write<Error = Reader::Error>,
76 {
77 let mut buffer = [0; 1024];
78
79 tx.send_text("Connected").await?;
80
81 let close_reason = loop {
82 match rx.next_message(&mut buffer).await {
83 Ok(Message::Pong(_)) => continue,
84 Ok(Message::Ping(data)) => tx.send_pong(data).await?,
85 Ok(Message::Close(reason)) => {
86 tracing::info!(?reason, "websocket closed");
87 break None;
88 }
89 Ok(Message::Text(data)) => match serde_json::from_str::<SystemCommand>(data) {
90 Ok(SystemCommand::I(i2c_cmd)) => {
91 I2C_CHANNEL.send(i2c_cmd).await;
92 tx.send_text("I2C command received and forwarded").await?;
93 }
94 Ok(SystemCommand::L(led_cmd)) => {
95 LED_CHANNEL.send(led_cmd).await;
96 tx.send_text("LED command received and forwarded").await?;
97 }
98 Err(error) => {
99 tracing::error!(?error, "error deserializing SystemCommand");
100 tx.send_text("Invalid command format").await?
101 }
102 },
103 Ok(Message::Binary(data)) => match serde_json::from_slice::<SystemCommand>(data) {
104 Ok(SystemCommand::I(i2c_cmd)) => {
105 I2C_CHANNEL.send(i2c_cmd).await;
106 tx.send_binary(b"I2C command received and forwarded")
107 .await?
108 }
109 Ok(SystemCommand::L(led_cmd)) => {
110 LED_CHANNEL.send(led_cmd).await;
111 tx.send_binary(b"LED command received and forwarded")
112 .await?
113 }
114 Err(error) => {
115 tracing::error!(?error, "error deserializing incoming message");
116 tx.send_binary(b"Invalid command format").await?
117 }
118 },
119 Err(error) => {
120 tracing::error!(?error, "websocket error");
121 let code = match error {
122 ReadMessageError::TextIsNotUtf8 => 1007,
123 ReadMessageError::ReservedOpcode(_) => 1003,
124 ReadMessageError::ReadFrameError(_)
125 | ReadMessageError::UnexpectedMessageStart
126 | ReadMessageError::MessageStartsWithContinuation => 1002,
127 ReadMessageError::Io(err) => return Err(err),
128 };
129 break Some((code, "Websocket Error"));
130 }
131 };
132 };
133
134 tx.close(close_reason).await
135 }
136}
137
138#[allow(dead_code)]
139impl SessionManager {
140 pub async fn create_session(
142 session_id: String,
143 timestamp: u64,
144 ) {
145 SESSION_STORE.lock().await.insert(
146 session_id,
147 SessionState {
148 last_seen: timestamp,
149 },
150 );
151 }
152
153 pub async fn get_session(session_id: &str) -> Option<SessionState> {
156 SESSION_STORE.lock().await.get(session_id).cloned()
157 }
158
159 pub async fn update_session(
164 session_id: &str,
165 timestamp: u64,
166 ) -> bool {
167 if let Some(session) = SESSION_STORE.lock().await.get_mut(session_id) {
168 session.last_seen = timestamp;
169 true
170 } else {
171 false
172 }
173 }
174
175 pub async fn remove_session(session_id: &str) -> bool {
178 SESSION_STORE.lock().await.remove(session_id).is_some()
179 }
180
181 pub async fn purge_stale_sessions(threshold: u64) {
187 SESSION_STORE
189 .lock()
190 .await
191 .retain(|_id, session| session.last_seen >= threshold);
192 }
193
194 pub async fn list_sessions() -> Vec<String> {
196 SESSION_STORE.lock().await.keys().cloned().collect()
197 }
198}
199
200pub async fn run(
207 id: usize,
208 port: u16,
209 stack: Stack<'static>,
210 config: Option<&'static picoserve::Config<Duration>>,
211) -> ! {
212 let default_config = picoserve::Config::new(picoserve::Timeouts {
213 start_read_request: Some(Duration::from_secs(5)),
214 persistent_start_read_request: None,
215 read_request: Some(Duration::from_secs(1)),
216 write: Some(Duration::from_secs(5)),
217 });
218
219 let config = config.unwrap_or(&default_config);
220
221 let router = Router::new()
222 .route(
224 "/",
225 picoserve::routing::get(|| async {
226 picoserve::response::Response::new(
228 StatusCode::OK,
229 HTML, )
231 .with_headers([
232 ("Content-Type", "text/html; charset=utf-8"),
233 ("Content-Encoding", "gzip"),
234 ])
235 }),
236 )
237 .route(
239 "/style.css",
240 picoserve::routing::get(|| async {
241 picoserve::response::Response::new(
243 StatusCode::OK,
244 CSS, )
246 .with_headers([
247 ("Content-Type", "text/css; charset=utf-8"),
248 ("Content-Encoding", "gzip"),
249 ])
250 }),
251 )
252 .route(
254 "/script.js",
255 picoserve::routing::get(|| async {
256 picoserve::response::Response::new(
258 StatusCode::OK,
259 JAVA, )
261 .with_headers([
262 ("Content-Type", "application/javascript; charset=utf-8"),
263 ("Content-Encoding", "gzip"),
264 ])
265 }),
266 )
267 .route(
269 "/ws",
270 picoserve::routing::get(|params: WsConnectionParams| async move {
271 let session_id = params.query.session;
272 tracing::info!("New WebSocket connection with session id: {}", session_id);
273 let now = embassy_time::Instant::now().as_secs();
274 SessionManager::create_session(session_id.clone(), now).await;
275 params
276 .upgrade
277 .on_upgrade(WebSocket)
278 .with_protocol("messages")
279 }),
280 );
281
282 if let Some(ip_cfg) = stack.config_v4() {
284 tracing::info!("Starting server at {}:{}", ip_cfg.address, port);
285 } else {
286 tracing::warn!(
287 "Starting WebSocket server on port {port}, but no IPv4 address is assigned yet!"
288 );
289 }
290
291 let (mut rx_buffer, mut tx_buffer, mut http_buffer) = ([0; 1024], [0; 1024], [0; 4096]);
292
293 picoserve::listen_and_serve_with_state(
294 id,
295 &router,
296 config,
297 stack,
298 port,
299 &mut rx_buffer,
300 &mut tx_buffer,
301 &mut http_buffer,
302 &(),
303 )
304 .await
305}
306
307#[derive(Debug, Deserialize)]
308pub struct QueryParams {
309 session: String,
310}
311
312pub struct WsConnectionParams {
313 pub upgrade: WebSocketUpgrade,
314 pub query: QueryParams,
315}
316
317impl<'r, S> FromRequest<'r, S> for WsConnectionParams {
318 type Rejection = &'static str; async fn from_request<R: Read>(
321 state: &'r S,
322 parts: RequestParts<'r>,
323 body: RequestBody<'r, R>,
324 ) -> Result<Self, Self::Rejection> {
325 let upgrade = WebSocketUpgrade::from_request(state, parts.clone(), body)
327 .await
328 .map_err(|_| "Failed to extract WebSocketUpgrade")?;
329
330 let query_str = parts.query().ok_or("Missing query parameters")?;
332 let query =
333 deserialize_form::<QueryParams>(query_str).map_err(|_| "Invalid query parameters")?;
334
335 if query.session.is_empty() {
336 return Err("Session ID is required");
337 }
338
339 Ok(WsConnectionParams { upgrade, query })
340 }
341}