Skip to main content

neuronbox_runtime/
server.rs

1//! Unix socket server (newline-delimited JSON).
2
3use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::gpu_manager::GpuManager;
10use crate::host::{compute_apps_display_lines, compute_apps_pid_memory_mb};
11use crate::model_loader::ModelLoader;
12use crate::protocol::{DaemonRequest, DaemonResponse, SessionInfo, SWAP_SIGNAL_FILE_VERSION};
13use crate::vram_watch;
14
15const PROTOCOL_VERSION: u32 = 1;
16
17/// Max bytes per incoming line (excluding the newline). Larger lines are rejected with an error response.
18pub const MAX_REQUEST_LINE_BYTES: usize = 256 * 1024;
19
20fn vram_watch_disabled() -> bool {
21    std::env::var_os("NEURONBOX_DISABLE_VRAM_WATCH")
22        .map(|v| {
23            let s = v.to_string_lossy().to_ascii_lowercase();
24            matches!(s.as_str(), "1" | "true" | "yes")
25        })
26        .unwrap_or(false)
27}
28
29enum RequestLineError {
30    TooLong,
31    BadUtf8,
32    Io(std::io::Error),
33}
34
35async fn read_request_line<R: tokio::io::AsyncRead + Unpin>(
36    reader: &mut BufReader<R>,
37) -> Result<Option<String>, RequestLineError> {
38    let mut line = Vec::new();
39    loop {
40        if line.len() >= MAX_REQUEST_LINE_BYTES {
41            return Err(RequestLineError::TooLong);
42        }
43        let mut b = [0u8; 1];
44        let n = reader.read(&mut b).await.map_err(RequestLineError::Io)?;
45        if n == 0 {
46            return if line.is_empty() {
47                Ok(None)
48            } else {
49                Err(RequestLineError::Io(std::io::Error::new(
50                    std::io::ErrorKind::UnexpectedEof,
51                    "EOF before newline",
52                )))
53            };
54        }
55        if b[0] == b'\n' {
56            break;
57        }
58        line.push(b[0]);
59    }
60    String::from_utf8(line)
61        .map(Some)
62        .map_err(|_| RequestLineError::BadUtf8)
63}
64
65pub async fn run_socket_server(
66    socket_path: &Path,
67    gpu_manager: GpuManager,
68    model_loader: ModelLoader,
69) -> Result<()> {
70    if socket_path.exists() {
71        std::fs::remove_file(socket_path).ok();
72    }
73    if let Some(dir) = socket_path.parent() {
74        std::fs::create_dir_all(dir).with_context(|| format!("create_dir_all {:?}", dir))?;
75    }
76    let listener = UnixListener::bind(socket_path)
77        .with_context(|| format!("bind unix socket {:?}", socket_path))?;
78
79    // Restrict socket permissions to owner only (security: local user isolation)
80    #[cfg(unix)]
81    {
82        use std::os::unix::fs::PermissionsExt;
83        std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600))
84            .with_context(|| format!("chmod 600 {:?}", socket_path))?;
85    }
86
87    if !vram_watch_disabled() {
88        let gm_watch = gpu_manager.clone();
89        tokio::spawn(vram_watch::run_soft_vram_enforcement(gm_watch));
90    }
91
92    loop {
93        let (stream, _) = listener.accept().await?;
94        let gm = gpu_manager.clone();
95        let ml = model_loader.clone();
96        tokio::spawn(async move {
97            if let Err(e) = handle_connection(stream, gm, ml).await {
98                tracing::warn!("connection error: {e:#}");
99            }
100        });
101    }
102}
103
104async fn handle_connection(
105    stream: UnixStream,
106    gpu_manager: GpuManager,
107    model_loader: ModelLoader,
108) -> Result<()> {
109    let (read_half, mut write_half) = stream.into_split();
110    let mut reader = BufReader::new(read_half);
111
112    loop {
113        let line = match read_request_line(&mut reader).await {
114            Ok(None) => break,
115            Ok(Some(s)) => s,
116            Err(RequestLineError::TooLong) => {
117                let err = DaemonResponse::Error {
118                    message: format!(
119                        "request line exceeds maximum size ({MAX_REQUEST_LINE_BYTES} bytes)"
120                    ),
121                };
122                write_response(&mut write_half, &err).await?;
123                break;
124            }
125            Err(RequestLineError::BadUtf8) => {
126                let err = DaemonResponse::Error {
127                    message: "invalid UTF-8 in request line".to_string(),
128                };
129                write_response(&mut write_half, &err).await?;
130                break;
131            }
132            Err(RequestLineError::Io(e)) => return Err(e.into()),
133        };
134        let trimmed = line.trim();
135        if trimmed.is_empty() {
136            continue;
137        }
138
139        let req: DaemonRequest = match serde_json::from_str(trimmed) {
140            Ok(r) => r,
141            Err(e) => {
142                let err = DaemonResponse::Error {
143                    message: format!("invalid JSON request: {e}"),
144                };
145                write_response(&mut write_half, &err).await?;
146                continue;
147            }
148        };
149
150        let resp = dispatch(req, &gpu_manager, &model_loader).await;
151        write_response(&mut write_half, &resp).await?;
152    }
153    Ok(())
154}
155
156async fn write_response(
157    w: &mut tokio::net::unix::OwnedWriteHalf,
158    resp: &DaemonResponse,
159) -> Result<()> {
160    let mut s = serde_json::to_string(resp)?;
161    s.push('\n');
162    w.write_all(s.as_bytes()).await?;
163    Ok(())
164}
165
166async fn dispatch(
167    req: DaemonRequest,
168    gpu_manager: &GpuManager,
169    model_loader: &ModelLoader,
170) -> DaemonResponse {
171    match req {
172        DaemonRequest::Ping => DaemonResponse::Pong,
173        DaemonRequest::Version { v } => {
174            if v != PROTOCOL_VERSION {
175                DaemonResponse::Error {
176                    message: format!("protocol mismatch: client {v}, daemon {PROTOCOL_VERSION}"),
177                }
178            } else {
179                DaemonResponse::VersionInfo {
180                    v: PROTOCOL_VERSION,
181                }
182            }
183        }
184        DaemonRequest::RegisterSession {
185            name,
186            estimated_vram_mb,
187            pid,
188            tokens_per_sec,
189            model_dir,
190        } => {
191            gpu_manager
192                .register(SessionInfo {
193                    name,
194                    pid,
195                    estimated_vram_mb,
196                    tokens_per_sec,
197                    model_dir,
198                })
199                .await;
200            DaemonResponse::Registered { pid }
201        }
202        DaemonRequest::UnregisterSession { pid } => {
203            let ok = gpu_manager.unregister(pid).await;
204            if ok {
205                DaemonResponse::Unregistered
206            } else {
207                DaemonResponse::Error {
208                    message: format!("pid {pid} not registered"),
209                }
210            }
211        }
212        DaemonRequest::ListSessions => {
213            let sessions = gpu_manager.list().await;
214            DaemonResponse::Sessions { sessions }
215        }
216        DaemonRequest::Stats => {
217            let sessions = gpu_manager.list().await;
218            let (gpu_lines, vram_used_by_pid) = nvidia_stats_bundle().await;
219            let note = if gpu_lines.is_empty() {
220                Some(
221                    "tokens/s are shown only when the session reports them (RegisterSession)."
222                        .to_string(),
223                )
224            } else {
225                None
226            };
227            let am = model_loader.get().await;
228            let active_model = if am.model_ref.is_empty() {
229                None
230            } else {
231                Some(crate::protocol::ActiveModelInfo {
232                    model_ref: am.model_ref,
233                    quantization: am.quantization,
234                })
235            };
236            DaemonResponse::Stats {
237                sessions,
238                gpu_lines,
239                note,
240                active_model,
241                vram_used_by_pid,
242            }
243        }
244        DaemonRequest::SwapModel {
245            model_ref,
246            quantization,
247        } => {
248            model_loader
249                .swap(model_ref.clone(), quantization.clone())
250                .await;
251            let swap_path = dirs::home_dir()
252                .unwrap_or_else(|| std::path::PathBuf::from("."))
253                .join(".neuronbox")
254                .join("swap_signal.json");
255            let payload = serde_json::json!({
256                "signal_version": SWAP_SIGNAL_FILE_VERSION,
257                "model_ref": model_ref.clone(),
258                "quantization": quantization.clone(),
259                "ts": std::time::SystemTime::now()
260                    .duration_since(std::time::UNIX_EPOCH)
261                    .map(|d| d.as_secs())
262                    .unwrap_or(0),
263            });
264            if let Ok(bytes) = serde_json::to_vec(&payload) {
265                let _ = tokio::fs::write(&swap_path, bytes).await;
266            }
267            DaemonResponse::Swapped {
268                model_ref,
269                quantization,
270            }
271        }
272    }
273}
274
275async fn nvidia_stats_bundle() -> (Vec<String>, std::collections::HashMap<u32, u64>) {
276    tokio::task::spawn_blocking(|| {
277        let lines = compute_apps_display_lines();
278        let map = compute_apps_pid_memory_mb().unwrap_or_default();
279        (lines, map)
280    })
281    .await
282    .unwrap_or_else(|_| (Vec::new(), std::collections::HashMap::new()))
283}