use anyhow::{bail, Result};
use cloudpub_client::commands::PublishArgs;
pub use cloudpub_client::config::ClientConfig;
use cloudpub_client::ping;
use cloudpub_client::shell::get_cache_dir;
use cloudpub_common::logging::WorkerGuard;
use cloudpub_common::protocol::message::Message;
use cloudpub_common::protocol::{
Break, EndpointClear, EndpointList, EndpointRemove, EndpointStart, EndpointStartAll,
EndpointStop, ServerEndpoint,
};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, mpsc, Mutex};
use tokio::time::{sleep, timeout};
use tracing::debug;
use crate::builder::ConnectionBuilder;
pub type CheckSignalFn = Arc<dyn Fn() -> Result<()> + Send + Sync>;
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionEvent {
#[allow(dead_code)]
Idle,
Authenticating,
Connected,
Endpoint(Box<ServerEndpoint>),
List(Vec<ServerEndpoint>),
Acknowledged,
Error(String),
Closed,
}
pub struct Connection {
pub config: Arc<RwLock<ClientConfig>>,
pub command_tx: mpsc::Sender<Message>,
pub(crate) event_tx: broadcast::Sender<ConnectionEvent>,
pub(crate) receiver_handle: Option<tokio::task::JoinHandle<()>>,
pub(crate) timeout: Arc<Mutex<Duration>>,
pub(crate) check_signal_fn: Option<CheckSignalFn>,
pub(crate) _worker_guard: WorkerGuard,
pub(crate) client_handle: Option<tokio::task::JoinHandle<()>>,
}
impl Connection {
pub fn builder() -> ConnectionBuilder {
ConnectionBuilder::new()
}
pub(crate) fn new(
config: Arc<RwLock<ClientConfig>>,
command_tx: mpsc::Sender<Message>,
mut result_rx: mpsc::Receiver<Message>,
timeout: Duration,
check_signal_fn: Option<CheckSignalFn>,
worker_guard: WorkerGuard,
client_handle: Option<tokio::task::JoinHandle<()>>,
) -> Self {
let (event_tx, _) = broadcast::channel(100);
let event_tx_clone = event_tx.clone();
let command_tx_clone = command_tx.clone();
let receiver_handle = tokio::spawn(async move {
while let Some(msg) = result_rx.recv().await {
debug!("Received message: {:?}", msg);
let new_event = match msg {
Message::ConnectState(st) => {
if st == cloudpub_common::protocol::ConnectState::Connected as i32 {
command_tx_clone
.send(Message::EndpointStartAll(EndpointStartAll {}))
.await
.ok();
ConnectionEvent::Connected
} else if st == cloudpub_common::protocol::ConnectState::Disconnected as i32
{
ConnectionEvent::Closed
} else {
continue;
}
}
Message::EndpointAck(endpoint) => ConnectionEvent::Endpoint(Box::new(endpoint)),
Message::EndpointListAck(list) => ConnectionEvent::List(list.endpoints),
Message::EndpointStopAck(_)
| Message::EndpointRemoveAck(_)
| Message::EndpointClearAck(_) => ConnectionEvent::Acknowledged,
Message::Error(err) => ConnectionEvent::Error(err.message),
Message::Break(_) => ConnectionEvent::Closed,
_ => continue, };
debug!("New event: {:?}", new_event);
let _ = event_tx_clone.send(new_event);
}
});
Connection {
config,
command_tx,
event_tx,
receiver_handle: Some(receiver_handle),
timeout: Arc::new(Mutex::new(timeout)),
check_signal_fn,
_worker_guard: worker_guard,
client_handle,
}
}
pub async fn wait_for_event(
&self,
target_event: impl Fn(&ConnectionEvent) -> bool + Send,
) -> Result<ConnectionEvent> {
let timeout_duration = *self.timeout.lock().await;
let check_signal = self.check_signal_fn.clone();
let mut event_rx = self.event_tx.subscribe();
timeout(timeout_duration, async {
loop {
let current_event = if let Some(ref check_fn) = check_signal {
tokio::select! {
Ok(event) = event_rx.recv() => {
debug!("Received event: {:?}", event);
event
}
_ = sleep(Duration::from_millis(100)) => {
check_fn()?;
continue;
}
}
} else {
match event_rx.recv().await {
Ok(event) => {
debug!("Received event: {:?}", event);
event
}
Err(broadcast::error::RecvError::Closed) => {
bail!("Канал событий закрыт");
}
Err(broadcast::error::RecvError::Lagged(_)) => {
continue;
}
}
};
if let ConnectionEvent::Error(ref msg) = current_event {
bail!("Операция не удалась: {}", msg);
}
if current_event == ConnectionEvent::Closed {
bail!("Соединение закрыто");
}
if target_event(¤t_event) {
return Ok(current_event);
}
debug!("Событие не соответствует целевому, продолжаем...");
}
})
.await
.map_err(|_| anyhow::anyhow!("Таймаут ожидания события"))?
}
pub fn set(&self, key: &str, value: &str) -> Result<()> {
self.config.write().set(key, value)?;
Ok(())
}
pub fn get(&self, key: &str) -> Result<String> {
self.config.read().get(key)
}
pub fn options(&self) -> HashMap<String, String> {
self.config.read().get_all_options().into_iter().collect()
}
pub fn logout(&self) -> Result<()> {
self.config.write().token = None;
self.config.write().save()?;
Ok(())
}
pub async fn register(
&mut self,
protocol: cloudpub_common::protocol::Protocol,
address: String,
name: Option<String>,
auth: Option<cloudpub_common::protocol::Auth>,
acl: Option<Vec<cloudpub_common::protocol::Acl>>,
headers: Option<Vec<cloudpub_common::protocol::Header>>,
rules: Option<Vec<cloudpub_common::protocol::FilterRule>>,
) -> Result<cloudpub_common::protocol::ServerEndpoint> {
let publish_args = PublishArgs {
protocol,
address,
username: None,
password: None,
name,
auth,
acl: acl.unwrap_or_default(),
headers: headers.unwrap_or_default(),
rules: rules.unwrap_or_default(),
};
let endpoint_start = publish_args.parse()?;
self.command_tx
.send(Message::EndpointStart(endpoint_start))
.await?;
let event = self
.wait_for_event(|event| matches!(event, ConnectionEvent::Endpoint(_)))
.await?;
if let ConnectionEvent::Endpoint(endpoint) = event {
Ok(*endpoint)
} else {
anyhow::bail!("Unexpected state")
}
}
pub async fn publish(
&mut self,
protocol: cloudpub_common::protocol::Protocol,
address: String,
name: Option<String>,
auth: Option<cloudpub_common::protocol::Auth>,
acl: Option<Vec<cloudpub_common::protocol::Acl>>,
headers: Option<Vec<cloudpub_common::protocol::Header>>,
rules: Option<Vec<cloudpub_common::protocol::FilterRule>>,
) -> Result<cloudpub_common::protocol::ServerEndpoint> {
let endpoint = self
.register(protocol, address, name, auth, acl, headers, rules)
.await?;
self.start(endpoint.guid.clone()).await?;
Ok(endpoint)
}
pub async fn ls(&mut self) -> Result<Vec<cloudpub_common::protocol::ServerEndpoint>> {
self.command_tx
.send(Message::EndpointList(EndpointList {}))
.await?;
let event = self
.wait_for_event(|event| matches!(event, ConnectionEvent::List(_)))
.await?;
if let ConnectionEvent::List(list) = event {
Ok(list)
} else {
anyhow::bail!("Unexpected state")
}
}
pub async fn start(&mut self, guid: String) -> Result<()> {
self.command_tx
.send(Message::EndpointGuidStart(EndpointStart { guid }))
.await?;
self.wait_for_event(|event| matches!(event, ConnectionEvent::Endpoint(_)))
.await?;
Ok(())
}
pub async fn stop(&mut self, guid: String) -> Result<()> {
self.command_tx
.send(Message::EndpointStop(EndpointStop { guid }))
.await?;
self.wait_for_event(|event| matches!(event, ConnectionEvent::Acknowledged))
.await?;
Ok(())
}
pub async fn unpublish(&mut self, guid: String) -> Result<()> {
self.command_tx
.send(Message::EndpointRemove(EndpointRemove { guid }))
.await?;
self.wait_for_event(|event| matches!(event, ConnectionEvent::Acknowledged))
.await?;
Ok(())
}
pub async fn clean(&mut self) -> Result<()> {
self.command_tx
.send(Message::EndpointClear(EndpointClear {}))
.await?;
self.wait_for_event(|event| matches!(event, ConnectionEvent::Acknowledged))
.await?;
Ok(())
}
pub async fn ping(&mut self) -> Result<u64> {
ping::publish(self.command_tx.clone()).await?;
let event = self
.wait_for_event(|event| matches!(event, ConnectionEvent::Endpoint(_)))
.await?;
let endpoint = if let ConnectionEvent::Endpoint(ep) = event {
*ep
} else {
anyhow::bail!("Unexpected state")
};
Ok(ping::ping_test(endpoint, true).await?.parse::<u64>()?)
}
pub fn purge(&self) -> Result<()> {
let cache_dir = get_cache_dir("")?;
std::fs::remove_dir_all(&cache_dir).ok();
Ok(())
}
}
impl Drop for Connection {
fn drop(&mut self) {
self.command_tx
.try_send(Message::Break(Break {
guid: String::new(),
}))
.ok();
if let Some(handle) = self.receiver_handle.take() {
handle.abort();
}
if let Some(handle) = self.client_handle.take() {
handle.abort();
}
}
}