1#[cfg(all(test, unix))]
2use std::fs;
3use std::io;
4#[cfg(windows)]
5use std::io::{Read, Write};
6use std::path::{Path, PathBuf};
7use std::sync::{Arc, Mutex as StdMutex};
8#[cfg(windows)]
9use std::time::Duration;
10
11use tokio::sync::oneshot;
12use tokio::task::JoinHandle;
13
14use rmux_core::events::SubscriptionLimits;
15#[cfg(windows)]
16use rmux_ipc::connect_blocking;
17use rmux_ipc::LocalEndpoint;
18#[cfg(windows)]
19use rmux_ipc::LocalListener;
20#[cfg(windows)]
21use rmux_proto::{
22 encode_frame, FrameDecoder, HasSessionRequest, Request, Response, RmuxError, SessionName,
23};
24
25use crate::listener;
26use crate::listener_options::ServeOptions;
27#[cfg(windows)]
28use crate::server_access::current_owner_uid;
29#[cfg(unix)]
30use crate::unix_socket::bind_unix_listener_at;
31#[cfg(unix)]
32use crate::unix_socket::real_user_id;
33#[cfg(all(test, unix))]
34use crate::unix_socket::{
35 ensure_parent_directory, indicates_stale_socket, remove_stale_socket_if_needed,
36};
37
38#[cfg(all(test, unix))]
39const FALLBACK_SOCKET_ROOT: &str = "/tmp";
40
41pub fn default_socket_path() -> io::Result<PathBuf> {
46 rmux_ipc::default_endpoint().map(LocalEndpoint::into_path)
47}
48
49#[cfg(all(test, unix))]
50fn socket_root_from_env(tmpdir: Option<&std::ffi::OsStr>) -> io::Result<PathBuf> {
51 let tmpdir = tmpdir
52 .filter(|value| !value.is_empty())
53 .map(PathBuf::from)
54 .into_iter();
55 let candidates = tmpdir.chain(std::iter::once(PathBuf::from(FALLBACK_SOCKET_ROOT)));
56
57 for candidate in candidates {
58 if let Ok(resolved) = fs::canonicalize(&candidate) {
59 return Ok(resolved);
60 }
61 }
62
63 Err(io::Error::new(
64 io::ErrorKind::NotFound,
65 "no suitable rmux socket directory",
66 ))
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct DaemonConfig {
72 socket_path: PathBuf,
73 config_load: ConfigLoadOptions,
74 subscription_limits: SubscriptionLimits,
75}
76
77impl DaemonConfig {
78 #[must_use]
80 pub fn new(socket_path: PathBuf) -> Self {
81 Self {
82 socket_path,
83 config_load: ConfigLoadOptions::disabled(),
84 subscription_limits: SubscriptionLimits::default(),
85 }
86 }
87
88 pub fn with_default_socket_path() -> io::Result<Self> {
90 Ok(Self::new(default_socket_path()?))
91 }
92
93 #[must_use]
95 pub fn socket_path(&self) -> &Path {
96 &self.socket_path
97 }
98
99 #[must_use]
101 pub const fn config_load(&self) -> &ConfigLoadOptions {
102 &self.config_load
103 }
104
105 #[must_use]
107 pub fn subscription_limits(&self) -> SubscriptionLimits {
108 self.subscription_limits
109 }
110
111 #[must_use]
113 pub fn with_default_config_load(mut self, quiet: bool, cwd: Option<PathBuf>) -> Self {
114 self.config_load = ConfigLoadOptions {
115 selection: ConfigFileSelection::Default,
116 quiet,
117 cwd,
118 };
119 self
120 }
121
122 #[must_use]
124 pub fn with_subscription_limits(mut self, subscription_limits: SubscriptionLimits) -> Self {
125 self.subscription_limits = subscription_limits;
126 self
127 }
128
129 #[must_use]
131 pub fn with_config_files(
132 mut self,
133 files: Vec<PathBuf>,
134 quiet: bool,
135 cwd: Option<PathBuf>,
136 ) -> Self {
137 self.config_load = ConfigLoadOptions {
138 selection: ConfigFileSelection::Files(files),
139 quiet,
140 cwd,
141 };
142 self
143 }
144}
145
146#[derive(Debug, Clone, PartialEq, Eq)]
148pub struct ConfigLoadOptions {
149 selection: ConfigFileSelection,
150 quiet: bool,
151 cwd: Option<PathBuf>,
152}
153
154impl ConfigLoadOptions {
155 #[must_use]
157 pub const fn disabled() -> Self {
158 Self {
159 selection: ConfigFileSelection::Disabled,
160 quiet: true,
161 cwd: None,
162 }
163 }
164
165 #[must_use]
167 pub const fn selection(&self) -> &ConfigFileSelection {
168 &self.selection
169 }
170
171 #[must_use]
173 pub const fn quiet(&self) -> bool {
174 self.quiet
175 }
176
177 #[must_use]
179 pub fn cwd(&self) -> Option<&Path> {
180 self.cwd.as_deref()
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq)]
186pub enum ConfigFileSelection {
187 Disabled,
189 Default,
191 Files(Vec<PathBuf>),
193}
194
195#[derive(Debug, Clone, PartialEq, Eq)]
197pub struct ServerDaemon {
198 config: DaemonConfig,
199}
200
201#[derive(Debug, Clone)]
202pub(crate) struct ShutdownHandle {
203 sender: Arc<StdMutex<Option<oneshot::Sender<()>>>>,
204}
205
206impl ShutdownHandle {
207 pub(crate) fn new() -> (Self, oneshot::Receiver<()>) {
208 let (sender, receiver) = oneshot::channel();
209 (
210 Self {
211 sender: Arc::new(StdMutex::new(Some(sender))),
212 },
213 receiver,
214 )
215 }
216
217 pub(crate) fn request_shutdown(&self) {
218 if let Some(sender) = self.sender.lock().expect("shutdown sender").take() {
219 let _ = sender.send(());
220 }
221 }
222}
223
224impl ServerDaemon {
225 #[must_use]
227 pub fn new(config: DaemonConfig) -> Self {
228 Self { config }
229 }
230
231 pub async fn bind(self) -> io::Result<ServerHandle> {
233 #[cfg(unix)]
234 {
235 let bound_listener = bind_unix_listener_at(self.config.socket_path())?;
236 let (shutdown_handle, shutdown_receiver) = ShutdownHandle::new();
237 let (server_signal_tx, server_signal_rx) = tokio::sync::mpsc::unbounded_channel();
238 let signal_watcher =
239 crate::signals::SignalWatcher::install(shutdown_handle.clone(), server_signal_tx)?;
240 let socket_path = self.config.socket_path().to_path_buf();
241 let owner_uid = real_user_id()?;
242 let serve_options = ServeOptions::new(
243 self.config.config_load().clone(),
244 self.config.subscription_limits(),
245 owner_uid,
246 )
247 .with_socket_identity(bound_listener.identity)
248 .with_server_signals(server_signal_rx);
249
250 let task = tokio::spawn(listener::serve(
251 bound_listener.listener,
252 socket_path.clone(),
253 shutdown_handle.clone(),
254 shutdown_receiver,
255 serve_options,
256 ));
257
258 Ok(ServerHandle {
259 socket_path,
260 shutdown_handle,
261 task: Some(task),
262 signal_watcher: Some(signal_watcher),
263 })
264 }
265
266 #[cfg(windows)]
267 {
268 let endpoint = LocalEndpoint::from_path(self.config.socket_path().to_path_buf());
269 let listener = bind_windows_listener(&endpoint)?;
270 let (shutdown_handle, shutdown_receiver) = ShutdownHandle::new();
271 let socket_path = self.config.socket_path().to_path_buf();
272 let owner_uid = current_owner_uid();
273 let serve_options = ServeOptions::new(
274 self.config.config_load().clone(),
275 self.config.subscription_limits(),
276 owner_uid,
277 );
278
279 let task = tokio::spawn(listener::serve(
280 listener,
281 socket_path.clone(),
282 shutdown_handle.clone(),
283 shutdown_receiver,
284 serve_options,
285 ));
286
287 Ok(ServerHandle {
288 socket_path,
289 shutdown_handle,
290 task: Some(task),
291 })
292 }
293 }
294}
295
296#[cfg(windows)]
297fn bind_windows_listener(endpoint: &LocalEndpoint) -> io::Result<LocalListener> {
298 match LocalListener::bind(endpoint) {
299 Ok(listener) => Ok(listener),
300 Err(bind_error) => Err(windows_bind_error(endpoint, bind_error)),
301 }
302}
303
304#[cfg(windows)]
305fn windows_bind_error(endpoint: &LocalEndpoint, bind_error: io::Error) -> io::Error {
306 if windows_pipe_responds(endpoint) {
307 return io::Error::new(
308 io::ErrorKind::AddrInUse,
309 format!(
310 "Windows named pipe '{}' is already held by a responsive rmux-compatible server",
311 endpoint.as_path().display()
312 ),
313 );
314 }
315
316 io::Error::new(
317 bind_error.kind(),
318 format!(
319 "failed to bind Windows named pipe '{}': {bind_error}. Another process may still be holding this endpoint",
320 endpoint.as_path().display()
321 ),
322 )
323}
324
325#[cfg(windows)]
326fn windows_pipe_responds(endpoint: &LocalEndpoint) -> bool {
327 let endpoint = endpoint.clone();
328 std::thread::spawn(move || windows_protocol_probe(&endpoint).unwrap_or(false))
329 .join()
330 .unwrap_or(false)
331}
332
333#[cfg(windows)]
334fn windows_protocol_probe(endpoint: &LocalEndpoint) -> io::Result<bool> {
335 let mut stream = connect_blocking(endpoint, Duration::from_millis(100))?;
336 stream.set_write_timeout(Some(Duration::from_millis(100)))?;
337 stream.set_read_timeout(Some(Duration::from_millis(100)))?;
338
339 let request = Request::HasSession(HasSessionRequest {
340 target: SessionName::new("__rmux_probe__").map_err(io::Error::other)?,
341 });
342 let frame = encode_frame(&request).map_err(io::Error::other)?;
343 stream.write_all(&frame)?;
344 stream.flush()?;
345
346 let mut decoder = FrameDecoder::new();
347 let mut buffer = [0_u8; 512];
348 loop {
349 let bytes_read = match stream.read(&mut buffer) {
350 Ok(0) => return Ok(false),
351 Ok(bytes_read) => bytes_read,
352 Err(error) if error.kind() == io::ErrorKind::TimedOut => return Ok(false),
353 Err(error) => return Err(error),
354 };
355 decoder.push_bytes(&buffer[..bytes_read]);
356 match decoder.next_frame::<Response>() {
357 Ok(Some(Response::HasSession(_))) => return Ok(true),
358 Ok(Some(_response)) => return Ok(false),
359 Ok(None) => continue,
360 Err(RmuxError::IncompleteFrame { .. }) => continue,
361 Err(_error) => return Ok(false),
362 }
363 }
364}
365
366#[derive(Debug)]
368pub struct ServerHandle {
369 socket_path: PathBuf,
370 shutdown_handle: ShutdownHandle,
371 task: Option<JoinHandle<io::Result<()>>>,
372 #[cfg(unix)]
373 signal_watcher: Option<crate::signals::SignalWatcher>,
374}
375
376impl ServerHandle {
377 #[must_use]
379 pub fn socket_path(&self) -> &Path {
380 &self.socket_path
381 }
382
383 pub async fn wait(mut self) -> io::Result<()> {
385 if let Some(task) = self.task.take() {
386 return task.await.map_err(io::Error::other)?;
387 }
388
389 Ok(())
390 }
391
392 pub async fn shutdown(mut self) -> io::Result<()> {
394 self.request_shutdown();
395
396 if let Some(task) = self.task.take() {
397 return task.await.map_err(io::Error::other)?;
398 }
399
400 Ok(())
401 }
402
403 fn request_shutdown(&mut self) {
404 #[cfg(unix)]
405 {
406 let _ = self.signal_watcher.take();
407 }
408 self.shutdown_handle.request_shutdown();
409 }
410}
411
412impl Drop for ServerHandle {
413 fn drop(&mut self) {
414 self.request_shutdown();
415 }
416}
417
418#[cfg(all(test, unix))]
419#[path = "daemon_tests/unix.rs"]
420mod tests;
421
422#[cfg(all(test, windows))]
423#[path = "daemon_tests/windows.rs"]
424mod tests;