use dotenvy::dotenv;
use reqwest::StatusCode;
use reqwest::blocking::Client as HttpClient;
use serde::Serialize;
use std::env;
use std::io;
pub struct LockserverClient {
addr: String,
owner: String,
secret: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LockMode {
Blocking,
NonBlocking,
}
impl LockserverClient {
pub fn new_with_env(
addr: Option<impl Into<String>>,
owner: Option<impl Into<String>>,
secret: Option<impl Into<String>>,
) -> Self {
let _ = dotenv();
let addr = addr
.map(|a| a.into())
.or_else(|| env::var("LOCKSERVER_ADDR").ok())
.unwrap_or_else(|| "127.0.0.1:8080".to_string());
let owner = owner
.map(|o| o.into())
.or_else(|| env::var("LOCKSERVER_OWNER").ok())
.unwrap_or_else(|| "default_owner".to_string());
let secret = secret
.map(|s| s.into())
.or_else(|| env::var("LOCKSERVER_SECRET").ok())
.unwrap_or_else(|| "changeme".to_string());
Self {
addr,
owner,
secret,
}
}
pub fn new(
addr: impl Into<String>,
owner: impl Into<String>,
secret: impl Into<String>,
) -> Self {
Self {
addr: addr.into(),
owner: owner.into(),
secret: secret.into(),
}
}
pub fn acquire(&self, resource: &str) -> io::Result<()> {
self.acquire_with_mode_and_expire(resource, LockMode::Blocking, None)
}
pub fn acquire_with_mode(&self, resource: &str, mode: LockMode) -> io::Result<()> {
self.acquire_with_mode_and_expire(resource, mode, None)
}
pub fn acquire_with_mode_and_expire(&self, resource: &str, mode: LockMode, expire: Option<u64>) -> io::Result<()> {
#[derive(Serialize)]
struct LockRequest<'a> {
resource: &'a str,
owner: &'a str,
expire: Option<u64>,
}
let client = HttpClient::new();
let url = format!("http://{}/acquire", self.addr);
let req = LockRequest {
resource,
owner: &self.owner,
expire,
};
loop {
let resp = client
.post(&url)
.header("X-LOCKSERVER-SECRET", &self.secret)
.json(&req)
.send();
match resp {
Ok(r) if r.status() == StatusCode::OK => return Ok(()),
Ok(r) if r.status() == StatusCode::CONFLICT => {
if mode == LockMode::NonBlocking {
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"Resource is locked",
));
} else {
std::thread::sleep(std::time::Duration::from_millis(200));
}
}
Ok(r) => {
return Err(io::Error::other(format!("HTTP error: {}", r.status())));
}
Err(e) => {
return Err(io::Error::other(format!("Request error: {}", e)));
}
}
}
}
pub fn release(&self, resource: &str) -> io::Result<()> {
#[derive(Serialize)]
struct LockRequest<'a> {
resource: &'a str,
owner: &'a str,
}
let client = HttpClient::new();
let url = format!("http://{}/release", self.addr);
let req = LockRequest {
resource,
owner: &self.owner,
};
let resp = client
.post(&url)
.header("X-LOCKSERVER-SECRET", &self.secret)
.json(&req)
.send();
match resp {
Ok(r) if r.status() == StatusCode::OK => Ok(()),
Ok(r) => Err(io::Error::other(format!("HTTP error: {}", r.status()))),
Err(e) => Err(io::Error::other(format!("Request error: {}", e))),
}
}
}
#[macro_export]
macro_rules! lock_scope {
($client:expr, $resource:expr, $block:block) => {{
$client.acquire($resource).expect("Failed to acquire lock");
let _guard = $crate::LockGuard::new($client, $resource);
let result = (|| $block)();
result
}};
($client:expr, $resource:expr, non_blocking, $block:block) => {{
$client
.acquire_with_mode($resource, lockserver::client::LockMode::NonBlocking)
.expect("Failed to acquire lock (non-blocking)");
let _guard = $crate::LockGuard::new($client, $resource);
let result = (|| $block)();
result
}};
}
pub struct LockGuard<'a> {
client: &'a LockserverClient,
resource: &'a str,
}
impl<'a> LockGuard<'a> {
pub fn new(client: &'a LockserverClient, resource: &'a str) -> Self {
Self { client, resource }
}
}
impl<'a> Drop for LockGuard<'a> {
fn drop(&mut self) {
let _ = self.client.release(self.resource);
}
}