use futures::{SinkExt, StreamExt};
use std::path::Path;
use tokio::io::{BufReader, BufWriter};
use tokio::net::UnixStream;
use tokio_util::bytes::{Buf, BytesMut};
use tokio_util::codec::{Decoder, Encoder, FramedRead, FramedWrite, LinesCodec};
pub async fn connection(socket_path: &Path) -> anyhow::Result<()> {
let stream = UnixStream::connect(socket_path).await?;
let (socket_read, socket_write) = stream.into_split();
let mut socket_read = FramedRead::new(socket_read, LinesCodec::new());
let mut socket_write = FramedWrite::new(socket_write, LinesCodec::new());
let mut stdin = FramedRead::new(BufReader::new(tokio::io::stdin()), ContentLengthCodec);
let mut stdout = FramedWrite::new(BufWriter::new(tokio::io::stdout()), ContentLengthCodec);
tokio::spawn(async move {
while let Some(Ok(message)) = socket_read.next().await {
stdout
.send(message)
.await
.expect("Failed to write to stdout");
}
std::process::exit(0);
});
while let Some(Ok(message)) = stdin.next().await {
socket_write.send(message).await?;
}
std::process::exit(0);
}
struct ContentLengthCodec;
impl Encoder<String> for ContentLengthCodec {
type Error = std::io::Error;
fn encode(&mut self, item: String, dst: &mut BytesMut) -> Result<(), Self::Error> {
let content_length = item.len();
dst.extend_from_slice(format!("Content-Length: {content_length}\r\n\r\n").as_bytes());
dst.extend_from_slice(item.as_bytes());
Ok(())
}
}
impl Decoder for ContentLengthCodec {
type Item = String;
type Error = anyhow::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
let c = b"Content-Length: ";
let Some(start_of_header) = src.windows(c.len()).position(|window| window == c) else {
return Ok(None);
};
let (end_of_line, end_of_line_bytes) = match src[start_of_header + c.len()..]
.windows(4)
.position(|window| window == b"\r\n\r\n")
{
Some(pos) => (pos, 4),
None => match src[start_of_header + c.len()..]
.windows(2)
.position(|window| window == b"\n\n")
{
Some(pos) => (pos, 2),
None => return Ok(None),
},
};
let content_length = std::str::from_utf8(
&src[start_of_header + c.len()..start_of_header + c.len() + end_of_line],
)?
.parse()?;
let content_start = start_of_header + c.len() + end_of_line + end_of_line_bytes;
src.reserve(content_start + content_length);
if src.len() < content_start + content_length {
return Ok(None);
}
src.advance(content_start);
let content = src.split_to(content_length);
Ok(Some(std::str::from_utf8(&content)?.to_string()))
}
}