Skip to main content

neuronbox_runtime/
server.rs

1//! Unix socket server (newline-delimited JSON).
2
3use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::gpu_manager::GpuManager;
10use crate::host::{compute_apps_display_lines, compute_apps_pid_memory_mb};
11use crate::model_loader::ModelLoader;
12use crate::protocol::{DaemonRequest, DaemonResponse, SessionInfo, SWAP_SIGNAL_FILE_VERSION};
13use crate::vram_watch;
14
15const PROTOCOL_VERSION: u32 = 1;
16
17/// Max bytes per incoming line (excluding the newline). Larger lines are rejected with an error response.
18pub const MAX_REQUEST_LINE_BYTES: usize = 256 * 1024;
19
20fn vram_watch_disabled() -> bool {
21    std::env::var_os("NEURONBOX_DISABLE_VRAM_WATCH")
22        .map(|v| {
23            let s = v.to_string_lossy().to_ascii_lowercase();
24            matches!(s.as_str(), "1" | "true" | "yes")
25        })
26        .unwrap_or(false)
27}
28
29enum RequestLineError {
30    TooLong,
31    BadUtf8,
32    Io(std::io::Error),
33}
34
35async fn read_request_line<R: tokio::io::AsyncRead + Unpin>(
36    reader: &mut BufReader<R>,
37) -> Result<Option<String>, RequestLineError> {
38    let mut line = Vec::new();
39    loop {
40        if line.len() >= MAX_REQUEST_LINE_BYTES {
41            return Err(RequestLineError::TooLong);
42        }
43        let mut b = [0u8; 1];
44        let n = reader.read(&mut b).await.map_err(RequestLineError::Io)?;
45        if n == 0 {
46            return if line.is_empty() {
47                Ok(None)
48            } else {
49                Err(RequestLineError::Io(std::io::Error::new(
50                    std::io::ErrorKind::UnexpectedEof,
51                    "EOF before newline",
52                )))
53            };
54        }
55        if b[0] == b'\n' {
56            break;
57        }
58        line.push(b[0]);
59    }
60    String::from_utf8(line)
61        .map(Some)
62        .map_err(|_| RequestLineError::BadUtf8)
63}
64
65pub async fn run_socket_server(
66    socket_path: &Path,
67    gpu_manager: GpuManager,
68    model_loader: ModelLoader,
69) -> Result<()> {
70    if socket_path.exists() {
71        std::fs::remove_file(socket_path).ok();
72    }
73    if let Some(dir) = socket_path.parent() {
74        std::fs::create_dir_all(dir).with_context(|| format!("create_dir_all {:?}", dir))?;
75    }
76    let listener = UnixListener::bind(socket_path)
77        .with_context(|| format!("bind unix socket {:?}", socket_path))?;
78
79    // Restrict socket permissions to owner only (security: local user isolation)
80    #[cfg(unix)]
81    {
82        use std::os::unix::fs::PermissionsExt;
83        std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600))
84            .with_context(|| format!("chmod 600 {:?}", socket_path))?;
85    }
86
87    if !vram_watch_disabled() {
88        let gm_watch = gpu_manager.clone();
89        tokio::spawn(vram_watch::run_soft_vram_enforcement(gm_watch));
90    }
91
92    loop {
93        let (stream, _) = listener.accept().await?;
94        let gm = gpu_manager.clone();
95        let ml = model_loader.clone();
96        tokio::spawn(async move {
97            if let Err(e) = handle_connection(stream, gm, ml).await {
98                tracing::warn!("connection error: {e:#}");
99            }
100        });
101    }
102}
103
104async fn handle_connection(
105    stream: UnixStream,
106    gpu_manager: GpuManager,
107    model_loader: ModelLoader,
108) -> Result<()> {
109    let (read_half, mut write_half) = stream.into_split();
110    let mut reader = BufReader::new(read_half);
111
112    loop {
113        let line = match read_request_line(&mut reader).await {
114            Ok(None) => break,
115            Ok(Some(s)) => s,
116            Err(RequestLineError::TooLong) => {
117                let err = DaemonResponse::Error {
118                    message: format!(
119                        "request line exceeds maximum size ({MAX_REQUEST_LINE_BYTES} bytes)"
120                    ),
121                };
122                write_response(&mut write_half, &err).await?;
123                break;
124            }
125            Err(RequestLineError::BadUtf8) => {
126                let err = DaemonResponse::Error {
127                    message: "invalid UTF-8 in request line".to_string(),
128                };
129                write_response(&mut write_half, &err).await?;
130                break;
131            }
132            Err(RequestLineError::Io(e)) => return Err(e.into()),
133        };
134        let trimmed = line.trim();
135        if trimmed.is_empty() {
136            continue;
137        }
138
139        let req: DaemonRequest = match serde_json::from_str(trimmed) {
140            Ok(r) => r,
141            Err(e) => {
142                let err = DaemonResponse::Error {
143                    message: format!("invalid JSON request: {e}"),
144                };
145                write_response(&mut write_half, &err).await?;
146                continue;
147            }
148        };
149
150        let resp = dispatch(req, &gpu_manager, &model_loader).await;
151        write_response(&mut write_half, &resp).await?;
152    }
153    Ok(())
154}
155
156async fn write_response(
157    w: &mut tokio::net::unix::OwnedWriteHalf,
158    resp: &DaemonResponse,
159) -> Result<()> {
160    let mut s = serde_json::to_string(resp)?;
161    s.push('\n');
162    w.write_all(s.as_bytes()).await?;
163    Ok(())
164}
165
166async fn dispatch(
167    req: DaemonRequest,
168    gpu_manager: &GpuManager,
169    model_loader: &ModelLoader,
170) -> DaemonResponse {
171    match req {
172        DaemonRequest::Ping => DaemonResponse::Pong,
173        DaemonRequest::Version { v } => {
174            if v != PROTOCOL_VERSION {
175                DaemonResponse::Error {
176                    message: format!("protocol mismatch: client {v}, daemon {PROTOCOL_VERSION}"),
177                }
178            } else {
179                DaemonResponse::VersionInfo {
180                    v: PROTOCOL_VERSION,
181                }
182            }
183        }
184        DaemonRequest::RegisterSession {
185            name,
186            estimated_vram_mb,
187            pid,
188            tokens_per_sec,
189        } => {
190            gpu_manager
191                .register(SessionInfo {
192                    name,
193                    pid,
194                    estimated_vram_mb,
195                    tokens_per_sec,
196                })
197                .await;
198            DaemonResponse::Registered { pid }
199        }
200        DaemonRequest::UnregisterSession { pid } => {
201            let ok = gpu_manager.unregister(pid).await;
202            if ok {
203                DaemonResponse::Unregistered
204            } else {
205                DaemonResponse::Error {
206                    message: format!("pid {pid} not registered"),
207                }
208            }
209        }
210        DaemonRequest::ListSessions => {
211            let sessions = gpu_manager.list().await;
212            DaemonResponse::Sessions { sessions }
213        }
214        DaemonRequest::Stats => {
215            let sessions = gpu_manager.list().await;
216            let (gpu_lines, vram_used_by_pid) = nvidia_stats_bundle().await;
217            let note = if gpu_lines.is_empty() {
218                Some(
219                    "tokens/s are shown only when the session reports them (RegisterSession)."
220                        .to_string(),
221                )
222            } else {
223                None
224            };
225            let am = model_loader.get().await;
226            let active_model = if am.model_ref.is_empty() {
227                None
228            } else {
229                Some(crate::protocol::ActiveModelInfo {
230                    model_ref: am.model_ref,
231                    quantization: am.quantization,
232                })
233            };
234            DaemonResponse::Stats {
235                sessions,
236                gpu_lines,
237                note,
238                active_model,
239                vram_used_by_pid,
240            }
241        }
242        DaemonRequest::SwapModel {
243            model_ref,
244            quantization,
245        } => {
246            model_loader
247                .swap(model_ref.clone(), quantization.clone())
248                .await;
249            let swap_path = dirs::home_dir()
250                .unwrap_or_else(|| std::path::PathBuf::from("."))
251                .join(".neuronbox")
252                .join("swap_signal.json");
253            let payload = serde_json::json!({
254                "signal_version": SWAP_SIGNAL_FILE_VERSION,
255                "model_ref": model_ref.clone(),
256                "quantization": quantization.clone(),
257                "ts": std::time::SystemTime::now()
258                    .duration_since(std::time::UNIX_EPOCH)
259                    .map(|d| d.as_secs())
260                    .unwrap_or(0),
261            });
262            if let Ok(bytes) = serde_json::to_vec(&payload) {
263                let _ = tokio::fs::write(&swap_path, bytes).await;
264            }
265            DaemonResponse::Swapped {
266                model_ref,
267                quantization,
268            }
269        }
270    }
271}
272
273async fn nvidia_stats_bundle() -> (Vec<String>, std::collections::HashMap<u32, u64>) {
274    tokio::task::spawn_blocking(|| {
275        let lines = compute_apps_display_lines();
276        let map = compute_apps_pid_memory_mb().unwrap_or_default();
277        (lines, map)
278    })
279    .await
280    .unwrap_or_else(|_| (Vec::new(), std::collections::HashMap::new()))
281}