1use std::ffi::OsStr;
4#[cfg(all(test, unix))]
5use std::ffi::OsString;
6#[cfg(all(test, unix))]
7use std::fs;
8use std::io::{self, Read, Write};
9#[cfg(all(test, unix))]
10use std::os::unix::ffi::{OsStrExt, OsStringExt};
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14use crate::ClientError;
15use rmux_ipc::{connect_blocking, BlockingLocalStream, LocalEndpoint};
16use rmux_proto::{
17 encode_frame, AttachSessionResponse, ControlMode, ControlModeResponse, FrameDecoder, Request,
18 Response,
19};
20
21const READ_BUFFER_SIZE: usize = 8192;
23const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(5);
26
27#[cfg(all(test, unix))]
28const FALLBACK_SOCKET_ROOT: &str = "/tmp";
29#[cfg(all(test, unix))]
30const SOCKET_DIR_PREFIX: &str = "rmux";
31
32pub fn default_socket_path() -> Result<PathBuf, ClientError> {
37 rmux_ipc::default_endpoint()
38 .map(LocalEndpoint::into_path)
39 .map_err(ClientError::Io)
40}
41
42pub fn socket_path_for_label(label: impl AsRef<OsStr>) -> Result<PathBuf, ClientError> {
44 rmux_ipc::endpoint_for_label(label)
45 .map(LocalEndpoint::into_path)
46 .map_err(ClientError::Io)
47}
48
49pub fn resolve_socket_path(
53 socket_name: Option<&OsStr>,
54 socket_path: Option<&Path>,
55) -> Result<PathBuf, ClientError> {
56 rmux_ipc::resolve_endpoint(socket_name, socket_path)
57 .map(LocalEndpoint::into_path)
58 .map_err(ClientError::Io)
59}
60
61#[derive(Debug)]
63pub enum ConnectResult {
64 Connected(Connection),
66 Absent,
68}
69
70pub fn connect_or_absent(socket_path: &Path) -> Result<ConnectResult, ClientError> {
78 connect_or_absent_with_timeout_using(
79 socket_path,
80 SOCKET_IO_TIMEOUT,
81 connect_stream_with_timeout,
82 )
83}
84
85pub fn connect(socket_path: &Path) -> Result<Connection, ClientError> {
87 connect_with_timeout_using(socket_path, SOCKET_IO_TIMEOUT, connect_stream_with_timeout)
88}
89
90#[derive(Debug)]
92pub struct Connection {
93 stream: BlockingLocalStream,
94 decoder: FrameDecoder,
95}
96
97#[derive(Debug)]
99pub enum AttachTransition {
100 Upgraded(AttachSessionUpgrade),
102 Rejected(Response),
104}
105
106#[derive(Debug)]
108pub enum ControlTransition {
109 Upgraded(ControlModeUpgrade),
111 Rejected(Response),
113}
114
115#[derive(Debug)]
117pub struct AttachSessionUpgrade {
118 response: AttachSessionResponse,
119 stream: BlockingLocalStream,
120 initial_bytes: Vec<u8>,
121}
122
123#[derive(Debug)]
125pub struct ControlModeUpgrade {
126 pub(crate) response: ControlModeResponse,
127 pub(crate) stream: BlockingLocalStream,
128}
129
130impl AttachSessionUpgrade {
131 #[must_use]
133 pub const fn response(&self) -> &AttachSessionResponse {
134 &self.response
135 }
136
137 #[must_use]
139 pub fn into_stream(self) -> BlockingLocalStream {
140 self.stream
141 }
142
143 #[must_use]
146 pub fn into_parts(self) -> (BlockingLocalStream, Vec<u8>) {
147 (self.stream, self.initial_bytes)
148 }
149}
150
151impl ControlModeUpgrade {
152 #[must_use]
154 pub const fn response(&self) -> &ControlModeResponse {
155 &self.response
156 }
157
158 #[must_use]
160 pub const fn mode(&self) -> ControlMode {
161 self.response.mode
162 }
163
164 #[must_use]
166 pub fn into_stream(self) -> BlockingLocalStream {
167 self.stream
168 }
169}
170
171impl Connection {
172 pub(crate) fn new(stream: BlockingLocalStream) -> Result<Self, ClientError> {
173 set_read_timeout(&stream, Some(SOCKET_IO_TIMEOUT)).map_err(ClientError::Io)?;
174 set_write_timeout(&stream, Some(SOCKET_IO_TIMEOUT)).map_err(ClientError::Io)?;
175
176 Ok(Self {
177 stream,
178 decoder: FrameDecoder::new(),
179 })
180 }
181
182 pub fn roundtrip(&mut self, request: &Request) -> Result<Response, ClientError> {
188 self.write_request(request)?;
189 self.read_response()
190 }
191
192 pub(crate) fn roundtrip_without_read_timeout(
197 &mut self,
198 request: &Request,
199 ) -> Result<Response, ClientError> {
200 let previous_timeout = read_timeout(&self.stream).map_err(ClientError::Io)?;
201 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
202 let result = self.roundtrip(request);
203 set_read_timeout(&self.stream, previous_timeout).map_err(ClientError::Io)?;
204 result
205 }
206
207 pub(crate) fn write_request(&mut self, request: &Request) -> Result<(), ClientError> {
208 let frame = encode_frame(request).map_err(ClientError::Protocol)?;
209 self.stream.write_all(&frame).map_err(ClientError::Io)
210 }
211
212 pub(crate) fn read_response(&mut self) -> Result<Response, ClientError> {
213 let mut buffer = [0u8; READ_BUFFER_SIZE];
214
215 loop {
216 match self.decoder.next_frame::<Response>() {
217 Ok(Some(response)) => return Ok(response),
218 Ok(None) => {}
219 Err(error) => return Err(ClientError::Protocol(error)),
220 }
221
222 let bytes_read = match self.stream.read(&mut buffer) {
223 Ok(bytes_read) => bytes_read,
224 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
225 Err(error) => return Err(ClientError::Io(error)),
226 };
227
228 if bytes_read == 0 {
229 return Err(ClientError::UnexpectedEof);
230 }
231
232 self.decoder.push_bytes(&buffer[..bytes_read]);
233 }
234 }
235
236 pub(crate) fn stream_mut(&mut self) -> &mut BlockingLocalStream {
237 &mut self.stream
238 }
239
240 pub(crate) fn into_attach_upgrade(
241 self,
242 response: AttachSessionResponse,
243 ) -> Result<AttachSessionUpgrade, ClientError> {
244 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
245 set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
246 let initial_bytes = self.decoder.remaining_bytes().to_vec();
247
248 Ok(AttachSessionUpgrade {
249 response,
250 stream: self.stream,
251 initial_bytes,
252 })
253 }
254
255 pub(crate) fn into_control_upgrade(
256 self,
257 response: ControlModeResponse,
258 ) -> Result<ControlModeUpgrade, ClientError> {
259 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
260 set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
261
262 Ok(ControlModeUpgrade {
263 response,
264 stream: self.stream,
265 })
266 }
267}
268
269pub(crate) fn read_response_frame_exact(
270 stream: &mut BlockingLocalStream,
271) -> Result<Response, ClientError> {
272 let mut decoder = FrameDecoder::new();
273 let mut byte = [0_u8; 1];
274
275 loop {
276 match decoder.next_frame::<Response>() {
277 Ok(Some(response)) => return Ok(response),
278 Ok(None) => {}
279 Err(error) => return Err(ClientError::Protocol(error)),
280 }
281
282 read_exact_or_eof(stream, &mut byte)?;
283 decoder.push_bytes(&byte);
284 }
285}
286
287fn read_exact_or_eof(
288 stream: &mut BlockingLocalStream,
289 buffer: &mut [u8],
290) -> Result<(), ClientError> {
291 match stream.read_exact(buffer) {
292 Ok(()) => Ok(()),
293 Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
294 Err(ClientError::UnexpectedEof)
295 }
296 Err(error) => Err(ClientError::Io(error)),
297 }
298}
299
300#[cfg(all(test, unix))]
301fn socket_path_from_parts(
302 rmux_tmpdir: Option<&OsStr>,
303 user_id: u32,
304 label: &OsStr,
305) -> io::Result<PathBuf> {
306 let root = socket_root_from_parts(rmux_tmpdir)?;
307 let base = root.join(format!("{SOCKET_DIR_PREFIX}-{user_id}"));
308 let mut path = base.into_os_string().into_vec();
309 path.push(b'/');
310 path.extend_from_slice(label.as_bytes());
311
312 Ok(PathBuf::from(OsString::from_vec(path)))
313}
314
315#[cfg(all(test, unix))]
316fn socket_root_from_parts(rmux_tmpdir: Option<&OsStr>) -> io::Result<PathBuf> {
317 let rmux_tmpdir = rmux_tmpdir
318 .filter(|value| !value.is_empty())
319 .map(PathBuf::from);
320 let candidates = rmux_tmpdir
321 .into_iter()
322 .chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
323
324 for candidate in candidates {
325 if let Ok(resolved) = fs::canonicalize(&candidate) {
326 return Ok(resolved);
327 }
328 }
329
330 Err(io::Error::new(
331 io::ErrorKind::NotFound,
332 "no suitable rmux socket directory",
333 ))
334}
335
336fn connect_or_absent_with_timeout_using<F>(
337 socket_path: &Path,
338 timeout: Duration,
339 connect_stream: F,
340) -> Result<ConnectResult, ClientError>
341where
342 F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
343{
344 match connect_stream(socket_path, timeout) {
345 Ok(stream) => Ok(ConnectResult::Connected(Connection::new(stream)?)),
346 Err(error) if is_absent_error(&error) => Ok(ConnectResult::Absent),
347 Err(error) => Err(ClientError::Io(error)),
348 }
349}
350
351fn connect_with_timeout_using<F>(
352 socket_path: &Path,
353 timeout: Duration,
354 connect_stream: F,
355) -> Result<Connection, ClientError>
356where
357 F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
358{
359 let stream = connect_stream(socket_path, timeout).map_err(ClientError::Io)?;
360 Connection::new(stream)
361}
362
363fn connect_stream_with_timeout(
364 socket_path: &Path,
365 timeout: Duration,
366) -> io::Result<BlockingLocalStream> {
367 connect_blocking(
368 &LocalEndpoint::from_path(socket_path.to_path_buf()),
369 timeout,
370 )
371}
372
373fn read_timeout(stream: &BlockingLocalStream) -> io::Result<Option<Duration>> {
374 stream.read_timeout()
375}
376
377fn set_read_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
378 stream.set_read_timeout(timeout)
379}
380
381fn set_write_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
382 stream.set_write_timeout(timeout)
383}
384
385fn is_absent_error(error: &io::Error) -> bool {
387 matches!(
388 error.kind(),
389 io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
390 )
391}
392
393#[cfg(all(test, unix))]
394mod tests {
395 include!("connection/tests.rs");
396}