veil_sdk/client.rs
1use std::time::{Duration, Instant};
2
3use reqwest::{header, StatusCode};
4use serde::de::DeserializeOwned;
5use tracing::{debug, info, warn};
6
7use crate::{
8 error::{Result, VeilError},
9 types::{
10 Health, Job, JobStatus, Proof, RegisterModelRequest, RegisterModelResponse,
11 SubmitJobRequest, SubmitJobResponse, VerifyResult,
12 },
13};
14
15// ── Builder ───────────────────────────────────────────────────────────────────
16
17/// Builder for [`VeilClient`].
18///
19/// ```rust
20/// use veil_sdk::VeilClient;
21/// use std::time::Duration;
22///
23/// let client = VeilClient::builder()
24/// .base_url("http://localhost:8080")
25/// .timeout(Duration::from_secs(600))
26/// .poll_interval(Duration::from_secs(3))
27/// .build()
28/// .unwrap();
29/// ```
30#[derive(Debug)]
31pub struct VeilClientBuilder {
32 base_url: String,
33 timeout: Duration,
34 poll_interval: Duration,
35}
36
37impl Default for VeilClientBuilder {
38 fn default() -> Self {
39 Self {
40 base_url: "http://localhost:8080".to_string(),
41 timeout: Duration::from_secs(600),
42 poll_interval: Duration::from_secs(3),
43 }
44 }
45}
46
47impl VeilClientBuilder {
48 /// Base URL of the Veil gateway, e.g. `"https://api.mugen.network"`.
49 /// Trailing slashes are stripped automatically.
50 pub fn base_url(mut self, url: impl Into<String>) -> Self {
51 self.base_url = url.into().trim_end_matches('/').to_string();
52 self
53 }
54
55 /// Maximum wall-clock time to wait for a job to reach a terminal state.
56 /// Applies to `verify_inference`. Defaults to 600 seconds.
57 pub fn timeout(mut self, d: Duration) -> Self {
58 self.timeout = d;
59 self
60 }
61
62 /// How often to poll `GET /v1/jobs/{id}` while waiting.
63 /// Defaults to 3 seconds.
64 pub fn poll_interval(mut self, d: Duration) -> Self {
65 self.poll_interval = d;
66 self
67 }
68
69 /// Consume the builder and construct a [`VeilClient`].
70 ///
71 /// # Errors
72 /// Returns [`VeilError::InvalidUrl`] if the base URL cannot be parsed by
73 /// `reqwest`.
74 pub fn build(self) -> Result<VeilClient> {
75 // Validate the URL by attempting to construct a reqwest client
76 // with a test request (parse-only, no network).
77 reqwest::Url::parse(&self.base_url)
78 .map_err(|e| VeilError::InvalidUrl(format!("{}: {e}", self.base_url)))?;
79
80 let http = reqwest::Client::builder()
81 .default_headers({
82 let mut h = header::HeaderMap::new();
83 h.insert(
84 header::CONTENT_TYPE,
85 header::HeaderValue::from_static("application/json"),
86 );
87 h.insert(
88 header::ACCEPT,
89 header::HeaderValue::from_static("application/json"),
90 );
91 h
92 })
93 // reqwest's own connection timeout — separate from our poll timeout.
94 .connect_timeout(Duration::from_secs(10))
95 .build()
96 .map_err(VeilError::Http)?;
97
98 Ok(VeilClient {
99 http,
100 base_url: self.base_url,
101 timeout: self.timeout,
102 poll_interval: self.poll_interval,
103 })
104 }
105}
106
107// ── Client ────────────────────────────────────────────────────────────────────
108
109/// Async client for the Mugen Veil verifiable inference gateway.
110///
111/// Construct via [`VeilClient::builder()`].
112///
113/// `VeilClient` is cheap to clone — the underlying `reqwest::Client` uses an
114/// `Arc` internally and shares the connection pool across clones.
115#[derive(Debug, Clone)]
116pub struct VeilClient {
117 http: reqwest::Client,
118 base_url: String,
119 timeout: Duration,
120 poll_interval: Duration,
121}
122
123impl VeilClient {
124 /// Begin building a client. See [`VeilClientBuilder`].
125 pub fn builder() -> VeilClientBuilder {
126 VeilClientBuilder::default()
127 }
128
129 // ── Private helpers ───────────────────────────────────────────────────────
130
131 fn url(&self, path: &str) -> String {
132 format!("{}{path}", self.base_url)
133 }
134
135 /// Parse a response, surfacing gateway error bodies as [`VeilError::Api`].
136 async fn parse<T: DeserializeOwned>(&self, res: reqwest::Response) -> Result<T> {
137 let status = res.status();
138 if status.is_success() {
139 Ok(res.json::<T>().await?)
140 } else {
141 // Best-effort extraction of an `error` field from the JSON body.
142 let message = res
143 .json::<serde_json::Value>()
144 .await
145 .ok()
146 .and_then(|v| v.get("error").and_then(|e| e.as_str()).map(String::from))
147 .unwrap_or_else(|| status.to_string());
148
149 Err(VeilError::Api {
150 status: status.as_u16(),
151 message,
152 })
153 }
154 }
155
156 // ── Primitive methods ─────────────────────────────────────────────────────
157
158 /// `GET /healthz` — Returns gateway health information.
159 ///
160 /// Use [`Health::is_healthy()`] to check overall readiness.
161 pub async fn health_check(&self) -> Result<Health> {
162 debug!("GET /healthz");
163 let res = self.http.get(self.url("/healthz")).send().await?;
164 self.parse(res).await
165 }
166
167 /// `POST /v1/jobs` — Submit an inference job and return immediately.
168 ///
169 /// Returns the `job_id`. Use [`get_job`](Self::get_job) to poll status,
170 /// or [`verify_inference`](Self::verify_inference) to submit and wait.
171 ///
172 /// # Arguments
173 /// - `model_id` — the model name registered with the gateway (e.g. `"tiny_mlp_v1"`)
174 /// - `input_data` — row-major input tensor, e.g. `vec![vec![0.1, 0.2, 0.3, 0.4]]`
175 pub async fn submit_job(
176 &self,
177 model_id: impl Into<String>,
178 input_data: Vec<Vec<f64>>,
179 ) -> Result<String> {
180 let body = SubmitJobRequest {
181 input_data,
182 model_id: model_id.into(),
183 };
184
185 debug!(model_id = %body.model_id, "POST /v1/jobs");
186
187 let res = self
188 .http
189 .post(self.url("/v1/jobs"))
190 .json(&body)
191 .send()
192 .await?;
193
194 let resp: SubmitJobResponse = self.parse(res).await?;
195 info!(job_id = %resp.job_id, "job submitted");
196 Ok(resp.job_id)
197 }
198
199 /// `GET /v1/jobs/{id}` — Poll the status of a job.
200 pub async fn get_job(&self, job_id: &str) -> Result<Job> {
201 debug!(%job_id, "GET /v1/jobs/{job_id}");
202 let res = self
203 .http
204 .get(self.url(&format!("/v1/jobs/{job_id}")))
205 .send()
206 .await?;
207 self.parse(res).await
208 }
209
210 /// `GET /v1/jobs/{id}/proof` — Fetch the raw proof bytes for a completed job.
211 ///
212 /// The gateway returns `HTTP 202` if the job is not yet complete.
213 /// Returns [`VeilError::Api`] with status 202 in that case — callers should
214 /// poll [`get_job`](Self::get_job) first.
215 pub async fn get_proof(&self, job_id: &str) -> Result<Proof> {
216 debug!(%job_id, "GET /v1/jobs/{job_id}/proof");
217 let res = self
218 .http
219 .get(self.url(&format!("/v1/jobs/{job_id}/proof")))
220 .send()
221 .await?;
222 self.parse(res).await
223 }
224
225 /// `POST /v1/models` — Register an ONNX model artifact with the gateway.
226 ///
227 /// Pins the artifact to IPFS and registers it on-chain.
228 /// Requires `PINATA_JWT` to be configured on the gateway.
229 pub async fn register_model(&self, req: RegisterModelRequest) -> Result<RegisterModelResponse> {
230 debug!(name = %req.name, version = %req.version, "POST /v1/models");
231 let res = self
232 .http
233 .post(self.url("/v1/models"))
234 .json(&req)
235 .send()
236 .await?;
237 self.parse(res).await
238 }
239
240 // ── High-level method ─────────────────────────────────────────────────────
241
242 /// Submit an inference job and block until it reaches a terminal state.
243 ///
244 /// Polls `GET /v1/jobs/{id}` at the configured `poll_interval` until the
245 /// job status is one of `done`, `settled`, or `failed`, or until the
246 /// configured `timeout` elapses.
247 ///
248 /// # Arguments
249 /// - `model_id` — registered model name, e.g. `"tiny_mlp_v1"`
250 /// - `input_data` — row-major input tensor
251 ///
252 /// # Errors
253 /// - [`VeilError::Timeout`] — polling exceeded the configured timeout
254 /// - [`VeilError::JobFailed`] — the gateway reported the job as failed
255 /// - [`VeilError::Api`] — gateway returned a non-2xx response
256 /// - [`VeilError::Http`] — network-level failure
257 ///
258 /// # Example
259 /// ```rust,no_run
260 /// # use veil_sdk::VeilClient;
261 /// # #[tokio::main] async fn main() -> veil_sdk::error::Result<()> {
262 /// let client = VeilClient::builder()
263 /// .base_url("http://localhost:8080")
264 /// .build()?;
265 ///
266 /// let result = client
267 /// .verify_inference("tiny_mlp_v1", vec![vec![0.1, 0.2, 0.3, 0.4]])
268 /// .await?;
269 ///
270 /// println!("tx_hash: {:?}", result.tx_hash);
271 /// # Ok(())
272 /// # }
273 /// ```
274 pub async fn verify_inference(
275 &self,
276 model_id: impl Into<String>,
277 input_data: Vec<Vec<f64>>,
278 ) -> Result<VerifyResult> {
279 let model_id = model_id.into();
280 let started = Instant::now();
281
282 // 1. Submit
283 let job_id = self.submit_job(&model_id, input_data).await?;
284 info!(%job_id, %model_id, "job submitted — polling until terminal state");
285
286 // 2. Poll
287 let deadline = started + self.timeout;
288 let mut last_status = String::from("queued");
289
290 loop {
291 tokio::time::sleep(self.poll_interval).await;
292
293 if Instant::now() >= deadline {
294 return Err(VeilError::Timeout {
295 job_id,
296 elapsed_ms: started.elapsed().as_millis() as u64,
297 last_status,
298 });
299 }
300
301 let job = match self.get_job(&job_id).await {
302 Ok(j) => j,
303 Err(e) => {
304 // Transient network errors during polling are logged and
305 // retried rather than propagated immediately.
306 warn!(%job_id, "poll error (will retry): {e}");
307 continue;
308 }
309 };
310
311 last_status = job.status.to_string();
312 debug!(%job_id, status = %last_status, "poll");
313
314 match &job.status {
315 JobStatus::Failed => {
316 return Err(VeilError::JobFailed {
317 job_id,
318 reason: job.reason,
319 });
320 }
321 s if s.is_terminal() => {
322 let elapsed_ms = started.elapsed().as_millis() as u64;
323 info!(%job_id, status = %last_status, elapsed_ms, "job complete");
324 return Ok(VerifyResult {
325 job_id,
326 status: job.status,
327 tx_hash: job.tx_hash,
328 attestation_hash: job.attestation_hash,
329 elapsed_ms,
330 });
331 }
332 _ => {} // still in progress — keep polling
333 }
334 }
335 }
336}