use std::io::{self, Read, Write};
use std::pin::pin;
use std::process::Stdio;
use bifrostlink::Port;
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt as _, AsyncWrite, AsyncWriteExt as _};
use tokio::process::{ChildStdout, Command};
use tokio::task::spawn_blocking;
use tokio::{join, spawn};
use tracing::{debug, error};
#[cfg(target_os = "windows")]
#[link(name = "msvcrt")]
extern "C" {
fn _setmode(fd: i32, mode: i32) -> i32;
}
#[cfg(target_os = "windows")]
struct ResetMode {
fd: i32,
mode: i32,
}
#[cfg(target_os = "windows")]
impl Drop for ResetMode {
fn drop(&mut self) {
unsafe { _setmode(self.fd, self.mode) }
}
}
#[cfg(target_os = "windows")]
fn set_mode(fd: i32, mode: i32) -> ResetMode {
let old_mode = unsafe { _setmode(fd, mode) };
assert_ne!(old_mode, -1, "invalid fd/mode");
ResetMode { fd, mode: old_mode }
}
fn write_bytes_sync<W: Write>(mut stdin: W, msg: Bytes) -> io::Result<()> {
let len = u32::try_from(msg.len()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"message shouldn't be larger than 4GB",
)
})?;
let len = u32::to_be_bytes(len);
stdin.write_all(&len)?;
stdin.write_all(&msg)?;
Ok(())
}
async fn write_bytes<W: AsyncWrite>(stdin: W, msg: Bytes) -> io::Result<()> {
let mut stdin = pin!(stdin);
let len = u32::try_from(msg.len()).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"message shouldn't be larger than 4GB",
)
})?;
let len = u32::to_be_bytes(len);
stdin.write_all(&len).await?;
stdin.write_all(&msg).await?;
Ok(())
}
fn read_bytes_sync<R: Read>(mut stdout: R) -> io::Result<BytesMut> {
let mut size = [0; 4];
stdout.read_exact(&mut size)?;
let size = u32::from_ne_bytes(size) as usize;
let mut buf = BytesMut::zeroed(size);
stdout.read_exact(&mut buf)?;
Ok(buf)
}
async fn read_bytes(stdout: &mut ChildStdout) -> io::Result<BytesMut> {
let mut size = [0; 4];
stdout.read_exact(&mut size).await?;
let size = u32::from_ne_bytes(size) as usize;
let mut buf = BytesMut::zeroed(size);
stdout.read_exact(&mut buf).await?;
Ok(buf)
}
pub fn native_messaging_port() -> Port {
Port::new(|mut rx, tx| async move {
let stdout_printer = spawn_blocking(move || {
let mut stdout = std::io::stdout().lock();
#[cfg(target_os = "windows")]
let _stdout_guard = set_mode(0, 0x8000);
while let Some(out) = rx.blocking_recv() {
if let Err(e) = write_bytes_sync(&mut stdout, out) {
error!("stdout write failed: {e}");
break;
}
}
error!("output stream end")
});
let stdin_reader = spawn_blocking(move || {
let mut stdin = std::io::stdin().lock();
#[cfg(target_os = "windows")]
let _stdin_guard = set_mode(1, 0x8000);
loop {
match read_bytes_sync(&mut stdin) {
Ok(buf) => {
if tx.send(buf.freeze()).is_err() {
break;
}
}
Err(e) => {
error!("stdin read failed: {e}");
break;
}
}
}
error!("input stream end");
});
let (a, b) = join!(stdout_printer, stdin_reader);
a.unwrap();
b.unwrap();
})
}
pub async fn start_native_messaging_extension(mut cmd: Command) -> io::Result<Port> {
cmd.stdin(Stdio::piped());
cmd.stdout(Stdio::piped());
cmd.stderr(Stdio::inherit());
let mut proc = cmd.spawn()?;
let mut stdout = proc.stdout.take().expect("stdout piped");
let mut stdin = proc.stdin.take().expect("stdin piped");
Ok(Port::new(|mut rx, tx| async move {
let stdin_printer = spawn(async move {
while let Some(msg) = rx.recv().await {
if let Err(e) = write_bytes(&mut stdin, msg).await {
error!("stdin write failed: {e}");
break;
};
}
debug!("output stream end");
});
let stdout_reader = spawn(async move {
loop {
match read_bytes(&mut stdout).await {
Ok(buf) => {
if tx.send(buf.freeze()).is_err() {
break;
}
}
Err(e) => {
error!("stdout read failed: {e}");
break;
}
}
}
debug!("input stream end");
});
let (a, b) = join!(stdin_printer, stdout_reader);
a.unwrap();
b.unwrap();
}))
}