use super::super::ScopeFileSystem;
use crate::{Error, Result};
const STATE_LOCK_FILE: &str = "state.lock";
#[derive(Debug, PartialEq, Eq)]
pub struct StateLock {
pid: u32,
process_name: String,
}
impl std::fmt::Display for StateLock {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.pid, self.process_name)
}
}
impl std::default::Default for StateLock {
fn default() -> Self {
let pid = std::process::id();
Self {
pid,
process_name: get_process_name(pid).unwrap_or_else(|| "unknown".to_string()),
}
}
}
impl StateLock {
pub async fn load(fs: &ScopeFileSystem) -> Result<Self> {
let mut reader = fs.stream_read(STATE_LOCK_FILE).await?;
let pid_bytes = reader.read(4).await.map_err(|e| {
Error::InvalidFormat(format!("Failed to read PID from '{STATE_LOCK_FILE}': {e}"))
})?;
if pid_bytes.len() != 4 {
return Err(Error::InvalidFormat(format!(
"Invalid PID in '{}': expected 4 bytes, got {}",
STATE_LOCK_FILE,
pid_bytes.len()
)));
}
let pid = u32::from_be_bytes([pid_bytes[0], pid_bytes[1], pid_bytes[2], pid_bytes[3]]);
let name_bytes = reader.read_to_end().await.map_err(|e| {
Error::InvalidFormat(format!(
"Failed to read process name from '{STATE_LOCK_FILE}': {e}"
))
})?;
let process_name = String::from_utf8(name_bytes).map_err(|e| {
Error::InvalidFormat(format!(
"Invalid UTF-8 in process name in '{STATE_LOCK_FILE}': {e}"
))
})?;
Ok(Self { pid, process_name })
}
pub async fn save(&self, fs: &ScopeFileSystem) -> Result<()> {
let mut writer = fs.stream_write(STATE_LOCK_FILE).await?;
writer.write(&self.pid.to_be_bytes()).await?;
writer.write(self.process_name.as_bytes()).await?;
writer.flush().await?;
Ok(())
}
pub async fn remove(fs: &ScopeFileSystem) -> Result<()> {
fs.remove_file(STATE_LOCK_FILE).await
}
pub fn is_running(&self) -> bool {
let Some(actual_name) = get_process_name(self.pid) else {
return false;
};
actual_name == self.process_name
}
pub fn is_current(&self) -> bool {
self.pid == std::process::id()
}
}
#[cfg(unix)]
fn get_process_name(pid: u32) -> Option<String> {
use std::process::Command;
let output = Command::new("ps")
.arg("-p")
.arg(pid.to_string())
.arg("-o")
.arg("comm=")
.output()
.ok()?;
if output.status.success() {
let name = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !name.is_empty() {
return Some(name);
}
}
None
}
#[cfg(windows)]
fn get_process_name(pid: u32) -> Option<String> {
use std::process::Command;
let output = Command::new("tasklist")
.arg("/FI")
.arg(format!("PID eq {}", pid))
.arg("/FO")
.arg("CSV")
.arg("/NH")
.output()
.ok()?;
if output.status.success() {
let stdout = String::from_utf8_lossy(&output.stdout);
if stdout.contains(&pid.to_string()) {
if let Some(first_quote_end) = stdout.find("\",") {
let name = stdout[1..first_quote_end].to_string();
return Some(name);
}
}
}
None
}
#[cfg(target_family = "wasm")]
fn get_process_name(pid: u32) -> Option<String> {
None
}
#[cfg(test)]
mod tests {
use super::{Result, ScopeFileSystem, StateLock};
#[tokio::test]
#[cfg_attr(miri, ignore)]
async fn test_state_lock() -> Result<()> {
let fs = ScopeFileSystem::new_memory_fs("/bucket1".into());
fs.ensure_exist().await?;
assert!(StateLock::load(&fs).await.is_err());
let lock = StateLock::default();
assert!(lock.is_running());
assert!(lock.is_current());
lock.save(&fs).await?;
let new_lock = StateLock::load(&fs).await?;
assert_eq!(lock, new_lock);
Ok(())
}
}