1use std::path::Path;
4
5use anyhow::{Context, Result};
6use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
7use tokio::net::{UnixListener, UnixStream};
8
9use crate::gpu_manager::GpuManager;
10use crate::host::{compute_apps_display_lines, compute_apps_pid_memory_mb};
11use crate::model_loader::ModelLoader;
12use crate::protocol::{DaemonRequest, DaemonResponse, SessionInfo, SWAP_SIGNAL_FILE_VERSION};
13use crate::vram_watch;
14
15const PROTOCOL_VERSION: u32 = 1;
16
17pub const MAX_REQUEST_LINE_BYTES: usize = 256 * 1024;
19
20fn vram_watch_disabled() -> bool {
21 std::env::var_os("NEURONBOX_DISABLE_VRAM_WATCH")
22 .map(|v| {
23 let s = v.to_string_lossy().to_ascii_lowercase();
24 matches!(s.as_str(), "1" | "true" | "yes")
25 })
26 .unwrap_or(false)
27}
28
29enum RequestLineError {
30 TooLong,
31 BadUtf8,
32 Io(std::io::Error),
33}
34
35async fn read_request_line<R: tokio::io::AsyncRead + Unpin>(
36 reader: &mut BufReader<R>,
37) -> Result<Option<String>, RequestLineError> {
38 let mut line = Vec::new();
39 loop {
40 if line.len() >= MAX_REQUEST_LINE_BYTES {
41 return Err(RequestLineError::TooLong);
42 }
43 let mut b = [0u8; 1];
44 let n = reader.read(&mut b).await.map_err(RequestLineError::Io)?;
45 if n == 0 {
46 return if line.is_empty() {
47 Ok(None)
48 } else {
49 Err(RequestLineError::Io(std::io::Error::new(
50 std::io::ErrorKind::UnexpectedEof,
51 "EOF before newline",
52 )))
53 };
54 }
55 if b[0] == b'\n' {
56 break;
57 }
58 line.push(b[0]);
59 }
60 String::from_utf8(line)
61 .map(Some)
62 .map_err(|_| RequestLineError::BadUtf8)
63}
64
65pub async fn run_socket_server(
66 socket_path: &Path,
67 gpu_manager: GpuManager,
68 model_loader: ModelLoader,
69) -> Result<()> {
70 if socket_path.exists() {
71 std::fs::remove_file(socket_path).ok();
72 }
73 if let Some(dir) = socket_path.parent() {
74 std::fs::create_dir_all(dir).with_context(|| format!("create_dir_all {:?}", dir))?;
75 }
76 let listener = UnixListener::bind(socket_path)
77 .with_context(|| format!("bind unix socket {:?}", socket_path))?;
78
79 #[cfg(unix)]
81 {
82 use std::os::unix::fs::PermissionsExt;
83 std::fs::set_permissions(socket_path, std::fs::Permissions::from_mode(0o600))
84 .with_context(|| format!("chmod 600 {:?}", socket_path))?;
85 }
86
87 if !vram_watch_disabled() {
88 let gm_watch = gpu_manager.clone();
89 tokio::spawn(vram_watch::run_soft_vram_enforcement(gm_watch));
90 }
91
92 loop {
93 let (stream, _) = listener.accept().await?;
94 let gm = gpu_manager.clone();
95 let ml = model_loader.clone();
96 tokio::spawn(async move {
97 if let Err(e) = handle_connection(stream, gm, ml).await {
98 tracing::warn!("connection error: {e:#}");
99 }
100 });
101 }
102}
103
104async fn handle_connection(
105 stream: UnixStream,
106 gpu_manager: GpuManager,
107 model_loader: ModelLoader,
108) -> Result<()> {
109 let (read_half, mut write_half) = stream.into_split();
110 let mut reader = BufReader::new(read_half);
111
112 loop {
113 let line = match read_request_line(&mut reader).await {
114 Ok(None) => break,
115 Ok(Some(s)) => s,
116 Err(RequestLineError::TooLong) => {
117 let err = DaemonResponse::Error {
118 message: format!(
119 "request line exceeds maximum size ({MAX_REQUEST_LINE_BYTES} bytes)"
120 ),
121 };
122 write_response(&mut write_half, &err).await?;
123 break;
124 }
125 Err(RequestLineError::BadUtf8) => {
126 let err = DaemonResponse::Error {
127 message: "invalid UTF-8 in request line".to_string(),
128 };
129 write_response(&mut write_half, &err).await?;
130 break;
131 }
132 Err(RequestLineError::Io(e)) => return Err(e.into()),
133 };
134 let trimmed = line.trim();
135 if trimmed.is_empty() {
136 continue;
137 }
138
139 let req: DaemonRequest = match serde_json::from_str(trimmed) {
140 Ok(r) => r,
141 Err(e) => {
142 let err = DaemonResponse::Error {
143 message: format!("invalid JSON request: {e}"),
144 };
145 write_response(&mut write_half, &err).await?;
146 continue;
147 }
148 };
149
150 let resp = dispatch(req, &gpu_manager, &model_loader).await;
151 write_response(&mut write_half, &resp).await?;
152 }
153 Ok(())
154}
155
156async fn write_response(
157 w: &mut tokio::net::unix::OwnedWriteHalf,
158 resp: &DaemonResponse,
159) -> Result<()> {
160 let mut s = serde_json::to_string(resp)?;
161 s.push('\n');
162 w.write_all(s.as_bytes()).await?;
163 Ok(())
164}
165
166async fn dispatch(
167 req: DaemonRequest,
168 gpu_manager: &GpuManager,
169 model_loader: &ModelLoader,
170) -> DaemonResponse {
171 match req {
172 DaemonRequest::Ping => DaemonResponse::Pong,
173 DaemonRequest::Version { v } => {
174 if v != PROTOCOL_VERSION {
175 DaemonResponse::Error {
176 message: format!("protocol mismatch: client {v}, daemon {PROTOCOL_VERSION}"),
177 }
178 } else {
179 DaemonResponse::VersionInfo {
180 v: PROTOCOL_VERSION,
181 }
182 }
183 }
184 DaemonRequest::RegisterSession {
185 name,
186 estimated_vram_mb,
187 pid,
188 tokens_per_sec,
189 model_dir,
190 } => {
191 gpu_manager
192 .register(SessionInfo {
193 name,
194 pid,
195 estimated_vram_mb,
196 tokens_per_sec,
197 model_dir,
198 })
199 .await;
200 DaemonResponse::Registered { pid }
201 }
202 DaemonRequest::UnregisterSession { pid } => {
203 let ok = gpu_manager.unregister(pid).await;
204 if ok {
205 DaemonResponse::Unregistered
206 } else {
207 DaemonResponse::Error {
208 message: format!("pid {pid} not registered"),
209 }
210 }
211 }
212 DaemonRequest::ListSessions => {
213 let sessions = gpu_manager.list().await;
214 DaemonResponse::Sessions { sessions }
215 }
216 DaemonRequest::Stats => {
217 let sessions = gpu_manager.list().await;
218 let (gpu_lines, vram_used_by_pid) = nvidia_stats_bundle().await;
219 let note = if gpu_lines.is_empty() {
220 Some(
221 "tokens/s are shown only when the session reports them (RegisterSession)."
222 .to_string(),
223 )
224 } else {
225 None
226 };
227 let am = model_loader.get().await;
228 let active_model = if am.model_ref.is_empty() {
229 None
230 } else {
231 Some(crate::protocol::ActiveModelInfo {
232 model_ref: am.model_ref,
233 quantization: am.quantization,
234 })
235 };
236 DaemonResponse::Stats {
237 sessions,
238 gpu_lines,
239 note,
240 active_model,
241 vram_used_by_pid,
242 }
243 }
244 DaemonRequest::SwapModel {
245 model_ref,
246 quantization,
247 } => {
248 model_loader
249 .swap(model_ref.clone(), quantization.clone())
250 .await;
251 let swap_path = dirs::home_dir()
252 .unwrap_or_else(|| std::path::PathBuf::from("."))
253 .join(".neuronbox")
254 .join("swap_signal.json");
255 let payload = serde_json::json!({
256 "signal_version": SWAP_SIGNAL_FILE_VERSION,
257 "model_ref": model_ref.clone(),
258 "quantization": quantization.clone(),
259 "ts": std::time::SystemTime::now()
260 .duration_since(std::time::UNIX_EPOCH)
261 .map(|d| d.as_secs())
262 .unwrap_or(0),
263 });
264 if let Ok(bytes) = serde_json::to_vec(&payload) {
265 let _ = tokio::fs::write(&swap_path, bytes).await;
266 }
267 DaemonResponse::Swapped {
268 model_ref,
269 quantization,
270 }
271 }
272 }
273}
274
275async fn nvidia_stats_bundle() -> (Vec<String>, std::collections::HashMap<u32, u64>) {
276 tokio::task::spawn_blocking(|| {
277 let lines = compute_apps_display_lines();
278 let map = compute_apps_pid_memory_mb().unwrap_or_default();
279 (lines, map)
280 })
281 .await
282 .unwrap_or_else(|_| (Vec::new(), std::collections::HashMap::new()))
283}