Skip to main content

agent_procs/daemon/
server.rs

1use crate::daemon::actor::{PmHandle, ProcessManagerActor, ProxyState};
2use crate::daemon::wait_engine;
3use crate::protocol::{self, ErrorCode, PROTOCOL_VERSION, Request, Response};
4use std::path::Path;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Duration;
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixListener;
10use tokio::sync::Mutex;
11use tokio::sync::{broadcast, watch};
12
13/// Maximum concurrent client connections.  Prevents accidental fork-bomb
14/// loops where each connection spawns more connections.
15const MAX_CONCURRENT_CONNECTIONS: usize = 64;
16
17pub async fn run(session: &str, socket_path: &Path) -> std::io::Result<()> {
18    let (handle, proxy_state_rx, actor) = ProcessManagerActor::new(session);
19
20    // Spawn the actor loop
21    tokio::spawn(actor.run());
22
23    let listener = UnixListener::bind(socket_path).map_err(|e| {
24        tracing::error!(path = %socket_path.display(), error = %e, "failed to bind socket");
25        e
26    })?;
27
28    // Shutdown signal: set to true when a Shutdown request is handled
29    let shutdown = Arc::new(tokio::sync::Notify::new());
30    let active_connections = Arc::new(AtomicUsize::new(0));
31
32    loop {
33        let (stream, _) = tokio::select! {
34            result = listener.accept() => match result {
35                Ok(conn) => conn,
36                Err(e) => {
37                    tracing::warn!(error = %e, "accept error");
38                    continue;
39                }
40            },
41            () = shutdown.notified() => break,
42        };
43
44        // Rate limiting: atomically increment then check to avoid TOCTOU race
45        let prev = active_connections.fetch_add(1, Ordering::AcqRel);
46        if prev >= MAX_CONCURRENT_CONNECTIONS {
47            active_connections.fetch_sub(1, Ordering::AcqRel);
48            tracing::warn!(
49                current = prev,
50                max = MAX_CONCURRENT_CONNECTIONS,
51                "connection rejected: too many concurrent connections"
52            );
53            drop(stream);
54            continue;
55        }
56
57        let handle = handle.clone();
58        let shutdown = Arc::clone(&shutdown);
59        let conn_counter = Arc::clone(&active_connections);
60        let proxy_state_rx = proxy_state_rx.clone();
61        tokio::spawn(async move {
62            let _guard = ConnectionGuard(conn_counter);
63            let (reader, writer) = stream.into_split();
64            let writer = Arc::new(Mutex::new(writer));
65            // Wrap reader in a size-limited adapter so read_line cannot
66            // allocate more than MAX_MESSAGE_SIZE bytes.
67            let limited = reader.take(protocol::MAX_MESSAGE_SIZE as u64);
68            let mut reader = BufReader::new(limited);
69
70            loop {
71                let mut line = String::new();
72                match reader.read_line(&mut line).await {
73                    Ok(0) | Err(_) => break, // EOF or error
74                    Ok(n) if n >= protocol::MAX_MESSAGE_SIZE => {
75                        let resp = Response::Error {
76                            code: ErrorCode::General,
77                            message: format!(
78                                "message too large: {} bytes (max {})",
79                                n,
80                                protocol::MAX_MESSAGE_SIZE
81                            ),
82                        };
83                        let _ = send_response(&writer, &resp).await;
84                        break; // disconnect oversized clients
85                    }
86                    Ok(_) => {}
87                }
88                // Reset the take limit for the next message
89                reader
90                    .get_mut()
91                    .set_limit(protocol::MAX_MESSAGE_SIZE as u64);
92
93                let request: Request = match serde_json::from_str(&line) {
94                    Ok(r) => r,
95                    Err(e) => {
96                        let resp = Response::Error {
97                            code: ErrorCode::General,
98                            message: format!("invalid request: {}", e),
99                        };
100                        let _ = send_response(&writer, &resp).await;
101                        continue;
102                    }
103                };
104
105                // Handle follow requests with streaming (before handle_request)
106                if let Request::Logs {
107                    follow: true,
108                    ref target,
109                    all,
110                    stderr,
111                    timeout_secs,
112                    lines,
113                    ..
114                } = request
115                {
116                    let output_rx = handle.subscribe().await;
117                    let max_lines = lines;
118                    let target_filter = target.clone();
119                    let show_all = all;
120
121                    handle_follow_stream(
122                        &writer,
123                        output_rx,
124                        target_filter,
125                        show_all,
126                        stderr,
127                        timeout_secs,
128                        max_lines,
129                    )
130                    .await;
131                    continue; // Don't call handle_request
132                }
133
134                let is_shutdown = matches!(request, Request::Shutdown);
135
136                let response = handle_request(&handle, &shutdown, &proxy_state_rx, request).await;
137                let _ = send_response(&writer, &response).await;
138
139                if is_shutdown {
140                    shutdown.notify_one();
141                    return;
142                }
143            }
144        });
145    }
146
147    Ok(())
148}
149
150/// RAII guard that decrements the active connection counter when dropped.
151struct ConnectionGuard(Arc<AtomicUsize>);
152
153impl Drop for ConnectionGuard {
154    fn drop(&mut self) {
155        self.0.fetch_sub(1, Ordering::Relaxed);
156    }
157}
158
159async fn handle_follow_stream(
160    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
161    mut output_rx: broadcast::Receiver<super::log_writer::OutputLine>,
162    target: Option<String>,
163    all: bool,
164    stderr_only: bool,
165    timeout_secs: Option<u64>,
166    max_lines: Option<usize>,
167) {
168    let mut line_count: usize = 0;
169
170    let stream_loop = async {
171        loop {
172            match output_rx.recv().await {
173                Ok(output_line) => {
174                    if !all
175                        && let Some(ref t) = target
176                        && output_line.process != *t
177                    {
178                        continue;
179                    }
180                    if stderr_only && output_line.stream != protocol::Stream::Stderr {
181                        continue;
182                    }
183                    if !stderr_only
184                        && (!all || target.is_some())
185                        && output_line.stream != protocol::Stream::Stdout
186                    {
187                        continue;
188                    }
189
190                    let resp = Response::LogLine {
191                        process: output_line.process,
192                        stream: output_line.stream,
193                        line: output_line.line,
194                    };
195                    if send_response(writer, &resp).await.is_err() {
196                        return;
197                    }
198
199                    line_count += 1;
200                    if let Some(max) = max_lines
201                        && line_count >= max
202                    {
203                        return;
204                    }
205                }
206                Err(broadcast::error::RecvError::Lagged(_)) => {}
207                Err(broadcast::error::RecvError::Closed) => return,
208            }
209        }
210    };
211
212    // Apply timeout only if specified; otherwise stream indefinitely
213    match timeout_secs {
214        Some(secs) => {
215            let _ = tokio::time::timeout(Duration::from_secs(secs), stream_loop).await;
216        }
217        None => {
218            stream_loop.await;
219        }
220    }
221
222    let _ = send_response(writer, &Response::LogEnd).await;
223}
224
225async fn handle_request(
226    handle: &PmHandle,
227    shutdown: &Arc<tokio::sync::Notify>,
228    proxy_state_rx: &watch::Receiver<ProxyState>,
229    request: Request,
230) -> Response {
231    match request {
232        Request::Run {
233            command,
234            name,
235            cwd,
236            env,
237            port,
238            restart,
239            watch,
240        } => {
241            handle
242                .spawn_process_supervised(command, name, cwd, env, port, restart, watch)
243                .await
244        }
245        Request::Stop { target } => handle.stop_process(&target).await,
246        Request::StopAll => handle.stop_all().await,
247        Request::Restart { target } => handle.restart_process(&target).await,
248        Request::Status => handle.status().await,
249        Request::Wait {
250            target,
251            until,
252            regex,
253            exit,
254            timeout_secs,
255        } => {
256            // Check process exists
257            if !handle.has_process(&target).await {
258                return Response::Error {
259                    code: ErrorCode::NotFound,
260                    message: format!("process not found: {}", target),
261                };
262            }
263
264            let session_name = handle.session_name().await;
265
266            // Subscribe BEFORE checking historical logs to avoid missing lines
267            let output_rx = handle.subscribe().await;
268
269            // Check historical log output for the pattern
270            if let Some(ref pattern) = until {
271                let log_path =
272                    crate::paths::log_dir(&session_name).join(format!("{}.stdout", target));
273                if let Ok(content) = std::fs::read_to_string(&log_path) {
274                    let compiled_re = if regex {
275                        regex::Regex::new(pattern).ok()
276                    } else {
277                        None
278                    };
279                    let matched_line = content.lines().find(|line| {
280                        if let Some(ref re) = compiled_re {
281                            re.is_match(line)
282                        } else {
283                            line.contains(pattern.as_str())
284                        }
285                    });
286                    if let Some(line) = matched_line {
287                        return Response::WaitMatch {
288                            line: line.to_string(),
289                        };
290                    }
291                }
292            }
293            let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
294            wait_engine::wait_for(
295                output_rx,
296                &target,
297                until.as_deref(),
298                regex,
299                exit,
300                timeout,
301                handle.clone(),
302            )
303            .await
304        }
305        Request::Logs { follow: false, .. } => Response::Error {
306            code: ErrorCode::General,
307            message: "non-follow logs are read directly from disk by CLI".into(),
308        },
309        Request::Logs { follow: true, .. } => Response::Error {
310            code: ErrorCode::General,
311            message: "follow requests handled in connection loop".into(),
312        },
313        Request::Shutdown => {
314            let _ = handle.stop_all().await;
315            Response::Ok {
316                message: "daemon shutting down".into(),
317            }
318        }
319        Request::EnableProxy { proxy_port } => {
320            let (listener, port) = match super::proxy::bind_proxy_port(proxy_port) {
321                Ok(pair) => pair,
322                Err(e) => {
323                    return Response::Error {
324                        code: ErrorCode::General,
325                        message: e.to_string(),
326                    };
327                }
328            };
329
330            if let Some(existing) = handle.enable_proxy(port).await {
331                return Response::Ok {
332                    message: format!("Proxy already listening on http://localhost:{}", existing),
333                };
334            }
335
336            let proxy_handle = handle.clone();
337            let proxy_shutdown = Arc::clone(shutdown);
338            let proxy_rx = proxy_state_rx.clone();
339            tokio::spawn(async move {
340                if let Err(e) = super::proxy::start_proxy(
341                    listener,
342                    port,
343                    proxy_handle,
344                    proxy_rx,
345                    proxy_shutdown,
346                )
347                .await
348                {
349                    tracing::error!(error = %e, "proxy error");
350                }
351            });
352
353            Response::Ok {
354                message: format!("Proxy listening on http://localhost:{}", port),
355            }
356        }
357        Request::Hello { .. } => Response::Hello {
358            version: PROTOCOL_VERSION,
359        },
360        Request::Unknown => Response::Error {
361            code: ErrorCode::General,
362            message: "unknown request type".into(),
363        },
364    }
365}
366
367async fn send_response(
368    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
369    response: &Response,
370) -> std::io::Result<()> {
371    let mut w = writer.lock().await;
372    let mut json = serde_json::to_string(response)
373        .expect("Response serialization should never fail for well-typed enums");
374    json.push('\n');
375    w.write_all(json.as_bytes()).await?;
376    w.flush().await
377}