use serde::{de::DeserializeOwned, Serialize};
use std::{
error::Error as StdError,
fmt,
io::{self, Read, Write},
};
pub const MAX_TO_BROWSER: usize = 1_048_576;
pub const MAX_FROM_BROWSER: usize = 64 * 1_048_576;
#[derive(Debug)]
pub enum NmError {
Disconnected,
OutgoingTooLarge { len: usize, max: usize },
IncomingTooLarge { len: usize, max: usize },
IncomingNotUtf8(std::string::FromUtf8Error),
SerializeJson(serde_json::Error),
DeserializeJson(serde_json::Error),
TokioJoin(tokio::task::JoinError),
OneshotRecv(tokio::sync::oneshot::error::RecvError),
Io(io::Error),
}
impl fmt::Display for NmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use NmError::*;
match self {
Disconnected => write!(f, "native messaging disconnected (stdin closed)"),
OutgoingTooLarge { len, max } => write!(
f,
"outgoing native message is {len} bytes (max {max}); \
reduce size (chunk/compress) before sending"
),
IncomingTooLarge { len, max } => write!(
f,
"incoming native message is {len} bytes (max {max}); \
extension must send smaller messages (chunk/compress)"
),
IncomingNotUtf8(e) => write!(f, "incoming native message is not valid UTF-8: {e}"),
SerializeJson(e) => write!(f, "failed to serialize JSON: {e}"),
DeserializeJson(e) => write!(f, "failed to deserialize JSON: {e}"),
TokioJoin(e) => write!(f, "internal task join error: {e}"),
OneshotRecv(e) => write!(f, "internal oneshot receive error: {e}"),
Io(e) => write!(f, "I/O error: {e}"),
}
}
}
impl StdError for NmError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
use NmError::*;
match self {
IncomingNotUtf8(e) => Some(e),
SerializeJson(e) => Some(e),
DeserializeJson(e) => Some(e),
TokioJoin(e) => Some(e),
OneshotRecv(e) => Some(e),
Io(e) => Some(e),
_ => None,
}
}
}
impl From<io::Error> for NmError {
fn from(e: io::Error) -> Self {
NmError::Io(e)
}
}
pub fn encode_message<T: Serialize>(msg: &T) -> Result<Vec<u8>, NmError> {
let json = serde_json::to_vec(msg).map_err(NmError::SerializeJson)?;
if json.len() > MAX_TO_BROWSER {
return Err(NmError::OutgoingTooLarge {
len: json.len(),
max: MAX_TO_BROWSER,
});
}
let mut out = Vec::with_capacity(4 + json.len());
out.extend_from_slice(&(json.len() as u32).to_ne_bytes());
out.extend_from_slice(&json);
Ok(out)
}
pub fn decode_message_opt<R: Read>(
reader: &mut R,
max_size: usize,
) -> Result<Option<String>, NmError> {
let cap = max_size.min(MAX_FROM_BROWSER);
let mut len_buf = [0u8; 4];
match reader.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(NmError::Io(e)),
}
let len = u32::from_ne_bytes(len_buf) as usize;
if len > cap {
return Err(NmError::IncomingTooLarge { len, max: cap });
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).map_err(NmError::Io)?;
String::from_utf8(buf)
.map(Some)
.map_err(NmError::IncomingNotUtf8)
}
pub fn decode_message<R: Read>(reader: &mut R, max_size: usize) -> Result<String, NmError> {
decode_message_opt(reader, max_size)?.ok_or(NmError::Disconnected)
}
pub fn recv_json<T: DeserializeOwned, R: Read>(
reader: &mut R,
max_size: usize,
) -> Result<T, NmError> {
let s = decode_message(reader, max_size)?;
serde_json::from_str::<T>(&s).map_err(NmError::DeserializeJson)
}
pub fn send_frame<W: Write>(writer: &mut W, frame: &[u8]) -> Result<(), NmError> {
writer.write_all(frame).map_err(NmError::Io)?;
writer.flush().map_err(NmError::Io)?;
Ok(())
}
pub fn send_json<T: Serialize, W: Write>(writer: &mut W, msg: &T) -> Result<(), NmError> {
let frame = encode_message(msg)?;
send_frame(writer, &frame)
}
pub fn spawn_reader(max_size: usize) -> tokio::sync::mpsc::Receiver<Result<String, NmError>> {
let (tx, rx) = tokio::sync::mpsc::channel::<Result<String, NmError>>(32);
tokio::task::spawn_blocking(move || {
let mut stdin = io::stdin();
loop {
match decode_message_opt(&mut stdin, max_size) {
Ok(Some(msg)) => {
if tx.blocking_send(Ok(msg)).is_err() {
break;
}
}
Ok(None) => {
let _ = tx.blocking_send(Err(NmError::Disconnected));
break;
}
Err(e) => {
let _ = tx.blocking_send(Err(e));
break;
}
}
}
});
rx
}
pub fn spawn_writer() -> tokio::sync::mpsc::Sender<Vec<u8>> {
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(32);
tokio::task::spawn_blocking(move || {
let mut stdout = io::stdout();
while let Some(frame) = rx.blocking_recv() {
if stdout.write_all(&frame).is_err() {
break;
}
if stdout.flush().is_err() {
break;
}
}
});
tx
}
pub async fn get_message() -> Result<String, NmError> {
let (tx, rx) = tokio::sync::oneshot::channel::<Result<String, NmError>>();
tokio::task::spawn_blocking(move || {
let mut stdin = io::stdin();
let res = decode_message(&mut stdin, MAX_FROM_BROWSER);
let _ = tx.send(res);
});
rx.await.map_err(NmError::OneshotRecv)?
}
pub async fn send_message<T: Serialize>(msg: &T) -> Result<(), NmError> {
let frame = encode_message(msg)?;
let (tx, rx) = tokio::sync::oneshot::channel::<Result<(), NmError>>();
tokio::task::spawn_blocking(move || {
let mut stdout = io::stdout();
let res = send_frame(&mut stdout, &frame);
let _ = tx.send(res);
});
rx.await.map_err(NmError::OneshotRecv)?
}
#[derive(Clone)]
pub struct Sender {
writer: tokio::sync::mpsc::Sender<Vec<u8>>,
}
impl Sender {
pub async fn send<T: Serialize>(&self, msg: &T) -> Result<(), NmError> {
let frame = encode_message(msg)?;
self.writer
.send(frame)
.await
.map_err(|_| NmError::Disconnected)?;
Ok(())
}
}
pub async fn event_loop<F, Fut>(mut handler: F) -> Result<(), NmError>
where
F: FnMut(String, Sender) -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<(), NmError>> + Send,
{
let mut rx = spawn_reader(MAX_FROM_BROWSER);
let writer = spawn_writer();
let sender = Sender { writer };
while let Some(item) = rx.recv().await {
match item {
Ok(msg) => handler(msg, sender.clone()).await?,
Err(NmError::Disconnected) => return Ok(()),
Err(e) => return Err(e),
}
}
Ok(())
}