use crate::generic_task_aggregation::{SignedTaskResponse, TaskResponse};
use blueprint_core::{debug, error, info};
use reqwest::{Client, Url};
use serde::{Serialize, de::DeserializeOwned};
use serde_json::{Value, json};
use std::fmt::Debug;
use std::marker::PhantomData;
use std::time::Duration;
use thiserror::Error;
use tokio::time::sleep;
#[derive(Error, Debug)]
pub enum ClientError {
#[error("URL parsing error: {0}")]
UrlParseError(#[from] url::ParseError),
#[error("JSON error: {0}")]
JsonError(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
HttpError(#[from] reqwest::Error),
#[error("RPC error: {0}")]
RpcError(String),
#[error("Retry limit exceeded after {0} attempts")]
RetryLimitExceeded(u32),
}
pub type Result<T> = std::result::Result<T, ClientError>;
#[derive(Debug, Clone)]
pub struct ClientConfig {
pub max_retries: u32,
pub initial_retry_delay: Duration,
pub use_exponential_backoff: bool,
}
impl Default for ClientConfig {
fn default() -> Self {
Self {
max_retries: 5,
initial_retry_delay: Duration::from_secs(1),
use_exponential_backoff: true,
}
}
}
#[derive(Debug, Clone)]
pub struct AggregatorClient<E, R, C>
where
E: Serialize + Clone + Send + Sync + 'static,
R: TaskResponse + DeserializeOwned + Debug + Send,
C: Fn(E) -> SignedTaskResponse<R> + Clone + Send + Sync + 'static,
{
client: Client,
server_url: Url,
converter: C,
config: ClientConfig,
_phantom: PhantomData<(E, R)>,
}
impl<E, R, C> AggregatorClient<E, R, C>
where
E: Serialize + Clone + Send + Sync + 'static,
R: TaskResponse + DeserializeOwned + Debug + Send,
C: Fn(E) -> SignedTaskResponse<R> + Clone + Send + Sync + 'static,
{
pub fn new(aggregator_address: &str, converter: C) -> Result<Self> {
let url = Url::parse(&format!("http://{}", aggregator_address))?;
let client = Client::new();
Ok(Self {
client,
server_url: url,
converter,
config: ClientConfig::default(),
_phantom: PhantomData,
})
}
pub fn with_config(
aggregator_address: &str,
converter: C,
config: ClientConfig,
) -> Result<Self> {
let url = Url::parse(&format!("http://{}", aggregator_address))?;
let client = Client::new();
Ok(Self {
client,
server_url: url,
converter,
config,
_phantom: PhantomData,
})
}
async fn json_rpc_request<T, U>(&self, method: &str, params: T) -> Result<U>
where
T: Serialize,
U: DeserializeOwned,
{
let request = json!({
"jsonrpc": "2.0",
"method": method,
"params": params,
"id": 1
});
let response = self
.client
.post(self.server_url.clone())
.json(&request)
.send()
.await?
.json::<Value>()
.await?;
if let Some(error) = response.get("error") {
return Err(ClientError::RpcError(error.to_string()));
}
let result = response
.get("result")
.ok_or_else(|| ClientError::RpcError("Missing 'result' in response".to_string()))?;
let result = serde_json::from_value(result.clone())?;
Ok(result)
}
pub async fn send_signed_task_response(&self, external_response: E) -> Result<()> {
let generic_response = (self.converter)(external_response.clone());
for attempt in 1..=self.config.max_retries {
match self
.json_rpc_request::<E, bool>(
"process_signed_task_response",
external_response.clone(),
)
.await
{
Ok(true) => {
info!(
"Task response accepted by aggregator for task {}",
generic_response.response.reference_task_index()
);
return Ok(());
}
Ok(false) => {
debug!("Task response not accepted, retrying...");
}
Err(e) => {
error!("Error sending task response: {}", e);
}
}
if attempt < self.config.max_retries {
let delay = if self.config.use_exponential_backoff {
self.config.initial_retry_delay * 2u32.pow(attempt - 1)
} else {
self.config.initial_retry_delay
};
info!("Retrying in {} seconds...", delay.as_secs());
sleep(delay).await;
}
}
debug!(
"Failed to send signed task response after {} attempts",
self.config.max_retries
);
Err(ClientError::RetryLimitExceeded(self.config.max_retries))
}
}