use std::path::{Path, PathBuf};
use std::time::Duration;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use crate::client::{parse_response, DEFAULT_SOCKET_PATH, DEFAULT_TIMEOUT};
use crate::error::{GetMyIdError, Result};
use crate::types::{Identity, RunnerRequest};
#[derive(Debug, Clone)]
pub struct AsyncClient {
socket_path: PathBuf,
timeout: Option<Duration>,
}
impl Default for AsyncClient {
fn default() -> Self {
Self::new()
}
}
impl AsyncClient {
pub fn new() -> Self {
Self {
socket_path: PathBuf::from(DEFAULT_SOCKET_PATH),
timeout: Some(DEFAULT_TIMEOUT),
}
}
pub fn builder() -> AsyncClientBuilder {
AsyncClientBuilder::new()
}
pub async fn get_identity(&self) -> Result<Identity> {
self.get_identity_with_runner(None).await
}
pub async fn get_identity_with_runner(&self, runner: Option<RunnerRequest>) -> Result<Identity> {
if !self.socket_path.exists() {
return Err(GetMyIdError::SocketNotFound(self.socket_path.clone()));
}
let get_identity_inner = async {
let mut stream = UnixStream::connect(&self.socket_path).await.map_err(|e| {
GetMyIdError::ConnectionFailed {
path: self.socket_path.clone(),
source: e,
}
})?;
if let Some(ref runner_req) = runner {
let request = serde_json::json!({ "runner": runner_req });
let request_str = serde_json::to_string(&request).map_err(GetMyIdError::InvalidJson)?;
stream
.write_all(request_str.as_bytes())
.await
.map_err(GetMyIdError::WriteError)?;
stream.flush().await.map_err(GetMyIdError::WriteError)?;
stream.shutdown().await.ok();
}
let mut response = String::new();
stream
.read_to_string(&mut response)
.await
.map_err(GetMyIdError::ReadError)?;
parse_response(&response)
};
if let Some(timeout) = self.timeout {
tokio::time::timeout(timeout, get_identity_inner)
.await
.map_err(|_| GetMyIdError::Timeout(timeout))?
} else {
get_identity_inner.await
}
}
pub fn socket_path(&self) -> &Path {
&self.socket_path
}
pub fn timeout(&self) -> Option<Duration> {
self.timeout
}
}
#[derive(Debug, Clone)]
pub struct AsyncClientBuilder {
socket_path: PathBuf,
timeout: Option<Duration>,
}
impl Default for AsyncClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl AsyncClientBuilder {
pub fn new() -> Self {
Self {
socket_path: PathBuf::from(DEFAULT_SOCKET_PATH),
timeout: Some(DEFAULT_TIMEOUT),
}
}
pub fn socket_path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.socket_path = path.as_ref().to_path_buf();
self
}
pub fn timeout(mut self, timeout: impl Into<Option<Duration>>) -> Self {
self.timeout = timeout.into();
self
}
pub fn build(self) -> AsyncClient {
AsyncClient {
socket_path: self.socket_path,
timeout: self.timeout,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_async_client_builder() {
let client = AsyncClient::builder()
.socket_path("/tmp/test.sock")
.timeout(Duration::from_secs(10))
.build();
assert_eq!(client.socket_path(), Path::new("/tmp/test.sock"));
assert_eq!(client.timeout(), Some(Duration::from_secs(10)));
}
#[test]
fn test_async_client_builder_no_timeout() {
let client = AsyncClient::builder().timeout(None).build();
assert_eq!(client.timeout(), None);
}
#[test]
fn test_default_async_client() {
let client = AsyncClient::new();
assert_eq!(client.socket_path(), Path::new(DEFAULT_SOCKET_PATH));
assert_eq!(client.timeout(), Some(DEFAULT_TIMEOUT));
}
}