Skip to main content

agent_procs/daemon/
server.rs

1use crate::daemon::actor::{PmHandle, ProcessManagerActor, ProxyState};
2use crate::daemon::wait_engine;
3use crate::protocol::{self, ErrorCode, PROTOCOL_VERSION, Request, Response};
4use std::path::Path;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Duration;
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixListener;
10use tokio::sync::Mutex;
11use tokio::sync::{broadcast, watch};
12
13/// Maximum concurrent client connections.  Prevents accidental fork-bomb
14/// loops where each connection spawns more connections.
15const MAX_CONCURRENT_CONNECTIONS: usize = 64;
16
17pub async fn run(session: &str, socket_path: &Path) {
18    let (handle, proxy_state_rx, actor) = ProcessManagerActor::new(session);
19
20    // Spawn the actor loop
21    tokio::spawn(actor.run());
22
23    let listener = match UnixListener::bind(socket_path) {
24        Ok(l) => l,
25        Err(e) => {
26            tracing::error!(path = %socket_path.display(), error = %e, "failed to bind socket");
27            return;
28        }
29    };
30
31    // Shutdown signal: set to true when a Shutdown request is handled
32    let shutdown = Arc::new(tokio::sync::Notify::new());
33    let active_connections = Arc::new(AtomicUsize::new(0));
34
35    loop {
36        let (stream, _) = tokio::select! {
37            result = listener.accept() => match result {
38                Ok(conn) => conn,
39                Err(e) => {
40                    tracing::warn!(error = %e, "accept error");
41                    continue;
42                }
43            },
44            () = shutdown.notified() => break,
45        };
46
47        // Rate limiting: atomically increment then check to avoid TOCTOU race
48        let prev = active_connections.fetch_add(1, Ordering::AcqRel);
49        if prev >= MAX_CONCURRENT_CONNECTIONS {
50            active_connections.fetch_sub(1, Ordering::AcqRel);
51            tracing::warn!(
52                current = prev,
53                max = MAX_CONCURRENT_CONNECTIONS,
54                "connection rejected: too many concurrent connections"
55            );
56            drop(stream);
57            continue;
58        }
59
60        let handle = handle.clone();
61        let shutdown = Arc::clone(&shutdown);
62        let conn_counter = Arc::clone(&active_connections);
63        let proxy_state_rx = proxy_state_rx.clone();
64        tokio::spawn(async move {
65            let _guard = ConnectionGuard(conn_counter);
66            let (reader, writer) = stream.into_split();
67            let writer = Arc::new(Mutex::new(writer));
68            // Wrap reader in a size-limited adapter so read_line cannot
69            // allocate more than MAX_MESSAGE_SIZE bytes.
70            let limited = reader.take(protocol::MAX_MESSAGE_SIZE as u64);
71            let mut reader = BufReader::new(limited);
72
73            loop {
74                let mut line = String::new();
75                match reader.read_line(&mut line).await {
76                    Ok(0) | Err(_) => break, // EOF or error
77                    Ok(n) if n >= protocol::MAX_MESSAGE_SIZE => {
78                        let resp = Response::Error {
79                            code: ErrorCode::General,
80                            message: format!(
81                                "message too large: {} bytes (max {})",
82                                n,
83                                protocol::MAX_MESSAGE_SIZE
84                            ),
85                        };
86                        let _ = send_response(&writer, &resp).await;
87                        break; // disconnect oversized clients
88                    }
89                    Ok(_) => {}
90                }
91                // Reset the take limit for the next message
92                reader
93                    .get_mut()
94                    .set_limit(protocol::MAX_MESSAGE_SIZE as u64);
95
96                let request: Request = match serde_json::from_str(&line) {
97                    Ok(r) => r,
98                    Err(e) => {
99                        let resp = Response::Error {
100                            code: ErrorCode::General,
101                            message: format!("invalid request: {}", e),
102                        };
103                        let _ = send_response(&writer, &resp).await;
104                        continue;
105                    }
106                };
107
108                // Handle follow requests with streaming (before handle_request)
109                if let Request::Logs {
110                    follow: true,
111                    ref target,
112                    all,
113                    timeout_secs,
114                    lines,
115                    ..
116                } = request
117                {
118                    let output_rx = handle.subscribe().await;
119                    let max_lines = lines;
120                    let target_filter = target.clone();
121                    let show_all = all;
122
123                    handle_follow_stream(
124                        &writer,
125                        output_rx,
126                        target_filter,
127                        show_all,
128                        timeout_secs,
129                        max_lines,
130                    )
131                    .await;
132                    continue; // Don't call handle_request
133                }
134
135                let is_shutdown = matches!(request, Request::Shutdown);
136
137                let response = handle_request(&handle, &shutdown, &proxy_state_rx, request).await;
138                let _ = send_response(&writer, &response).await;
139
140                if is_shutdown {
141                    shutdown.notify_one();
142                    return;
143                }
144            }
145        });
146    }
147}
148
149/// RAII guard that decrements the active connection counter when dropped.
150struct ConnectionGuard(Arc<AtomicUsize>);
151
152impl Drop for ConnectionGuard {
153    fn drop(&mut self) {
154        self.0.fetch_sub(1, Ordering::Relaxed);
155    }
156}
157
158async fn handle_follow_stream(
159    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
160    mut output_rx: broadcast::Receiver<super::log_writer::OutputLine>,
161    target: Option<String>,
162    all: bool,
163    timeout_secs: Option<u64>,
164    max_lines: Option<usize>,
165) {
166    let mut line_count: usize = 0;
167
168    let stream_loop = async {
169        loop {
170            match output_rx.recv().await {
171                Ok(output_line) => {
172                    if !all
173                        && let Some(ref t) = target
174                        && output_line.process != *t
175                    {
176                        continue;
177                    }
178
179                    let resp = Response::LogLine {
180                        process: output_line.process,
181                        stream: output_line.stream,
182                        line: output_line.line,
183                    };
184                    if send_response(writer, &resp).await.is_err() {
185                        return;
186                    }
187
188                    line_count += 1;
189                    if let Some(max) = max_lines
190                        && line_count >= max
191                    {
192                        return;
193                    }
194                }
195                Err(broadcast::error::RecvError::Lagged(_)) => {}
196                Err(broadcast::error::RecvError::Closed) => return,
197            }
198        }
199    };
200
201    // Apply timeout only if specified; otherwise stream indefinitely
202    match timeout_secs {
203        Some(secs) => {
204            let _ = tokio::time::timeout(Duration::from_secs(secs), stream_loop).await;
205        }
206        None => {
207            stream_loop.await;
208        }
209    }
210
211    let _ = send_response(writer, &Response::LogEnd).await;
212}
213
214async fn handle_request(
215    handle: &PmHandle,
216    shutdown: &Arc<tokio::sync::Notify>,
217    proxy_state_rx: &watch::Receiver<ProxyState>,
218    request: Request,
219) -> Response {
220    match request {
221        Request::Run {
222            command,
223            name,
224            cwd,
225            env,
226            port,
227        } => handle.spawn_process(command, name, cwd, env, port).await,
228        Request::Stop { target } => handle.stop_process(&target).await,
229        Request::StopAll => handle.stop_all().await,
230        Request::Restart { target } => handle.restart_process(&target).await,
231        Request::Status => handle.status().await,
232        Request::Wait {
233            target,
234            until,
235            regex,
236            exit,
237            timeout_secs,
238        } => {
239            // Check process exists
240            if !handle.has_process(&target).await {
241                return Response::Error {
242                    code: ErrorCode::NotFound,
243                    message: format!("process not found: {}", target),
244                };
245            }
246
247            let session_name = handle.session_name().await;
248
249            // Subscribe BEFORE checking historical logs to avoid missing lines
250            let output_rx = handle.subscribe().await;
251
252            // Check historical log output for the pattern
253            if let Some(ref pattern) = until {
254                let log_path =
255                    crate::paths::log_dir(&session_name).join(format!("{}.stdout", target));
256                if let Ok(content) = std::fs::read_to_string(&log_path) {
257                    let compiled_re = if regex {
258                        regex::Regex::new(pattern).ok()
259                    } else {
260                        None
261                    };
262                    let matched_line = content.lines().find(|line| {
263                        if let Some(ref re) = compiled_re {
264                            re.is_match(line)
265                        } else {
266                            line.contains(pattern.as_str())
267                        }
268                    });
269                    if let Some(line) = matched_line {
270                        return Response::WaitMatch {
271                            line: line.to_string(),
272                        };
273                    }
274                }
275            }
276            let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
277            wait_engine::wait_for(
278                output_rx,
279                &target,
280                until.as_deref(),
281                regex,
282                exit,
283                timeout,
284                handle.clone(),
285            )
286            .await
287        }
288        Request::Logs { follow: false, .. } => Response::Error {
289            code: ErrorCode::General,
290            message: "non-follow logs are read directly from disk by CLI".into(),
291        },
292        Request::Logs { follow: true, .. } => Response::Error {
293            code: ErrorCode::General,
294            message: "follow requests handled in connection loop".into(),
295        },
296        Request::Shutdown => {
297            let _ = handle.stop_all().await;
298            Response::Ok {
299                message: "daemon shutting down".into(),
300            }
301        }
302        Request::EnableProxy { proxy_port } => {
303            let (listener, port) = match super::proxy::bind_proxy_port(proxy_port) {
304                Ok(pair) => pair,
305                Err(e) => {
306                    return Response::Error {
307                        code: ErrorCode::General,
308                        message: e.to_string(),
309                    };
310                }
311            };
312
313            if let Some(existing) = handle.enable_proxy(port).await {
314                return Response::Ok {
315                    message: format!("Proxy already listening on http://localhost:{}", existing),
316                };
317            }
318
319            let proxy_handle = handle.clone();
320            let proxy_shutdown = Arc::clone(shutdown);
321            let proxy_rx = proxy_state_rx.clone();
322            tokio::spawn(async move {
323                if let Err(e) = super::proxy::start_proxy(
324                    listener,
325                    port,
326                    proxy_handle,
327                    proxy_rx,
328                    proxy_shutdown,
329                )
330                .await
331                {
332                    tracing::error!(error = %e, "proxy error");
333                }
334            });
335
336            Response::Ok {
337                message: format!("Proxy listening on http://localhost:{}", port),
338            }
339        }
340        Request::Hello { .. } => Response::Hello {
341            version: PROTOCOL_VERSION,
342        },
343        Request::Unknown => Response::Error {
344            code: ErrorCode::General,
345            message: "unknown request type".into(),
346        },
347    }
348}
349
350async fn send_response(
351    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
352    response: &Response,
353) -> std::io::Result<()> {
354    let mut w = writer.lock().await;
355    let mut json = serde_json::to_string(response)
356        .expect("Response serialization should never fail for well-typed enums");
357    json.push('\n');
358    w.write_all(json.as_bytes()).await?;
359    w.flush().await
360}