use std::io::{self, ErrorKind, Read, Write};
use std::net::{Shutdown, SocketAddr, TcpStream};
use std::sync::{Arc, Mutex};
use log::debug;
use malloc_size_of_derive::MallocSizeOf;
use serde::Serialize;
use serde_json::{self, Value, json};
use crate::actor::ActorError;
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub(crate) struct ActorDescription {
pub category: &'static str,
pub type_name: &'static str,
pub methods: Vec<Method>,
}
#[derive(Serialize)]
pub(crate) struct Method {
pub name: &'static str,
pub request: Value,
pub response: Value,
}
pub trait JsonPacketStream {
fn write_json_packet<T: Serialize>(&mut self, message: &T) -> Result<(), ActorError>;
fn read_json_packet(&mut self) -> Result<Option<Value>, String>;
}
impl JsonPacketStream for TcpStream {
fn write_json_packet<T: Serialize>(&mut self, message: &T) -> Result<(), ActorError> {
let s = serde_json::to_string(message).map_err(|_| ActorError::Internal)?;
debug!("<- {}", s);
write!(self, "{}:{}", s.len(), s).map_err(|_| ActorError::Internal)?;
Ok(())
}
fn read_json_packet(&mut self) -> Result<Option<Value>, String> {
let mut buffer = vec![];
loop {
let mut buf = [0];
let byte = match self.read(&mut buf) {
Ok(0) => return Ok(None), Err(e) if e.kind() == ErrorKind::ConnectionReset => return Ok(None), Ok(1) => buf[0],
Ok(_) => unreachable!(),
Err(e) => return Err(e.to_string()),
};
match byte {
b':' => {
let packet_len_str = match String::from_utf8(buffer) {
Ok(packet_len) => packet_len,
Err(_) => return Err("nonvalid UTF8 in packet length".to_owned()),
};
let packet_len = match packet_len_str.parse::<u64>() {
Ok(packet_len) => packet_len,
Err(_) => return Err("packet length missing / not parsable".to_owned()),
};
let mut packet = String::new();
self.take(packet_len)
.read_to_string(&mut packet)
.map_err(|e| e.to_string())?;
debug!("{}", packet);
return match serde_json::from_str(&packet) {
Ok(json) => Ok(Some(json)),
Err(err) => Err(err.to_string()),
};
},
c => buffer.push(c),
}
}
}
}
#[derive(Clone, MallocSizeOf)]
pub(crate) struct DevtoolsConnection {
#[conditional_malloc_size_of]
receiver: Arc<Mutex<TcpStream>>,
#[conditional_malloc_size_of]
sender: Arc<Mutex<TcpStream>>,
}
impl From<TcpStream> for DevtoolsConnection {
fn from(value: TcpStream) -> Self {
Self {
receiver: Arc::new(Mutex::new(value.try_clone().unwrap())),
sender: Arc::new(Mutex::new(value)),
}
}
}
impl DevtoolsConnection {
pub(crate) fn peer_addr(&self) -> io::Result<SocketAddr> {
self.receiver.lock().unwrap().peer_addr()
}
pub(crate) fn shutdown(&self, how: Shutdown) -> io::Result<()> {
self.receiver.lock().unwrap().shutdown(how)
}
}
impl JsonPacketStream for DevtoolsConnection {
fn write_json_packet<T: serde::Serialize>(&mut self, message: &T) -> Result<(), ActorError> {
let s = serde_json::to_string(message).map_err(|_| ActorError::Internal)?;
log::debug!("<- {}", s);
let mut stream = self.sender.lock().unwrap();
write!(*stream, "{}:{}", s.len(), s).map_err(|_| ActorError::Internal)
}
fn read_json_packet(&mut self) -> Result<Option<Value>, String> {
let mut buffer = vec![];
let mut stream = self.receiver.lock().unwrap();
loop {
let mut buf = [0];
match (*stream).read(&mut buf) {
Ok(0) => return Ok(None), Ok(1) if buf[0] == b':' => {
let packet_len_str = String::from_utf8(buffer)
.map_err(|_| "nonvalid UTF8 in packet length".to_owned())?;
let packet_len = packet_len_str
.parse::<u64>()
.map_err(|_| "packet length missing / not parsable".to_owned())?;
let mut packet = String::new();
stream
.try_clone()
.unwrap()
.take(packet_len)
.read_to_string(&mut packet)
.map_err(|e| e.to_string())?;
log::debug!("{}", packet);
return serde_json::from_str(&packet)
.map(Some)
.map_err(|e| e.to_string());
},
Ok(1) => buffer.push(buf[0]),
Ok(_) => unreachable!(),
Err(e) if e.kind() == ErrorKind::ConnectionReset => return Ok(None), Err(e) => return Err(e.to_string()),
}
}
}
}
pub(crate) struct ClientRequest<'req, 'handled> {
stream: DevtoolsConnection,
actor_name: &'req str,
handled: &'handled mut bool,
}
impl ClientRequest<'_, '_> {
pub fn handle<'req>(
stream: DevtoolsConnection,
actor_name: &'req str,
handler: impl FnOnce(ClientRequest<'req, '_>) -> Result<(), ActorError>,
) -> Result<(), ActorError> {
let mut sent = false;
let request = ClientRequest {
stream,
actor_name,
handled: &mut sent,
};
handler(request)?;
if sent {
Ok(())
} else {
Err(ActorError::UnrecognizedPacketType)
}
}
}
impl<'req> ClientRequest<'req, '_> {
pub fn reply<T: Serialize>(mut self, reply: &T) -> Result<Self, ActorError> {
debug_assert!(self.is_valid_reply(reply), "Message is not a valid reply");
self.stream.write_json_packet(reply)?;
*self.handled = true;
Ok(self)
}
pub fn reply_final<T: Serialize>(self, reply: &T) -> Result<(), ActorError> {
let _stream = self.reply(reply)?;
Ok(())
}
fn is_valid_reply<T: Serialize>(&self, message: &T) -> bool {
let reply = json!(message);
reply.get("from").and_then(|from| from.as_str()) == Some(self.actor_name) &&
reply.get("to").is_none() &&
reply.get("type").is_none()
}
pub fn mark_handled(self) -> Self {
*self.handled = true;
self
}
pub fn stream(&self) -> DevtoolsConnection {
self.stream.clone()
}
}
impl JsonPacketStream for ClientRequest<'_, '_> {
fn write_json_packet<T: Serialize>(&mut self, message: &T) -> Result<(), ActorError> {
debug_assert!(
!self.is_valid_reply(message),
"Replies must use reply() or reply_final()"
);
self.stream.write_json_packet(message)
}
fn read_json_packet(&mut self) -> Result<Option<Value>, String> {
self.stream.read_json_packet()
}
}