use crate::types::ServiceInfoFile;
use crate::{Error, PlatformRefCounter, RefCounter, Result, ServiceInfo};
use parking_lot::RwLock;
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type OnFirstAcquire = Box<dyn Fn() -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
pub type OnLastRelease =
Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<()>> + Send + Sync>;
pub type OnHealthCheck = Box<dyn Fn(&ServiceInfo) -> BoxFuture<'static, bool> + Send + Sync>;
pub type OnRecover =
Box<dyn Fn(ServiceInfo) -> BoxFuture<'static, Result<ServiceInfo>> + Send + Sync>;
pub struct ServiceHandle {
service: Arc<SharedServiceInner>,
info: ServiceInfo,
}
impl ServiceHandle {
pub fn info(&self) -> &ServiceInfo {
&self.info
}
pub fn port(&self) -> u16 {
self.info.port()
}
pub fn pid(&self) -> u32 {
self.info.pid()
}
}
impl Drop for ServiceHandle {
fn drop(&mut self) {
let _ = self.service.release_sync();
}
}
struct SharedServiceInner {
name: String,
ref_counter: PlatformRefCounter,
info_path: PathBuf,
current_info: RwLock<Option<ServiceInfo>>,
on_first_acquire: Option<OnFirstAcquire>,
on_last_release: Option<OnLastRelease>,
on_health_check: Option<OnHealthCheck>,
on_recover: Option<OnRecover>,
}
impl SharedServiceInner {
fn release_sync(&self) -> Result<()> {
let count = self.ref_counter.release()?;
if count == 0 {
if let Some(ref _callback) = self.on_last_release {
if let Some(info) = self.current_info.read().clone() {
tracing::info!(
"Last client released, service {} should be stopped",
self.name
);
crate::process::stop(info.pid(), 5000);
let _ = std::fs::remove_file(&self.info_path);
}
}
}
Ok(())
}
}
pub struct SharedService {
inner: Arc<SharedServiceInner>,
}
impl SharedService {
pub fn builder(name: &str) -> SharedServiceBuilder {
SharedServiceBuilder::new(name)
}
pub async fn acquire(&self) -> Result<ServiceHandle> {
let count = self.inner.ref_counter.acquire()?;
tracing::debug!("Acquired reference, count={}", count);
if count == 1 {
if self.inner.ref_counter.try_lock()? {
let info = self.start_service().await?;
self.inner.ref_counter.unlock()?;
return Ok(ServiceHandle {
service: Arc::clone(&self.inner),
info,
});
}
self.wait_for_service().await?;
}
let info = self.get_or_recover_service().await?;
Ok(ServiceHandle {
service: Arc::clone(&self.inner),
info,
})
}
async fn start_service(&self) -> Result<ServiceInfo> {
tracing::info!("Starting service {}", self.inner.name);
let info = if let Some(ref callback) = self.inner.on_first_acquire {
callback().await?
} else {
return Err(Error::ServiceStart(
"No on_first_acquire callback registered".to_string(),
));
};
self.save_info(&info)?;
*self.inner.current_info.write() = Some(info.clone());
Ok(info)
}
async fn wait_for_service(&self) -> Result<()> {
let start = std::time::Instant::now();
let timeout = std::time::Duration::from_secs(30);
while start.elapsed() < timeout {
if self.inner.info_path.exists() {
if let Ok(info) = self.load_info() {
if info.is_alive() {
return Ok(());
}
}
}
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
Err(Error::ServiceStart(
"Timeout waiting for service to start".to_string(),
))
}
async fn get_or_recover_service(&self) -> Result<ServiceInfo> {
let info = match self.load_info() {
Ok(info) => info,
Err(_) => {
return self.recover_service(None).await;
}
};
let is_healthy = if let Some(ref check) = self.inner.on_health_check {
check(&info).await
} else {
info.is_alive()
};
if is_healthy {
*self.inner.current_info.write() = Some(info.clone());
Ok(info)
} else {
self.recover_service(Some(info)).await
}
}
async fn recover_service(&self, old_info: Option<ServiceInfo>) -> Result<ServiceInfo> {
tracing::warn!("Service {} needs recovery", self.inner.name);
if !self.inner.ref_counter.try_lock()? {
return self.wait_for_service().await.and_then(|_| {
self.load_info()
});
}
let info = if let Some(ref callback) = self.inner.on_recover {
if let Some(old) = old_info {
callback(old).await?
} else if let Some(ref start) = self.inner.on_first_acquire {
start().await?
} else {
self.inner.ref_counter.unlock()?;
return Err(Error::ServiceRecovery(
"No recovery or startup callback".to_string(),
));
}
} else if let Some(ref start) = self.inner.on_first_acquire {
if let Some(old) = old_info {
crate::process::stop(old.pid(), 5000);
}
start().await?
} else {
self.inner.ref_counter.unlock()?;
return Err(Error::ServiceRecovery(
"No recovery or startup callback".to_string(),
));
};
self.save_info(&info)?;
*self.inner.current_info.write() = Some(info.clone());
self.inner.ref_counter.unlock()?;
Ok(info)
}
fn save_info(&self, info: &ServiceInfo) -> Result<()> {
if let Some(parent) = self.inner.info_path.parent() {
std::fs::create_dir_all(parent)?;
}
let file_info = ServiceInfoFile::from(info);
let content = serde_json::to_string_pretty(&file_info)
.map_err(|e| Error::ServiceInfo(format!("Serialization failed: {}", e)))?;
std::fs::write(&self.inner.info_path, content)?;
Ok(())
}
fn load_info(&self) -> Result<ServiceInfo> {
let content = std::fs::read_to_string(&self.inner.info_path)?;
let file_info: ServiceInfoFile = serde_json::from_str(&content)
.map_err(|e| Error::ServiceInfo(format!("Deserialization failed: {}", e)))?;
Ok(ServiceInfo::from(file_info))
}
pub fn count(&self) -> Result<u32> {
self.inner.ref_counter.count()
}
pub fn name(&self) -> &str {
&self.inner.name
}
}
pub struct SharedServiceBuilder {
name: String,
base_dir: Option<PathBuf>,
on_first_acquire: Option<OnFirstAcquire>,
on_last_release: Option<OnLastRelease>,
on_health_check: Option<OnHealthCheck>,
on_recover: Option<OnRecover>,
}
impl SharedServiceBuilder {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
base_dir: None,
on_first_acquire: None,
on_last_release: None,
on_health_check: None,
on_recover: None,
}
}
pub fn base_dir(mut self, dir: impl Into<PathBuf>) -> Self {
self.base_dir = Some(dir.into());
self
}
pub fn on_first_acquire<F, Fut>(mut self, f: F) -> Self
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
{
self.on_first_acquire = Some(Box::new(move || Box::pin(f())));
self
}
pub fn on_last_release<F, Fut>(mut self, f: F) -> Self
where
F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
self.on_last_release = Some(Box::new(move |info| Box::pin(f(info))));
self
}
pub fn on_health_check<F, Fut>(mut self, f: F) -> Self
where
F: Fn(&ServiceInfo) -> Fut + Send + Sync + 'static,
Fut: Future<Output = bool> + Send + 'static,
{
let f = Arc::new(f);
self.on_health_check = Some(Box::new(move |info| {
let info = info.clone();
let f = Arc::clone(&f);
Box::pin(async move { f(&info).await })
}));
self
}
pub fn on_recover<F, Fut>(mut self, f: F) -> Self
where
F: Fn(ServiceInfo) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<ServiceInfo>> + Send + 'static,
{
self.on_recover = Some(Box::new(move |info| Box::pin(f(info))));
self
}
pub fn build(self) -> Result<SharedService> {
let base_dir = self.base_dir.unwrap_or_else(|| {
dirs::home_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join(".procref")
});
std::fs::create_dir_all(&base_dir)?;
let info_path = base_dir.join(format!("{}.json", self.name));
let ref_counter = PlatformRefCounter::new(&self.name)?;
let inner = SharedServiceInner {
name: self.name,
ref_counter,
info_path,
current_info: RwLock::new(None),
on_first_acquire: self.on_first_acquire,
on_last_release: self.on_last_release,
on_health_check: self.on_health_check,
on_recover: self.on_recover,
};
Ok(SharedService {
inner: Arc::new(inner),
})
}
}