1use std::{
2 io,
3 path::PathBuf,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use async_trait::async_trait;
9use futures::future::BoxFuture;
10use tokio::{
11 io::{AsyncRead, AsyncWrite},
12 net::{UnixListener, UnixStream},
13};
14use tracing::debug;
15
16use crate::{Acceptor, PeerAddress, Transport, TransportExt};
17
18use msg_common::async_error;
19
20#[derive(Debug, Default)]
21pub struct Config;
22
23#[derive(Debug, Default)]
36pub struct Ipc {
37 #[allow(unused)]
38 config: Config,
39 listener: Option<UnixListener>,
40 path: Option<PathBuf>,
41}
42
43impl Ipc {
44 pub fn new(config: Config) -> Self {
46 Self { config, listener: None, path: None }
47 }
48}
49
50pub struct IpcStream {
52 peer: PathBuf,
53 stream: UnixStream,
54}
55
56impl IpcStream {
57 pub async fn connect(peer: PathBuf) -> io::Result<Self> {
59 let stream = UnixStream::connect(&peer).await?;
60 Ok(Self { peer, stream })
61 }
62}
63
64impl AsyncRead for IpcStream {
65 fn poll_read(
66 self: Pin<&mut Self>,
67 cx: &mut Context<'_>,
68 buf: &mut tokio::io::ReadBuf<'_>,
69 ) -> Poll<io::Result<()>> {
70 Pin::new(&mut self.get_mut().stream).poll_read(cx, buf)
71 }
72}
73
74impl AsyncWrite for IpcStream {
75 fn poll_write(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 buf: &[u8],
79 ) -> Poll<io::Result<usize>> {
80 Pin::new(&mut self.get_mut().stream).poll_write(cx, buf)
81 }
82
83 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
84 Pin::new(&mut self.get_mut().stream).poll_flush(cx)
85 }
86
87 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
88 Pin::new(&mut self.get_mut().stream).poll_shutdown(cx)
89 }
90}
91
92impl PeerAddress<PathBuf> for IpcStream {
93 fn peer_addr(&self) -> Result<PathBuf, io::Error> {
94 Ok(self.peer.clone())
95 }
96}
97
98type NoStats = ();
99
100impl From<&IpcStream> for NoStats {
101 fn from(_: &IpcStream) -> Self {}
103}
104
105#[async_trait]
106impl Transport<PathBuf> for Ipc {
107 type Stats = NoStats;
109 type Io = IpcStream;
110
111 type Control = ();
112
113 type Error = io::Error;
114
115 type Connect = BoxFuture<'static, Result<Self::Io, Self::Error>>;
116 type Accept = BoxFuture<'static, Result<Self::Io, Self::Error>>;
117
118 fn local_addr(&self) -> Option<PathBuf> {
119 self.path.clone()
120 }
121
122 async fn bind(&mut self, addr: PathBuf) -> Result<(), Self::Error> {
123 if addr.exists() {
124 debug!("Socket file already exists. Attempting to remove.");
125 if let Err(e) = std::fs::remove_file(&addr) {
126 return Err(io::Error::other(format!(
127 "Failed to remove existing socket file, {e:?}"
128 )));
129 }
130 }
131
132 let listener = UnixListener::bind(&addr)?;
133 self.listener = Some(listener);
134 self.path = Some(addr);
135 Ok(())
136 }
137
138 fn connect(&mut self, addr: PathBuf) -> Self::Connect {
139 Box::pin(async move { IpcStream::connect(addr).await })
140 }
141
142 fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Accept> {
143 let this = self.get_mut();
144
145 let Some(ref listener) = this.listener else {
146 return Poll::Ready(async_error(io::ErrorKind::NotConnected.into()));
147 };
148
149 match listener.poll_accept(cx) {
150 Poll::Ready(Ok((io, _addr))) => {
151 debug!("accepted IPC connection");
152 let stream = IpcStream {
153 peer: this.path.clone().expect("listener not bound"),
155 stream: io,
156 };
157 Poll::Ready(Box::pin(async move { Ok(stream) }))
158 }
159 Poll::Ready(Err(e)) => Poll::Ready(async_error(e)),
160 Poll::Pending => Poll::Pending,
161 }
162 }
163}
164
165#[async_trait]
166impl TransportExt<PathBuf> for Ipc {
167 fn accept(&mut self) -> Acceptor<'_, Self, PathBuf>
168 where
169 Self: Sized + Unpin,
170 {
171 Acceptor::new(self)
172 }
173}