Skip to main content

orchard/engine/
lifecycle.rs

1//! Inference engine process lifecycle management.
2
3use std::path::PathBuf;
4use std::process::{Child, Command, Stdio};
5use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
6use std::sync::Mutex;
7use std::time::{Duration, Instant};
8
9use fs4::fs_std::FileExt;
10use nng::options::Options;
11use nng::{Protocol, Socket};
12use serde_json::json;
13
14use crate::engine::fetch::EngineFetcher;
15use crate::engine::multiprocess::{
16    pid_is_alive, pid_is_engine, read_pid_file, reap_engine_process, stop_engine_process,
17    write_pid_file,
18};
19use crate::error::{Error, Result};
20use crate::ipc::endpoints::{management_url, response_url, EVENT_TOPIC_PREFIX};
21
22const DEFAULT_STARTUP_TIMEOUT_SECS: u64 = 60;
23const LOCK_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(5);
24const LOCK_BACKOFF_INITIAL: Duration = Duration::from_millis(10);
25const LOCK_BACKOFF_MAX: Duration = Duration::from_millis(250);
26
27/// Remove a file, logging non-NotFound errors.
28fn remove_if_exists(path: &std::path::Path) {
29    if let Err(e) = std::fs::remove_file(path) {
30        if e.kind() != std::io::ErrorKind::NotFound {
31            tracing::warn!("Failed to remove {}: {}", path.display(), e);
32        }
33    }
34}
35
36fn lock_exclusive_with_timeout(lock_file: &std::fs::File) -> Result<()> {
37    let deadline = Instant::now() + LOCK_ACQUIRE_TIMEOUT;
38    let mut sleep = LOCK_BACKOFF_INITIAL;
39
40    loop {
41        match lock_file.try_lock_exclusive() {
42            Ok(()) => return Ok(()),
43            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
44                if Instant::now() >= deadline {
45                    return Err(Error::LockFailed(format!(
46                        "Timed out acquiring engine lock after {:?}",
47                        LOCK_ACQUIRE_TIMEOUT
48                    )));
49                }
50                std::thread::sleep(sleep);
51                sleep = std::cmp::min(sleep * 2, LOCK_BACKOFF_MAX);
52            }
53            Err(e) => return Err(Error::LockFailed(e.to_string())),
54        }
55    }
56}
57
58/// File paths used by the engine lifecycle manager.
59#[derive(Debug, Clone)]
60pub struct EnginePaths {
61    pub cache_dir: PathBuf,
62    pub pid_file: PathBuf,
63    pub lock_file: PathBuf,
64    pub ready_file: PathBuf,
65    pub engine_log_file: PathBuf,
66    pub client_log_file: PathBuf,
67}
68
69impl EnginePaths {
70    pub fn new() -> Result<Self> {
71        let cache_dir = dirs::cache_dir()
72            .ok_or_else(|| Error::Internal("Cannot determine cache directory".into()))?
73            .join("com.theproxycompany");
74
75        Ok(Self {
76            pid_file: cache_dir.join("engine.pid"),
77            lock_file: cache_dir.join("engine.lock"),
78            ready_file: cache_dir.join("engine.ready"),
79            engine_log_file: cache_dir.join("engine.log"),
80            client_log_file: cache_dir.join("client.log"),
81            cache_dir,
82        })
83    }
84}
85
86/// Global context shared across all InferenceEngine instances in a process.
87struct GlobalContext {
88    ref_count: AtomicU32,
89    initialized: AtomicBool,
90    pid_file: Mutex<Option<PathBuf>>,
91}
92
93static GLOBAL_CONTEXT: GlobalContext = GlobalContext {
94    ref_count: AtomicU32::new(0),
95    initialized: AtomicBool::new(false),
96    pid_file: Mutex::new(None),
97};
98
99pub(crate) fn current_engine_pid_file() -> Option<PathBuf> {
100    GLOBAL_CONTEXT
101        .pid_file
102        .lock()
103        .unwrap_or_else(|e| e.into_inner())
104        .clone()
105}
106
107fn set_current_engine_pid_file(pid_file: Option<PathBuf>) {
108    *GLOBAL_CONTEXT
109        .pid_file
110        .lock()
111        .unwrap_or_else(|e| e.into_inner()) = pid_file;
112}
113
114/// Manages the PIE (Proxy Inference Engine) process lifecycle.
115///
116/// Handles:
117/// - Binary fetching if not installed
118/// - Process spawning with proper daemonization
119/// - Per-process lease tracking for orchard-rs clients
120/// - Registration with PIE's management plane
121pub struct InferenceEngine {
122    paths: EnginePaths,
123    fetcher: EngineFetcher,
124    startup_timeout: Duration,
125    lease_active: bool,
126    closed: bool,
127    launch_process: Option<Child>,
128}
129
130impl InferenceEngine {
131    /// Create a new InferenceEngine and connect to (or spawn) the engine process.
132    pub async fn new() -> Result<Self> {
133        Self::with_options(EnginePaths::new()?, None).await
134    }
135
136    /// Create with custom options.
137    pub async fn with_options(
138        paths: EnginePaths,
139        startup_timeout: Option<Duration>,
140    ) -> Result<Self> {
141        let fetcher = EngineFetcher::new();
142        let startup_timeout =
143            startup_timeout.unwrap_or(Duration::from_secs(DEFAULT_STARTUP_TIMEOUT_SECS));
144
145        let mut engine = Self {
146            paths,
147            fetcher,
148            startup_timeout,
149            lease_active: false,
150            closed: false,
151            launch_process: None,
152        };
153
154        engine.acquire_lease().await?;
155
156        Ok(engine)
157    }
158
159    /// Close this engine instance.
160    ///
161    /// Releases this process's orchard-rs lease and best-effort deregisters
162    /// the client with PIE. Normal disconnect never sends signals to PIE.
163    pub fn close(&mut self) -> Result<()> {
164        if self.closed {
165            return Ok(());
166        }
167
168        let should_release = if self.lease_active {
169            let prev = GLOBAL_CONTEXT.ref_count.fetch_sub(1, Ordering::SeqCst);
170            prev == 1 // Was the last reference
171        } else {
172            false
173        };
174
175        if !self.lease_active || !should_release {
176            self.closed = true;
177            self.lease_active = false;
178            return Ok(());
179        }
180
181        // Acquire lock while checking the shared engine PID file so we don't race
182        // a concurrent startup or explicit shutdown in another process.
183        let lock_file = std::fs::File::create(&self.paths.lock_file)?;
184        lock_exclusive_with_timeout(&lock_file)?;
185
186        let engine_pid = read_pid_file(&self.paths.pid_file);
187        let engine_running = engine_pid.map(pid_is_alive).unwrap_or(false);
188
189        if engine_running {
190            if let Err(e) =
191                self.send_client_lifecycle_command("client_deregister", Duration::from_secs(5))
192            {
193                tracing::warn!(
194                    "Failed to deregister orchard-rs client PID {} from PIE: {}",
195                    std::process::id(),
196                    e
197                );
198            }
199        } else {
200            remove_if_exists(&self.paths.pid_file);
201            remove_if_exists(&self.paths.ready_file);
202            remove_if_exists(&self.paths.cache_dir.join("engine.refs"));
203        }
204
205        drop(lock_file);
206
207        self.lease_active = false;
208        self.closed = true;
209        GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
210        set_current_engine_pid_file(None);
211
212        Ok(())
213    }
214
215    /// Force shutdown the engine regardless of reference count.
216    pub fn shutdown(timeout: Duration) -> Result<()> {
217        let paths = EnginePaths::new()?;
218
219        let lock_file = std::fs::File::create(&paths.lock_file)?;
220        lock_exclusive_with_timeout(&lock_file)?;
221
222        // Helper to clean up all state and socket files
223        let cleanup_all = |paths: &EnginePaths| {
224            remove_if_exists(&paths.pid_file);
225            remove_if_exists(&paths.ready_file);
226            remove_if_exists(&paths.cache_dir.join("engine.refs"));
227            let ipc_dir = paths.cache_dir.join("ipc");
228            if ipc_dir.exists() {
229                remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
230                remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
231                remove_if_exists(&ipc_dir.join("pie_management.ipc"));
232            }
233        };
234
235        let pid = match read_pid_file(&paths.pid_file) {
236            Some(p) if pid_is_alive(p) => p,
237            _ => {
238                tracing::info!("Engine is not running. Cleaning up stale files.");
239                cleanup_all(&paths);
240                GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
241                set_current_engine_pid_file(None);
242                return Ok(());
243            }
244        };
245
246        if !pid_is_engine(pid) {
247            tracing::warn!(
248                "PID {} does not belong to proxy_inference_engine; cleaning stale files.",
249                pid
250            );
251            cleanup_all(&paths);
252            GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
253            set_current_engine_pid_file(None);
254            return Ok(());
255        }
256        tracing::info!("Sending shutdown signal to engine process {}", pid);
257
258        if stop_engine_process(pid, timeout) {
259            cleanup_all(&paths);
260            reap_engine_process(pid);
261            GLOBAL_CONTEXT.initialized.store(false, Ordering::SeqCst);
262            set_current_engine_pid_file(None);
263            tracing::info!("Engine process {} terminated gracefully", pid);
264            Ok(())
265        } else {
266            Err(Error::ShutdownFailed(format!(
267                "Failed to stop engine process {}",
268                pid
269            )))
270        }
271    }
272
273    async fn acquire_lease(&mut self) -> Result<()> {
274        if self.closed || self.lease_active {
275            return Ok(());
276        }
277
278        // Ensure cache directory exists
279        std::fs::create_dir_all(&self.paths.cache_dir)?;
280
281        // Acquire file lock
282        let lock_file = std::fs::File::create(&self.paths.lock_file)?;
283        lock_exclusive_with_timeout(&lock_file)?;
284
285        let engine_pid = read_pid_file(&self.paths.pid_file);
286        let engine_running = engine_pid
287            .map(|pid| pid_is_alive(pid) && pid_is_engine(pid))
288            .unwrap_or(false);
289        let mut launched_engine = false;
290
291        // Launch engine if needed
292        if !engine_running {
293            tracing::debug!("Inference engine not running. Launching new instance.");
294
295            // Clean up stale state files
296            remove_if_exists(&self.paths.pid_file);
297            remove_if_exists(&self.paths.ready_file);
298            remove_if_exists(&self.paths.cache_dir.join("engine.refs"));
299
300            // Clean up stale IPC socket files (left over from crashed engine)
301            let ipc_dir = self.paths.cache_dir.join("ipc");
302            if ipc_dir.exists() {
303                remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
304                remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
305                remove_if_exists(&ipc_dir.join("pie_management.ipc"));
306            }
307
308            self.launch_engine().await?;
309            self.wait_for_engine_ready().await?;
310            launched_engine = true;
311        }
312
313        if let Err(e) =
314            self.send_client_lifecycle_command("client_register", Duration::from_secs(5))
315        {
316            if launched_engine {
317                if let Some(pid) = read_pid_file(&self.paths.pid_file) {
318                    if let Err(stop_err) = self.stop_engine_locked(pid) {
319                        tracing::warn!(
320                            "Failed to clean up newly launched engine {} after register failure: {}",
321                            pid,
322                            stop_err
323                        );
324                    }
325                }
326            }
327            return Err(e);
328        }
329
330        // Update global context
331        GLOBAL_CONTEXT.ref_count.fetch_add(1, Ordering::SeqCst);
332        GLOBAL_CONTEXT.initialized.store(true, Ordering::SeqCst);
333        set_current_engine_pid_file(Some(self.paths.pid_file.clone()));
334
335        drop(lock_file);
336
337        self.lease_active = true;
338        Ok(())
339    }
340
341    fn send_client_lifecycle_command(&self, command_type: &str, timeout: Duration) -> Result<()> {
342        let socket = Socket::new(Protocol::Req0)?;
343        socket.set_opt::<nng::options::RecvTimeout>(Some(timeout))?;
344        socket.set_opt::<nng::options::SendTimeout>(Some(timeout))?;
345        socket.dial(&management_url())?;
346
347        let payload = json!({
348            "type": command_type,
349            "client_pid": std::process::id(),
350        });
351        let data = serde_json::to_vec(&payload)?;
352        let msg = nng::Message::from(data.as_slice());
353        socket.send(msg).map_err(|(_, e)| Error::Nng(e))?;
354
355        let response = socket.recv()?;
356        let json: serde_json::Value = serde_json::from_slice(&response)?;
357        let status = json.get("status").and_then(|value| value.as_str());
358        if matches!(status, Some("ok") | Some("accepted")) {
359            return Ok(());
360        }
361
362        let message = json
363            .get("message")
364            .and_then(|value| value.as_str())
365            .unwrap_or("unknown error");
366        Err(Error::Other(format!(
367            "Engine rejected {} for PID {}: {}",
368            command_type,
369            std::process::id(),
370            message
371        )))
372    }
373
374    async fn launch_engine(&mut self) -> Result<()> {
375        let engine_path = self.fetcher.get_engine_path().await?;
376
377        tracing::info!("Launching PIE from {:?}", engine_path);
378
379        // Open log file for engine output
380        let log_file = std::fs::File::create(&self.paths.engine_log_file)?;
381
382        let child = Command::new(&engine_path)
383            .stdout(Stdio::from(log_file.try_clone()?))
384            .stderr(Stdio::from(log_file))
385            .spawn()
386            .map_err(|e| Error::StartupFailed(format!("Failed to spawn engine: {}", e)))?;
387
388        self.launch_process = Some(child);
389        Ok(())
390    }
391
392    async fn wait_for_engine_ready(&self) -> Result<()> {
393        tracing::info!("Waiting for telemetry heartbeat from engine...");
394
395        // Subscribe to telemetry topic via NNG
396        let telemetry_topic = [EVENT_TOPIC_PREFIX, b"telemetry"].concat();
397        let response_url = response_url();
398
399        let socket = nng::Socket::new(nng::Protocol::Sub0)
400            .map_err(|e| Error::StartupFailed(format!("Failed to create Sub0 socket: {}", e)))?;
401
402        socket
403            .set_opt::<nng::options::protocol::pubsub::Subscribe>(telemetry_topic.clone())
404            .map_err(|e| Error::StartupFailed(format!("Failed to subscribe: {}", e)))?;
405
406        socket
407            .set_opt::<nng::options::RecvTimeout>(Some(Duration::from_millis(250)))
408            .map_err(|e| Error::StartupFailed(format!("Failed to set timeout: {}", e)))?;
409
410        // Non-blocking dial - nng will reconnect automatically in background
411        socket
412            .dial_async(&response_url)
413            .map_err(|e| Error::StartupFailed(format!("Failed to dial {}: {}", response_url, e)))?;
414
415        let deadline = Instant::now() + self.startup_timeout;
416
417        while Instant::now() < deadline {
418            // Check if launched process died
419            if let Some(ref child) = self.launch_process {
420                if !pid_is_alive(child.id()) {
421                    return Err(Error::StartupFailed(
422                        "Engine process exited before signaling readiness; check the engine log"
423                            .into(),
424                    ));
425                }
426            }
427
428            // Try to receive a message
429            let msg = match socket.recv() {
430                Ok(msg) => msg,
431                Err(nng::Error::TimedOut) => continue,
432                Err(e) => {
433                    tracing::debug!("Error receiving telemetry: {}", e);
434                    continue;
435                }
436            };
437
438            // Parse message: topic\x00json_body
439            let bytes = msg.as_slice();
440            let parts: Vec<&[u8]> = bytes.splitn(2, |&b| b == 0).collect();
441            if parts.len() < 2 {
442                tracing::warn!("Discarding malformed event message while waiting for telemetry");
443                continue;
444            }
445
446            let (topic_part, json_body) = (parts[0], parts[1]);
447            if topic_part != telemetry_topic.as_slice() {
448                tracing::debug!(
449                    "Ignoring unexpected startup topic '{}'",
450                    String::from_utf8_lossy(topic_part)
451                );
452                continue;
453            }
454
455            // Parse JSON payload
456            let payload: serde_json::Value = match serde_json::from_slice(json_body) {
457                Ok(v) => v,
458                Err(e) => {
459                    tracing::warn!("Discarding malformed telemetry payload: {}", e);
460                    continue;
461                }
462            };
463
464            // Extract health.pid
465            let engine_pid = payload
466                .get("health")
467                .and_then(|h| h.get("pid"))
468                .and_then(|p| p.as_u64())
469                .map(|p| p as u32);
470
471            match engine_pid {
472                Some(pid) if pid > 0 => {
473                    if let Err(e) = write_pid_file(&self.paths.pid_file, pid) {
474                        tracing::warn!("Failed to write PID file: {}", e);
475                    }
476                    tracing::info!("Received telemetry heartbeat. Engine PID {} recorded.", pid);
477                    return Ok(());
478                }
479                _ => {
480                    tracing::warn!(
481                        "Telemetry payload missing valid PID; waiting for next heartbeat"
482                    );
483                    continue;
484                }
485            }
486        }
487
488        Err(Error::StartupFailed(format!(
489            "Timed out after {:?}s waiting for telemetry heartbeat from engine",
490            self.startup_timeout.as_secs()
491        )))
492    }
493
494    fn stop_engine_locked(&mut self, pid: u32) -> Result<()> {
495        let cleanup_files = || {
496            remove_if_exists(&self.paths.pid_file);
497            remove_if_exists(&self.paths.ready_file);
498            let ipc_dir = self.paths.cache_dir.join("ipc");
499            if ipc_dir.exists() {
500                remove_if_exists(&ipc_dir.join("pie_requests.ipc"));
501                remove_if_exists(&ipc_dir.join("pie_responses.ipc"));
502                remove_if_exists(&ipc_dir.join("pie_management.ipc"));
503            }
504        };
505
506        if !pid_is_alive(pid) {
507            tracing::debug!("Engine PID {} already exited", pid);
508            cleanup_files();
509            return Ok(());
510        }
511
512        if !pid_is_engine(pid) {
513            tracing::warn!(
514                "PID {} does not belong to proxy_inference_engine; cleaning stale files.",
515                pid
516            );
517            cleanup_files();
518            return Ok(());
519        }
520
521        if !stop_engine_process(pid, Duration::from_secs(5)) {
522            return Err(Error::ShutdownFailed(format!(
523                "Failed to stop engine PID {}",
524                pid
525            )));
526        }
527
528        reap_engine_process(pid);
529        cleanup_files();
530
531        tracing::info!("Engine PID {} stopped", pid);
532        Ok(())
533    }
534
535    /// Generate a unique response channel ID for this client.
536    ///
537    /// Format: (PID << 32) | random_32_bits
538    /// Uses true randomness to avoid collisions between rapid successive calls.
539    pub fn generate_response_channel_id() -> u64 {
540        use rand::Rng;
541
542        let pid = std::process::id() as u64 & 0xFFFFFFFF;
543        let random: u32 = rand::thread_rng().gen();
544
545        let channel_id = (pid << 32) | (random as u64);
546        if channel_id == 0 {
547            1
548        } else {
549            channel_id
550        }
551    }
552}
553
554impl Drop for InferenceEngine {
555    fn drop(&mut self) {
556        if !self.closed {
557            if let Err(e) = self.close() {
558                tracing::error!("Failed to close InferenceEngine: {}", e);
559            }
560        }
561    }
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567
568    #[test]
569    fn test_engine_paths() {
570        let paths = EnginePaths::new().expect("cache dir should be available");
571        assert!(paths
572            .cache_dir
573            .to_string_lossy()
574            .contains("com.theproxycompany"));
575    }
576
577    #[test]
578    fn test_generate_channel_id_uniqueness() {
579        use std::collections::HashSet;
580
581        // Generate 1000 channel IDs in rapid succession
582        let ids: HashSet<u64> = (0..1000)
583            .map(|_| InferenceEngine::generate_response_channel_id())
584            .collect();
585
586        // All IDs must be unique (HashSet dedupes)
587        assert_eq!(
588            ids.len(),
589            1000,
590            "Channel IDs must be unique across rapid calls"
591        );
592
593        // All IDs must be non-zero
594        assert!(!ids.contains(&0), "Channel ID must never be zero");
595
596        // All IDs should have the current PID in upper 32 bits
597        let expected_pid = std::process::id() as u64 & 0xFFFFFFFF;
598        for id in &ids {
599            let id_pid = id >> 32;
600            assert_eq!(id_pid, expected_pid, "Upper 32 bits must be current PID");
601        }
602    }
603}