1use crate::types::ServiceInfoFile;
4use crate::{Error, PlatformRefCounter, RefCounter, Result, ServiceInfo};
5use parking_lot::RwLock;
6use std::future::Future;
7use std::path::PathBuf;
8use std::pin::Pin;
9use std::sync::Arc;
10
11pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
13
14pub type OnFirstAcquire = Box<dyn Fn() -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
16
17pub type OnLastRelease =
19 Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<()>> + Send + Sync>;
20
21pub type OnHealthCheck = Box<dyn Fn(&ServiceInfo) -> BoxFuture<'static, bool> + Send + Sync>;
23
24pub type OnRecover =
26 Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
27
28pub struct ServiceHandle {
33 service: Arc<SharedServiceInner>,
34 info: ServiceInfo,
35}
36
37impl ServiceHandle {
38 pub fn info(&self) -> &ServiceInfo {
40 &self.info
41 }
42
43 pub fn port(&self) -> u16 {
45 self.info.port()
46 }
47
48 pub fn pid(&self) -> u32 {
50 self.info.pid()
51 }
52}
53
54impl Drop for ServiceHandle {
55 fn drop(&mut self) {
56 let _ = self.service.release_sync();
59 }
60}
61
62struct SharedServiceInner {
64 name: String,
65 ref_counter: PlatformRefCounter,
66 info_path: PathBuf,
67 current_info: RwLock<Option<ServiceInfo>>,
68 on_first_acquire: Option<OnFirstAcquire>,
69 on_last_release: Option<OnLastRelease>,
70 on_health_check: Option<OnHealthCheck>,
71 on_recover: Option<OnRecover>,
72}
73
74impl SharedServiceInner {
75 fn release_sync(&self) -> Result<()> {
77 let count = self.ref_counter.release()?;
78
79 if count == 0 {
80 if let Some(ref _callback) = self.on_last_release {
82 if let Some(info) = self.current_info.read().clone() {
83 tracing::info!(
86 "Last client released, service {} should be stopped",
87 self.name
88 );
89
90 crate::process::stop(info.pid(), 5000);
92
93 let _ = std::fs::remove_file(&self.info_path);
95 }
96 }
97 }
98
99 Ok(())
100 }
101}
102
103pub struct SharedService {
108 inner: Arc<SharedServiceInner>,
109}
110
111impl SharedService {
112 pub fn builder(name: &str) -> SharedServiceBuilder {
114 SharedServiceBuilder::new(name)
115 }
116
117 pub async fn acquire(&self) -> Result<ServiceHandle> {
122 let count = self.inner.ref_counter.acquire()?;
124 tracing::debug!("Acquired reference, count={}", count);
125
126 if count == 1 {
128 if self.inner.ref_counter.try_lock()? {
130 let info = self.start_service().await?;
132 self.inner.ref_counter.unlock()?;
133 return Ok(ServiceHandle {
134 service: Arc::clone(&self.inner),
135 info,
136 });
137 }
138 self.wait_for_service().await?;
140 }
141
142 let info = self.get_or_recover_service().await?;
144
145 Ok(ServiceHandle {
146 service: Arc::clone(&self.inner),
147 info,
148 })
149 }
150
151 async fn start_service(&self) -> Result<ServiceInfo> {
153 tracing::info!("Starting service {}", self.inner.name);
154
155 let info = if let Some(ref callback) = self.inner.on_first_acquire {
156 callback().await?
157 } else {
158 return Err(Error::ServiceStart(
159 "No on_first_acquire callback registered".to_string(),
160 ));
161 };
162
163 self.save_info(&info)?;
165 *self.inner.current_info.write() = Some(info.clone());
166
167 Ok(info)
168 }
169
170 async fn wait_for_service(&self) -> Result<()> {
172 let start = std::time::Instant::now();
173 let timeout = std::time::Duration::from_secs(30);
174
175 while start.elapsed() < timeout {
176 if self.inner.info_path.exists() {
177 if let Ok(info) = self.load_info() {
178 if info.is_alive() {
179 return Ok(());
180 }
181 }
182 }
183 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184 }
185
186 Err(Error::ServiceStart(
187 "Timeout waiting for service to start".to_string(),
188 ))
189 }
190
191 async fn get_or_recover_service(&self) -> Result<ServiceInfo> {
193 let info = match self.load_info() {
195 Ok(info) => info,
196 Err(_) => {
197 return self.recover_service(None).await;
199 }
200 };
201
202 let is_healthy = if let Some(ref check) = self.inner.on_health_check {
204 check(&info).await
205 } else {
206 info.is_alive()
208 };
209
210 if is_healthy {
211 *self.inner.current_info.write() = Some(info.clone());
212 Ok(info)
213 } else {
214 self.recover_service(Some(info)).await
216 }
217 }
218
219 async fn recover_service(&self, old_info: Option<ServiceInfo>) -> Result<ServiceInfo> {
221 tracing::warn!("Service {} needs recovery", self.inner.name);
222
223 if !self.inner.ref_counter.try_lock()? {
225 return self.wait_for_service().await.and_then(|_| {
227 self.load_info()
228 });
229 }
230
231 let info = if let Some(ref callback) = self.inner.on_recover {
232 if let Some(old) = old_info {
233 callback(old).await?
234 } else if let Some(ref start) = self.inner.on_first_acquire {
235 start().await?
236 } else {
237 self.inner.ref_counter.unlock()?;
238 return Err(Error::ServiceRecovery(
239 "No recovery or startup callback".to_string(),
240 ));
241 }
242 } else if let Some(ref start) = self.inner.on_first_acquire {
243 if let Some(old) = old_info {
245 crate::process::stop(old.pid(), 5000);
246 }
247 start().await?
248 } else {
249 self.inner.ref_counter.unlock()?;
250 return Err(Error::ServiceRecovery(
251 "No recovery or startup callback".to_string(),
252 ));
253 };
254
255 self.save_info(&info)?;
256 *self.inner.current_info.write() = Some(info.clone());
257 self.inner.ref_counter.unlock()?;
258
259 Ok(info)
260 }
261
262 fn save_info(&self, info: &ServiceInfo) -> Result<()> {
264 if let Some(parent) = self.inner.info_path.parent() {
265 std::fs::create_dir_all(parent)?;
266 }
267
268 let file_info = ServiceInfoFile::from(info);
269 let content = serde_json::to_string_pretty(&file_info)
270 .map_err(|e| Error::ServiceInfo(format!("Serialization failed: {}", e)))?;
271
272 std::fs::write(&self.inner.info_path, content)?;
273 Ok(())
274 }
275
276 fn load_info(&self) -> Result<ServiceInfo> {
278 let content = std::fs::read_to_string(&self.inner.info_path)?;
279 let file_info: ServiceInfoFile = serde_json::from_str(&content)
280 .map_err(|e| Error::ServiceInfo(format!("Deserialization failed: {}", e)))?;
281 Ok(ServiceInfo::from(file_info))
282 }
283
284 pub fn count(&self) -> Result<u32> {
286 self.inner.ref_counter.count()
287 }
288
289 pub fn name(&self) -> &str {
291 &self.inner.name
292 }
293}
294
295pub struct SharedServiceBuilder {
297 name: String,
298 base_dir: Option<PathBuf>,
299 on_first_acquire: Option<OnFirstAcquire>,
300 on_last_release: Option<OnLastRelease>,
301 on_health_check: Option<OnHealthCheck>,
302 on_recover: Option<OnRecover>,
303}
304
305impl SharedServiceBuilder {
306 pub fn new(name: &str) -> Self {
308 Self {
309 name: name.to_string(),
310 base_dir: None,
311 on_first_acquire: None,
312 on_last_release: None,
313 on_health_check: None,
314 on_recover: None,
315 }
316 }
317
318 pub fn base_dir(mut self, dir: impl Into<PathBuf>) -> Self {
320 self.base_dir = Some(dir.into());
321 self
322 }
323
324 pub fn on_first_acquire<F, Fut>(mut self, f: F) -> Self
329 where
330 F: Fn() -> Fut + Send + Sync + 'static,
331 Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
332 {
333 self.on_first_acquire = Some(Box::new(move || Box::pin(f())));
334 self
335 }
336
337 pub fn on_last_release<F, Fut>(mut self, f: F) -> Self
342 where
343 F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
344 Fut: Future<Output = Result<()>> + Send + 'static,
345 {
346 self.on_last_release = Some(Box::new(move |info| Box::pin(f(info))));
347 self
348 }
349
350 pub fn on_health_check<F, Fut>(mut self, f: F) -> Self
355 where
356 F: Fn(&ServiceInfo) -> Fut + Send + Sync + 'static,
357 Fut: Future<Output = bool> + Send + 'static,
358 {
359 let f = Arc::new(f);
360 self.on_health_check = Some(Box::new(move |info| {
361 let info = info.clone();
362 let f = Arc::clone(&f);
363 Box::pin(async move { f(&info).await })
364 }));
365 self
366 }
367
368 pub fn on_recover<F, Fut>(mut self, f: F) -> Self
373 where
374 F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
375 Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
376 {
377 self.on_recover = Some(Box::new(move |info| Box::pin(f(info))));
378 self
379 }
380
381 pub fn build(self) -> Result<SharedService> {
383 let base_dir = self.base_dir.unwrap_or_else(|| {
384 dirs::home_dir()
385 .unwrap_or_else(|| PathBuf::from("."))
386 .join(".procref")
387 });
388
389 std::fs::create_dir_all(&base_dir)?;
390
391 let info_path = base_dir.join(format!("{}.json", self.name));
392 let ref_counter = PlatformRefCounter::new(&self.name)?;
393
394 let inner = SharedServiceInner {
395 name: self.name,
396 ref_counter,
397 info_path,
398 current_info: RwLock::new(None),
399 on_first_acquire: self.on_first_acquire,
400 on_last_release: self.on_last_release,
401 on_health_check: self.on_health_check,
402 on_recover: self.on_recover,
403 };
404
405 Ok(SharedService {
406 inner: Arc::new(inner),
407 })
408 }
409}