#![cfg_attr(not(test), no_std)]
use core::clone::Clone;
use core::cmp::{Eq, PartialEq};
use core::default::Default;
use core::fmt::Debug;
use core::format_args;
use core::iter::Iterator;
use core::marker::Copy;
use core::option::Option::{self, *};
use core::prelude::v1::derive;
use core::result::Result::{self, *};
use embassy_futures::select::{select, Either};
use embassy_sync::{
blocking_mutex::raw::CriticalSectionRawMutex,
pubsub::{PubSubChannel, WaitResult},
};
use embassy_time::{with_timeout, Duration};
use embedded_io_async::{Read, Write};
use heapless::{FnvIndexMap, String, Vec};
use serde::{Deserialize, Serialize};
use stackfuture::StackFuture;
#[cfg(feature = "defmt")]
use defmt::*;
pub mod stackfuture;
pub const DEFAULT_MAX_CLIENTS: usize = 4;
pub const DEFAULT_MAX_HANDLERS: usize = 8;
pub const DEFAULT_MAX_MESSAGE_LEN: usize = 1460;
pub const DEFAULT_STACK_SIZE: usize = 256;
pub const DEFAULT_WRITE_TIMEOUT_MS: u64 = 5000;
pub const DEFAULT_HANDLER_TIMEOUT_MS: u64 = 5000;
pub const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, Deserialize, Serialize)]
pub struct RpcRequest<'a, T> {
pub jsonrpc: &'a str,
pub id: Option<u64>,
pub method: &'a str,
pub params: Option<T>,
}
#[derive(Debug, Deserialize)]
struct RpcRequestMetadata<'a> {
pub jsonrpc: &'a str,
pub id: Option<u64>,
pub method: &'a str,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RpcResponse<'a, T> {
pub jsonrpc: &'a str,
pub id: Option<u64>,
pub error: Option<RpcError>,
pub result: Option<T>,
}
#[allow(dead_code)]
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum RpcErrorCode {
ParseError = -32700,
InvalidRequest = -32600,
MethodNotFound = -32601,
InvalidParams = -32602,
InternalError = -32603,
}
impl RpcErrorCode {
pub fn message(self) -> &'static str {
match self {
RpcErrorCode::ParseError => "Invalid JSON.",
RpcErrorCode::InvalidRequest => "Invalid request.",
RpcErrorCode::MethodNotFound => "Method not found.",
RpcErrorCode::InvalidParams => "Invalid parameters.",
RpcErrorCode::InternalError => "Internal error.",
}
}
}
impl Serialize for RpcErrorCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
(*self as i32).serialize(serializer)
}
}
impl<'a> Deserialize<'a> for RpcErrorCode {
fn deserialize<D>(deserializer: D) -> Result<RpcErrorCode, D::Error>
where
D: serde::de::Deserializer<'a>,
{
let code = i32::deserialize(deserializer)?;
match code {
-32700 => Ok(RpcErrorCode::ParseError),
-32600 => Ok(RpcErrorCode::InvalidRequest),
-32601 => Ok(RpcErrorCode::MethodNotFound),
-32602 => Ok(RpcErrorCode::InvalidParams),
-32603 => Ok(RpcErrorCode::InternalError),
_ => Err(serde::de::Error::custom("Invalid error code")),
}
}
}
#[derive(Debug, Deserialize, Serialize)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct RpcError {
pub code: RpcErrorCode,
pub message: String<32>,
}
impl RpcError {
pub fn from_code(code: RpcErrorCode) -> Self {
RpcError {
code,
message: String::try_from(code.message()).unwrap(),
}
}
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum RpcServerError {
BufferOverflow,
IoError,
ParseError,
TooManyHandlers,
TimeoutError,
}
pub trait RpcHandler<const STACK_SIZE: usize = DEFAULT_STACK_SIZE>: Sync {
fn handle<'a>(
&'a self,
id: Option<u64>,
method: &'a str,
request_json: &'a [u8],
response_json: &'a mut [u8],
) -> StackFuture<'a, Result<usize, RpcError>, STACK_SIZE>;
}
pub struct RpcServer<
'a,
const MAX_CLIENTS: usize = DEFAULT_MAX_CLIENTS,
const MAX_HANDLERS: usize = DEFAULT_MAX_HANDLERS,
const MAX_MESSAGE_LEN: usize = DEFAULT_MAX_MESSAGE_LEN,
const STACK_SIZE: usize = DEFAULT_STACK_SIZE,
> {
handlers: FnvIndexMap<&'a str, &'a dyn RpcHandler<STACK_SIZE>, MAX_HANDLERS>,
notifications:
PubSubChannel<CriticalSectionRawMutex, Vec<u8, MAX_MESSAGE_LEN>, 1, MAX_CLIENTS, 1>,
}
impl<
'a,
const MAX_CLIENTS: usize,
const MAX_HANDLERS: usize,
const MAX_MESSAGE_LEN: usize,
const STACK_SIZE: usize,
> Default for RpcServer<'a, MAX_CLIENTS, MAX_HANDLERS, MAX_MESSAGE_LEN, STACK_SIZE>
{
fn default() -> Self {
Self::new()
}
}
impl<
'a,
const MAX_CLIENTS: usize,
const MAX_HANDLERS: usize,
const MAX_MESSAGE_LEN: usize,
const STACK_SIZE: usize,
> RpcServer<'a, MAX_CLIENTS, MAX_HANDLERS, MAX_MESSAGE_LEN, STACK_SIZE>
{
pub fn new() -> Self {
#[cfg(feature = "defmt")]
debug!("Initializing new RPC server");
Self {
handlers: FnvIndexMap::new(),
notifications: PubSubChannel::new(),
}
}
pub fn register_method(
&mut self,
name: &'a str,
handler: &'a dyn RpcHandler<STACK_SIZE>,
) -> Result<(), RpcServerError> {
#[cfg(feature = "defmt")]
debug!("Registering method: {}", name);
if self.handlers.insert(name, handler).is_err() {
#[cfg(feature = "defmt")]
warn!("Failed to register method (too many handlers): {}", name);
return Err(RpcServerError::TooManyHandlers);
}
Ok(())
}
pub async fn notify(&self, notification_json: &[u8]) -> Result<(), RpcServerError> {
#[cfg(feature = "defmt")]
debug!("Broadcasting notification");
let mut headers: String<32> = String::new();
core::fmt::write(
&mut headers,
format_args!("Content-Length: {}\r\n\r\n", notification_json.len()),
)
.unwrap();
if headers.len() + notification_json.len() > MAX_MESSAGE_LEN {
#[cfg(feature = "defmt")]
error!("Broadcast message too large");
return Err(RpcServerError::BufferOverflow);
}
let mut framed_message: heapless::Vec<u8, MAX_MESSAGE_LEN> = heapless::Vec::new();
framed_message
.extend_from_slice(headers.as_bytes())
.unwrap();
framed_message.extend_from_slice(notification_json).unwrap();
let notifications = self.notifications.publisher().unwrap();
notifications.publish(framed_message).await;
Ok(())
}
pub async fn serve<T: Read + Write>(&self, stream: &mut T) -> Result<(), RpcServerError> {
#[cfg(feature = "defmt")]
debug!("Starting RPC server");
let mut notifications = self.notifications.subscriber().unwrap();
let mut request_buffer = [0u8; MAX_MESSAGE_LEN];
let mut response_json = [0u8; MAX_MESSAGE_LEN];
let mut read_offset = 0;
loop {
#[cfg(feature = "defmt")]
debug!("Waiting for data from client");
let result = select(
notifications.next_message(),
stream.read(&mut request_buffer[read_offset..]),
)
.await;
match result {
Either::First(WaitResult::Message(notification_json)) => {
#[cfg(feature = "defmt")]
debug!("Writing notification");
with_timeout(
Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
stream.write_all(¬ification_json),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
.map_err(|_| RpcServerError::IoError)?;
with_timeout(
Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
stream.flush(),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
.map_err(|_| RpcServerError::IoError)?;
#[cfg(feature = "defmt")]
debug!("Notification sent to client");
continue;
}
Either::First(WaitResult::Lagged(x)) => {
#[cfg(feature = "defmt")]
warn!("Dropped {} notifications due to lag", x);
}
Either::Second(Ok(0)) => {
#[cfg(feature = "defmt")]
debug!("Client disconnected");
return Ok(());
}
Either::Second(Ok(n)) => {
#[cfg(feature = "defmt")]
debug!("Received {} bytes from client", n);
read_offset += n;
while let Some(headers_len) =
Self::parse_headers(&request_buffer[..read_offset])
{
let content_len =
Self::parse_content_length(&mut request_buffer[..headers_len])?;
let total_message_len = headers_len + content_len;
if read_offset < total_message_len {
#[cfg(feature = "defmt")]
debug!("Incomplete message, waiting for more data");
break;
}
#[cfg(feature = "defmt")]
debug!("Received complete message, handling request");
let request_json = &request_buffer[headers_len..headers_len + content_len];
let response_json_len = self
.handle_request(request_json, &mut response_json)
.await?;
#[cfg(feature = "defmt")]
debug!("Sending response to client");
let mut headers: String<32> = String::new();
core::fmt::write(
&mut headers,
format_args!("Content-Length: {}\r\n\r\n", response_json_len),
)
.unwrap();
if headers.len() + response_json_len > MAX_MESSAGE_LEN {
#[cfg(feature = "defmt")]
error!("Response message too large");
return Err(RpcServerError::BufferOverflow);
}
#[cfg(feature = "defmt")]
debug!("Writing response");
with_timeout(
Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
stream.write_all(headers.as_bytes()),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
.map_err(|_| RpcServerError::IoError)?;
with_timeout(
Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
stream.write_all(&response_json[..response_json_len]),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
.map_err(|_| RpcServerError::IoError)?;
with_timeout(
Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
stream.flush(),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
.map_err(|_| RpcServerError::IoError)?;
#[cfg(feature = "defmt")]
debug!("Response sent to client");
let remaining = read_offset - total_message_len;
request_buffer.copy_within(total_message_len..read_offset, 0);
read_offset = remaining;
}
}
Either::Second(Err(_)) => {
#[cfg(feature = "defmt")]
error!("IO error during stream read");
return Err(RpcServerError::IoError);
}
}
}
}
async fn handle_request(
&self,
request_json: &'a [u8],
response_json: &'a mut [u8],
) -> Result<usize, RpcServerError> {
#[cfg(feature = "defmt")]
debug!("Handling request");
let request: RpcRequestMetadata = match serde_json_core::from_slice(request_json) {
Ok((request, _remainder)) => request,
Err(_) => {
#[cfg(feature = "defmt")]
warn!("Failed to parse request JSON");
let response: RpcResponse<'_, ()> = RpcResponse {
jsonrpc: JSONRPC_VERSION,
error: Some(RpcError::from_code(RpcErrorCode::ParseError)),
id: None,
result: None,
};
return Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap());
}
};
let id = request.id;
if request.jsonrpc != JSONRPC_VERSION {
#[cfg(feature = "defmt")]
warn!("Unsupported JSON-RPC version");
let response: RpcResponse<'_, ()> = RpcResponse {
jsonrpc: JSONRPC_VERSION,
error: Some(RpcError::from_code(RpcErrorCode::InvalidRequest)),
result: None,
id,
};
return Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap());
}
#[cfg(feature = "defmt")]
debug!("Dispatching method: {}", request.method);
match self.handlers.get(request.method) {
Some(handler) => match with_timeout(
Duration::from_millis(DEFAULT_HANDLER_TIMEOUT_MS),
handler.handle(id, request.method, request_json, response_json),
)
.await
.map_err(|_| RpcServerError::TimeoutError)?
{
Ok(response_len) => Ok(response_len),
Err(e) => {
#[cfg(feature = "defmt")]
error!("Handler returned error: {:?}", e);
let response: RpcResponse<'_, ()> = RpcResponse {
jsonrpc: JSONRPC_VERSION,
error: Some(e),
result: None,
id,
};
Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap())
}
},
None => {
#[cfg(feature = "defmt")]
warn!("Method not found: {}", request.method);
let response: RpcResponse<'_, ()> = RpcResponse {
jsonrpc: JSONRPC_VERSION,
error: Some(RpcError::from_code(RpcErrorCode::MethodNotFound)),
result: None,
id,
};
Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap())
}
}
}
fn parse_headers(buffer: &[u8]) -> Option<usize> {
return buffer
.windows(4)
.position(|window| window == b"\r\n\r\n")
.map(|i| i + 4);
}
fn parse_content_length(buffer: &mut [u8]) -> Result<usize, RpcServerError> {
let headers = core::str::from_utf8_mut(buffer).map_err(|_| RpcServerError::ParseError)?;
headers.make_ascii_lowercase();
for line in headers.lines() {
if let Some(value) = line.strip_prefix("content-length:") {
return value.trim().parse().map_err(|_| RpcServerError::ParseError);
}
}
Err(RpcServerError::ParseError)
}
}
#[cfg(test)]
mod tests {
use super::*;
use memory_pipe::MemoryPipe;
use std::sync::Arc;
mod memory_pipe;
#[tokio::test]
async fn test_request_response() {
let mut server: RpcServer<'_> = RpcServer::new();
server.register_method("echo", &EchoHandler).unwrap();
let (mut stream1, mut stream2) = MemoryPipe::new();
tokio::spawn(async move {
server.serve(&mut stream2).await.unwrap();
});
let request: RpcRequest<'_, ()> = RpcRequest {
jsonrpc: JSONRPC_VERSION,
id: Some(1),
method: "echo",
params: None,
};
let mut request_json = [0u8; 256];
let request_len = serde_json_core::to_slice(&request, &mut request_json).unwrap();
let request_message = format!(
"Content-Length: {}\r\n\r\n{}",
request_len,
core::str::from_utf8(&request_json[..request_len]).unwrap()
);
stream1.write_all(request_message.as_bytes()).await.unwrap();
let mut response_buffer = [0u8; DEFAULT_MAX_MESSAGE_LEN];
let response_len = stream1.read(&mut response_buffer).await.unwrap();
let response = core::str::from_utf8(&response_buffer[..response_len]).unwrap();
assert_eq!(
response,
"Content-Length: 51\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":1,\"error\":null,\"result\":null}"
);
}
#[tokio::test]
async fn test_notify() {
let server: Arc<RpcServer<'_>> = Arc::new(RpcServer::new());
let server_clone = Arc::clone(&server); let (mut stream1, mut stream2) = MemoryPipe::new();
tokio::spawn(async move {
server_clone.serve(&mut stream2).await.unwrap();
});
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
let notification: RpcRequest<'_, ()> = RpcRequest {
jsonrpc: JSONRPC_VERSION,
method: "notify",
id: None,
params: None,
};
let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
let notification_len =
serde_json_core::to_slice(¬ification, &mut notification_json).unwrap();
server
.notify(¬ification_json[..notification_len])
.await
.unwrap();
let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
let notification_len = stream1.read(&mut notification_json).await.unwrap();
let notification_json =
core::str::from_utf8(¬ification_json[..notification_len]).unwrap();
assert_eq!(
notification_json,
"Content-Length: 59\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":null,\"method\":\"notify\",\"params\":null}",
);
}
struct EchoHandler;
impl RpcHandler for EchoHandler {
fn handle<'a>(
&self,
id: Option<u64>,
_method: &'a str,
_request_json: &'a [u8],
response_json: &'a mut [u8],
) -> StackFuture<'a, Result<usize, RpcError>, DEFAULT_STACK_SIZE> {
StackFuture::from(async move {
let response: RpcResponse<'static, ()> = RpcResponse {
jsonrpc: JSONRPC_VERSION,
error: None,
result: None,
id,
};
Ok(serde_json_core::to_slice(&response, response_json).unwrap())
})
}
}
}