Skip to main content

agent_procs/daemon/
server.rs

1use crate::daemon::wait_engine;
2use crate::protocol::{self, Request, Response};
3use std::path::Path;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::time::Duration;
7use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
8use tokio::net::UnixListener;
9use tokio::sync::Mutex;
10use tokio::sync::broadcast;
11
12use super::process_manager::ProcessManager;
13
14/// Maximum concurrent client connections.  Prevents accidental fork-bomb
15/// loops where each connection spawns more connections.
16const MAX_CONCURRENT_CONNECTIONS: usize = 64;
17
18pub struct DaemonState {
19    pub process_manager: ProcessManager,
20    pub proxy_port: Option<u16>,
21}
22
23pub async fn run(session: &str, socket_path: &Path) {
24    let state = Arc::new(Mutex::new(DaemonState {
25        process_manager: ProcessManager::new(session),
26        proxy_port: None,
27    }));
28
29    let listener = match UnixListener::bind(socket_path) {
30        Ok(l) => l,
31        Err(e) => {
32            tracing::error!(path = %socket_path.display(), error = %e, "failed to bind socket");
33            return;
34        }
35    };
36
37    // Shutdown signal: set to true when a Shutdown request is handled
38    let shutdown = Arc::new(tokio::sync::Notify::new());
39    let active_connections = Arc::new(AtomicUsize::new(0));
40
41    loop {
42        let (stream, _) = tokio::select! {
43            result = listener.accept() => match result {
44                Ok(conn) => conn,
45                Err(e) => {
46                    tracing::warn!(error = %e, "accept error");
47                    continue;
48                }
49            },
50            () = shutdown.notified() => break,
51        };
52
53        // Rate limiting: atomically increment then check to avoid TOCTOU race
54        let prev = active_connections.fetch_add(1, Ordering::AcqRel);
55        if prev >= MAX_CONCURRENT_CONNECTIONS {
56            active_connections.fetch_sub(1, Ordering::AcqRel);
57            tracing::warn!(
58                current = prev,
59                max = MAX_CONCURRENT_CONNECTIONS,
60                "connection rejected: too many concurrent connections"
61            );
62            drop(stream);
63            continue;
64        }
65
66        let state = Arc::clone(&state);
67        let shutdown = Arc::clone(&shutdown);
68        let conn_counter = Arc::clone(&active_connections);
69        tokio::spawn(async move {
70            let _guard = ConnectionGuard(conn_counter);
71            let (reader, writer) = stream.into_split();
72            let writer = Arc::new(Mutex::new(writer));
73            // Wrap reader in a size-limited adapter so read_line cannot
74            // allocate more than MAX_MESSAGE_SIZE bytes.
75            let limited = reader.take(protocol::MAX_MESSAGE_SIZE as u64);
76            let mut reader = BufReader::new(limited);
77
78            loop {
79                let mut line = String::new();
80                match reader.read_line(&mut line).await {
81                    Ok(0) | Err(_) => break, // EOF or error
82                    Ok(n) if n >= protocol::MAX_MESSAGE_SIZE => {
83                        let resp = Response::Error {
84                            code: 1,
85                            message: format!(
86                                "message too large: {} bytes (max {})",
87                                n,
88                                protocol::MAX_MESSAGE_SIZE
89                            ),
90                        };
91                        let _ = send_response(&writer, &resp).await;
92                        break; // disconnect oversized clients
93                    }
94                    Ok(_) => {}
95                }
96                // Reset the take limit for the next message
97                reader
98                    .get_mut()
99                    .set_limit(protocol::MAX_MESSAGE_SIZE as u64);
100
101                let request: Request = match serde_json::from_str(&line) {
102                    Ok(r) => r,
103                    Err(e) => {
104                        let resp = Response::Error {
105                            code: 1,
106                            message: format!("invalid request: {}", e),
107                        };
108                        let _ = send_response(&writer, &resp).await;
109                        continue;
110                    }
111                };
112
113                // Handle follow requests with streaming (before handle_request)
114                if let Request::Logs {
115                    follow: true,
116                    ref target,
117                    all,
118                    timeout_secs,
119                    lines,
120                    ..
121                } = request
122                {
123                    let output_rx = state.lock().await.process_manager.output_tx.subscribe();
124                    let max_lines = lines;
125                    let target_filter = target.clone();
126                    let show_all = all;
127
128                    handle_follow_stream(
129                        &writer,
130                        output_rx,
131                        target_filter,
132                        show_all,
133                        timeout_secs,
134                        max_lines,
135                    )
136                    .await;
137                    continue; // Don't call handle_request
138                }
139
140                let is_shutdown = matches!(request, Request::Shutdown);
141
142                let response = handle_request(&state, &shutdown, request).await;
143                let _ = send_response(&writer, &response).await;
144
145                if is_shutdown {
146                    shutdown.notify_one();
147                    return;
148                }
149            }
150        });
151    }
152}
153
154/// RAII guard that decrements the active connection counter when dropped.
155struct ConnectionGuard(Arc<AtomicUsize>);
156
157impl Drop for ConnectionGuard {
158    fn drop(&mut self) {
159        self.0.fetch_sub(1, Ordering::Relaxed);
160    }
161}
162
163async fn handle_follow_stream(
164    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
165    mut output_rx: broadcast::Receiver<super::log_writer::OutputLine>,
166    target: Option<String>,
167    all: bool,
168    timeout_secs: Option<u64>,
169    max_lines: Option<usize>,
170) {
171    let mut line_count: usize = 0;
172
173    let stream_loop = async {
174        loop {
175            match output_rx.recv().await {
176                Ok(output_line) => {
177                    if !all {
178                        if let Some(ref t) = target {
179                            if output_line.process != *t {
180                                continue;
181                            }
182                        }
183                    }
184
185                    let resp = Response::LogLine {
186                        process: output_line.process,
187                        stream: output_line.stream,
188                        line: output_line.line,
189                    };
190                    if send_response(writer, &resp).await.is_err() {
191                        return;
192                    }
193
194                    line_count += 1;
195                    if let Some(max) = max_lines {
196                        if line_count >= max {
197                            return;
198                        }
199                    }
200                }
201                Err(broadcast::error::RecvError::Lagged(_)) => {}
202                Err(broadcast::error::RecvError::Closed) => return,
203            }
204        }
205    };
206
207    // Apply timeout only if specified; otherwise stream indefinitely
208    match timeout_secs {
209        Some(secs) => {
210            let _ = tokio::time::timeout(Duration::from_secs(secs), stream_loop).await;
211        }
212        None => {
213            stream_loop.await;
214        }
215    }
216
217    let _ = send_response(writer, &Response::LogEnd).await;
218}
219
220async fn handle_request(
221    state: &Arc<Mutex<DaemonState>>,
222    shutdown: &Arc<tokio::sync::Notify>,
223    request: Request,
224) -> Response {
225    match request {
226        Request::Run {
227            command,
228            name,
229            cwd,
230            env,
231            port,
232        } => {
233            let mut s = state.lock().await;
234            let proxy_port = s.proxy_port;
235            let mut resp = s
236                .process_manager
237                .spawn_process(&command, name, cwd.as_deref(), env.as_ref(), port)
238                .await;
239            drop(s);
240            if let Response::RunOk {
241                ref name,
242                ref mut url,
243                port: Some(_),
244                ..
245            } = resp
246            {
247                if let Some(pp) = proxy_port {
248                    *url = Some(format!("http://{}.localhost:{}", name, pp));
249                }
250            }
251            resp
252        }
253        Request::Stop { target } => {
254            state
255                .lock()
256                .await
257                .process_manager
258                .stop_process(&target)
259                .await
260        }
261        Request::StopAll => state.lock().await.process_manager.stop_all().await,
262        Request::Restart { target } => {
263            state
264                .lock()
265                .await
266                .process_manager
267                .restart_process(&target)
268                .await
269        }
270        Request::Status => {
271            let mut s = state.lock().await;
272            let proxy_port = s.proxy_port;
273            let mut resp = s.process_manager.status();
274            if let Some(pp) = proxy_port {
275                if let Response::Status { ref mut processes } = resp {
276                    for p in processes.iter_mut() {
277                        if p.port.is_some() {
278                            p.url = Some(format!("http://{}.localhost:{}", p.name, pp));
279                        }
280                    }
281                }
282            }
283            resp
284        }
285        Request::Wait {
286            target,
287            until,
288            regex,
289            exit,
290            timeout_secs,
291        } => {
292            // Check process exists
293            let session_name = {
294                let s = state.lock().await;
295                if !s.process_manager.has_process(&target) {
296                    return Response::Error {
297                        code: 2,
298                        message: format!("process not found: {}", target),
299                    };
300                }
301                s.process_manager.session_name().to_string()
302            };
303
304            // Subscribe BEFORE checking historical logs to avoid missing lines
305            // emitted between the historical scan and the subscription.
306            let output_rx = state.lock().await.process_manager.output_tx.subscribe();
307
308            // Check historical log output for the pattern (fixes race where
309            // fast processes emit the pattern before Wait subscribes).
310            if let Some(ref pattern) = until {
311                let log_path =
312                    crate::paths::log_dir(&session_name).join(format!("{}.stdout", target));
313                if let Ok(content) = std::fs::read_to_string(&log_path) {
314                    // Compile regex once for the entire scan
315                    let compiled_re = if regex {
316                        regex::Regex::new(pattern).ok()
317                    } else {
318                        None
319                    };
320                    // Single-pass: find returns the first match (no need for any + find)
321                    let matched_line = content.lines().find(|line| {
322                        if let Some(ref re) = compiled_re {
323                            re.is_match(line)
324                        } else {
325                            line.contains(pattern.as_str())
326                        }
327                    });
328                    if let Some(line) = matched_line {
329                        return Response::WaitMatch {
330                            line: line.to_string(),
331                        };
332                    }
333                }
334            }
335            let timeout = Duration::from_secs(timeout_secs.unwrap_or(30));
336            let state_clone = Arc::clone(state);
337            let target_clone = target.clone();
338            wait_engine::wait_for(
339                output_rx,
340                &target,
341                until.as_deref(),
342                regex,
343                exit,
344                timeout,
345                move || {
346                    // This is called synchronously from the wait loop
347                    // We can't hold the lock across the whole wait, so we check briefly
348                    let state = state_clone.clone();
349                    let target = target_clone.clone();
350                    // Use try_lock to avoid deadlock
351                    match state.try_lock() {
352                        Ok(mut s) => s.process_manager.is_process_exited(&target),
353                        Err(_) => None,
354                    }
355                },
356            )
357            .await
358        }
359        Request::Logs { follow: false, .. } => {
360            // Non-follow logs are read directly from files by the CLI — no daemon involvement needed
361            Response::Error {
362                code: 1,
363                message: "non-follow logs are read directly from disk by CLI".into(),
364            }
365        }
366        Request::Logs { follow: true, .. } => {
367            // Handled separately in connection loop (needs streaming)
368            Response::Error {
369                code: 1,
370                message: "follow requests handled in connection loop".into(),
371            }
372        }
373        Request::Shutdown => {
374            let _ = state.lock().await.process_manager.stop_all().await;
375            Response::Ok {
376                message: "daemon shutting down".into(),
377            }
378        }
379        Request::EnableProxy { proxy_port } => {
380            let mut s = state.lock().await;
381            if let Some(existing_port) = s.proxy_port {
382                return Response::Ok {
383                    message: format!(
384                        "Proxy already listening on http://localhost:{}",
385                        existing_port
386                    ),
387                };
388            }
389
390            let (listener, port) = match super::proxy::bind_proxy_port(proxy_port) {
391                Ok(pair) => pair,
392                Err(e) => {
393                    return Response::Error {
394                        code: 1,
395                        message: e.to_string(),
396                    };
397                }
398            };
399
400            s.proxy_port = Some(port);
401            s.process_manager.enable_proxy();
402            drop(s);
403
404            let proxy_state = Arc::clone(state);
405            let proxy_shutdown = Arc::clone(shutdown);
406            tokio::spawn(async move {
407                if let Err(e) =
408                    super::proxy::start_proxy(listener, port, proxy_state, proxy_shutdown).await
409                {
410                    tracing::error!(error = %e, "proxy error");
411                }
412            });
413
414            Response::Ok {
415                message: format!("Proxy listening on http://localhost:{}", port),
416            }
417        }
418    }
419}
420
421async fn send_response(
422    writer: &Arc<Mutex<tokio::net::unix::OwnedWriteHalf>>,
423    response: &Response,
424) -> std::io::Result<()> {
425    let mut w = writer.lock().await;
426    let mut json = serde_json::to_string(response)
427        .expect("Response serialization should never fail for well-typed enums");
428    json.push('\n');
429    w.write_all(json.as_bytes()).await?;
430    w.flush().await
431}