posemesh_compute_node/dms/
client.rs

1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3    CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9    StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18/// Minimal DMS HTTP client using rustls with sensitive Authorization header.
19#[derive(Clone)]
20pub struct DmsClient {
21    base: Url,
22    http: Client,
23    auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26    /// Create client with base URL, timeout, and a token provider for Authorization.
27    pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28        let http = Client::builder()
29            .use_rustls_tls()
30            .timeout(timeout)
31            .build()
32            .context("build dms reqwest client")?;
33        Ok(Self { base, http, auth })
34    }
35
36    async fn auth_headers(&self) -> Result<HeaderMap> {
37        let mut h = HeaderMap::new();
38        let b = self
39            .auth
40            .bearer()
41            .await
42            .map_err(|e| anyhow!("token provider: {e}"))?;
43        let token = format!("Bearer {}", b);
44        let mut v = HeaderValue::from_str(&token)
45            .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46        v.set_sensitive(true);
47        h.insert(AUTHORIZATION, v);
48        Ok(h)
49    }
50
51    /// Lease a task: GET /tasks
52    ///
53    /// `capability` is accepted for optional filter but not implemented yet.
54    pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
55        let url = self.base.join("tasks").context("join /tasks")?;
56        if tracing::enabled!(Level::DEBUG) {
57            tracing::debug!(
58                endpoint = %url,
59                "Sending DMS lease request"
60            );
61        }
62        // First attempt
63        let mut headers = self.auth_headers().await?;
64        let mut res = self
65            .http
66            .get(url.clone())
67            .headers(headers.clone())
68            .send()
69            .await
70            .context("send GET /tasks")?;
71        let mut status = res.status();
72        let mut bytes = res.bytes().await.context("read lease body")?;
73        // Retry once on 401
74        if status == StatusCode::UNAUTHORIZED {
75            let body_preview = String::from_utf8_lossy(&bytes);
76            tracing::warn!(
77                status = %status,
78                body = %body_preview,
79                "DMS lease unauthorized; refreshing token and retrying"
80            );
81            self.auth.on_unauthorized().await;
82            headers = self.auth_headers().await?;
83            res = self
84                .http
85                .get(url)
86                .headers(headers)
87                .send()
88                .await
89                .context("retry GET /tasks")?;
90            status = res.status();
91            bytes = res.bytes().await.context("read lease body (retry)")?;
92        }
93        if status == StatusCode::NO_CONTENT {
94            tracing::debug!("DMS lease returned 204 (no work available)");
95            return Ok(None);
96        }
97        let body_preview = String::from_utf8_lossy(&bytes);
98        if !status.is_success() {
99            tracing::warn!(
100                status = %status,
101                body = %body_preview,
102                "DMS lease request returned non-success status"
103            );
104            return Err(anyhow!("/tasks status: {}", status));
105        }
106        let lease: LeaseResponse = serde_json::from_slice(&bytes)
107            .map_err(|err| {
108                tracing::error!(
109                    error = %err,
110                    body = %body_preview,
111                    "Failed to decode DMS lease response"
112                );
113                err
114            })
115            .context("decode lease")?;
116
117        if tracing::enabled!(Level::DEBUG) {
118            tracing::debug!(
119                status = %status,
120                body = %body_preview,
121                "Decoded DMS lease response"
122            );
123        }
124
125        Ok(Some(lease))
126    }
127
128    /// Complete task: POST /tasks/{id}/complete
129    pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
130        let url = self
131            .base
132            .join(&format!("tasks/{}/complete", task_id))
133            .context("join /complete")?;
134        let mut headers = self.auth_headers().await?;
135        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136        if let Some(preview) = json_debug_preview(body) {
137            tracing::debug!(
138                endpoint = %url,
139                task_id = %task_id,
140                body = %preview,
141                "Sending DMS complete request"
142            );
143        }
144        // First attempt
145        let mut res = self
146            .http
147            .post(url.clone())
148            .headers(headers.clone())
149            .json(body)
150            .send()
151            .await
152            .context("send POST /complete")?;
153        let mut status = res.status();
154        let mut body_text = res
155            .text()
156            .await
157            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
158        if status == StatusCode::UNAUTHORIZED {
159            let preview = truncate_preview(&body_text);
160            tracing::warn!(
161                status = %status,
162                body = %preview,
163                task_id = %task_id,
164                "DMS complete unauthorized; refreshing token and retrying"
165            );
166            self.auth.on_unauthorized().await;
167            headers = self.auth_headers().await?;
168            res = self
169                .http
170                .post(url)
171                .headers(headers)
172                .json(body)
173                .send()
174                .await
175                .context("retry POST /complete")?;
176            status = res.status();
177            body_text = res
178                .text()
179                .await
180                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
181        }
182        let preview = truncate_preview(&body_text);
183        if tracing::enabled!(Level::DEBUG) {
184            tracing::debug!(
185                status = %status,
186                body = %preview,
187                task_id = %task_id,
188                "DMS complete response"
189            );
190        }
191        if !status.is_success() {
192            tracing::error!(
193                status = %status,
194                body = %preview,
195                task_id = %task_id,
196                "DMS complete endpoint returned non-success status"
197            );
198            return Err(anyhow!(
199                "POST /tasks/{task_id}/complete status {status}; body: {preview}"
200            ));
201        }
202        Ok(())
203    }
204
205    /// Fail task: POST /tasks/{id}/fail
206    pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
207        let url = self
208            .base
209            .join(&format!("tasks/{}/fail", task_id))
210            .context("join /fail")?;
211        let mut headers = self.auth_headers().await?;
212        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
213        if let Some(preview) = json_debug_preview(body) {
214            tracing::debug!(
215                endpoint = %url,
216                task_id = %task_id,
217                body = %preview,
218                "Sending DMS fail request"
219            );
220        }
221        // First attempt
222        let mut res = self
223            .http
224            .post(url.clone())
225            .headers(headers.clone())
226            .json(body)
227            .send()
228            .await
229            .context("send POST /fail")?;
230        let mut status = res.status();
231        let mut body_text = res
232            .text()
233            .await
234            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
235        if status == StatusCode::UNAUTHORIZED {
236            let preview = truncate_preview(&body_text);
237            tracing::warn!(
238                status = %status,
239                body = %preview,
240                task_id = %task_id,
241                "DMS fail unauthorized; refreshing token and retrying"
242            );
243            self.auth.on_unauthorized().await;
244            headers = self.auth_headers().await?;
245            res = self
246                .http
247                .post(url)
248                .headers(headers)
249                .json(body)
250                .send()
251                .await
252                .context("retry POST /fail")?;
253            status = res.status();
254            body_text = res
255                .text()
256                .await
257                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
258        }
259        let preview = truncate_preview(&body_text);
260        if tracing::enabled!(Level::DEBUG) {
261            tracing::debug!(
262                status = %status,
263                body = %preview,
264                task_id = %task_id,
265                "DMS fail response"
266            );
267        }
268        if !status.is_success() {
269            tracing::error!(
270                status = %status,
271                body = %preview,
272                task_id = %task_id,
273                "DMS fail endpoint returned non-success status"
274            );
275            return Err(anyhow!(
276                "POST /tasks/{task_id}/fail status {status}; body: {preview}"
277            ));
278        }
279        Ok(())
280    }
281
282    /// Heartbeat: POST /tasks/{id}/heartbeat with progress payload.
283    /// Returns potential new access token for storage.
284    pub async fn heartbeat(
285        &self,
286        task_id: Uuid,
287        body: &HeartbeatRequest,
288    ) -> Result<HeartbeatResponse> {
289        let url = self
290            .base
291            .join(&format!("tasks/{}/heartbeat", task_id))
292            .context("join /heartbeat")?;
293        let mut headers = self.auth_headers().await?;
294        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
295        if let Some(preview) = json_debug_preview(body) {
296            tracing::debug!(
297                endpoint = %url,
298                task_id = %task_id,
299                body = %preview,
300                "Sending DMS heartbeat request"
301            );
302        }
303        // First attempt
304        let mut res = self
305            .http
306            .post(url.clone())
307            .headers(headers.clone())
308            .json(body)
309            .send()
310            .await
311            .context("send POST /heartbeat")?;
312        let mut status = res.status();
313        let mut bytes = res.bytes().await.context("read heartbeat response body")?;
314        if status == StatusCode::UNAUTHORIZED {
315            let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
316            tracing::warn!(
317                status = %status,
318                body = %preview,
319                task_id = %task_id,
320                "DMS heartbeat unauthorized; refreshing token and retrying"
321            );
322            self.auth.on_unauthorized().await;
323            headers = self.auth_headers().await?;
324            res = self
325                .http
326                .post(url)
327                .headers(headers)
328                .json(body)
329                .send()
330                .await
331                .context("retry POST /heartbeat")?;
332            status = res.status();
333            bytes = res
334                .bytes()
335                .await
336                .context("read heartbeat response body (retry)")?;
337        }
338        let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
339        if tracing::enabled!(Level::DEBUG) {
340            tracing::debug!(
341                status = %status,
342                body = %preview,
343                task_id = %task_id,
344                "DMS heartbeat response"
345            );
346        }
347        if !status.is_success() {
348            return Err(anyhow!(
349                "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
350            ));
351        }
352        let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
353            .context("decode heartbeat response")?;
354        Ok(hb)
355    }
356}
357
358fn truncate_preview(body: &str) -> String {
359    const MAX: usize = 512;
360    if body.len() <= MAX {
361        body.to_string()
362    } else {
363        let mut preview: String = body.chars().take(MAX).collect();
364        preview.push_str("… (truncated)");
365        preview
366    }
367}
368
369fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
370    if !tracing::enabled!(Level::DEBUG) {
371        return None;
372    }
373    serde_json::to_string(value)
374        .map(|s| truncate_preview(&s))
375        .ok()
376}