use std::time::{Duration, Instant};
use reqwest::{header, StatusCode};
use serde::de::DeserializeOwned;
use tracing::{debug, info, warn};
use crate::{
error::{Result, VeilError},
types::{
Health, Job, JobStatus, Proof, RegisterModelRequest, RegisterModelResponse,
SubmitJobRequest, SubmitJobResponse, VerifyResult,
},
};
#[derive(Debug)]
pub struct VeilClientBuilder {
base_url: String,
timeout: Duration,
poll_interval: Duration,
}
impl Default for VeilClientBuilder {
fn default() -> Self {
Self {
base_url: "http://localhost:8080".to_string(),
timeout: Duration::from_secs(600),
poll_interval: Duration::from_secs(3),
}
}
}
impl VeilClientBuilder {
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = url.into().trim_end_matches('/').to_string();
self
}
pub fn timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn poll_interval(mut self, d: Duration) -> Self {
self.poll_interval = d;
self
}
pub fn build(self) -> Result<VeilClient> {
reqwest::Url::parse(&self.base_url)
.map_err(|e| VeilError::InvalidUrl(format!("{}: {e}", self.base_url)))?;
let http = reqwest::Client::builder()
.default_headers({
let mut h = header::HeaderMap::new();
h.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
h.insert(
header::ACCEPT,
header::HeaderValue::from_static("application/json"),
);
h
})
.connect_timeout(Duration::from_secs(10))
.build()
.map_err(VeilError::Http)?;
Ok(VeilClient {
http,
base_url: self.base_url,
timeout: self.timeout,
poll_interval: self.poll_interval,
})
}
}
#[derive(Debug, Clone)]
pub struct VeilClient {
http: reqwest::Client,
base_url: String,
timeout: Duration,
poll_interval: Duration,
}
impl VeilClient {
pub fn builder() -> VeilClientBuilder {
VeilClientBuilder::default()
}
fn url(&self, path: &str) -> String {
format!("{}{path}", self.base_url)
}
async fn parse<T: DeserializeOwned>(&self, res: reqwest::Response) -> Result<T> {
let status = res.status();
if status.is_success() {
Ok(res.json::<T>().await?)
} else {
let message = res
.json::<serde_json::Value>()
.await
.ok()
.and_then(|v| v.get("error").and_then(|e| e.as_str()).map(String::from))
.unwrap_or_else(|| status.to_string());
Err(VeilError::Api {
status: status.as_u16(),
message,
})
}
}
pub async fn health_check(&self) -> Result<Health> {
debug!("GET /healthz");
let res = self.http.get(self.url("/healthz")).send().await?;
self.parse(res).await
}
pub async fn submit_job(
&self,
model_id: impl Into<String>,
input_data: Vec<Vec<f64>>,
) -> Result<String> {
let body = SubmitJobRequest {
input_data,
model_id: model_id.into(),
};
debug!(model_id = %body.model_id, "POST /v1/jobs");
let res = self
.http
.post(self.url("/v1/jobs"))
.json(&body)
.send()
.await?;
let resp: SubmitJobResponse = self.parse(res).await?;
info!(job_id = %resp.job_id, "job submitted");
Ok(resp.job_id)
}
pub async fn get_job(&self, job_id: &str) -> Result<Job> {
debug!(%job_id, "GET /v1/jobs/{job_id}");
let res = self
.http
.get(self.url(&format!("/v1/jobs/{job_id}")))
.send()
.await?;
self.parse(res).await
}
pub async fn get_proof(&self, job_id: &str) -> Result<Proof> {
debug!(%job_id, "GET /v1/jobs/{job_id}/proof");
let res = self
.http
.get(self.url(&format!("/v1/jobs/{job_id}/proof")))
.send()
.await?;
self.parse(res).await
}
pub async fn register_model(&self, req: RegisterModelRequest) -> Result<RegisterModelResponse> {
debug!(name = %req.name, version = %req.version, "POST /v1/models");
let res = self
.http
.post(self.url("/v1/models"))
.json(&req)
.send()
.await?;
self.parse(res).await
}
pub async fn verify_inference(
&self,
model_id: impl Into<String>,
input_data: Vec<Vec<f64>>,
) -> Result<VerifyResult> {
let model_id = model_id.into();
let started = Instant::now();
let job_id = self.submit_job(&model_id, input_data).await?;
info!(%job_id, %model_id, "job submitted — polling until terminal state");
let deadline = started + self.timeout;
let mut last_status = String::from("queued");
loop {
tokio::time::sleep(self.poll_interval).await;
if Instant::now() >= deadline {
return Err(VeilError::Timeout {
job_id,
elapsed_ms: started.elapsed().as_millis() as u64,
last_status,
});
}
let job = match self.get_job(&job_id).await {
Ok(j) => j,
Err(e) => {
warn!(%job_id, "poll error (will retry): {e}");
continue;
}
};
last_status = job.status.to_string();
debug!(%job_id, status = %last_status, "poll");
match &job.status {
JobStatus::Failed => {
return Err(VeilError::JobFailed {
job_id,
reason: job.reason,
});
}
s if s.is_terminal() => {
let elapsed_ms = started.elapsed().as_millis() as u64;
info!(%job_id, status = %last_status, elapsed_ms, "job complete");
return Ok(VerifyResult {
job_id,
status: job.status,
tx_hash: job.tx_hash,
attestation_hash: job.attestation_hash,
elapsed_ms,
});
}
_ => {} }
}
}
}