extern crate alloc;
use alloc::{string::String, vec::Vec};
use embassy_net::Stack;
use embassy_sync::{blocking_mutex::raw::CriticalSectionRawMutex, mutex::Mutex};
use embassy_time::Duration;
use embedded_io_async::Read;
use hashbrown::HashMap;
use lazy_static::lazy_static;
use picoserve::{
extract::FromRequest,
io::embedded_io_async as embedded_aio,
request::{RequestBody, RequestParts},
response::{
ws::{Message, ReadMessageError, SocketRx, SocketTx, WebSocketCallback, WebSocketUpgrade},
StatusCode,
},
url_encoded::deserialize_form,
Router,
};
use serde::Deserialize;
use crate::utils::{
controllers::{SystemCommand, I2C_CHANNEL, LED_CHANNEL},
frontend::{CSS, HTML, JAVA},
};
pub struct ServerTimer;
pub struct WebSocket;
#[derive(Clone, Debug)]
pub struct SessionState {
pub last_seen: u64,
}
pub struct SessionManager;
lazy_static! {
pub static ref SESSION_STORE: Mutex<CriticalSectionRawMutex, HashMap<String, SessionState>> =
Mutex::new(HashMap::new());
}
#[allow(unused_qualifications)]
impl picoserve::Timer for ServerTimer {
type Duration = embassy_time::Duration;
type TimeoutError = embassy_time::TimeoutError;
async fn run_with_timeout<F: core::future::Future>(
&mut self,
duration: Self::Duration,
future: F,
) -> Result<F::Output, Self::TimeoutError> {
embassy_time::with_timeout(duration, future).await
}
}
impl WebSocketCallback for WebSocket {
async fn run<Reader, Writer>(
self,
mut rx: SocketRx<Reader>,
mut tx: SocketTx<Writer>,
) -> Result<(), Writer::Error>
where
Reader: embedded_aio::Read,
Writer: embedded_aio::Write<Error = Reader::Error>,
{
let mut buffer = [0; 1024];
tx.send_text("Connected").await?;
let close_reason = loop {
match rx.next_message(&mut buffer).await {
Ok(Message::Pong(_)) => continue,
Ok(Message::Ping(data)) => tx.send_pong(data).await?,
Ok(Message::Close(reason)) => {
tracing::info!(?reason, "websocket closed");
break None;
}
Ok(Message::Text(data)) => match serde_json::from_str::<SystemCommand>(data) {
Ok(SystemCommand::I(i2c_cmd)) => {
I2C_CHANNEL.send(i2c_cmd).await;
tx.send_text("I2C command received and forwarded").await?;
}
Ok(SystemCommand::L(led_cmd)) => {
LED_CHANNEL.send(led_cmd).await;
tx.send_text("LED command received and forwarded").await?;
}
Err(error) => {
tracing::error!(?error, "error deserializing SystemCommand");
tx.send_text("Invalid command format").await?
}
},
Ok(Message::Binary(data)) => match serde_json::from_slice::<SystemCommand>(data) {
Ok(SystemCommand::I(i2c_cmd)) => {
I2C_CHANNEL.send(i2c_cmd).await;
tx.send_binary(b"I2C command received and forwarded")
.await?
}
Ok(SystemCommand::L(led_cmd)) => {
LED_CHANNEL.send(led_cmd).await;
tx.send_binary(b"LED command received and forwarded")
.await?
}
Err(error) => {
tracing::error!(?error, "error deserializing incoming message");
tx.send_binary(b"Invalid command format").await?
}
},
Err(error) => {
tracing::error!(?error, "websocket error");
let code = match error {
ReadMessageError::TextIsNotUtf8 => 1007,
ReadMessageError::ReservedOpcode(_) => 1003,
ReadMessageError::ReadFrameError(_)
| ReadMessageError::UnexpectedMessageStart
| ReadMessageError::MessageStartsWithContinuation => 1002,
ReadMessageError::Io(err) => return Err(err),
};
break Some((code, "Websocket Error"));
}
};
};
tx.close(close_reason).await
}
}
#[allow(dead_code)]
impl SessionManager {
pub async fn create_session(
session_id: String,
timestamp: u64,
) {
SESSION_STORE.lock().await.insert(
session_id,
SessionState {
last_seen: timestamp,
},
);
}
pub async fn get_session(session_id: &str) -> Option<SessionState> {
SESSION_STORE.lock().await.get(session_id).cloned()
}
pub async fn update_session(
session_id: &str,
timestamp: u64,
) -> bool {
if let Some(session) = SESSION_STORE.lock().await.get_mut(session_id) {
session.last_seen = timestamp;
true
} else {
false
}
}
pub async fn remove_session(session_id: &str) -> bool {
SESSION_STORE.lock().await.remove(session_id).is_some()
}
pub async fn purge_stale_sessions(threshold: u64) {
SESSION_STORE
.lock()
.await
.retain(|_id, session| session.last_seen >= threshold);
}
pub async fn list_sessions() -> Vec<String> {
SESSION_STORE.lock().await.keys().cloned().collect()
}
}
pub async fn run(
id: usize,
port: u16,
stack: Stack<'static>,
config: Option<&'static picoserve::Config<Duration>>,
) -> ! {
let default_config = picoserve::Config::new(picoserve::Timeouts {
start_read_request: Some(Duration::from_secs(5)),
persistent_start_read_request: None,
read_request: Some(Duration::from_secs(1)),
write: Some(Duration::from_secs(5)),
});
let config = config.unwrap_or(&default_config);
let router = Router::new()
.route(
"/",
picoserve::routing::get(|| async {
picoserve::response::Response::new(
StatusCode::OK,
HTML, )
.with_headers([
("Content-Type", "text/html; charset=utf-8"),
("Content-Encoding", "gzip"),
])
}),
)
.route(
"/style.css",
picoserve::routing::get(|| async {
picoserve::response::Response::new(
StatusCode::OK,
CSS, )
.with_headers([
("Content-Type", "text/css; charset=utf-8"),
("Content-Encoding", "gzip"),
])
}),
)
.route(
"/script.js",
picoserve::routing::get(|| async {
picoserve::response::Response::new(
StatusCode::OK,
JAVA, )
.with_headers([
("Content-Type", "application/javascript; charset=utf-8"),
("Content-Encoding", "gzip"),
])
}),
)
.route(
"/ws",
picoserve::routing::get(|params: WsConnectionParams| async move {
let session_id = params.query.session;
tracing::info!("New WebSocket connection with session id: {}", session_id);
let now = embassy_time::Instant::now().as_secs();
SessionManager::create_session(session_id.clone(), now).await;
params
.upgrade
.on_upgrade(WebSocket)
.with_protocol("messages")
}),
);
if let Some(ip_cfg) = stack.config_v4() {
tracing::info!("Starting server at {}:{}", ip_cfg.address, port);
} else {
tracing::warn!(
"Starting WebSocket server on port {port}, but no IPv4 address is assigned yet!"
);
}
let (mut rx_buffer, mut tx_buffer, mut http_buffer) = ([0; 1024], [0; 1024], [0; 4096]);
picoserve::listen_and_serve_with_state(
id,
&router,
config,
stack,
port,
&mut rx_buffer,
&mut tx_buffer,
&mut http_buffer,
&(),
)
.await
}
#[derive(Debug, Deserialize)]
pub struct QueryParams {
session: String,
}
pub struct WsConnectionParams {
pub upgrade: WebSocketUpgrade,
pub query: QueryParams,
}
impl<'r, S> FromRequest<'r, S> for WsConnectionParams {
type Rejection = &'static str;
async fn from_request<R: Read>(
state: &'r S,
parts: RequestParts<'r>,
body: RequestBody<'r, R>,
) -> Result<Self, Self::Rejection> {
let upgrade = WebSocketUpgrade::from_request(state, parts.clone(), body)
.await
.map_err(|_| "Failed to extract WebSocketUpgrade")?;
let query_str = parts.query().ok_or("Missing query parameters")?;
let query =
deserialize_form::<QueryParams>(query_str).map_err(|_| "Invalid query parameters")?;
if query.session.is_empty() {
return Err("Session ID is required");
}
Ok(WsConnectionParams { upgrade, query })
}
}