use crate::types::{
ErebusPackedRequest, ObjectCacheRequest, UrlCacheRequest, UrlCacheResponse, HANDSHAKE_TOKEN,
};
use crate::{ErebusResponse, ObjectCacheResponse};
use futures_util::{SinkExt, StreamExt};
use serde::de::DeserializeOwned;
use serde::Serialize;
use tokio::net::TcpStream;
use tokio_tungstenite::{connect_async, tungstenite, MaybeTlsStream, WebSocketStream};
use tokio::time::{timeout, Duration};
pub struct ErebusWebSocketClient {
socket: Option<WebSocketStream<MaybeTlsStream<TcpStream>>>,
authenticated: bool,
}
impl ErebusWebSocketClient {
pub fn new() -> Self {
Self {
socket: None,
authenticated: false,
}
}
async fn connect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let _ = rustls::crypto::ring::default_provider().install_default();
self.connect_internal().await?;
Box::pin(self.authenticate()).await
}
async fn authenticate(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if !self.authenticated {
self.connect_internal().await?;
}
let request = ErebusPackedRequest::Authenticate(HANDSHAKE_TOKEN.to_vec());
let response = self.send_request::<()>(request).await?;
match response {
ErebusResponse::Proceed => {
self.authenticated = true;
Ok(())
}
_ => Err("Authentication failed".into()),
}
}
async fn connect_internal(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let (socket, _) = connect_async(std::env::var("EREBUS_WSS_URL")?).await?;
self.socket = Some(socket);
self.authenticated = false;
Ok(())
}
pub async fn object_cache<T: Serialize + DeserializeOwned + ?Sized>(
&mut self,
name: String,
tags: Vec<String>,
body: Option<T>,
) -> Result<Option<T>, Box<dyn std::error::Error>> {
if !self.authenticated {
self.connect().await?;
}
let packed_request = ErebusPackedRequest::ObjectCache(ObjectCacheRequest {
name,
tags,
body: body.map(|body| serde_json::to_string(&body).unwrap()),
});
let response = self.send_request::<T>(packed_request).await?;
match response {
ErebusResponse::Success(mut data) => {
Ok(rmp_serde::from_slice::<ObjectCacheResponse>(&mut data[..])?
.body
.map(|body| serde_json::from_str(&body).ok())
.flatten())
}
_ => Err("Invalid response from server".into()),
}
}
pub async fn url_cache(
&mut self,
request: UrlCacheRequest,
) -> Result<UrlCacheResponse, Box<dyn std::error::Error>> {
if !self.authenticated {
self.connect().await?;
}
let packed_request = ErebusPackedRequest::UrlCache(request);
let response = self.send_request::<String>(packed_request).await?;
match response {
ErebusResponse::Success(data) => {
let result: UrlCacheResponse = rmp_serde::from_slice(&data[..])?;
Ok(result)
}
_ => Err("Invalid response from server".into()),
}
}
async fn send_request<T: Serialize + DeserializeOwned + ?Sized>(
&mut self,
request: ErebusPackedRequest,
) -> Result<ErebusResponse, Box<dyn std::error::Error>> {
const MAX_ATTEMPTS: u8 = 5;
const TIMEOUT_DURATION: Duration = Duration::from_secs(5);
let buf = rmp_serde::to_vec(&request)?;
for attempt in 0..MAX_ATTEMPTS {
let socket = match self.socket.as_mut() {
Some(s) => s,
None => {
self.connect_internal().await?;
self.socket.as_mut().unwrap()
}
};
if let Err(_) = socket.send(tungstenite::Message::Binary(buf.clone())).await {
self.socket = None;
continue;
}
match timeout(TIMEOUT_DURATION, socket.next()).await {
Ok(Some(Ok(tungstenite::Message::Binary(data)))) => {
return Ok(rmp_serde::from_slice(&data)?);
}
Ok(Some(Ok(_))) => return Err("Unexpected message type".into()),
_ => {
self.socket = None;
if attempt < MAX_ATTEMPTS - 1 {
tokio::time::sleep(Duration::from_millis(100 * (2_u64.pow(attempt as u32)))).await;
}
}
}
}
Err("Failed to send request after multiple attempts".into())
}
async fn close(mut self) {
if let Some(socket) = &mut self.socket {
let _ = socket.close(None).await;
}
}
}