posemesh_compute_node/dms/
client.rs

1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3    CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9    StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18/// Minimal DMS HTTP client using rustls with sensitive Authorization header.
19#[derive(Clone)]
20pub struct DmsClient {
21    base: Url,
22    http: Client,
23    auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26    /// Create client with base URL, timeout, and a token provider for Authorization.
27    pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28        let http = Client::builder()
29            .use_rustls_tls()
30            .timeout(timeout)
31            .build()
32            .context("build dms reqwest client")?;
33        Ok(Self { base, http, auth })
34    }
35
36    async fn auth_headers(&self) -> Result<HeaderMap> {
37        let mut h = HeaderMap::new();
38        let b = self
39            .auth
40            .bearer()
41            .await
42            .map_err(|e| anyhow!("token provider: {e}"))?;
43        let token = format!("Bearer {}", b);
44        let mut v = HeaderValue::from_str(&token)
45            .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46        v.set_sensitive(true);
47        h.insert(AUTHORIZATION, v);
48        Ok(h)
49    }
50
51    fn join_segments(&self, segments: &[&str]) -> Result<Url> {
52        let mut url = self.base.clone();
53        url.path_segments_mut()
54            .map_err(|_| anyhow!("invalid DMS base URL; cannot be a base"))?
55            .extend(segments.iter().copied());
56        Ok(url)
57    }
58
59    /// Lease a task: GET /tasks
60    ///
61    /// `capability` is accepted for optional filter but not implemented yet.
62    pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
63        let url = self.join_segments(&["tasks"]).context("join /tasks")?;
64        if tracing::enabled!(Level::DEBUG) {
65            tracing::debug!(
66                endpoint = %url,
67                "Sending DMS lease request"
68            );
69        }
70        // First attempt
71        let mut headers = self.auth_headers().await?;
72        let mut res = self
73            .http
74            .get(url.clone())
75            .headers(headers.clone())
76            .send()
77            .await
78            .context("send GET /tasks")?;
79        let mut status = res.status();
80        let mut bytes = res.bytes().await.context("read lease body")?;
81        // Retry once on 401
82        if status == StatusCode::UNAUTHORIZED {
83            let body_preview = String::from_utf8_lossy(&bytes);
84            tracing::warn!(
85                status = %status,
86                body = %body_preview,
87                "DMS lease unauthorized; refreshing token and retrying"
88            );
89            self.auth.on_unauthorized().await;
90            headers = self.auth_headers().await?;
91            res = self
92                .http
93                .get(url)
94                .headers(headers)
95                .send()
96                .await
97                .context("retry GET /tasks")?;
98            status = res.status();
99            bytes = res.bytes().await.context("read lease body (retry)")?;
100        }
101        if status == StatusCode::NO_CONTENT {
102            tracing::debug!("DMS lease returned 204 (no work available)");
103            return Ok(None);
104        }
105        let body_preview = String::from_utf8_lossy(&bytes);
106        if !status.is_success() {
107            tracing::warn!(
108                status = %status,
109                body = %body_preview,
110                "DMS lease request returned non-success status"
111            );
112            return Err(anyhow!("/tasks status: {}", status));
113        }
114        let lease: LeaseResponse = serde_json::from_slice(&bytes)
115            .map_err(|err| {
116                tracing::error!(
117                    error = %err,
118                    body = %body_preview,
119                    "Failed to decode DMS lease response"
120                );
121                err
122            })
123            .context("decode lease")?;
124
125        if tracing::enabled!(Level::DEBUG) {
126            tracing::debug!(
127                status = %status,
128                body = %body_preview,
129                "Decoded DMS lease response"
130            );
131        }
132
133        Ok(Some(lease))
134    }
135
136    /// Complete task: POST /tasks/{id}/complete
137    pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
138        let url = self
139            .join_segments(&["tasks", &task_id.to_string(), "complete"])
140            .context("join /complete")?;
141        let mut headers = self.auth_headers().await?;
142        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
143        if let Some(preview) = json_debug_preview(body) {
144            tracing::debug!(
145                endpoint = %url,
146                task_id = %task_id,
147                body = %preview,
148                "Sending DMS complete request"
149            );
150        }
151        // First attempt
152        let mut res = self
153            .http
154            .post(url.clone())
155            .headers(headers.clone())
156            .json(body)
157            .send()
158            .await
159            .context("send POST /complete")?;
160        let mut status = res.status();
161        let mut body_text = res
162            .text()
163            .await
164            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
165        if status == StatusCode::UNAUTHORIZED {
166            let preview = truncate_preview(&body_text);
167            tracing::warn!(
168                status = %status,
169                body = %preview,
170                task_id = %task_id,
171                "DMS complete unauthorized; refreshing token and retrying"
172            );
173            self.auth.on_unauthorized().await;
174            headers = self.auth_headers().await?;
175            res = self
176                .http
177                .post(url)
178                .headers(headers)
179                .json(body)
180                .send()
181                .await
182                .context("retry POST /complete")?;
183            status = res.status();
184            body_text = res
185                .text()
186                .await
187                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
188        }
189        let preview = truncate_preview(&body_text);
190        if tracing::enabled!(Level::DEBUG) {
191            tracing::debug!(
192                status = %status,
193                body = %preview,
194                task_id = %task_id,
195                "DMS complete response"
196            );
197        }
198        if !status.is_success() {
199            tracing::error!(
200                status = %status,
201                body = %preview,
202                task_id = %task_id,
203                "DMS complete endpoint returned non-success status"
204            );
205            return Err(anyhow!(
206                "POST /tasks/{task_id}/complete status {status}; body: {preview}"
207            ));
208        }
209        Ok(())
210    }
211
212    /// Fail task: POST /tasks/{id}/fail
213    pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
214        let url = self
215            .join_segments(&["tasks", &task_id.to_string(), "fail"])
216            .context("join /fail")?;
217        let mut headers = self.auth_headers().await?;
218        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219        if let Some(preview) = json_debug_preview(body) {
220            tracing::debug!(
221                endpoint = %url,
222                task_id = %task_id,
223                body = %preview,
224                "Sending DMS fail request"
225            );
226        }
227        // First attempt
228        let mut res = self
229            .http
230            .post(url.clone())
231            .headers(headers.clone())
232            .json(body)
233            .send()
234            .await
235            .context("send POST /fail")?;
236        let mut status = res.status();
237        let mut body_text = res
238            .text()
239            .await
240            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
241        if status == StatusCode::UNAUTHORIZED {
242            let preview = truncate_preview(&body_text);
243            tracing::warn!(
244                status = %status,
245                body = %preview,
246                task_id = %task_id,
247                "DMS fail unauthorized; refreshing token and retrying"
248            );
249            self.auth.on_unauthorized().await;
250            headers = self.auth_headers().await?;
251            res = self
252                .http
253                .post(url)
254                .headers(headers)
255                .json(body)
256                .send()
257                .await
258                .context("retry POST /fail")?;
259            status = res.status();
260            body_text = res
261                .text()
262                .await
263                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
264        }
265        let preview = truncate_preview(&body_text);
266        if tracing::enabled!(Level::DEBUG) {
267            tracing::debug!(
268                status = %status,
269                body = %preview,
270                task_id = %task_id,
271                "DMS fail response"
272            );
273        }
274        if !status.is_success() {
275            tracing::error!(
276                status = %status,
277                body = %preview,
278                task_id = %task_id,
279                "DMS fail endpoint returned non-success status"
280            );
281            return Err(anyhow!(
282                "POST /tasks/{task_id}/fail status {status}; body: {preview}"
283            ));
284        }
285        Ok(())
286    }
287
288    /// Heartbeat: POST /tasks/{id}/heartbeat with progress payload.
289    /// Returns potential new access token for storage.
290    pub async fn heartbeat(
291        &self,
292        task_id: Uuid,
293        body: &HeartbeatRequest,
294    ) -> Result<HeartbeatResponse> {
295        let url = self
296            .join_segments(&["tasks", &task_id.to_string(), "heartbeat"])
297            .context("join /heartbeat")?;
298        let mut headers = self.auth_headers().await?;
299        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
300        if let Some(preview) = json_debug_preview(body) {
301            tracing::debug!(
302                endpoint = %url,
303                task_id = %task_id,
304                body = %preview,
305                "Sending DMS heartbeat request"
306            );
307        }
308        // First attempt
309        let mut res = self
310            .http
311            .post(url.clone())
312            .headers(headers.clone())
313            .json(body)
314            .send()
315            .await
316            .context("send POST /heartbeat")?;
317        let mut status = res.status();
318        let mut bytes = res.bytes().await.context("read heartbeat response body")?;
319        if status == StatusCode::UNAUTHORIZED {
320            let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
321            tracing::warn!(
322                status = %status,
323                body = %preview,
324                task_id = %task_id,
325                "DMS heartbeat unauthorized; refreshing token and retrying"
326            );
327            self.auth.on_unauthorized().await;
328            headers = self.auth_headers().await?;
329            res = self
330                .http
331                .post(url)
332                .headers(headers)
333                .json(body)
334                .send()
335                .await
336                .context("retry POST /heartbeat")?;
337            status = res.status();
338            bytes = res
339                .bytes()
340                .await
341                .context("read heartbeat response body (retry)")?;
342        }
343        let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
344        if tracing::enabled!(Level::DEBUG) {
345            tracing::debug!(
346                status = %status,
347                body = %preview,
348                task_id = %task_id,
349                "DMS heartbeat response"
350            );
351        }
352        if !status.is_success() {
353            return Err(anyhow!(
354                "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
355            ));
356        }
357        let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
358            .context("decode heartbeat response")?;
359        Ok(hb)
360    }
361}
362
363fn truncate_preview(body: &str) -> String {
364    const MAX: usize = 512;
365    if body.len() <= MAX {
366        body.to_string()
367    } else {
368        let mut preview: String = body.chars().take(MAX).collect();
369        preview.push_str("… (truncated)");
370        preview
371    }
372}
373
374fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
375    if !tracing::enabled!(Level::DEBUG) {
376        return None;
377    }
378    serde_json::to_string(value)
379        .map(|s| truncate_preview(&s))
380        .ok()
381}