Skip to main content

agent_procs/daemon/
server.rs

1use crate::daemon::actor::{PmHandle, ProcessManagerActor, ProxyState};
2use crate::daemon::wait_engine;
3use crate::protocol::{self, ErrorCode, Request, Response, PROTOCOL_VERSION};
4use std::path::Path;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicUsize, Ordering};
7use std::time::Duration;
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::net::UnixListener;
10use tokio::sync::Mutex;
11use tokio::sync::{broadcast, watch};
12
13/// Maximum concurrent client connections.  Prevents accidental fork-bomb
14/// loops where each connection spawns more connections.
15const MAX_CONCURRENT_CONNECTIONS: usize = 64;
16
17pub async fn run(session: &str, socket_path: &Path) {
18    let (handle, proxy_state_rx, actor) = ProcessManagerActor::new(session);
19
20    // Spawn the actor loop
21    tokio::spawn(actor.run());
22
23    let listener = match UnixListener::bind(socket_path) {
24        Ok(l) => l,
25        Err(e) => {
26            tracing::error!(path = %socket_path.display(), error = %e, "failed to bind socket");
27            return;
28        }
29    };
30
31    // Shutdown signal: set to true when a Shutdown request is handled
32    let shutdown = Arc::new(tokio::sync::Notify::new());
33    let active_connections = Arc::new(AtomicUsize::new(0));
34
35    loop {
36        let (stream, _) = tokio::select! {
37            result = listener.accept() => match result {
38                Ok(conn) => conn,
39                Err(e) => {
40                    tracing::warn!(error = %e, "accept error");
41                    continue;
42                }
43            },
44            () = shutdown.notified() => break,
45        };
46
47        // Rate limiting: atomically increment then check to avoid TOCTOU race
48        let prev = active_connections.fetch_add(1, Ordering::AcqRel);
49        if prev >= MAX_CONCURRENT_CONNECTIONS {
50            active_connections.fetch_sub(1, Ordering::AcqRel);
51            tracing::warn!(
52                current = prev,
53                max = MAX_CONCURRENT_CONNECTIONS,
54                "connection rejected: too many concurrent connections"
55            );
56            drop(stream);
57            continue;
58        }
59
60        let handle = handle.clone();
61        let shutdown = Arc::clone(&shutdown);
62        let conn_counter = Arc::clone(&active_connections);
63        let proxy_state_rx = proxy_state_rx.clone();
64        tokio::spawn(async move {
65            let _guard = ConnectionGuard(conn_counter);
66            let (reader, writer) = stream.into_split();
67            let writer = Arc::new(Mutex::new(writer));
68            // Wrap reader in a size-limited adapter so read_line cannot
69            // allocate more than MAX_MESSAGE_SIZE bytes.
70            let limited = reader.take(protocol::MAX_MESSAGE_SIZE as u64);
71            let mut reader = BufReader::new(limited);
72
73            loop {
74                let mut line = String::new();
75                match reader.read_line(&mut line).await {
76                    Ok(0) | Err(_) => break, // EOF or error
77                    Ok(n) if n >= protocol::MAX_MESSAGE_SIZE => {
78                        let resp = Response::Error {
79                            code: ErrorCode::General,
80                            message: format!(
81                                "message too large: {} bytes (max {})",
82                                n,
83                                protocol::MAX_MESSAGE_SIZE
84                            ),
85                        };
86                        let _ = send_response(&writer, &resp).await;
87                        break; // disconnect oversized clients
88                    }
89                    Ok(_) => {}
90                }
91                // Reset the take limit for the next message
92                reader
93                    .get_mut()
94                    .set_limit(protocol::MAX_MESSAGE_SIZE as u64);
95
96                let request: Request = match serde_json::from_str(&line) {
97                    Ok(r) => r,
98                    Err(e) => {
99                        let resp = Response::Error {
100                            code: ErrorCode::General,
101                            message: format!("invalid request: {}", e),
102                        };
103                        let _ = send_response(&writer, &resp).await;
104                        continue;
105                    }
106                };
107
108                // Handle follow requests with streaming (before handle_request)
109                if let Request::Logs {
110                    follow: true,
111                    ref target,
112                    all,
113                    timeout_secs,
114                    lines,
115                    ..
116                } = request
117                {
118                    let output_rx = handle.subscribe().await;
119                    let max_lines = lines;
120                    let target_filter = target.clone();
121                    let show_all = all;
122
123                    handle_follow_stream(
124                        &writer,
125                        output_rx,
126                        target_filter,
127                        show_all,
128                        timeout_secs,
129                        max_lines,
130                    )
131                    .await;
132                    continue; // Don't call handle_request
133                }
134
135                let is_shutdown = matches!(request, Request::Shutdown);
136
137                let response =
138                    handle_request(&handle, &shutdown, &proxy_state_rx, request).await;
139                let _ = send_response(&writer, &response).await;
140
141                if is_shutdown {
142                    shutdown.notify_one();
143                    return;
144                }
145            }
146        });
147    }
148}
149
150/// RAII guard that decrements the active connection counter when dropped.
151struct ConnectionGuard(Arc<AtomicUsize>);
152
153impl Drop for ConnectionGuard {
154    fn drop(&mut self) {
155        self.0.fetch_sub(1, Ordering::Relaxed);
156    }
157}
158
159async fn handle_follow_stream(
160    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
161    mut output_rx: broadcast::Receiver<super::log_writer::OutputLine>,
162    target: Option<String>,
163    all: bool,
164    timeout_secs: Option<u64>,
165    max_lines: Option<usize>,
166) {
167    let mut line_count: usize = 0;
168
169    let stream_loop = async {
170        loop {
171            match output_rx.recv().await {
172                Ok(output_line) => {
173                    if !all {
174                        if let Some(ref t) = target {
175                            if output_line.process != *t {
176                                continue;
177                            }
178                        }
179                    }
180
181                    let resp = Response::LogLine {
182                        process: output_line.process,
183                        stream: output_line.stream,
184                        line: output_line.line,
185                    };
186                    if send_response(writer, &resp).await.is_err() {
187                        return;
188                    }
189
190                    line_count += 1;
191                    if let Some(max) = max_lines {
192                        if line_count >= max {
193                            return;
194                        }
195                    }
196                }
197                Err(broadcast::error::RecvError::Lagged(_)) => {}
198                Err(broadcast::error::RecvError::Closed) => return,
199            }
200        }
201    };
202
203    // Apply timeout only if specified; otherwise stream indefinitely
204    match timeout_secs {
205        Some(secs) => {
206            let _ = tokio::time::timeout(Duration::from_secs(secs), stream_loop).await;
207        }
208        None => {
209            stream_loop.await;
210        }
211    }
212
213    let _ = send_response(writer, &Response::LogEnd).await;
214}
215
216async fn handle_request(
217    handle: &PmHandle,
218    shutdown: &Arc<tokio::sync::Notify>,
219    proxy_state_rx: &watch::Receiver<ProxyState>,
220    request: Request,
221) -> Response {
222    match request {
223        Request::Run {
224            command,
225            name,
226            cwd,
227            env,
228            port,
229        } => handle.spawn_process(command, name, cwd, env, port).await,
230        Request::Stop { target } => handle.stop_process(&target).await,
231        Request::StopAll => handle.stop_all().await,
232        Request::Restart { target } => handle.restart_process(&target).await,
233        Request::Status => handle.status().await,
234        Request::Wait {
235            target,
236            until,
237            regex,
238            exit,
239            timeout_secs,
240        } => {
241            // Check process exists
242            if !handle.has_process(&target).await {
243                return Response::Error {
244                    code: ErrorCode::NotFound,
245                    message: format!("process not found: {}", target),
246                };
247            }
248
249            let session_name = handle.session_name().await;
250
251            // Subscribe BEFORE checking historical logs to avoid missing lines
252            let output_rx = handle.subscribe().await;
253
254            // Check historical log output for the pattern
255            if let Some(ref pattern) = until {
256                let log_path =
257                    crate::paths::log_dir(&session_name).join(format!("{}.stdout", target));
258                if let Ok(content) = std::fs::read_to_string(&log_path) {
259                    let compiled_re = if regex {
260                        regex::Regex::new(pattern).ok()
261                    } else {
262                        None
263                    };
264                    let matched_line = content.lines().find(|line| {
265                        if let Some(ref re) = compiled_re {
266                            re.is_match(line)
267                        } else {
268                            line.contains(pattern.as_str())
269                        }
270                    });
271                    if let Some(line) = matched_line {
272                        return Response::WaitMatch {
273                            line: line.to_string(),
274                        };
275                    }
276                }
277            }
278            let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
279            wait_engine::wait_for(
280                output_rx,
281                &target,
282                until.as_deref(),
283                regex,
284                exit,
285                timeout,
286                handle.clone(),
287            )
288            .await
289        }
290        Request::Logs { follow: false, .. } => Response::Error {
291            code: ErrorCode::General,
292            message: "non-follow logs are read directly from disk by CLI".into(),
293        },
294        Request::Logs { follow: true, .. } => Response::Error {
295            code: ErrorCode::General,
296            message: "follow requests handled in connection loop".into(),
297        },
298        Request::Shutdown => {
299            let _ = handle.stop_all().await;
300            Response::Ok {
301                message: "daemon shutting down".into(),
302            }
303        }
304        Request::EnableProxy { proxy_port } => {
305            let (listener, port) = match super::proxy::bind_proxy_port(proxy_port) {
306                Ok(pair) => pair,
307                Err(e) => {
308                    return Response::Error {
309                        code: ErrorCode::General,
310                        message: e.to_string(),
311                    };
312                }
313            };
314
315            if let Some(existing) = handle.enable_proxy(port).await {
316                return Response::Ok {
317                    message: format!(
318                        "Proxy already listening on http://localhost:{}",
319                        existing
320                    ),
321                };
322            }
323
324            let proxy_handle = handle.clone();
325            let proxy_shutdown = Arc::clone(shutdown);
326            let proxy_rx = proxy_state_rx.clone();
327            tokio::spawn(async move {
328                if let Err(e) =
329                    super::proxy::start_proxy(listener, port, proxy_handle, proxy_rx, proxy_shutdown)
330                        .await
331                {
332                    tracing::error!(error = %e, "proxy error");
333                }
334            });
335
336            Response::Ok {
337                message: format!("Proxy listening on http://localhost:{}", port),
338            }
339        }
340        Request::Hello { .. } => Response::Hello {
341            version: PROTOCOL_VERSION,
342        },
343        Request::Unknown => Response::Error {
344            code: ErrorCode::General,
345            message: "unknown request type".into(),
346        },
347    }
348}
349
350async fn send_response(
351    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
352    response: &Response,
353) -> std::io::Result<()> {
354    let mut w = writer.lock().await;
355    let mut json = serde_json::to_string(response)
356        .expect("Response serialization should never fail for well-typed enums");
357    json.push('\n');
358    w.write_all(json.as_bytes()).await?;
359    w.flush().await
360}