1use std::ffi::OsStr;
4#[cfg(all(test, unix))]
5use std::ffi::OsString;
6#[cfg(all(test, unix))]
7use std::fs;
8use std::io::{self, Read, Write};
9#[cfg(all(test, unix))]
10use std::os::unix::ffi::{OsStrExt, OsStringExt};
11use std::path::{Path, PathBuf};
12use std::time::Duration;
13
14use crate::ClientError;
15use rmux_ipc::{connect_blocking, BlockingLocalStream, LocalEndpoint};
16use rmux_proto::{
17 encode_frame, AttachSessionResponse, ControlMode, ControlModeResponse, FrameDecoder, Request,
18 Response,
19};
20
21const READ_BUFFER_SIZE: usize = 8192;
23const SOCKET_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
25const SOCKET_WRITE_TIMEOUT: Duration = Duration::from_secs(5);
27const SOCKET_RESPONSE_TIMEOUT: Duration = Duration::from_secs(15);
29
30#[cfg(all(test, unix))]
31const FALLBACK_SOCKET_ROOT: &str = "/tmp";
32#[cfg(all(test, unix))]
33const SOCKET_DIR_PREFIX: &str = "rmux";
34
35pub fn default_socket_path() -> Result<PathBuf, ClientError> {
40 rmux_ipc::default_endpoint()
41 .map(LocalEndpoint::into_path)
42 .map_err(ClientError::Io)
43}
44
45pub fn socket_path_for_label(label: impl AsRef<OsStr>) -> Result<PathBuf, ClientError> {
47 rmux_ipc::endpoint_for_label(label)
48 .map(LocalEndpoint::into_path)
49 .map_err(ClientError::Io)
50}
51
52pub fn resolve_socket_path(
56 socket_name: Option<&OsStr>,
57 socket_path: Option<&Path>,
58) -> Result<PathBuf, ClientError> {
59 rmux_ipc::resolve_endpoint(socket_name, socket_path)
60 .map(LocalEndpoint::into_path)
61 .map_err(ClientError::Io)
62}
63
64#[derive(Debug)]
66pub enum ConnectResult {
67 Connected(Connection),
69 Absent,
71}
72
73pub fn connect_or_absent(socket_path: &Path) -> Result<ConnectResult, ClientError> {
81 connect_or_absent_with_timeout_using(
82 socket_path,
83 SOCKET_CONNECT_TIMEOUT,
84 connect_stream_with_timeout,
85 )
86}
87
88pub fn connect(socket_path: &Path) -> Result<Connection, ClientError> {
90 connect_with_timeout_using(
91 socket_path,
92 SOCKET_CONNECT_TIMEOUT,
93 connect_stream_with_timeout,
94 )
95}
96
97#[derive(Debug)]
99pub struct Connection {
100 stream: BlockingLocalStream,
101 decoder: FrameDecoder,
102}
103
104#[derive(Debug)]
106pub enum AttachTransition {
107 Upgraded(AttachSessionUpgrade),
109 Rejected(Response),
111}
112
113#[derive(Debug)]
115pub enum ControlTransition {
116 Upgraded(ControlModeUpgrade),
118 Rejected(Response),
120}
121
122#[derive(Debug)]
124pub struct AttachSessionUpgrade {
125 response: AttachSessionResponse,
126 stream: BlockingLocalStream,
127 initial_bytes: Vec<u8>,
128}
129
130#[derive(Debug)]
132pub struct ControlModeUpgrade {
133 pub(crate) response: ControlModeResponse,
134 pub(crate) stream: BlockingLocalStream,
135}
136
137impl AttachSessionUpgrade {
138 #[must_use]
140 pub const fn response(&self) -> &AttachSessionResponse {
141 &self.response
142 }
143
144 #[must_use]
146 pub fn into_stream(self) -> BlockingLocalStream {
147 self.stream
148 }
149
150 #[must_use]
153 pub fn into_parts(self) -> (BlockingLocalStream, Vec<u8>) {
154 (self.stream, self.initial_bytes)
155 }
156}
157
158impl ControlModeUpgrade {
159 #[must_use]
161 pub const fn response(&self) -> &ControlModeResponse {
162 &self.response
163 }
164
165 #[must_use]
167 pub const fn mode(&self) -> ControlMode {
168 self.response.mode
169 }
170
171 #[must_use]
173 pub fn into_stream(self) -> BlockingLocalStream {
174 self.stream
175 }
176}
177
178impl Connection {
179 pub(crate) fn new(stream: BlockingLocalStream) -> Result<Self, ClientError> {
180 set_read_timeout(&stream, Some(SOCKET_RESPONSE_TIMEOUT)).map_err(ClientError::Io)?;
181 set_write_timeout(&stream, Some(SOCKET_WRITE_TIMEOUT)).map_err(ClientError::Io)?;
182
183 Ok(Self {
184 stream,
185 decoder: FrameDecoder::new(),
186 })
187 }
188
189 pub fn roundtrip(&mut self, request: &Request) -> Result<Response, ClientError> {
195 self.write_request(request)?;
196 self.read_response()
197 }
198
199 pub(crate) fn roundtrip_without_read_timeout(
204 &mut self,
205 request: &Request,
206 ) -> Result<Response, ClientError> {
207 let previous_timeout = read_timeout(&self.stream).map_err(ClientError::Io)?;
208 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
209 let result = self.roundtrip(request);
210 set_read_timeout(&self.stream, previous_timeout).map_err(ClientError::Io)?;
211 result
212 }
213
214 pub(crate) fn write_request(&mut self, request: &Request) -> Result<(), ClientError> {
215 let frame = encode_frame(request).map_err(ClientError::Protocol)?;
216 self.stream.write_all(&frame).map_err(ClientError::Io)
217 }
218
219 pub(crate) fn read_response(&mut self) -> Result<Response, ClientError> {
220 let mut buffer = [0u8; READ_BUFFER_SIZE];
221
222 loop {
223 match self.decoder.next_frame::<Response>() {
224 Ok(Some(response)) => return Ok(response),
225 Ok(None) => {}
226 Err(error) => return Err(ClientError::Protocol(error)),
227 }
228
229 let bytes_read = match self.stream.read(&mut buffer) {
230 Ok(bytes_read) => bytes_read,
231 Err(error) if error.kind() == io::ErrorKind::Interrupted => continue,
232 Err(error) => return Err(ClientError::Io(error)),
233 };
234
235 if bytes_read == 0 {
236 return Err(ClientError::UnexpectedEof);
237 }
238
239 self.decoder.push_bytes(&buffer[..bytes_read]);
240 }
241 }
242
243 pub(crate) fn stream_mut(&mut self) -> &mut BlockingLocalStream {
244 &mut self.stream
245 }
246
247 pub(crate) fn into_attach_upgrade(
248 self,
249 response: AttachSessionResponse,
250 ) -> Result<AttachSessionUpgrade, ClientError> {
251 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
252 set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
253 let initial_bytes = self.decoder.remaining_bytes().to_vec();
254
255 Ok(AttachSessionUpgrade {
256 response,
257 stream: self.stream,
258 initial_bytes,
259 })
260 }
261
262 pub(crate) fn into_control_upgrade(
263 self,
264 response: ControlModeResponse,
265 ) -> Result<ControlModeUpgrade, ClientError> {
266 set_read_timeout(&self.stream, None).map_err(ClientError::Io)?;
267 set_write_timeout(&self.stream, None).map_err(ClientError::Io)?;
268
269 Ok(ControlModeUpgrade {
270 response,
271 stream: self.stream,
272 })
273 }
274}
275
276pub(crate) fn read_response_frame_exact(
277 stream: &mut BlockingLocalStream,
278) -> Result<Response, ClientError> {
279 let mut decoder = FrameDecoder::new();
280 let mut byte = [0_u8; 1];
281
282 loop {
283 match decoder.next_frame::<Response>() {
284 Ok(Some(response)) => return Ok(response),
285 Ok(None) => {}
286 Err(error) => return Err(ClientError::Protocol(error)),
287 }
288
289 read_exact_or_eof(stream, &mut byte)?;
290 decoder.push_bytes(&byte);
291 }
292}
293
294fn read_exact_or_eof(
295 stream: &mut BlockingLocalStream,
296 buffer: &mut [u8],
297) -> Result<(), ClientError> {
298 match stream.read_exact(buffer) {
299 Ok(()) => Ok(()),
300 Err(error) if error.kind() == io::ErrorKind::UnexpectedEof => {
301 Err(ClientError::UnexpectedEof)
302 }
303 Err(error) => Err(ClientError::Io(error)),
304 }
305}
306
307#[cfg(all(test, unix))]
308fn socket_path_from_parts(
309 rmux_tmpdir: Option<&OsStr>,
310 user_id: u32,
311 label: &OsStr,
312) -> io::Result<PathBuf> {
313 let root = socket_root_from_parts(rmux_tmpdir)?;
314 let base = root.join(format!("{SOCKET_DIR_PREFIX}-{user_id}"));
315 let mut path = base.into_os_string().into_vec();
316 path.push(b'/');
317 path.extend_from_slice(label.as_bytes());
318
319 Ok(PathBuf::from(OsString::from_vec(path)))
320}
321
322#[cfg(all(test, unix))]
323fn socket_root_from_parts(rmux_tmpdir: Option<&OsStr>) -> io::Result<PathBuf> {
324 let rmux_tmpdir = rmux_tmpdir
325 .filter(|value| !value.is_empty())
326 .map(PathBuf::from);
327 let candidates = rmux_tmpdir
328 .into_iter()
329 .chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
330
331 for candidate in candidates {
332 if let Ok(resolved) = fs::canonicalize(&candidate) {
333 return Ok(resolved);
334 }
335 }
336
337 Err(io::Error::new(
338 io::ErrorKind::NotFound,
339 "no suitable rmux socket directory",
340 ))
341}
342
343fn connect_or_absent_with_timeout_using<F>(
344 socket_path: &Path,
345 timeout: Duration,
346 connect_stream: F,
347) -> Result<ConnectResult, ClientError>
348where
349 F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
350{
351 match connect_stream(socket_path, timeout) {
352 Ok(stream) => Ok(ConnectResult::Connected(Connection::new(stream)?)),
353 Err(error) if is_absent_error(&error) => Ok(ConnectResult::Absent),
354 Err(error) => Err(ClientError::Io(error)),
355 }
356}
357
358fn connect_with_timeout_using<F>(
359 socket_path: &Path,
360 timeout: Duration,
361 connect_stream: F,
362) -> Result<Connection, ClientError>
363where
364 F: FnOnce(&Path, Duration) -> io::Result<BlockingLocalStream>,
365{
366 let stream = connect_stream(socket_path, timeout).map_err(ClientError::Io)?;
367 Connection::new(stream)
368}
369
370fn connect_stream_with_timeout(
371 socket_path: &Path,
372 timeout: Duration,
373) -> io::Result<BlockingLocalStream> {
374 connect_blocking(
375 &LocalEndpoint::from_path(socket_path.to_path_buf()),
376 timeout,
377 )
378}
379
380fn read_timeout(stream: &BlockingLocalStream) -> io::Result<Option<Duration>> {
381 stream.read_timeout()
382}
383
384fn set_read_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
385 stream.set_read_timeout(timeout)
386}
387
388fn set_write_timeout(stream: &BlockingLocalStream, timeout: Option<Duration>) -> io::Result<()> {
389 stream.set_write_timeout(timeout)
390}
391
392fn is_absent_error(error: &io::Error) -> bool {
394 matches!(
395 error.kind(),
396 io::ErrorKind::NotFound | io::ErrorKind::ConnectionRefused
397 )
398}
399
400#[cfg(all(test, unix))]
401mod tests {
402 include!("connection/tests.rs");
403}