use crate::types::*;
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ClientError {
#[error("HTTP error: {0}")]
Http(#[from] reqwest::Error),
#[error("Server error: {0}")]
Server(String),
#[error("Invalid response: {0}")]
InvalidResponse(String),
#[error("Task not found")]
NotFound,
#[error("Threshold not met: {collected}/{required}")]
ThresholdNotMet { collected: usize, required: usize },
}
#[derive(Debug, Clone)]
pub struct AggregationServiceClient {
client: reqwest::Client,
base_url: String,
}
impl AggregationServiceClient {
pub fn new(base_url: impl Into<String>) -> Self {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client");
Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
}
}
pub fn with_client(client: reqwest::Client, base_url: impl Into<String>) -> Self {
Self {
client,
base_url: base_url.into().trim_end_matches('/').to_string(),
}
}
pub async fn health(&self) -> Result<bool, ClientError> {
let url = format!("{}/health", self.base_url);
let response = self.client.get(&url).send().await?;
Ok(response.status().is_success())
}
pub async fn init_task(
&self,
service_id: u64,
call_id: u64,
output: &[u8],
operator_count: u32,
threshold: ThresholdConfig,
) -> Result<(), ClientError> {
let url = format!("{}/v1/tasks/init", self.base_url);
let request = InitTaskRequest {
service_id,
call_id,
operator_count,
threshold,
output: output.to_vec(),
};
let response: InitTaskResponse = self
.client
.post(&url)
.json(&request)
.send()
.await?
.json()
.await?;
if response.success {
Ok(())
} else {
Err(ClientError::Server(
response
.error
.unwrap_or_else(|| "Unknown error".to_string()),
))
}
}
pub async fn submit_signature(
&self,
request: SubmitSignatureRequest,
) -> Result<SubmitSignatureResponse, ClientError> {
let url = format!("{}/v1/tasks/submit", self.base_url);
let response: SubmitSignatureResponse = self
.client
.post(&url)
.json(&request)
.send()
.await?
.json()
.await?;
if response.accepted {
Ok(response)
} else {
Err(ClientError::Server(
response
.error
.unwrap_or_else(|| "Signature rejected".to_string()),
))
}
}
pub async fn get_status(
&self,
service_id: u64,
call_id: u64,
) -> Result<GetStatusResponse, ClientError> {
let url = format!("{}/v1/tasks/status", self.base_url);
let request = GetStatusRequest {
service_id,
call_id,
};
let response: GetStatusResponse = self
.client
.post(&url)
.json(&request)
.send()
.await?
.json()
.await?;
if !response.exists {
return Err(ClientError::NotFound);
}
Ok(response)
}
pub async fn get_aggregated(
&self,
service_id: u64,
call_id: u64,
) -> Result<Option<AggregatedResultResponse>, ClientError> {
let url = format!("{}/v1/tasks/aggregate", self.base_url);
let request = GetStatusRequest {
service_id,
call_id,
};
let response = self.client.post(&url).json(&request).send().await?;
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Ok(None);
}
if !response.status().is_success() {
return Err(ClientError::Server(format!(
"Server returned {}",
response.status()
)));
}
let result: Option<AggregatedResultResponse> = response.json().await?;
Ok(result)
}
pub async fn mark_submitted(&self, service_id: u64, call_id: u64) -> Result<(), ClientError> {
let url = format!("{}/v1/tasks/mark-submitted", self.base_url);
let request = GetStatusRequest {
service_id,
call_id,
};
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let error: serde_json::Value = response.json().await?;
return Err(ClientError::Server(
error["error"]
.as_str()
.unwrap_or("Unknown error")
.to_string(),
));
}
Ok(())
}
pub async fn get_stats(&self) -> Result<crate::ServiceStats, ClientError> {
let url = format!("{}/v1/stats", self.base_url);
let response = self.client.get(&url).send().await?;
if !response.status().is_success() {
return Err(ClientError::Server(format!(
"Server returned {}",
response.status()
)));
}
let stats: crate::ServiceStats = response.json().await?;
Ok(stats)
}
pub async fn wait_for_threshold(
&self,
service_id: u64,
call_id: u64,
poll_interval: Duration,
timeout: Duration,
) -> Result<AggregatedResultResponse, ClientError> {
let start = std::time::Instant::now();
loop {
if start.elapsed() > timeout {
let status = self.get_status(service_id, call_id).await?;
return Err(ClientError::ThresholdNotMet {
collected: status.signatures_collected,
required: status.threshold_required,
});
}
if let Some(result) = self.get_aggregated(service_id, call_id).await? {
return Ok(result);
}
tokio::time::sleep(poll_interval).await;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_url_normalization() {
let client = AggregationServiceClient::new("http://localhost:8080/");
assert_eq!(client.base_url, "http://localhost:8080");
let client = AggregationServiceClient::new("http://localhost:8080");
assert_eq!(client.base_url, "http://localhost:8080");
}
}