#![cfg_attr(not(feature = "std"), no_std)]
use core::clone::Clone;
use core::cmp::{Eq, PartialEq};
use core::default::Default;
use core::fmt::Debug;
use core::format_args;
use core::iter::Iterator;
use core::marker::Copy;
use core::option::Option::{self, *};
use core::prelude::v1::derive;
use core::result::Result::{self, *};
use embassy_futures::select::{select, Either};
use embassy_sync::{
    blocking_mutex::raw::CriticalSectionRawMutex,
    pubsub::{PubSubChannel, WaitResult},
};
use embedded_io_async::{Read, Write};
use heapless::{FnvIndexMap, String, Vec};
use serde::{Deserialize, Serialize};
use stackfuture::StackFuture;
#[cfg(feature = "defmt")]
use defmt::*;
pub const JSONRPC_VERSION: &str = "2.0";
#[derive(Debug, Deserialize, Serialize)]
pub struct RpcRequest<'a, T> {
    pub jsonrpc: &'a str,
    pub id: Option<u64>,
    pub method: &'a str,
    pub params: Option<T>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RpcResponse<'a, T> {
    pub jsonrpc: &'a str,
    pub id: Option<u64>,
    pub error: Option<RpcError>,
    pub result: Option<T>,
}
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
#[allow(dead_code)]
pub enum RpcErrorCode {
    ParseError = -32700,
    InvalidRequest = -32600,
    MethodNotFound = -32601,
    InvalidParams = -32602,
    InternalError = -32603,
}
impl RpcErrorCode {
    pub fn message(self) -> &'static str {
        match self {
            RpcErrorCode::ParseError => "Invalid JSON.",
            RpcErrorCode::InvalidRequest => "Invalid request.",
            RpcErrorCode::MethodNotFound => "Method not found.",
            RpcErrorCode::InvalidParams => "Invalid parameters.",
            RpcErrorCode::InternalError => "Internal error.",
        }
    }
}
#[derive(Debug, Deserialize, Serialize)]
pub struct RpcError {
    pub code: RpcErrorCode,
    pub message: String<32>,
}
impl RpcError {
    pub fn from_code(code: RpcErrorCode) -> Self {
        RpcError {
            code,
            message: String::try_from(code.message()).unwrap(),
        }
    }
}
#[derive(PartialEq, Eq, Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum RpcServerError {
    BufferOverflow,
    IoError,
    ParseError,
}
pub const DEFAULT_MAX_CLIENTS: usize = 4;
pub const DEFAULT_MAX_HANDLERS: usize = 8;
pub const DEFAULT_MAX_MESSAGE_LEN: usize = 512;
pub const DEFAULT_STACK_SIZE: usize = 256;
pub trait RpcHandler<const STACK_SIZE: usize = DEFAULT_STACK_SIZE>: Debug + Sync {
    fn handle<'a>(
        &self,
        id: Option<u64>,
        request_json: &'a [u8],
        response_json: &'a mut [u8],
    ) -> StackFuture<'a, Result<usize, RpcError>, STACK_SIZE>;
}
pub struct RpcServer<
    'a,
    const MAX_CLIENTS: usize = DEFAULT_MAX_CLIENTS,
    const MAX_HANDLERS: usize = DEFAULT_MAX_HANDLERS,
    const MAX_MESSAGE_LEN: usize = DEFAULT_MAX_MESSAGE_LEN,
    const STACK_SIZE: usize = DEFAULT_STACK_SIZE,
