Skip to main content

rmux_client/
connection.rs

1//! Blocking Unix-socket transport for detached RPC traffic.
2
3use std::ffi::OsStr;
4#[cfg(all(test, unix))]
5use std::ffi::OsString;
6#[cfg(all(test, unix))]
7use std::fs;
8use std::io::{self, Read, Write};
9#[cfg(all(test, unix))]
10use std::os::unix::ffi::{OsStrExt, OsStringExt};
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14use crate::ClientError;
15use rmux_ipc::{connect_blocking, BlockingLocalStream, LocalEndpoint};
16use rmux_proto::{
17    encode_frame, AttachSessionResponse, ControlMode, ControlModeResponse, FrameDecoder, Request,
18    Response,
19};
20
21/// Read buffer size for blocking socket reads.
22const READ_BUFFER_SIZE: usize = 8192;
23/// Default timeout for establishing detached RPC connections.
24const SOCKET_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
25/// Default timeout for writing detached RPC requests.
26const SOCKET_WRITE_TIMEOUT: Duration = Duration::from_secs(5);
27/// Default timeout for ordinary detached RPC response reads.
28const SOCKET_RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
29
30#[cfg(all(test, unix))]
31const FALLBACK_SOCKET_ROOT: &str = "/tmp";
32#[cfg(all(test, unix))]
33const SOCKET_DIR_PREFIX: &str = "rmux";
34
35/// Computes the default RMUX client socket path.
36///
37/// The path uses an rmux-specific per-user directory so an rmux client never
38/// speaks the rmux wire protocol to a real tmux server.
39pub fn default_socket_path() -> Result<PathBuf, ClientError> {
40    rmux_ipc::default_endpoint()
41        .map(LocalEndpoint::into_path)
42        .map_err(ClientError::Io)
43}
44
45/// Computes an rmux socket path for a top-level `-L` socket name.
46pub fn socket_path_for_label(label: impl AsRef<OsStr>) -> Result<PathBuf, ClientError> {
47    rmux_ipc::endpoint_for_label(label)
48        .map(LocalEndpoint::into_path)
49        .map_err(ClientError::Io)
50}
51
52/// Resolves the top-level socket path from `-L`, `-S`, `$RMUX`, or defaults.
53///
54/// `-S` wins over `-L`; both command-line forms win over `$RMUX`.
55pub fn resolve_socket_path(
56    socket_name: Option<&OsStr>,
57    socket_path: Option<&Path>,
58) -> Result<PathBuf, ClientError> {
59    rmux_ipc::resolve_endpoint(socket_name, socket_path)
60        .map(LocalEndpoint::into_path)
61        .map_err(ClientError::Io)
62}
63
64/// Result of attempting to connect to the RMUX server.
65#[derive(Debug)]
66pub enum ConnectResult {
67    /// Successfully connected to the server.
68    Connected(Connection),
69    /// The server is absent (socket does not exist or connection refused).
70    Absent,
71}
72
73/// Attempts to connect to the RMUX server, distinguishing absent servers from
74/// real connection errors.
75///
76/// Returns [`ConnectResult::Absent`] when the socket does not exist or the
77/// connection is refused, which lets callers like `kill-session` succeed with
78/// exit code `0` for an absent server. Returns an error only for unexpected
79/// transport failures.
80pub fn connect_or_absent(socket_path: &Path) -> Result<ConnectResult, ClientError> {
81    connect_or_absent_with_timeout_using(
82        socket_path,
83        SOCKET_CONNECT_TIMEOUT,
84        connect_stream_with_timeout,
85    )
86}
87
88/// Connects to the RMUX server, returning an error if the server is absent.
89pub fn connect(socket_path: &Path) -> Result<Connection, ClientError> {
90    connect_with_timeout_using(
91        socket_path,
92        SOCKET_CONNECT_TIMEOUT,
93        connect_stream_with_timeout,
94    )
95}
96
97/// A blocking connection to the RMUX server that exchanges typed frames.
98#[derive(Debug)]
99pub struct Connection {
100    stream: BlockingLocalStream,
101    decoder: FrameDecoder,
102}
103
104/// The explicit result of requesting an attach-stream upgrade.
105#[derive(Debug)]
106pub enum AttachTransition {
107    /// The server accepted the attach request and switched protocols.
108    Upgraded(AttachSessionUpgrade),
109    /// The server responded without switching protocols.
110    Rejected(Response),
111}
112
113/// The explicit result of requesting a control-mode upgrade.
114#[derive(Debug)]
115pub enum ControlTransition {
116    /// The server accepted the control-mode request and switched protocols.
117    Upgraded(ControlModeUpgrade),
118    /// The server responded without switching protocols.
119    Rejected(Response),
120}
121
122/// A detached connection that has transitioned into attach-stream mode.
123#[derive(Debug)]
124pub struct AttachSessionUpgrade {
125    response: AttachSessionResponse,
126    stream: BlockingLocalStream,
127    initial_bytes: Vec<u8>,
128}
129
130/// A detached connection that has transitioned into control-mode streaming.
131#[derive(Debug)]
132pub struct ControlModeUpgrade {
133    pub(crate) response: ControlModeResponse,
134    pub(crate) stream: BlockingLocalStream,
135}
136
137impl AttachSessionUpgrade {
138    /// Returns the upgrade response sent by the server.
139    #[must_use]
140    pub const fn response(&self) -> &AttachSessionResponse {
141        &self.response
142    }
143
144    /// Consumes the upgrade and returns the raw attach-stream socket.
145    #[must_use]
146    pub fn into_stream(self) -> BlockingLocalStream {
147        self.stream
148    }
149
150    /// Consumes the upgrade and returns the raw attach-stream socket plus any
151    /// bytes already read beyond the detached response frame.
152    #[must_use]
153    pub fn into_parts(self) -> (BlockingLocalStream, Vec<u8>) {
154        (self.stream, self.initial_bytes)
155    }
156}
157
158impl ControlModeUpgrade {
159    /// Returns the upgrade response sent by the server.
160    #[must_use]
161    pub const fn response(&self) -> &ControlModeResponse {
162        &self.response
163    }
164
165    /// Returns the negotiated control-mode flavor.
166    #[must_use]
167    pub const fn mode(&self) -> ControlMode {
168        self.response.mode
169    }
170
171    /// Consumes the upgrade and returns the raw control-mode socket.
172    #[must_use]
173    pub fn into_stream(self) -> BlockingLocalStream {
174        self.stream
175    }
176}
177
178impl Connection {
179    pub(crate) fn new(stream: BlockingLocalStream) -> Result<Self, ClientError> {
180        set_read_timeout(&stream, Some(SOCKET_RESPONSE_TIMEOUT)).map_err(ClientError::Io)?;
181        set_write_timeout(&stream, Some(SOCKET_WRITE_TIMEOUT)).map_err(ClientError::Io)?;
182
183        Ok(Self {
184            stream,
185            decoder: FrameDecoder::new(),
186        })
187    }
188
189    /// Sends a request and reads the server's response.
190    ///
191    /// Server-side `Response::Error` payloads are returned as-is in the `Ok`
192    /// variant so callers can pattern-match on them. Only transport and framing
193    /// failures produce `Err`.
194    pub fn roundtrip(&mut self, request: &Request) -> Result<Response, ClientError> {
195        self.write_request(request)?;
196        self.read_response()
197    }
198
199    /// Sends a request without a detached response read timeout.
200    ///
201    /// This is reserved for scripting requests whose server-side completion can
202    /// legitimately block beyond the normal five-second detached RPC bound.
203    pub(crate) fn roundtrip_without_read_timeout(
204        &mut self,
205        request: &Request,
206    ) -> Result<Response, ClientError> {
207        let previous_timeout = read_timeout(&self.stream).map_err(ClientError::Io)?;
208        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
209        let result = self.roundtrip(request);
210        set_read_timeout(&self.stream, previous_timeout).map_err(ClientError::Io)?;
211        result
212    }
213
214    pub(crate) fn write_request(&mut self, request: &Request) -> Result<(), ClientError> {
215        let frame = encode_frame(request).map_err(ClientError::Protocol)?;
216        self.stream.write_all(&frame).map_err(ClientError::Io)
217    }
218
219    pub(crate) fn read_response(&mut self) -> Result<Response, ClientError> {
220        let mut buffer = [0u8; READ_BUFFER_SIZE];
221
222        loop {
223            match self.decoder.next_frame::<Response>() {
224                Ok(Some(response)) => return Ok(response),
225                Ok(None) => {}
226                Err(error) => return Err(ClientError::Protocol(error)),
227            }
228
229            let bytes_read = match self.stream.read(&mut buffer) {
230                Ok(bytes_read) => bytes_read,
231                Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
232                Err(error) => return Err(ClientError::Io(error)),
233            };
234
235            if bytes_read == 0 {
236                return Err(ClientError::UnexpectedEof);
237            }
238
239            self.decoder.push_bytes(&buffer[..bytes_read]);
240        }
241    }
242
243    pub(crate) fn stream_mut(&mut self) -> &mut BlockingLocalStream {
244        &mut self.stream
245    }
246
247    pub(crate) fn into_attach_upgrade(
248        self,
249        response: AttachSessionResponse,
250    ) -> Result<AttachSessionUpgrade, ClientError> {
251        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
252        set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
253        let initial_bytes = self.decoder.remaining_bytes().to_vec();
254
255        Ok(AttachSessionUpgrade {
256            response,
257            stream: self.stream,
258            initial_bytes,
259        })
260    }
261
262    pub(crate) fn into_control_upgrade(
263        self,
264        response: ControlModeResponse,
265    ) -> Result<ControlModeUpgrade, ClientError> {
266        set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
267        set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
268
269        Ok(ControlModeUpgrade {
270            response,
271            stream: self.stream,
272        })
273    }
274}
275
276pub(crate) fn read_response_frame_exact(
277    stream: &mut BlockingLocalStream,
278) -> Result<Response, ClientError> {
279    let mut decoder = FrameDecoder::new();
280    let mut byte = [0_u8; 1];
281
282    loop {
283        match decoder.next_frame::<Response>() {
284            Ok(Some(response)) => return Ok(response),
285            Ok(None) => {}
286            Err(error) => return Err(ClientError::Protocol(error)),
287        }
288
289        read_exact_or_eof(stream, &mut byte)?;
290        decoder.push_bytes(&byte);
291    }
292}
293
294fn read_exact_or_eof(
295    stream: &mut BlockingLocalStream,
296    buffer: &mut [u8],
297) -> Result<(), ClientError> {
298    match stream.read_exact(buffer) {
299        Ok(()) => Ok(()),
300        Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
301            Err(ClientError::UnexpectedEof)
302        }
303        Err(error) => Err(ClientError::Io(error)),
304    }
305}
306
307#[cfg(all(test, unix))]
308fn socket_path_from_parts(
309    rmux_tmpdir: Option<&OsStr>,
310    user_id: u32,
311    label: &OsStr,
312) -> io::Result<PathBuf> {
313    let root = socket_root_from_parts(rmux_tmpdir)?;
314    let base = root.join(format!("{SOCKET_DIR_PREFIX}-{user_id}"));
315    let mut path = base.into_os_string().into_vec();
316    path.push(b'/');
317    path.extend_from_slice(label.as_bytes());
318
319    Ok(PathBuf::from(OsString::from_vec(path)))
320}
321
322#[cfg(all(test, unix))]
323fn socket_root_from_parts(rmux_tmpdir: Option<&OsStr>) -> io::Result<PathBuf> {
324    let rmux_tmpdir = rmux_tmpdir
325        .filter(|value| !value.is_empty())
326        .map(PathBuf::from);
327    let candidates = rmux_tmpdir
328        .into_iter()
329        .chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
330
331    for candidate in candidates {
332        if let Ok(resolved) = fs::canonicalize(&candidate) {
333            return Ok(resolved);
334        }
335    }
336
337    Err(io::Error::new(
338        io::ErrorKind::NotFound,
339        "no suitable rmux socket directory",
340    ))
341}
342
343fn connect_or_absent_with_timeout_using<F>(
344    socket_path: &Path,
345    timeout: Duration,
346    connect_stream: F,
347) -> Result<ConnectResult, ClientError>
348where
349    F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
350{
351    match connect_stream(socket_path, timeout) {
352        Ok(stream) => Ok(ConnectResult::Connected(Connection::new(stream)?)),
353        Err(error) if is_absent_error(&error) => Ok(ConnectResult::Absent),
354        Err(error) => Err(ClientError::Io(error)),
355    }
356}
357
358fn connect_with_timeout_using<F>(
359    socket_path: &Path,
360    timeout: Duration,
361    connect_stream: F,
362) -> Result<Connection, ClientError>
363where
364    F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
365{
366    let stream = connect_stream(socket_path, timeout).map_err(ClientError::Io)?;
367    Connection::new(stream)
368}
369
370fn connect_stream_with_timeout(
371    socket_path: &Path,
372    timeout: Duration,
373) -> io::Result<BlockingLocalStream> {
374    connect_blocking(
375        &LocalEndpoint::from_path(socket_path.to_path_buf()),
376        timeout,
377    )
378}
379
380fn read_timeout(stream: &BlockingLocalStream) -> io::Result<Option<Duration>> {
381    stream.read_timeout()
382}
383
384fn set_read_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
385    stream.set_read_timeout(timeout)
386}
387
388fn set_write_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
389    stream.set_write_timeout(timeout)
390}
391
392/// Returns `true` for I/O errors that indicate the server is not running.
393fn is_absent_error(error: &io::Error) -> bool {
394    matches!(
395        error.kind(),
396        io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
397    )
398}
399
400#[cfg(all(test, unix))]
401mod tests {
402    include!("connection/tests.rs");
403}