neuronbox_runtime/
server.rs1use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::gpu_manager::GpuManager;
10use crate::host::{compute_apps_display_lines, compute_apps_pid_memory_mb};
11use crate::model_loader::ModelLoader;
12use crate::protocol::{DaemonRequest, DaemonResponse, SessionInfo, SWAP_SIGNAL_FILE_VERSION};
13use crate::vram_watch;
14
15const PROTOCOL_VERSION: u32 = 1;
16
17pub const MAX_REQUEST_LINE_BYTES: usize = 256 * 1024;
19
20fn vram_watch_disabled() -> bool {
21 std::env::var_os("NEURONBOX_DISABLE_VRAM_WATCH")
22 .map(|v| {
23 let s = v.to_string_lossy().to_ascii_lowercase();
24 matches!(s.as_str(), "1" | "true" | "yes")
25 })
26 .unwrap_or(false)
27}
28
29enum RequestLineError {
30 TooLong,
31 BadUtf8,
32 Io(std::io::Error),
33}
34
35async fn read_request_line<R: tokio::io::AsyncRead + Unpin>(
36 reader: &mut BufReader<R>,
37) -> Result<Option<String>, RequestLineError> {
38 let mut line = Vec::new();
39 loop {
40 if line.len() >= MAX_REQUEST_LINE_BYTES {
41 return Err(RequestLineError::TooLong);
42 }
43 let mut b = [0u8; 1];
44 let n = reader.read(&mut b).await.map_err(RequestLineError::Io)?;
45 if n == 0 {
46 return if line.is_empty() {
47 Ok(None)
48 } else {
49 Err(RequestLineError::Io(std::io::Error::new(
50 std::io::ErrorKind::UnexpectedEof,
51 "EOF before newline",
52 )))
53 };
54 }
55 if b[0] == b'\n' {
56 break;
57 }
58 line.push(b[0]);
59 }
60 String::from_utf8(line)
61 .map(Some)
62 .map_err(|_| RequestLineError::BadUtf8)
63}
64
65pub async fn run_socket_server(
66 socket_path: &Path,
67 gpu_manager: GpuManager,
68 model_loader: ModelLoader,
69) -> Result<()> {
70 if socket_path.exists() {
71 std::fs::remove_file(socket_path).ok();
72 }
73 if let Some(dir) = socket_path.parent() {
74 std::fs::create_dir_all(dir).with_context(|| format!("create_dir_all {:?}", dir))?;
75 }
76 let listener = UnixListener::bind(socket_path)
77 .with_context(|| format!("bind unix socket {:?}", socket_path))?;
78
79 #[cfg(unix)]
81 {
82 use std::os::unix::fs::PermissionsExt;
83 std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600))
84 .with_context(|| format!("chmod 600 {:?}", socket_path))?;
85 }
86
87 if !vram_watch_disabled() {
88 let gm_watch = gpu_manager.clone();
89 tokio::spawn(vram_watch::run_soft_vram_enforcement(gm_watch));
90 }
91
92 loop {
93 let (stream, _) = listener.accept().await?;
94 let gm = gpu_manager.clone();
95 let ml = model_loader.clone();
96 tokio::spawn(async move {
97 if let Err(e) = handle_connection(stream, gm, ml).await {
98 tracing::warn!("connection error: {e:#}");
99 }
100 });
101 }
102}
103
104async fn handle_connection(
105 stream: UnixStream,
106 gpu_manager: GpuManager,
107 model_loader: ModelLoader,
108) -> Result<()> {
109 let (read_half, mut write_half) = stream.into_split();
110 let mut reader = BufReader::new(read_half);
111
112 loop {
113 let line = match read_request_line(&mut reader).await {
114 Ok(None) => break,
115 Ok(Some(s)) => s,
116 Err(RequestLineError::TooLong) => {
117 let err = DaemonResponse::Error {
118 message: format!(
119 "request line exceeds maximum size ({MAX_REQUEST_LINE_BYTES} bytes)"
120 ),
121 };
122 write_response(&mut write_half, &err).await?;
123 break;
124 }
125 Err(RequestLineError::BadUtf8) => {
126 let err = DaemonResponse::Error {
127 message: "invalid UTF-8 in request line".to_string(),
128 };
129 write_response(&mut write_half, &err).await?;
130 break;
131 }
132 Err(RequestLineError::Io(e)) => return Err(e.into()),
133 };
134 let trimmed = line.trim();
135 if trimmed.is_empty() {
136 continue;
137 }
138
139 let req: DaemonRequest = match serde_json::from_str(trimmed) {
140 Ok(r) => r,
141 Err(e) => {
142 let err = DaemonResponse::Error {
143 message: format!("invalid JSON request: {e}"),
144 };
145 write_response(&mut write_half, &err).await?;
146 continue;
147 }
148 };
149
150 let resp = dispatch(req, &gpu_manager, &model_loader).await;
151 write_response(&mut write_half, &resp).await?;
152 }
153 Ok(())
154}
155
156async fn write_response(
157 w: &mut tokio::net::unix::OwnedWriteHalf,
158 resp: &DaemonResponse,
159) -> Result<()> {
160 let mut s = serde_json::to_string(resp)?;
161 s.push('\n');
162 w.write_all(s.as_bytes()).await?;
163 Ok(())
164}
165
166async fn dispatch(
167 req: DaemonRequest,
168 gpu_manager: &GpuManager,
169 model_loader: &ModelLoader,
170) -> DaemonResponse {
171 match req {
172 DaemonRequest::Ping => DaemonResponse::Pong,
173 DaemonRequest::Version { v } => {
174 if v != PROTOCOL_VERSION {
175 DaemonResponse::Error {
176 message: format!("protocol mismatch: client {v}, daemon {PROTOCOL_VERSION}"),
177 }
178 } else {
179 DaemonResponse::VersionInfo {
180 v: PROTOCOL_VERSION,
181 }
182 }
183 }
184 DaemonRequest::RegisterSession {
185 name,
186 estimated_vram_mb,
187 pid,
188 tokens_per_sec,
189 } => {
190 gpu_manager
191 .register(SessionInfo {
192 name,
193 pid,
194 estimated_vram_mb,
195 tokens_per_sec,
196 })
197 .await;
198 DaemonResponse::Registered { pid }
199 }
200 DaemonRequest::UnregisterSession { pid } => {
201 let ok = gpu_manager.unregister(pid).await;
202 if ok {
203 DaemonResponse::Unregistered
204 } else {
205 DaemonResponse::Error {
206 message: format!("pid {pid} not registered"),
207 }
208 }
209 }
210 DaemonRequest::ListSessions => {
211 let sessions = gpu_manager.list().await;
212 DaemonResponse::Sessions { sessions }
213 }
214 DaemonRequest::Stats => {
215 let sessions = gpu_manager.list().await;
216 let (gpu_lines, vram_used_by_pid) = nvidia_stats_bundle().await;
217 let note = if gpu_lines.is_empty() {
218 Some(
219 "tokens/s are shown only when the session reports them (RegisterSession)."
220 .to_string(),
221 )
222 } else {
223 None
224 };
225 let am = model_loader.get().await;
226 let active_model = if am.model_ref.is_empty() {
227 None
228 } else {
229 Some(crate::protocol::ActiveModelInfo {
230 model_ref: am.model_ref,
231 quantization: am.quantization,
232 })
233 };
234 DaemonResponse::Stats {
235 sessions,
236 gpu_lines,
237 note,
238 active_model,
239 vram_used_by_pid,
240 }
241 }
242 DaemonRequest::SwapModel {
243 model_ref,
244 quantization,
245 } => {
246 model_loader
247 .swap(model_ref.clone(), quantization.clone())
248 .await;
249 let swap_path = dirs::home_dir()
250 .unwrap_or_else(|| std::path::PathBuf::from("."))
251 .join(".neuronbox")
252 .join("swap_signal.json");
253 let payload = serde_json::json!({
254 "signal_version": SWAP_SIGNAL_FILE_VERSION,
255 "model_ref": model_ref.clone(),
256 "quantization": quantization.clone(),
257 "ts": std::time::SystemTime::now()
258 .duration_since(std::time::UNIX_EPOCH)
259 .map(|d| d.as_secs())
260 .unwrap_or(0),
261 });
262 if let Ok(bytes) = serde_json::to_vec(&payload) {
263 let _ = tokio::fs::write(&swap_path, bytes).await;
264 }
265 DaemonResponse::Swapped {
266 model_ref,
267 quantization,
268 }
269 }
270 }
271}
272
273async fn nvidia_stats_bundle() -> (Vec<String>, std::collections::HashMap<u32, u64>) {
274 tokio::task::spawn_blocking(|| {
275 let lines = compute_apps_display_lines();
276 let map = compute_apps_pid_memory_mb().unwrap_or_default();
277 (lines, map)
278 })
279 .await
280 .unwrap_or_else(|_| (Vec::new(), std::collections::HashMap::new()))
281}