> {
    handlers: FnvIndexMap<&'a str, &'a dyn RpcHandler<STACK_SIZE>, MAX_HANDLERS>,
    notifications:
        PubSubChannel<CriticalSectionRawMutex, Vec<u8, MAX_MESSAGE_LEN>, 2, MAX_CLIENTS, 1>,
}
impl<'a, const MAX_CLIENTS: usize, const MAX_HANDLERS: usize, const MAX_MESSAGE_LEN: usize> Default
    for RpcServer<'a, MAX_CLIENTS, MAX_HANDLERS, MAX_MESSAGE_LEN>
{
    fn default() -> Self {
        Self::new()
    }
}
impl<'a, const MAX_CLIENTS: usize, const MAX_HANDLERS: usize, const MAX_MESSAGE_LEN: usize>
    RpcServer<'a, MAX_CLIENTS, MAX_HANDLERS, MAX_MESSAGE_LEN>
{
    pub fn new() -> Self {
        Self {
            handlers: FnvIndexMap::new(),
            notifications: PubSubChannel::new(),
        }
    }
    pub fn register_method(&mut self, name: &'a str, handler: &'a dyn RpcHandler) {
        self.handlers.insert(name, handler).unwrap();
    }
    pub async fn notify(&self, notification_json: &[u8]) -> Result<(), RpcServerError> {
        let mut headers: String<32> = String::new();
        core::fmt::write(
            &mut headers,
            format_args!("Content-Length: {}\r\n\r\n", notification_json.len()),
        )
        .unwrap();
        if headers.len() + notification_json.len() > MAX_MESSAGE_LEN {
            return Err(RpcServerError::BufferOverflow);
        }
        let mut framed_message: heapless::Vec<u8, MAX_MESSAGE_LEN> = heapless::Vec::new();
        framed_message
            .extend_from_slice(headers.as_bytes())
            .unwrap();
        framed_message.extend_from_slice(notification_json).unwrap();
        let notifications = self.notifications.publisher().unwrap();
        notifications.publish(framed_message).await;
        Ok(())
    }
    pub async fn serve<T: Read + Write>(&self, stream: &mut T) -> Result<(), RpcServerError> {
        let mut notifications = self.notifications.subscriber().unwrap();
        let mut request_buffer = [0u8; MAX_MESSAGE_LEN];
        let mut response_json = [0u8; MAX_MESSAGE_LEN];
        let mut read_offset = 0;
        loop {
            let result = select(
                notifications.next_message(),
                stream.read(&mut request_buffer[read_offset..]),
            )
            .await;
            match result {
                Either::First(WaitResult::Message(notification_json)) => {
                    stream
                        .write_all(¬ification_json)
                        .await
                        .map_err(|_| RpcServerError::IoError)?;
                    stream.flush().await.map_err(|_| RpcServerError::IoError)?;
                    continue;
                }
                Either::First(WaitResult::Lagged(x)) => {
                    #[cfg(feature = "defmt")]
                    warn!("Dropped {:?} notifications", x);
                }
                Either::Second(Ok(0)) => return Ok(()),
                Either::Second(Ok(n)) => {
                    read_offset += n;
                    while let Some(headers_len) =
                        Self::parse_headers(&request_buffer[..read_offset])
                    {
                        let content_len: usize =
                            Self::parse_content_length(&mut request_buffer[..headers_len])?;
                        let total_message_len = headers_len + content_len;
                        if read_offset < total_message_len {
                            break;
                        }
                        let request_json = &request_buffer[headers_len..headers_len + content_len];
                        let response_json_len =
                            self.handle_request(request_json, &mut response_json).await;
                        let mut headers: String<32> = String::new();
                        core::fmt::write(
                            &mut headers,
                            format_args!("Content-Length: {}\r\n\r\n", response_json_len),
                        )
                        .unwrap();
                        if headers.len() + response_json_len > MAX_MESSAGE_LEN {
                            return Err(RpcServerError::BufferOverflow);
                        }
                        stream
                            .write_all(headers.as_bytes())
                            .await
                            .map_err(|_| RpcServerError::IoError)?;
                        stream
                            .write_all(&response_json[..response_json_len])
                            .await
                            .map_err(|_| RpcServerError::IoError)?;
                        stream.flush().await.map_err(|_| RpcServerError::IoError)?;
                        let remaining = read_offset - total_message_len;
                        request_buffer.copy_within(total_message_len..read_offset, 0);
                        read_offset = remaining;
                    }
                }
                Either::Second(Err(_)) => return Err(RpcServerError::IoError),
            }
        }
    }
    async fn handle_request(&self, request_json: &'a [u8], response_json: &'a mut [u8]) -> usize {
        let request: RpcRequest<'_, ()> = match serde_json_core::from_slice(request_json) {
            Ok((request, _remainder)) => request,
            Err(_) => {
                if let Ok(json_str) = core::str::from_utf8(request_json) {
                    #[cfg(feature = "defmt")]
                    warn!("Invalid JSON-RPC request: {}", json_str)
                } else {
                    #[cfg(feature = "defmt")]
                    warn!("Invalid JSON-RPC request: [non-UTF8 data]")
                }
                let response: RpcResponse<'_, ()> = RpcResponse {
                    jsonrpc: JSONRPC_VERSION,
                    error: Some(RpcError::from_code(RpcErrorCode::ParseError)),
                    id: None,
                    result: None,
                };
                return serde_json_core::to_slice(&response, &mut response_json[..]).unwrap();
            }
        };
        let id = request.id;
        if request.jsonrpc != JSONRPC_VERSION {
            let response: RpcResponse<'_, ()> = RpcResponse {
                jsonrpc: JSONRPC_VERSION,
                error: Some(RpcError::from_code(RpcErrorCode::InvalidRequest)),
                result: None,
                id,
            };
            return serde_json_core::to_slice(&response, &mut response_json[..]).unwrap();
        }
        return match self.handlers.get(request.method) {
            Some(handler) => match handler.handle(id, request_json, response_json).await {
                Ok(response_len) => response_len,
                Err(e) => {
                    let response: RpcResponse<'_, ()> = RpcResponse {
                        jsonrpc: JSONRPC_VERSION,
                        error: Some(e),
                        result: None,
                        id,
                    };
                    serde_json_core::to_slice(&response, &mut response_json[..]).unwrap()
                }
            },
            None => {
                let response: RpcResponse<'_, ()> = RpcResponse {
                    jsonrpc: JSONRPC_VERSION,
                    error: Some(RpcError::from_code(RpcErrorCode::MethodNotFound)),
                    result: None,
                    id,
                };
                serde_json_core::to_slice(&response, &mut response_json[..]).unwrap()
            }
        };
    }
    fn parse_headers(buffer: &[u8]) -> Option<usize> {
        return buffer
            .windows(4)
            .position(|window| window == b"\r\n\r\n")
            .map(|i| i + 4);
    }
    fn parse_content_length(buffer: &mut [u8]) -> Result<usize, RpcServerError> {
        let headers = core::str::from_utf8_mut(buffer).map_err(|_| RpcServerError::ParseError)?;
        headers.make_ascii_lowercase();
        for line in headers.lines() {
            if let Some(value) = line.strip_prefix("content-length:") {
                return value.trim().parse().map_err(|_| RpcServerError::ParseError);
            }
        }
        Err(RpcServerError::ParseError)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use memory_pipe::MemoryPipe;
    use std::sync::Arc;
    mod memory_pipe;
    #[tokio::test]
    async fn test_request_response() {
        let mut server: RpcServer<'_> = RpcServer::new();
        server.register_method("echo", &EchoHandler);
        let (mut stream1, mut stream2) = MemoryPipe::new();
        tokio::spawn(async move {
            server.serve(&mut stream2).await.unwrap();
        });
        let request: RpcRequest<'_, ()> = RpcRequest {
            jsonrpc: JSONRPC_VERSION,
            id: Some(1),
            method: "echo",
            params: None,
        };
        let mut request_json = [0u8; 256];
        let request_len = serde_json_core::to_slice(&request, &mut request_json).unwrap();
        let request_message = format!(
            "Content-Length: {}\r\n\r\n{}",
            request_len,
            core::str::from_utf8(&request_json[..request_len]).unwrap()
        );
        stream1.write_all(request_message.as_bytes()).await.unwrap();
        let mut response_buffer = [0u8; DEFAULT_MAX_MESSAGE_LEN];
        let response_len = stream1.read(&mut response_buffer).await.unwrap();
        let response = core::str::from_utf8(&response_buffer[..response_len]).unwrap();
        assert_eq!(
            response,
            "Content-Length: 51\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":1,\"error\":null,\"result\":null}"
        );
    }
    #[tokio::test]
    async fn test_notify() {
        let server: Arc<RpcServer<'_>> = Arc::new(RpcServer::new());
        let server_clone = Arc::clone(&server); let (mut stream1, mut stream2) = MemoryPipe::new();
        tokio::spawn(async move {
            server_clone.serve(&mut stream2).await.unwrap();
        });
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        let notification: RpcRequest<'_, ()> = RpcRequest {
            jsonrpc: JSONRPC_VERSION,
            method: "notify",
            id: None,
            params: None,
        };
        let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
        let notification_len =
            serde_json_core::to_slice(¬ification, &mut notification_json).unwrap();
        server
            .notify(¬ification_json[..notification_len])
            .await
            .unwrap();
        let mut notification_json = [0u8; DEFAULT_MAX_MESSAGE_LEN];
        let notification_len = stream1.read(&mut notification_json).await.unwrap();
        let notification_json =
            core::str::from_utf8(¬ification_json[..notification_len]).unwrap();
        assert_eq!(
            notification_json,
            "Content-Length: 59\r\n\r\n{\"jsonrpc\":\"2.0\",\"id\":null,\"method\":\"notify\",\"params\":null}",
        );
    }
    #[derive(Debug)]
    struct EchoHandler;
    impl RpcHandler for EchoHandler {
        fn handle<'a>(
            &self,
            id: Option<u64>,
            _request_json: &'a [u8],
            response_json: &'a mut [u8],
        ) -> StackFuture<'a, Result<usize, RpcError>, DEFAULT_STACK_SIZE> {
            StackFuture::from(async move {
                let response: RpcResponse<'static, ()> = RpcResponse {
                    jsonrpc: JSONRPC_VERSION,
                    error: None,
                    result: None,
                    id,
                };
                Ok(serde_json_core::to_slice(&response, response_json).unwrap())
            })
        }
    }
}