use async_trait::async_trait;
use futures::Stream;
use serde::de::{self, MapAccess, Visitor};
use serde::ser::SerializeMap;
use serde::{Deserialize, Serialize};
use std::fmt;
use std::pin::Pin;
use crate::Error;
use crate::protocol::{Notification, Request, Response};
#[derive(Debug, Clone)]
pub enum Message {
Request(Request),
Response(Response),
Notification(Notification),
}
impl Message {
fn message_type(&self) -> &'static str {
match self {
Message::Request(_) => "request",
Message::Response(_) => "response",
Message::Notification(_) => "notification",
}
}
}
struct MessageVisitor;
impl<'de> Visitor<'de> for MessageVisitor {
type Value = Message;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a valid JSON-RPC 2.0 message")
}
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
where
M: MapAccess<'de>,
{
tracing::debug!("Attempting to deserialize message...");
let mut obj = serde_json::Map::new();
while let Some(key) = map.next_key::<String>()? {
let value = map.next_value()?;
obj.insert(key, value);
}
let value = serde_json::Value::Object(obj);
if let Some(_id_val) = value.get("id") {
if value.get("method").is_some() {
tracing::debug!("Deserializing as request...");
tracing::debug!("Value: {:?}", value);
Ok(Message::Request(
Request::deserialize(value).map_err(de::Error::custom)?,
))
} else if value.get("result").is_some() || value.get("error").is_some() {
tracing::debug!("Deserializing as response...");
tracing::debug!("Value: {:?}", value);
Ok(Message::Response(
Response::deserialize(value).map_err(de::Error::custom)?,
))
} else {
Err(de::Error::custom(
"invalid message: 'id' present without 'method' or 'result/error'",
))
}
} else if value.get("method").is_some() {
tracing::debug!("Deserializing as notification...");
tracing::debug!("Value: {:?}", value);
Ok(Message::Notification(
Notification::deserialize(value).map_err(de::Error::custom)?,
))
} else {
Err(de::Error::custom(
"invalid message: missing 'id' and 'method'",
))
}
}
}
impl<'de> Deserialize<'de> for Message {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
deserializer.deserialize_map(MessageVisitor)
}
}
impl Serialize for Message {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut map = serializer.serialize_map(None)?;
match self {
Message::Request(req) => {
map.serialize_entry("jsonrpc", &req.jsonrpc)?;
map.serialize_entry("method", &req.method)?;
if let Some(ref params) = req.params {
map.serialize_entry("params", params)?;
}
map.serialize_entry("id", &req.id)?;
}
Message::Response(resp) => {
map.serialize_entry("jsonrpc", &resp.jsonrpc)?;
map.serialize_entry("id", &resp.id)?;
if let Some(ref result) = resp.result {
map.serialize_entry("result", result)?;
}
if let Some(ref error) = resp.error {
map.serialize_entry("error", error)?;
}
}
Message::Notification(notif) => {
map.serialize_entry("jsonrpc", ¬if.jsonrpc)?;
map.serialize_entry("method", ¬if.method)?;
if let Some(ref params) = notif.params {
map.serialize_entry("params", params)?;
}
}
}
map.end()
}
}
#[async_trait]
pub trait Transport: Send + Sync + 'static {
async fn send(&self, message: Message) -> Result<(), Error>;
fn receive(&self) -> Pin<Box<dyn Stream<Item = Result<Message, Error>> + Send>>;
async fn close(&self) -> Result<(), Error>;
}
pub mod stdio;