1use std::path::PathBuf;
4use std::process::{Child, Command, Stdio};
5use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
6use std::sync::Mutex;
7use std::time::{Duration, Instant};
8
9use fs4::fs_std::FileExt;
10use nng::options::Options;
11use nng::{Protocol, Socket};
12use serde_json::json;
13
14use crate::engine::fetch::EngineFetcher;
15use crate::engine::multiprocess::{
16 pid_is_alive, pid_is_engine, read_pid_file, reap_engine_process, stop_engine_process,
17 write_pid_file,
18};
19use crate::error::{Error, Result};
20use crate::ipc::endpoints::{management_url, response_url, EVENT_TOPIC_PREFIX};
21
22const DEFAULT_STARTUP_TIMEOUT_SECS: u64 = 60;
23const LOCK_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
24const LOCK_BACKOFF_INITIAL: Duration = Duration::from_millis(10);
25const LOCK_BACKOFF_MAX: Duration = Duration::from_millis(250);
26
27fn remove_if_exists(path: &std::path::Path) {
29 if let Err(e) = std::fs::remove_file(path) {
30 if e.kind() != std::io::ErrorKind::NotFound {
31 tracing::warn!("Failed to remove {}: {}", path.display(), e);
32 }
33 }
34}
35
36fn lock_exclusive_with_timeout(lock_file: &std::fs::File) -> Result<()> {
37 let deadline = Instant::now() + LOCK_ACQUIRE_TIMEOUT;
38 let mut sleep = LOCK_BACKOFF_INITIAL;
39
40 loop {
41 match lock_file.try_lock_exclusive() {
42 Ok(()) => return Ok(()),
43 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
44 if Instant::now() >= deadline {
45 return Err(Error::LockFailed(format!(
46 "Timed out acquiring engine lock after {:?}",
47 LOCK_ACQUIRE_TIMEOUT
48 )));
49 }
50 std::thread::sleep(sleep);
51 sleep = std::cmp::min(sleep * 2, LOCK_BACKOFF_MAX);
52 }
53 Err(e) => return Err(Error::LockFailed(e.to_string())),
54 }
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct EnginePaths {
61 pub cache_dir: PathBuf,
62 pub pid_file: PathBuf,
63 pub lock_file: PathBuf,
64 pub ready_file: PathBuf,
65 pub engine_log_file: PathBuf,
66 pub client_log_file: PathBuf,
67}
68
69impl EnginePaths {
70 pub fn new() -> Result<Self> {
71 let cache_dir = dirs::cache_dir()
72 .ok_or_else(|| Error::Internal("Cannot determine cache directory".into()))?
73 .join("com.theproxycompany");
74
75 Ok(Self {
76 pid_file: cache_dir.join("engine.pid"),
77 lock_file: cache_dir.join("engine.lock"),
78 ready_file: cache_dir.join("engine.ready"),
79 engine_log_file: cache_dir.join("engine.log"),
80 client_log_file: cache_dir.join("client.log"),
81 cache_dir,
82 })
83 }
84}
85
86struct GlobalContext {
88 ref_count: AtomicU32,
89 initialized: AtomicBool,
90 pid_file: Mutex<Option<PathBuf>>,
91}
92
93static GLOBAL_CONTEXT: GlobalContext = GlobalContext {
94 ref_count: AtomicU32::new(0),
95 initialized: AtomicBool::new(false),
96 pid_file: Mutex::new(None),
97};
98
99pub(crate) fn current_engine_pid_file() -> Option<PathBuf> {
100 GLOBAL_CONTEXT
101 .pid_file
102 .lock()
103 .unwrap_or_else(|e| e.into_inner())
104 .clone()
105}
106
107fn set_current_engine_pid_file(pid_file: Option<PathBuf>) {
108 *GLOBAL_CONTEXT
109 .pid_file
110 .lock()
111 .unwrap_or_else(|e| e.into_inner()) = pid_file;
112}
113
114pub struct InferenceEngine {
122 paths: EnginePaths,
123 fetcher: EngineFetcher,
124 startup_timeout: Duration,
125 lease_active: bool,
126 closed: bool,
127 launch_process: Option<Child>,
128}
129
130impl InferenceEngine {
131 pub async fn new() -> Result<Self> {
133 Self::with_options(EnginePaths::new()?, None).await
134 }
135
136 pub async fn with_options(
138 paths: EnginePaths,
139 startup_timeout: Option<Duration>,
140 ) -> Result<Self> {
141 let fetcher = EngineFetcher::new();
142 let startup_timeout =
143 startup_timeout.unwrap_or(Duration::from_secs(DEFAULT_STARTUP_TIMEOUT_SECS));
144
145 let mut engine = Self {
146 paths,
147 fetcher,
148 startup_timeout,
149 lease_active: false,
150 closed: false,
151 launch_process: None,
152 };
153
154 engine.acquire_lease().await?;
155
156 Ok(engine)
157 }
158
159 pub fn close(&mut self) -> Result<()> {
164 if self.closed {
165 return Ok(());
166 }
167
168 let should_release = if self.lease_active {
169 let prev = GLOBAL_CONTEXT.ref_count.fetch_sub(1, Ordering::SeqCst);
170 prev == 1 } else {
172 false
173 };
174
175 if !self.lease_active || !should_release {
176 self.closed = true;
177 self.lease_active = false;
178 return Ok(());
179 }
180
181 let lock_file = std::fs::File::create(&self.paths.lock_file)?;
184 lock_exclusive_with_timeout(&lock_file)?;
185
186 let engine_pid = read_pid_file(&self.paths.pid_file);
187 let engine_running = engine_pid.map(pid_is_alive).unwrap_or(false);
188
189 if engine_running {
190 if let Err(e) =
191 self.send_client_lifecycle_command("client_deregister", Duration::from_secs(5))
192 {
193 tracing::warn!(
194 "Failed to deregister orchard-rs client PID {} from PIE: {}",
195 std::process::id(),
196 e
197 );
198 }
199 } else {
200 remove_if_exists(&self.paths.pid_file);
201 remove_if_exists(&self.paths.ready_file);
202 remove_if_exists(&self.paths.cache_dir.join("engine.refs"));
203 }
204
205 drop(lock_file);
206
207 self.lease_active = false;
208 self.closed = true;
209 GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
210 set_current_engine_pid_file(None);
211
212 Ok(())
213 }
214
215 pub fn shutdown(timeout: Duration) -> Result<()> {
217 let paths = EnginePaths::new()?;
218
219 let lock_file = std::fs::File::create(&paths.lock_file)?;
220 lock_exclusive_with_timeout(&lock_file)?;
221
222 let cleanup_all = |paths: &EnginePaths| {
224 remove_if_exists(&paths.pid_file);
225 remove_if_exists(&paths.ready_file);
226 remove_if_exists(&paths.cache_dir.join("engine.refs"));
227 let ipc_dir = paths.cache_dir.join("ipc");
228 if ipc_dir.exists() {
229 remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
230 remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
231 remove_if_exists(&ipc_dir.join("pie_management.ipc"));
232 }
233 };
234
235 let pid = match read_pid_file(&paths.pid_file) {
236 Some(p) if pid_is_alive(p) => p,
237 _ => {
238 tracing::info!("Engine is not running. Cleaning up stale files.");
239 cleanup_all(&paths);
240 GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
241 set_current_engine_pid_file(None);
242 return Ok(());
243 }
244 };
245
246 if !pid_is_engine(pid) {
247 tracing::warn!(
248 "PID {} does not belong to proxy_inference_engine; cleaning stale files.",
249 pid
250 );
251 cleanup_all(&paths);
252 GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
253 set_current_engine_pid_file(None);
254 return Ok(());
255 }
256 tracing::info!("Sending shutdown signal to engine process {}", pid);
257
258 if stop_engine_process(pid, timeout) {
259 cleanup_all(&paths);
260 reap_engine_process(pid);
261 GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
262 set_current_engine_pid_file(None);
263 tracing::info!("Engine process {} terminated gracefully", pid);
264 Ok(())
265 } else {
266 Err(Error::ShutdownFailed(format!(
267 "Failed to stop engine process {}",
268 pid
269 )))
270 }
271 }
272
273 async fn acquire_lease(&mut self) -> Result<()> {
274 if self.closed || self.lease_active {
275 return Ok(());
276 }
277
278 std::fs::create_dir_all(&self.paths.cache_dir)?;
280
281 let lock_file = std::fs::File::create(&self.paths.lock_file)?;
283 lock_exclusive_with_timeout(&lock_file)?;
284
285 let engine_pid = read_pid_file(&self.paths.pid_file);
286 let engine_running = engine_pid
287 .map(|pid| pid_is_alive(pid) && pid_is_engine(pid))
288 .unwrap_or(false);
289 let mut launched_engine = false;
290
291 if !engine_running {
293 tracing::debug!("Inference engine not running. Launching new instance.");
294
295 remove_if_exists(&self.paths.pid_file);
297 remove_if_exists(&self.paths.ready_file);
298 remove_if_exists(&self.paths.cache_dir.join("engine.refs"));
299
300 let ipc_dir = self.paths.cache_dir.join("ipc");
302 if ipc_dir.exists() {
303 remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
304 remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
305 remove_if_exists(&ipc_dir.join("pie_management.ipc"));
306 }
307
308 self.launch_engine().await?;
309 self.wait_for_engine_ready().await?;
310 launched_engine = true;
311 }
312
313 if let Err(e) =
314 self.send_client_lifecycle_command("client_register", Duration::from_secs(5))
315 {
316 if launched_engine {
317 if let Some(pid) = read_pid_file(&self.paths.pid_file) {
318 if let Err(stop_err) = self.stop_engine_locked(pid) {
319 tracing::warn!(
320 "Failed to clean up newly launched engine {} after register failure: {}",
321 pid,
322 stop_err
323 );
324 }
325 }
326 }
327 return Err(e);
328 }
329
330 GLOBAL_CONTEXT.ref_count.fetch_add(1, Ordering::SeqCst);
332 GLOBAL_CONTEXT.initialized.store(true, Ordering::SeqCst);
333 set_current_engine_pid_file(Some(self.paths.pid_file.clone()));
334
335 drop(lock_file);
336
337 self.lease_active = true;
338 Ok(())
339 }
340
341 fn send_client_lifecycle_command(&self, command_type: &str, timeout: Duration) -> Result<()> {
342 let socket = Socket::new(Protocol::Req0)?;
343 socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
344 socket.set_opt::<nng::options::SendTimeout>(Some(timeout))?;
345 socket.dial(&management_url())?;
346
347 let payload = json!({
348 "type": command_type,
349 "client_pid": std::process::id(),
350 });
351 let data = serde_json::to_vec(&payload)?;
352 let msg = nng::Message::from(data.as_slice());
353 socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
354
355 let response = socket.recv()?;
356 let json: serde_json::Value = serde_json::from_slice(&response)?;
357 let status = json.get("status").and_then(|value| value.as_str());
358 if matches!(status, Some("ok") | Some("accepted")) {
359 return Ok(());
360 }
361
362 let message = json
363 .get("message")
364 .and_then(|value| value.as_str())
365 .unwrap_or("unknown error");
366 Err(Error::Other(format!(
367 "Engine rejected {} for PID {}: {}",
368 command_type,
369 std::process::id(),
370 message
371 )))
372 }
373
374 async fn launch_engine(&mut self) -> Result<()> {
375 let engine_path = self.fetcher.get_engine_path().await?;
376
377 tracing::info!("Launching PIE from {:?}", engine_path);
378
379 let log_file = std::fs::File::create(&self.paths.engine_log_file)?;
381
382 let child = Command::new(&engine_path)
383 .stdout(Stdio::from(log_file.try_clone()?))
384 .stderr(Stdio::from(log_file))
385 .spawn()
386 .map_err(|e| Error::StartupFailed(format!("Failed to spawn engine: {}", e)))?;
387
388 self.launch_process = Some(child);
389 Ok(())
390 }
391
392 async fn wait_for_engine_ready(&self) -> Result<()> {
393 tracing::info!("Waiting for telemetry heartbeat from engine...");
394
395 let telemetry_topic = [EVENT_TOPIC_PREFIX, b"telemetry"].concat();
397 let response_url = response_url();
398
399 let socket = nng::Socket::new(nng::Protocol::Sub0)
400 .map_err(|e| Error::StartupFailed(format!("Failed to create Sub0 socket: {}", e)))?;
401
402 socket
403 .set_opt::<nng::options::protocol::pubsub::Subscribe>(telemetry_topic.clone())
404 .map_err(|e| Error::StartupFailed(format!("Failed to subscribe: {}", e)))?;
405
406 socket
407 .set_opt::<nng::options::RecvTimeout>(Some(Duration::from_millis(250)))
408 .map_err(|e| Error::StartupFailed(format!("Failed to set timeout: {}", e)))?;
409
410 socket
412 .dial_async(&response_url)
413 .map_err(|e| Error::StartupFailed(format!("Failed to dial {}: {}", response_url, e)))?;
414
415 let deadline = Instant::now() + self.startup_timeout;
416
417 while Instant::now() < deadline {
418 if let Some(ref child) = self.launch_process {
420 if !pid_is_alive(child.id()) {
421 return Err(Error::StartupFailed(
422 "Engine process exited before signaling readiness; check the engine log"
423 .into(),
424 ));
425 }
426 }
427
428 let msg = match socket.recv() {
430 Ok(msg) => msg,
431 Err(nng::Error::TimedOut) => continue,
432 Err(e) => {
433 tracing::debug!("Error receiving telemetry: {}", e);
434 continue;
435 }
436 };
437
438 let bytes = msg.as_slice();
440 let parts: Vec<&[u8]> = bytes.splitn(2, |&b| b == 0).collect();
441 if parts.len() < 2 {
442 tracing::warn!("Discarding malformed event message while waiting for telemetry");
443 continue;
444 }
445
446 let (topic_part, json_body) = (parts[0], parts[1]);
447 if topic_part != telemetry_topic.as_slice() {
448 tracing::debug!(
449 "Ignoring unexpected startup topic '{}'",
450 String::from_utf8_lossy(topic_part)
451 );
452 continue;
453 }
454
455 let payload: serde_json::Value = match serde_json::from_slice(json_body) {
457 Ok(v) => v,
458 Err(e) => {
459 tracing::warn!("Discarding malformed telemetry payload: {}", e);
460 continue;
461 }
462 };
463
464 let engine_pid = payload
466 .get("health")
467 .and_then(|h| h.get("pid"))
468 .and_then(|p| p.as_u64())
469 .map(|p| p as u32);
470
471 match engine_pid {
472 Some(pid) if pid > 0 => {
473 if let Err(e) = write_pid_file(&self.paths.pid_file, pid) {
474 tracing::warn!("Failed to write PID file: {}", e);
475 }
476 tracing::info!("Received telemetry heartbeat. Engine PID {} recorded.", pid);
477 return Ok(());
478 }
479 _ => {
480 tracing::warn!(
481 "Telemetry payload missing valid PID; waiting for next heartbeat"
482 );
483 continue;
484 }
485 }
486 }
487
488 Err(Error::StartupFailed(format!(
489 "Timed out after {:?}s waiting for telemetry heartbeat from engine",
490 self.startup_timeout.as_secs()
491 )))
492 }
493
494 fn stop_engine_locked(&mut self, pid: u32) -> Result<()> {
495 let cleanup_files = || {
496 remove_if_exists(&self.paths.pid_file);
497 remove_if_exists(&self.paths.ready_file);
498 let ipc_dir = self.paths.cache_dir.join("ipc");
499 if ipc_dir.exists() {
500 remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
501 remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
502 remove_if_exists(&ipc_dir.join("pie_management.ipc"));
503 }
504 };
505
506 if !pid_is_alive(pid) {
507 tracing::debug!("Engine PID {} already exited", pid);
508 cleanup_files();
509 return Ok(());
510 }
511
512 if !pid_is_engine(pid) {
513 tracing::warn!(
514 "PID {} does not belong to proxy_inference_engine; cleaning stale files.",
515 pid
516 );
517 cleanup_files();
518 return Ok(());
519 }
520
521 if !stop_engine_process(pid, Duration::from_secs(5)) {
522 return Err(Error::ShutdownFailed(format!(
523 "Failed to stop engine PID {}",
524 pid
525 )));
526 }
527
528 reap_engine_process(pid);
529 cleanup_files();
530
531 tracing::info!("Engine PID {} stopped", pid);
532 Ok(())
533 }
534
535 pub fn generate_response_channel_id() -> u64 {
540 use rand::Rng;
541
542 let pid = std::process::id() as u64 & 0xFFFFFFFF;
543 let random: u32 = rand::thread_rng().gen();
544
545 let channel_id = (pid << 32) | (random as u64);
546 if channel_id == 0 {
547 1
548 } else {
549 channel_id
550 }
551 }
552}
553
554impl Drop for InferenceEngine {
555 fn drop(&mut self) {
556 if !self.closed {
557 if let Err(e) = self.close() {
558 tracing::error!("Failed to close InferenceEngine: {}", e);
559 }
560 }
561 }
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567
568 #[test]
569 fn test_engine_paths() {
570 let paths = EnginePaths::new().expect("cache dir should be available");
571 assert!(paths
572 .cache_dir
573 .to_string_lossy()
574 .contains("com.theproxycompany"));
575 }
576
577 #[test]
578 fn test_generate_channel_id_uniqueness() {
579 use std::collections::HashSet;
580
581 let ids: HashSet<u64> = (0..1000)
583 .map(|_| InferenceEngine::generate_response_channel_id())
584 .collect();
585
586 assert_eq!(
588 ids.len(),
589 1000,
590 "Channel IDs must be unique across rapid calls"
591 );
592
593 assert!(!ids.contains(&0), "Channel ID must never be zero");
595
596 let expected_pid = std::process::id() as u64 & 0xFFFFFFFF;
598 for id in &ids {
599 let id_pid = id >> 32;
600 assert_eq!(id_pid, expected_pid, "Upper 32 bits must be current PID");
601 }
602 }
603}