procref/
service.rs

1//! SharedService - High-level API for managing shared service lifecycles.
2
3use crate::types::ServiceInfoFile;
4use crate::{Error, PlatformRefCounter, RefCounter, Result, ServiceInfo};
5use parking_lot::RwLock;
6use std::future::Future;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::Arc;
10
11/// Type alias for async callbacks.
12pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
13
14/// Callback for first acquire (service startup).
15pub type OnFirstAcquire = Box<dyn Fn() -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
16
17/// Callback for last release (service shutdown).
18pub type OnLastRelease =
19    Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<()>> + Send + Sync>;
20
21/// Callback for health check.
22pub type OnHealthCheck = Box<dyn Fn(&ServiceInfo) -> BoxFuture<'static, bool> + Send + Sync>;
23
24/// Callback for recovery.
25pub type OnRecover =
26    Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
27
28/// A handle to a shared service.
29///
30/// Dropping this handle releases the reference.
31/// If it's the last reference, the service may be shut down.
32pub struct ServiceHandle {
33    service: Arc<SharedServiceInner>,
34    info: ServiceInfo,
35}
36
37impl ServiceHandle {
38    /// Get information about the service.
39    pub fn info(&self) -> &ServiceInfo {
40        &self.info
41    }
42
43    /// Get the service port.
44    pub fn port(&self) -> u16 {
45        self.info.port()
46    }
47
48    /// Get the service PID.
49    pub fn pid(&self) -> u32 {
50        self.info.pid()
51    }
52}
53
54impl Drop for ServiceHandle {
55    fn drop(&mut self) {
56        // Release is handled by the service
57        // We can't do async in drop, so we spawn a task or use sync release
58        let _ = self.service.release_sync();
59    }
60}
61
62/// Inner state for SharedService.
63struct SharedServiceInner {
64    name: String,
65    ref_counter: PlatformRefCounter,
66    info_path: PathBuf,
67    current_info: RwLock<Option<ServiceInfo>>,
68    on_first_acquire: Option<OnFirstAcquire>,
69    on_last_release: Option<OnLastRelease>,
70    on_health_check: Option<OnHealthCheck>,
71    on_recover: Option<OnRecover>,
72}
73
74impl SharedServiceInner {
75    /// Synchronous release (for Drop).
76    fn release_sync(&self) -> Result<()> {
77        let count = self.ref_counter.release()?;
78
79        if count == 0 {
80            // We're the last client
81            if let Some(ref _callback) = self.on_last_release {
82                if let Some(info) = self.current_info.read().clone() {
83                    // We can't call async in sync context easily
84                    // For now, just log. Proper async drop needs runtime support.
85                    tracing::info!(
86                        "Last client released, service {} should be stopped",
87                        self.name
88                    );
89
90                    // Try to stop the process directly
91                    crate::process::stop(info.pid(), 5000);
92
93                    // Clean up info file
94                    let _ = std::fs::remove_file(&self.info_path);
95                }
96            }
97        }
98
99        Ok(())
100    }
101}
102
103/// Shared service manager.
104///
105/// Manages a service that is shared across multiple processes.
106/// Uses kernel-level reference counting to track clients.
107pub struct SharedService {
108    inner: Arc<SharedServiceInner>,
109}
110
111impl SharedService {
112    /// Create a new builder for SharedService.
113    pub fn builder(name: &str) -> SharedServiceBuilder {
114        SharedServiceBuilder::new(name)
115    }
116
117    /// Acquire a reference to the service.
118    ///
119    /// If this is the first client, `on_first_acquire` is called to start the service.
120    /// Otherwise, the existing service info is returned after health check.
121    pub async fn acquire(&self) -> Result<ServiceHandle> {
122        // Step 1: Increment reference count
123        let count = self.inner.ref_counter.acquire()?;
124        tracing::debug!("Acquired reference, count={}", count);
125
126        // Step 2: Check if we're the first client
127        if count == 1 {
128            // Try to get startup lock
129            if self.inner.ref_counter.try_lock()? {
130                // We have the lock, start the service
131                let info = self.start_service().await?;
132                self.inner.ref_counter.unlock()?;
133                return Ok(ServiceHandle {
134                    service: Arc::clone(&self.inner),
135                    info,
136                });
137            }
138            // Someone else has the lock, wait for them to finish
139            self.wait_for_service().await?;
140        }
141
142        // Step 3: Get existing service info
143        let info = self.get_or_recover_service().await?;
144
145        Ok(ServiceHandle {
146            service: Arc::clone(&self.inner),
147            info,
148        })
149    }
150
151    /// Start the service (called when we're the first client).
152    async fn start_service(&self) -> Result<ServiceInfo> {
153        tracing::info!("Starting service {}", self.inner.name);
154
155        let info = if let Some(ref callback) = self.inner.on_first_acquire {
156            callback().await?
157        } else {
158            return Err(Error::ServiceStart(
159                "No on_first_acquire callback registered".to_string(),
160            ));
161        };
162
163        // Save service info to file
164        self.save_info(&info)?;
165        *self.inner.current_info.write() = Some(info.clone());
166
167        Ok(info)
168    }
169
170    /// Wait for another process to start the service.
171    async fn wait_for_service(&self) -> Result<()> {
172        let start = std::time::Instant::now();
173        let timeout = std::time::Duration::from_secs(30);
174
175        while start.elapsed() < timeout {
176            if self.inner.info_path.exists() {
177                if let Ok(info) = self.load_info() {
178                    if info.is_alive() {
179                        return Ok(());
180                    }
181                }
182            }
183            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184        }
185
186        Err(Error::ServiceStart(
187            "Timeout waiting for service to start".to_string(),
188        ))
189    }
190
191    /// Get existing service info, or recover if unhealthy.
192    async fn get_or_recover_service(&self) -> Result<ServiceInfo> {
193        // Try to load from file
194        let info = match self.load_info() {
195            Ok(info) => info,
196            Err(_) => {
197                // No info file, try to recover
198                return self.recover_service(None).await;
199            }
200        };
201
202        // Check if healthy
203        let is_healthy = if let Some(ref check) = self.inner.on_health_check {
204            check(&info).await
205        } else {
206            // Default health check: process is alive
207            info.is_alive()
208        };
209
210        if is_healthy {
211            *self.inner.current_info.write() = Some(info.clone());
212            Ok(info)
213        } else {
214            // Try to recover
215            self.recover_service(Some(info)).await
216        }
217    }
218
219    /// Recover an unhealthy service.
220    async fn recover_service(&self, old_info: Option<ServiceInfo>) -> Result<ServiceInfo> {
221        tracing::warn!("Service {} needs recovery", self.inner.name);
222
223        // Try to get startup lock for recovery
224        if !self.inner.ref_counter.try_lock()? {
225            // Someone else is recovering, wait
226            return self.wait_for_service().await.and_then(|_| {
227                self.load_info()
228            });
229        }
230
231        let info = if let Some(ref callback) = self.inner.on_recover {
232            if let Some(old) = old_info {
233                callback(old).await?
234            } else if let Some(ref start) = self.inner.on_first_acquire {
235                start().await?
236            } else {
237                self.inner.ref_counter.unlock()?;
238                return Err(Error::ServiceRecovery(
239                    "No recovery or startup callback".to_string(),
240                ));
241            }
242        } else if let Some(ref start) = self.inner.on_first_acquire {
243            // Fall back to restart
244            if let Some(old) = old_info {
245                crate::process::stop(old.pid(), 5000);
246            }
247            start().await?
248        } else {
249            self.inner.ref_counter.unlock()?;
250            return Err(Error::ServiceRecovery(
251                "No recovery or startup callback".to_string(),
252            ));
253        };
254
255        self.save_info(&info)?;
256        *self.inner.current_info.write() = Some(info.clone());
257        self.inner.ref_counter.unlock()?;
258
259        Ok(info)
260    }
261
262    /// Save service info to file.
263    fn save_info(&self, info: &ServiceInfo) -> Result<()> {
264        if let Some(parent) = self.inner.info_path.parent() {
265            std::fs::create_dir_all(parent)?;
266        }
267
268        let file_info = ServiceInfoFile::from(info);
269        let content = serde_json::to_string_pretty(&file_info)
270            .map_err(|e| Error::ServiceInfo(format!("Serialization failed: {}", e)))?;
271
272        std::fs::write(&self.inner.info_path, content)?;
273        Ok(())
274    }
275
276    /// Load service info from file.
277    fn load_info(&self) -> Result<ServiceInfo> {
278        let content = std::fs::read_to_string(&self.inner.info_path)?;
279        let file_info: ServiceInfoFile = serde_json::from_str(&content)
280            .map_err(|e| Error::ServiceInfo(format!("Deserialization failed: {}", e)))?;
281        Ok(ServiceInfo::from(file_info))
282    }
283
284    /// Get current reference count.
285    pub fn count(&self) -> Result<u32> {
286        self.inner.ref_counter.count()
287    }
288
289    /// Get service name.
290    pub fn name(&self) -> &str {
291        &self.inner.name
292    }
293}
294
295/// Builder for SharedService.
296pub struct SharedServiceBuilder {
297    name: String,
298    base_dir: Option<PathBuf>,
299    on_first_acquire: Option<OnFirstAcquire>,
300    on_last_release: Option<OnLastRelease>,
301    on_health_check: Option<OnHealthCheck>,
302    on_recover: Option<OnRecover>,
303}
304
305impl SharedServiceBuilder {
306    /// Create a new builder.
307    pub fn new(name: &str) -> Self {
308        Self {
309            name: name.to_string(),
310            base_dir: None,
311            on_first_acquire: None,
312            on_last_release: None,
313            on_health_check: None,
314            on_recover: None,
315        }
316    }
317
318    /// Set base directory for service info files.
319    pub fn base_dir(mut self, dir: impl Into<PathBuf>) -> Self {
320        self.base_dir = Some(dir.into());
321        self
322    }
323
324    /// Set callback for first client (service startup).
325    ///
326    /// This is called when the first client acquires a reference.
327    /// It should start the service and return ServiceInfo.
328    pub fn on_first_acquire<F, Fut>(mut self, f: F) -> Self
329    where
330        F: Fn() -> Fut + Send + Sync + 'static,
331        Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
332    {
333        self.on_first_acquire = Some(Box::new(move || Box::pin(f())));
334        self
335    }
336
337    /// Set callback for last client (service shutdown).
338    ///
339    /// This is called when the last client releases their reference.
340    /// It should stop the service.
341    pub fn on_last_release<F, Fut>(mut self, f: F) -> Self
342    where
343        F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
344        Fut: Future<Output = Result<()>> + Send + 'static,
345    {
346        self.on_last_release = Some(Box::new(move |info| Box::pin(f(info))));
347        self
348    }
349
350    /// Set callback for health check.
351    ///
352    /// This is called to verify the service is healthy.
353    /// Return true if healthy, false if recovery is needed.
354    pub fn on_health_check<F, Fut>(mut self, f: F) -> Self
355    where
356        F: Fn(&ServiceInfo) -> Fut + Send + Sync + 'static,
357        Fut: Future<Output = bool> + Send + 'static,
358    {
359        let f = Arc::new(f);
360        self.on_health_check = Some(Box::new(move |info| {
361            let info = info.clone();
362            let f = Arc::clone(&f);
363            Box::pin(async move { f(&info).await })
364        }));
365        self
366    }
367
368    /// Set callback for recovery.
369    ///
370    /// This is called when health check fails.
371    /// It should recover the service and return new ServiceInfo.
372    pub fn on_recover<F, Fut>(mut self, f: F) -> Self
373    where
374        F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
375        Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
376    {
377        self.on_recover = Some(Box::new(move |info| Box::pin(f(info))));
378        self
379    }
380
381    /// Build the SharedService.
382    pub fn build(self) -> Result<SharedService> {
383        let base_dir = self.base_dir.unwrap_or_else(|| {
384            dirs::home_dir()
385                .unwrap_or_else(|| PathBuf::from("."))
386                .join(".procref")
387        });
388
389        std::fs::create_dir_all(&base_dir)?;
390
391        let info_path = base_dir.join(format!("{}.json", self.name));
392        let ref_counter = PlatformRefCounter::new(&self.name)?;
393
394        let inner = SharedServiceInner {
395            name: self.name,
396            ref_counter,
397            info_path,
398            current_info: RwLock::new(None),
399            on_first_acquire: self.on_first_acquire,
400            on_last_release: self.on_last_release,
401            on_health_check: self.on_health_check,
402            on_recover: self.on_recover,
403        };
404
405        Ok(SharedService {
406            inner: Arc::new(inner),
407        })
408    }
409}