1#![cfg_attr(not(test), no_std)]
78
79use crate::glob_match::glob_match;
80use crate::stackfuture::StackFuture;
81use core::clone::Clone;
82use core::cmp::{Eq, PartialEq};
83use core::default::Default;
84use core::fmt::Debug;
85use core::format_args;
86use core::iter::Iterator;
87use core::marker::Copy;
88use core::option::Option::{self, *};
89use core::prelude::v1::derive;
90use core::result::Result::{self, *};
91use embassy_futures::select::{select, Either};
92use embassy_sync::{
93 blocking_mutex::raw::CriticalSectionRawMutex,
94 mutex::Mutex,
95 pubsub::{PubSubChannel, WaitResult},
96};
97use embedded_io_async::{Read, Write};
98use heapless::{FnvIndexMap, String, Vec};
99use serde::{Deserialize, Serialize};
100
101#[cfg(feature = "defmt")]
102use defmt::{debug, error, warn};
103
104#[cfg(feature = "embassy-time")]
105use embassy_time::{with_timeout, Duration};
106
107mod glob_match;
108pub mod stackfuture;
109
110pub const DEFAULT_MAX_CLIENTS: usize = 4;
112pub const DEFAULT_MAX_HANDLERS: usize = 8;
114pub const DEFAULT_MAX_MESSAGE_LEN: usize = 1460;
117pub const DEFAULT_HANDLER_STACK_SIZE: usize = 256;
120pub const DEFAULT_NOTIFICATION_QUEUE_SIZE: usize = 1;
123pub const DEFAULT_WRITE_TIMEOUT_MS: u64 = 5000;
126pub const DEFAULT_HANDLER_TIMEOUT_MS: u64 = 5000;
129
130pub const JSONRPC_VERSION: &str = "2.0";
134
135#[derive(Debug, Deserialize, Serialize)]
137pub struct RpcRequest<'a, T> {
138 pub jsonrpc: &'a str,
139 pub id: Option<u64>,
140 pub method: &'a str,
141 pub params: Option<T>,
142}
143
144#[derive(Debug, Deserialize)]
145struct RpcRequestMetadata<'a> {
146 pub jsonrpc: &'a str,
147 pub id: Option<u64>,
148 pub method: &'a str,
149}
150
151#[derive(Debug, Deserialize, Serialize)]
153pub struct RpcResponse<'a, T> {
154 pub jsonrpc: &'a str,
155 pub id: Option<u64>,
156 pub error: Option<RpcError>,
157 pub result: Option<T>,
158}
159
160#[allow(dead_code)]
162#[derive(Clone, Copy, Debug)]
163#[cfg_attr(feature = "defmt", derive(defmt::Format))]
164pub enum RpcErrorCode {
165 ParseError = -32700,
166 InvalidRequest = -32600,
167 MethodNotFound = -32601,
168 InvalidParams = -32602,
169 InternalError = -32603,
170}
171
172impl RpcErrorCode {
173 pub fn message(self) -> &'static str {
175 match self {
176 RpcErrorCode::ParseError => "Invalid JSON.",
177 RpcErrorCode::InvalidRequest => "Invalid request.",
178 RpcErrorCode::MethodNotFound => "Method not found.",
179 RpcErrorCode::InvalidParams => "Invalid parameters.",
180 RpcErrorCode::InternalError => "Internal error.",
181 }
182 }
183}
184
185impl Serialize for RpcErrorCode {
186 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
187 where
188 S: serde::ser::Serializer,
189 {
190 (*self as i32).serialize(serializer)
191 }
192}
193
194impl<'a> Deserialize<'a> for RpcErrorCode {
195 fn deserialize<D>(deserializer: D) -> Result<RpcErrorCode, D::Error>
196 where
197 D: serde::de::Deserializer<'a>,
198 {
199 let code = i32::deserialize(deserializer)?;
200 match code {
201 -32700 => Ok(RpcErrorCode::ParseError),
202 -32600 => Ok(RpcErrorCode::InvalidRequest),
203 -32601 => Ok(RpcErrorCode::MethodNotFound),
204 -32602 => Ok(RpcErrorCode::InvalidParams),
205 -32603 => Ok(RpcErrorCode::InternalError),
206 _ => Err(serde::de::Error::custom("Invalid error code")),
207 }
208 }
209}
210
211#[derive(Debug, Deserialize, Serialize)]
213#[cfg_attr(feature = "defmt", derive(defmt::Format))]
214pub struct RpcError {
215 pub code: RpcErrorCode,
216 pub message: String<32>,
217}
218
219impl From<RpcErrorCode> for RpcError {
220 fn from(code: RpcErrorCode) -> Self {
221 RpcError {
222 code,
223 message: String::try_from(code.message()).unwrap(),
224 }
225 }
226}
227
228#[derive(PartialEq, Eq, Clone, Debug)]
230#[cfg_attr(feature = "defmt", derive(defmt::Format))]
231pub enum RpcServerError<E> {
232 BufferOverflow,
234 IoError(E),
236 ParseError,
238 TooManyHandlers,
241 TimeoutError,
244}
245
246pub trait RpcHandler<const STACK_SIZE: usize = DEFAULT_HANDLER_STACK_SIZE>: Sync {
250 fn handle<'a>(
251 &'a self,
252 id: Option<u64>,
253 method: &'a str,
254 request_json: &'a [u8],
255 response_json: &'a mut [u8],
256 ) -> StackFuture<'a, Result<usize, RpcError>, STACK_SIZE>;
257}
258
259pub struct RpcServer<
261 'a,
262 StreamError,
263 const MAX_CLIENTS: usize = DEFAULT_MAX_CLIENTS,
264 const MAX_HANDLERS: usize = DEFAULT_MAX_HANDLERS,
265 const MAX_MESSAGE_LEN: usize = DEFAULT_MAX_MESSAGE_LEN,
266 const HANDLER_STACK_SIZE: usize = DEFAULT_HANDLER_STACK_SIZE,
267 const NOTIFICATION_QUEUE_SIZE: usize = DEFAULT_NOTIFICATION_QUEUE_SIZE,
268> {
269 handlers: FnvIndexMap<&'a str, &'a dyn RpcHandler<HANDLER_STACK_SIZE>, MAX_HANDLERS>,
270 notifications: PubSubChannel<
271 CriticalSectionRawMutex,
272 Vec<u8, MAX_MESSAGE_LEN>,
273 NOTIFICATION_QUEUE_SIZE,
274 MAX_CLIENTS,
275 1,
276 >,
277 notification_publisher_mutex: Mutex<CriticalSectionRawMutex, ()>,
278 _phantom: core::marker::PhantomData<StreamError>,
279}
280
281impl<
282 StreamError,
283 const MAX_CLIENTS: usize,
284 const MAX_HANDLERS: usize,
285 const MAX_MESSAGE_LEN: usize,
286 const HANDLER_STACK_SIZE: usize,
287 const NOTIFICATION_QUEUE_SIZE: usize,
288 > Default
289 for RpcServer<
290 '_,
291 StreamError,
292 MAX_CLIENTS,
293 MAX_HANDLERS,
294 MAX_MESSAGE_LEN,
295 HANDLER_STACK_SIZE,
296 NOTIFICATION_QUEUE_SIZE,
297 >
298{
299 fn default() -> Self {
300 Self::new()
301 }
302}
303
304impl<
305 'a,
306 StreamError,
307 const MAX_CLIENTS: usize,
308 const MAX_HANDLERS: usize,
309 const MAX_MESSAGE_LEN: usize,
310 const HANDLER_STACK_SIZE: usize,
311 const NOTIFICATION_QUEUE_SIZE: usize,
312 >
313 RpcServer<
314 'a,
315 StreamError,
316 MAX_CLIENTS,
317 MAX_HANDLERS,
318 MAX_MESSAGE_LEN,
319 HANDLER_STACK_SIZE,
320 NOTIFICATION_QUEUE_SIZE,
321 >
322{
323 pub fn new() -> Self {
325 #[cfg(feature = "defmt")]
326 debug!("Initializing new RPC server");
327
328 Self {
329 handlers: FnvIndexMap::new(),
330 notifications: PubSubChannel::new(),
331 notification_publisher_mutex: Mutex::new(()),
332 _phantom: core::marker::PhantomData,
333 }
334 }
335
336 pub fn register_handler(
338 &mut self,
339 method_name_glob: &'a str,
340 handler: &'a dyn RpcHandler<HANDLER_STACK_SIZE>,
341 ) -> Result<(), RpcServerError<StreamError>> {
342 #[cfg(feature = "defmt")]
343 debug!("Registering method: {}", method_name_glob);
344
345 if self.handlers.insert(method_name_glob, handler).is_err() {
346 #[cfg(feature = "defmt")]
347 warn!(
348 "Failed to register method (too many handlers): {}",
349 method_name_glob
350 );
351 return Err(RpcServerError::TooManyHandlers);
352 }
353
354 Ok(())
355 }
356
357 pub async fn notify(
359 &self,
360 notification_json: &[u8],
361 ) -> Result<(), RpcServerError<StreamError>> {
362 #[cfg(feature = "defmt")]
363 debug!("Broadcasting notification");
364
365 let mut headers: String<32> = String::new();
366 core::fmt::write(
367 &mut headers,
368 format_args!("Content-Length: {}\r\n\r\n", notification_json.len()),
369 )
370 .unwrap();
371
372 if headers.len() + notification_json.len() > MAX_MESSAGE_LEN {
373 #[cfg(feature = "defmt")]
374 error!("Broadcast message too large");
375 return Err(RpcServerError::BufferOverflow);
376 }
377
378 let mut framed_message: heapless::Vec<u8, MAX_MESSAGE_LEN> = heapless::Vec::new();
379 framed_message
380 .extend_from_slice(headers.as_bytes())
381 .unwrap();
382 framed_message.extend_from_slice(notification_json).unwrap();
383
384 {
385 let _lock = self.notification_publisher_mutex.lock().await;
386 let notifications = self.notifications.publisher().unwrap();
387 notifications.publish(framed_message).await;
388 }
389
390 Ok(())
391 }
392
393 pub async fn serve<Stream: Read<Error = StreamError> + Write<Error = StreamError>>(
395 &self,
396 stream: &mut Stream,
397 ) -> Result<(), RpcServerError<StreamError>> {
398 #[cfg(feature = "defmt")]
399 debug!("Starting RPC server");
400
401 let mut notifications = self.notifications.subscriber().unwrap();
402 let mut request_buffer = [0u8; MAX_MESSAGE_LEN];
403 let mut response_json = [0u8; MAX_MESSAGE_LEN];
404 let mut read_offset = 0;
405
406 loop {
407 #[cfg(feature = "defmt")]
408 debug!("Waiting for data from client");
409
410 let result = select(
411 notifications.next_message(),
412 stream.read(&mut request_buffer[read_offset..]),
413 )
414 .await;
415
416 match result {
417 Either::First(WaitResult::Message(notification_json)) => {
418 #[cfg(feature = "defmt")]
419 debug!("Writing notification");
420
421 #[cfg(feature = "embassy-time")]
422 {
423 with_timeout(
424 Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
425 stream.write_all(¬ification_json),
426 )
427 .await
428 .map_err(|_| RpcServerError::TimeoutError)?
429 .map_err(RpcServerError::IoError)?;
430
431 with_timeout(
432 Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
433 stream.flush(),
434 )
435 .await
436 .map_err(|_| RpcServerError::TimeoutError)?
437 .map_err(RpcServerError::IoError)?;
438 }
439
440 #[cfg(not(feature = "embassy-time"))]
441 {
442 stream
443 .write_all(¬ification_json)
444 .await
445 .map_err(RpcServerError::IoError)?;
446
447 stream.flush().await.map_err(RpcServerError::IoError)?;
448 }
449
450 #[cfg(feature = "defmt")]
451 debug!("Notification sent to client");
452
453 continue;
454 }
455 Either::First(WaitResult::Lagged(x)) => {
456 #[cfg(feature = "defmt")]
457 warn!("Dropped {} notifications due to lag", x);
458 }
459 Either::Second(Ok(0)) => {
460 #[cfg(feature = "defmt")]
461 debug!("Client disconnected");
462 return Ok(());
463 }
464 Either::Second(Ok(n)) => {
465 #[cfg(feature = "defmt")]
466 debug!("Received {} bytes from client", n);
467
468 read_offset += n;
469 while let Some(headers_len) =
470 Self::parse_headers(&request_buffer[..read_offset])
471 {
472 let content_len =
473 Self::parse_content_length(&mut request_buffer[..headers_len])?;
474 let total_message_len = headers_len + content_len;
475
476 if read_offset < total_message_len {
477 #[cfg(feature = "defmt")]
478 debug!("Incomplete message, waiting for more data");
479 break;
480 }
481
482 #[cfg(feature = "defmt")]
483 debug!("Received complete message, handling request");
484
485 let request_json = &request_buffer[headers_len..headers_len + content_len];
486 let response_json_len = self
487 .handle_request(request_json, &mut response_json)
488 .await?;
489
490 #[cfg(feature = "defmt")]
491 debug!("Sending response to client");
492
493 let mut headers: String<32> = String::new();
494 core::fmt::write(
495 &mut headers,
496 format_args!("Content-Length: {}\r\n\r\n", response_json_len),
497 )
498 .unwrap();
499
500 if headers.len() + response_json_len > MAX_MESSAGE_LEN {
501 #[cfg(feature = "defmt")]
502 error!("Response message too large");
503 return Err(RpcServerError::BufferOverflow);
504 }
505
506 #[cfg(feature = "defmt")]
507 debug!("Writing response");
508
509 #[cfg(feature = "embassy-time")]
510 {
511 with_timeout(
512 Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
513 stream.write_all(headers.as_bytes()),
514 )
515 .await
516 .map_err(|_| RpcServerError::TimeoutError)?
517 .map_err(RpcServerError::IoError)?;
518
519 with_timeout(
520 Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
521 stream.write_all(&response_json[..response_json_len]),
522 )
523 .await
524 .map_err(|_| RpcServerError::TimeoutError)?
525 .map_err(RpcServerError::IoError)?;
526
527 with_timeout(
528 Duration::from_millis(DEFAULT_WRITE_TIMEOUT_MS),
529 stream.flush(),
530 )
531 .await
532 .map_err(|_| RpcServerError::TimeoutError)?
533 .map_err(RpcServerError::IoError)?;
534 }
535
536 #[cfg(not(feature = "embassy-time"))]
537 {
538 stream
539 .write_all(headers.as_bytes())
540 .await
541 .map_err(RpcServerError::IoError)?;
542
543 stream
544 .write_all(&response_json[..response_json_len])
545 .await
546 .map_err(RpcServerError::IoError)?;
547
548 stream.flush().await.map_err(RpcServerError::IoError)?;
549 }
550
551 #[cfg(feature = "defmt")]
552 debug!("Response sent to client");
553 let remaining = read_offset - total_message_len;
554 request_buffer.copy_within(total_message_len..read_offset, 0);
555 read_offset = remaining;
556 }
557 }
558 Either::Second(Err(e)) => {
559 #[cfg(feature = "defmt")]
560 error!("IO error during stream read");
561 return Err(RpcServerError::IoError(e));
562 }
563 }
564 }
565 }
566
567 async fn handle_request(
569 &self,
570 request_json: &'a [u8],
571 response_json: &'a mut [u8],
572 ) -> Result<usize, RpcServerError<StreamError>> {
573 #[cfg(feature = "defmt")]
574 debug!("Handling request");
575
576 let request: RpcRequestMetadata = match serde_json_core::from_slice(request_json) {
577 Ok((request, _remainder)) => request,
578 Err(_) => {
579 #[cfg(feature = "defmt")]
580 warn!("Failed to parse request JSON");
581
582 let response: RpcResponse<'_, ()> = RpcResponse {
583 jsonrpc: JSONRPC_VERSION,
584 error: Some(RpcErrorCode::ParseError.into()),
585 id: None,
586 result: None,
587 };
588
589 return Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap());
590 }
591 };
592
593 let id = request.id;
594
595 if request.jsonrpc != JSONRPC_VERSION {
596 #[cfg(feature = "defmt")]
597 warn!("Unsupported JSON-RPC version");
598
599 let response: RpcResponse<'_, ()> = RpcResponse {
600 jsonrpc: JSONRPC_VERSION,
601 error: Some(RpcErrorCode::InvalidRequest.into()),
602 result: None,
603 id,
604 };
605
606 return Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap());
607 }
608
609 #[cfg(feature = "defmt")]
610 debug!("Dispatching method: {}", request.method);
611
612 let mut handler: Option<&dyn RpcHandler<HANDLER_STACK_SIZE>> = None;
613 for (method_name_glob, h) in self.handlers.iter() {
614 if glob_match(method_name_glob, request.method) {
615 #[cfg(feature = "defmt")]
616 debug!("Matched method: {}", method_name_glob);
617
618 handler = Some(*h);
619 }
620 }
621
622 if handler.is_none() {
623 #[cfg(feature = "defmt")]
624 warn!("Method not found: {}", request.method);
625
626 let response: RpcResponse<'_, ()> = RpcResponse {
627 jsonrpc: JSONRPC_VERSION,
628 error: Some(RpcErrorCode::MethodNotFound.into()),
629 result: None,
630 id,
631 };
632
633 return Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap());
634 }
635
636 #[cfg(feature = "embassy-time")]
637 let result = with_timeout(
638 Duration::from_millis(DEFAULT_HANDLER_TIMEOUT_MS),
639 handler.handle(id, request.method, request_json, response_json),
640 )
641 .await
642 .map_err(|_| RpcServerError::TimeoutError)?;
643
644 #[cfg(not(feature = "embassy-time"))]
645 let result = handler
646 .unwrap()
647 .handle(id, request.method, request_json, response_json)
648 .await;
649
650 match result {
651 Ok(response_len) => Ok(response_len),
652 Err(e) => {
653 #[cfg(feature = "defmt")]
654 error!("Handler returned error: {:?}", e);
655
656 let response: RpcResponse<'_, ()> = RpcResponse {
657 jsonrpc: JSONRPC_VERSION,
658 error: Some(e),
659 result: None,
660 id,
661 };
662
663 Ok(serde_json_core::to_slice(&response, &mut response_json[..]).unwrap())
664 }
665 }
666 }
667
668 fn parse_headers(buffer: &[u8]) -> Option<usize> {
670 buffer
671 .windows(4)
672 .position(|window| window == b"\r\n\r\n")
673 .map(|i| i + 4)
674 }
675
676 fn parse_content_length(buffer: &mut [u8]) -> Result<usize, RpcServerError<StreamError>> {
678 let headers = core::str::from_utf8_mut(buffer).map_err(|_| RpcServerError::ParseError)?;
679 headers.make_ascii_lowercase();
680 for line in headers.lines() {
681 if let Some(value) = line.strip_prefix("content-length:") {
682 return value.trim().parse().map_err(|_| RpcServerError::ParseError);
683 }
684 }
685 Err(RpcServerError::ParseError)
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692 use memory_pipe::MemoryPipe;
693 use std::sync::Arc;
694
695 #[cfg(feature = "defmt")]
696 use defmt_logger_tcp as _;
697
698 mod memory_pipe;
699
700 #[tokio::test]
701 async fn test_request_response() {
702 let mut server: RpcServer<'_, _> = RpcServer::new();
703 server.register_handler("echo", &EchoHandler).unwrap();
704
705 let (mut stream1, mut stream2) = MemoryPipe::new();
706
707 tokio::spawn(async move {
708 server.serve(&mut stream2).await.unwrap();
709 });
710
711 let request: RpcRequest<'_, ()> = RpcRequest {
712 jsonrpc: JSONRPC_VERSION,
713 id: Some(1),
714 method: "echo",
715 params: None,
716 };
717
718 let mut request_json = [0u8; 256];
719 let request_len = serde_json_core::to_slice(&request, &mut request_json).unwrap();
720
721 let request_message = format!(
723 "Content-Length: {}\r\n\r\n{}",
724 request_len,
725 core::str::from_utf8(&request_json[..request_len]).unwrap()
726 );
727 stream1.write_all(request_message.as_bytes()).await.unwrap();
728
729 let mut response_buffer = [0u8; DEFAULT_MAX_MESSAGE_LEN];
731 let response_len = stream1.read(&mut response_buffer).await.unwrap();
732
733 let response = core::str::from_utf8(&response_buffer[..response_len]).unwrap();
734
735 assert_eq!(
736 response,
737 "Content-Length: 51\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":1,\"error\":null,\"result\":null}"
738 );
739 }
740
741 #[tokio::test]
742 async fn test_notify() {
743 let server: Arc<RpcServer<'_, _>> = Arc::new(RpcServer::new());
744
745 let server_clone = Arc::clone(&server); let (mut stream1, mut stream2) = MemoryPipe::new();
747
748 tokio::spawn(async move {
750 server_clone.serve(&mut stream2).await.unwrap();
751 });
752
753 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
755
756 let notification: RpcRequest<'_, ()> = RpcRequest {
758 jsonrpc: JSONRPC_VERSION,
759 method: "notify",
760 id: None,
761 params: None,
762 };
763
764 let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
765 let notification_len =
766 serde_json_core::to_slice(¬ification, &mut notification_json).unwrap();
767
768 server
770 .notify(¬ification_json[..notification_len])
771 .await
772 .unwrap();
773
774 let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
776 let notification_len = stream1.read(&mut notification_json).await.unwrap();
777
778 let notification_json =
779 core::str::from_utf8(¬ification_json[..notification_len]).unwrap();
780
781 assert_eq!(
782 notification_json,
783 "Content-Length: 59\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":null,\"method\":\"notify\",\"params\":null}",
784 );
785 }
786
787 struct EchoHandler;
788
789 impl RpcHandler for EchoHandler {
790 fn handle<'a>(
791 &self,
792 id: Option<u64>,
793 _method: &'a str,
794 _request_json: &'a [u8],
795 response_json: &'a mut [u8],
796 ) -> StackFuture<'a, Result<usize, RpcError>, DEFAULT_HANDLER_STACK_SIZE> {
797 StackFuture::from(async move {
798 let response: RpcResponse<'static, ()> = RpcResponse {
799 jsonrpc: JSONRPC_VERSION,
800 error: None,
801 result: None,
802 id,
803 };
804
805 Ok(serde_json_core::to_slice(&response, response_json).unwrap())
806 })
807 }
808 }
809}