Skip to main content

rmux_client/
connection.rs

1//! Blocking Unix-socket transport for detached RPC traffic.
2
3use std::ffi::OsStr;
4#[cfg(all(test, unix))]
5use std::ffi::OsString;
6#[cfg(all(test, unix))]
7use std::fs;
8use std::io::{self, Read, Write};
9#[cfg(all(test, unix))]
10use std::os::unix::ffi::{OsStrExt, OsStringExt};
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14use crate::ClientError;
15use rmux_ipc::{connect_blocking, BlockingLocalStream, LocalEndpoint};
16use rmux_proto::{
17    encode_frame, AttachSessionResponse, ControlMode, ControlModeResponse, FrameDecoder, Request,
18    Response,
19};
20
21/// Read buffer size for blocking socket reads.
22const READ_BUFFER_SIZE: usize = 8192;
23/// Default timeout for detached connects, request writes, and ordinary response
24/// reads.
25const SOCKET_IO_TIMEOUT: Duration = Duration::from_secs(5);
26
27#[cfg(all(test, unix))]
28const FALLBACK_SOCKET_ROOT: &str = "/tmp";
29#[cfg(all(test, unix))]
30const SOCKET_DIR_PREFIX: &str = "rmux";
31
32/// Computes the default RMUX client socket path.
33///
34/// The path uses an rmux-specific per-user directory so an rmux client never
35/// speaks the rmux wire protocol to a real tmux server.
36pub fn default_socket_path() -> Result<PathBuf, ClientError> {
37    rmux_ipc::default_endpoint()
38        .map(LocalEndpoint::into_path)
39        .map_err(ClientError::Io)
40}
41
42/// Computes an rmux socket path for a top-level `-L` socket name.
43pub fn socket_path_for_label(label: impl AsRef<OsStr>) -> Result<PathBuf, ClientError> {
44    rmux_ipc::endpoint_for_label(label)
45        .map(LocalEndpoint::into_path)
46        .map_err(ClientError::Io)
47}
48
49/// Resolves the top-level socket path from `-L`, `-S`, `$RMUX`, or defaults.
50///
51/// `-S` wins over `-L`; both command-line forms win over `$RMUX`.
52pub fn resolve_socket_path(
53    socket_name: Option<&OsStr>,
54    socket_path: Option<&Path>,
55) -> Result<PathBuf, ClientError> {
56    rmux_ipc::resolve_endpoint(socket_name, socket_path)
57        .map(LocalEndpoint::into_path)
58        .map_err(ClientError::Io)
59}
60
61/// Result of attempting to connect to the RMUX server.
62#[derive(Debug)]
63pub enum ConnectResult {
64    /// Successfully connected to the server.
65    Connected(Connection),
66    /// The server is absent (socket does not exist or connection refused).
67    Absent,
68}
69
70/// Attempts to connect to the RMUX server, distinguishing absent servers from
71/// real connection errors.
72///
73/// Returns [`ConnectResult::Absent`] when the socket does not exist or the
74/// connection is refused, which lets callers like `kill-session` succeed with
75/// exit code `0` for an absent server. Returns an error only for unexpected
76/// transport failures.
77pub fn connect_or_absent(socket_path: &Path) -> Result<ConnectResult, ClientError> {
78    connect_or_absent_with_timeout_using(
79        socket_path,
80        SOCKET_IO_TIMEOUT,
81        connect_stream_with_timeout,
82    )
83}
84
85/// Connects to the RMUX server, returning an error if the server is absent.
86pub fn connect(socket_path: &Path) -> Result<Connection, ClientError> {
87    connect_with_timeout_using(socket_path, SOCKET_IO_TIMEOUT, connect_stream_with_timeout)
88}
89
90/// A blocking connection to the RMUX server that exchanges typed frames.
91#[derive(Debug)]
92pub struct Connection {
93    stream: BlockingLocalStream,
94    decoder: FrameDecoder,
95}
96
97/// The explicit result of requesting an attach-stream upgrade.
98#[derive(Debug)]
99pub enum AttachTransition {
100    /// The server accepted the attach request and switched protocols.
101    Upgraded(AttachSessionUpgrade),
102    /// The server responded without switching protocols.
103    Rejected(Response),
104}
105
106/// The explicit result of requesting a control-mode upgrade.
107#[derive(Debug)]
108pub enum ControlTransition {
109    /// The server accepted the control-mode request and switched protocols.
110    Upgraded(ControlModeUpgrade),
111    /// The server responded without switching protocols.
112    Rejected(Response),
113}
114
115/// A detached connection that has transitioned into attach-stream mode.
116#[derive(Debug)]
117pub struct AttachSessionUpgrade {
118    response: AttachSessionResponse,
119    stream: BlockingLocalStream,
120    initial_bytes: Vec<u8>,
121}
122
123/// A detached connection that has transitioned into control-mode streaming.
124#[derive(Debug)]
125pub struct ControlModeUpgrade {
126    pub(crate) response: ControlModeResponse,
127    pub(crate) stream: BlockingLocalStream,
128}
129
130impl AttachSessionUpgrade {
131    /// Returns the upgrade response sent by the server.
132    #[must_use]
133    pub const fn response(&self) -> &AttachSessionResponse {
134        &self.response
135    }
136
137    /// Consumes the upgrade and returns the raw attach-stream socket.
138    #[must_use]
139    pub fn into_stream(self) -> BlockingLocalStream {
140        self.stream
141    }
142
143    /// Consumes the upgrade and returns the raw attach-stream socket plus any
144    /// bytes already read beyond the detached response frame.
145    #[must_use]
146    pub fn into_parts(self) -> (BlockingLocalStream, Vec<u8>) {
147        (self.stream, self.initial_bytes)
148    }
149}
150
151impl ControlModeUpgrade {
152    /// Returns the upgrade response sent by the server.
153    #[must_use]
154    pub const fn response(&self) -> &ControlModeResponse {
155        &self.response
156    }
157
158    /// Returns the negotiated control-mode flavor.
159    #[must_use]
160    pub const fn mode(&self) -> ControlMode {
161        self.response.mode
162    }
163
164    /// Consumes the upgrade and returns the raw control-mode socket.
165    #[must_use]
166    pub fn into_stream(self) -> BlockingLocalStream {
167        self.stream
168    }
169}
170
171impl Connection {
172    pub(crate) fn new(stream: BlockingLocalStream) -> Result<Self, ClientError> {
173        set_read_timeout(&stream, Some(SOCKET_IO_TIMEOUT)).map_err(ClientError::Io)?;
174        set_write_timeout(&stream, Some(SOCKET_IO_TIMEOUT)).map_err(ClientError::Io)?;
175
176        Ok(Self {
177            stream,
178            decoder: FrameDecoder::new(),
179        })
180    }
181
182    /// Sends a request and reads the server's response.
183    ///
184    /// Server-side `Response::Error` payloads are returned as-is in the `Ok`
185    /// variant so callers can pattern-match on them. Only transport and framing
186    /// failures produce `Err`.
187    pub fn roundtrip(&mut self, request: &Request) -> Result<Response, ClientError> {
188        self.write_request(request)?;
189        self.read_response()
190    }
191
192    /// Sends a request without a detached response read timeout.
193    ///
194    /// This is reserved for scripting requests whose server-side completion can
195    /// legitimately block beyond the normal five-second detached RPC bound.
196    pub(crate) fn roundtrip_without_read_timeout(
197        &mut self,
198        request: &Request,
199    ) -> Result<Response, ClientError> {
200        let previous_timeout = read_timeout(&self.stream).map_err(ClientError::Io)?;
201        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
202        let result = self.roundtrip(request);
203        set_read_timeout(&self.stream, previous_timeout).map_err(ClientError::Io)?;
204        result
205    }
206
207    pub(crate) fn write_request(&mut self, request: &Request) -> Result<(), ClientError> {
208        let frame = encode_frame(request).map_err(ClientError::Protocol)?;
209        self.stream.write_all(&frame).map_err(ClientError::Io)
210    }
211
212    pub(crate) fn read_response(&mut self) -> Result<Response, ClientError> {
213        let mut buffer = [0u8; READ_BUFFER_SIZE];
214
215        loop {
216            match self.decoder.next_frame::<Response>() {
217                Ok(Some(response)) => return Ok(response),
218                Ok(None) => {}
219                Err(error) => return Err(ClientError::Protocol(error)),
220            }
221
222            let bytes_read = match self.stream.read(&mut buffer) {
223                Ok(bytes_read) => bytes_read,
224                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
225                Err(error) => return Err(ClientError::Io(error)),
226            };
227
228            if bytes_read == 0 {
229                return Err(ClientError::UnexpectedEof);
230            }
231
232            self.decoder.push_bytes(&buffer[..bytes_read]);
233        }
234    }
235
236    pub(crate) fn stream_mut(&mut self) -> &mut BlockingLocalStream {
237        &mut self.stream
238    }
239
240    pub(crate) fn into_attach_upgrade(
241        self,
242        response: AttachSessionResponse,
243    ) -> Result<AttachSessionUpgrade, ClientError> {
244        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
245        set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
246        let initial_bytes = self.decoder.remaining_bytes().to_vec();
247
248        Ok(AttachSessionUpgrade {
249            response,
250            stream: self.stream,
251            initial_bytes,
252        })
253    }
254
255    pub(crate) fn into_control_upgrade(
256        self,
257        response: ControlModeResponse,
258    ) -> Result<ControlModeUpgrade, ClientError> {
259        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
260        set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
261
262        Ok(ControlModeUpgrade {
263            response,
264            stream: self.stream,
265        })
266    }
267}
268
269pub(crate) fn read_response_frame_exact(
270    stream: &mut BlockingLocalStream,
271) -> Result<Response, ClientError> {
272    let mut decoder = FrameDecoder::new();
273    let mut byte = [0_u8; 1];
274
275    loop {
276        match decoder.next_frame::<Response>() {
277            Ok(Some(response)) => return Ok(response),
278            Ok(None) => {}
279            Err(error) => return Err(ClientError::Protocol(error)),
280        }
281
282        read_exact_or_eof(stream, &mut byte)?;
283        decoder.push_bytes(&byte);
284    }
285}
286
287fn read_exact_or_eof(
288    stream: &mut BlockingLocalStream,
289    buffer: &mut [u8],
290) -> Result<(), ClientError> {
291    match stream.read_exact(buffer) {
292        Ok(()) => Ok(()),
293        Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
294            Err(ClientError::UnexpectedEof)
295        }
296        Err(error) => Err(ClientError::Io(error)),
297    }
298}
299
300#[cfg(all(test, unix))]
301fn socket_path_from_parts(
302    rmux_tmpdir: Option<&OsStr>,
303    user_id: u32,
304    label: &OsStr,
305) -> io::Result<PathBuf> {
306    let root = socket_root_from_parts(rmux_tmpdir)?;
307    let base = root.join(format!("{SOCKET_DIR_PREFIX}-{user_id}"));
308    let mut path = base.into_os_string().into_vec();
309    path.push(b'/');
310    path.extend_from_slice(label.as_bytes());
311
312    Ok(PathBuf::from(OsString::from_vec(path)))
313}
314
315#[cfg(all(test, unix))]
316fn socket_root_from_parts(rmux_tmpdir: Option<&OsStr>) -> io::Result<PathBuf> {
317    let rmux_tmpdir = rmux_tmpdir
318        .filter(|value| !value.is_empty())
319        .map(PathBuf::from);
320    let candidates = rmux_tmpdir
321        .into_iter()
322        .chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
323
324    for candidate in candidates {
325        if let Ok(resolved) = fs::canonicalize(&candidate) {
326            return Ok(resolved);
327        }
328    }
329
330    Err(io::Error::new(
331        io::ErrorKind::NotFound,
332        "no suitable rmux socket directory",
333    ))
334}
335
336fn connect_or_absent_with_timeout_using<F>(
337    socket_path: &Path,
338    timeout: Duration,
339    connect_stream: F,
340) -> Result<ConnectResult, ClientError>
341where
342    F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
343{
344    match connect_stream(socket_path, timeout) {
345        Ok(stream) => Ok(ConnectResult::Connected(Connection::new(stream)?)),
346        Err(error) if is_absent_error(&error) => Ok(ConnectResult::Absent),
347        Err(error) => Err(ClientError::Io(error)),
348    }
349}
350
351fn connect_with_timeout_using<F>(
352    socket_path: &Path,
353    timeout: Duration,
354    connect_stream: F,
355) -> Result<Connection, ClientError>
356where
357    F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
358{
359    let stream = connect_stream(socket_path, timeout).map_err(ClientError::Io)?;
360    Connection::new(stream)
361}
362
363fn connect_stream_with_timeout(
364    socket_path: &Path,
365    timeout: Duration,
366) -> io::Result<BlockingLocalStream> {
367    connect_blocking(
368        &LocalEndpoint::from_path(socket_path.to_path_buf()),
369        timeout,
370    )
371}
372
373fn read_timeout(stream: &BlockingLocalStream) -> io::Result<Option<Duration>> {
374    stream.read_timeout()
375}
376
377fn set_read_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
378    stream.set_read_timeout(timeout)
379}
380
381fn set_write_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
382    stream.set_write_timeout(timeout)
383}
384
385/// Returns `true` for I/O errors that indicate the server is not running.
386fn is_absent_error(error: &io::Error) -> bool {
387    matches!(
388        error.kind(),
389        io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
390    )
391}
392
393#[cfg(all(test, unix))]
394mod tests {
395    include!("connection/tests.rs");
396}