Skip to main content

posemesh_compute_node/dms/
client.rs

1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3    CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9    StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18/// Minimal DMS HTTP client using rustls with sensitive Authorization header.
19#[derive(Clone)]
20pub struct DmsClient {
21    base: Url,
22    http: Client,
23    auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26    /// Create client with base URL, timeout, and a token provider for Authorization.
27    pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28        let http = Client::builder()
29            .use_rustls_tls()
30            .timeout(timeout)
31            .build()
32            .context("build dms reqwest client")?;
33        Ok(Self { base, http, auth })
34    }
35
36    async fn auth_headers(&self) -> Result<HeaderMap> {
37        let mut h = HeaderMap::new();
38        let b = self
39            .auth
40            .bearer()
41            .await
42            .map_err(|e| anyhow!("token provider: {e}"))?;
43        let token = format!("Bearer {}", b);
44        let mut v = HeaderValue::from_str(&token)
45            .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46        v.set_sensitive(true);
47        h.insert(AUTHORIZATION, v);
48        Ok(h)
49    }
50
51    fn join_segments(&self, segments: &[&str]) -> Result<Url> {
52        let mut url = self.base.clone();
53        url.path_segments_mut()
54            .map_err(|_| anyhow!("invalid DMS base URL; cannot be a base"))?
55            .extend(segments.iter().copied());
56        Ok(url)
57    }
58
59    /// Lease a task: GET /tasks
60    ///
61    /// `capability` is accepted for optional filter but not implemented yet.
62    pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
63        let url = self.join_segments(&["tasks"]).context("join /tasks")?;
64        if tracing::enabled!(Level::DEBUG) {
65            tracing::debug!(
66                endpoint = %url,
67                "Sending DMS lease request"
68            );
69        }
70        // First attempt
71        let mut headers = self.auth_headers().await?;
72        let mut res = self
73            .http
74            .get(url.clone())
75            .headers(headers.clone())
76            .send()
77            .await
78            .context("send GET /tasks")?;
79        let mut status = res.status();
80        let mut bytes = res.bytes().await.context("read lease body")?;
81        // Retry once on 401
82        if status == StatusCode::UNAUTHORIZED {
83            let body_preview = String::from_utf8_lossy(&bytes);
84            tracing::warn!(
85                status = %status,
86                body = %body_preview,
87                "DMS lease unauthorized; refreshing token and retrying"
88            );
89            self.auth.on_unauthorized().await;
90            headers = self.auth_headers().await?;
91            res = self
92                .http
93                .get(url)
94                .headers(headers)
95                .send()
96                .await
97                .context("retry GET /tasks")?;
98            status = res.status();
99            bytes = res.bytes().await.context("read lease body (retry)")?;
100        }
101        if status == StatusCode::NO_CONTENT {
102            tracing::debug!("DMS lease returned 204 (no work available)");
103            return Ok(None);
104        }
105        let body_preview = String::from_utf8_lossy(&bytes);
106        if status == StatusCode::CONFLICT {
107            tracing::debug!(
108                status = %status,
109                body = %body_preview,
110                "DMS lease returned conflict (busy); treating as no work"
111            );
112            return Ok(None);
113        }
114        if !status.is_success() {
115            tracing::warn!(
116                status = %status,
117                body = %body_preview,
118                "DMS lease request returned non-success status"
119            );
120            return Err(anyhow!("/tasks status: {}", status));
121        }
122        let lease: LeaseResponse = serde_json::from_slice(&bytes)
123            .map_err(|err| {
124                tracing::error!(
125                    error = %err,
126                    body = %body_preview,
127                    "Failed to decode DMS lease response"
128                );
129                err
130            })
131            .context("decode lease")?;
132
133        if tracing::enabled!(Level::DEBUG) {
134            tracing::debug!(
135                status = %status,
136                body = %body_preview,
137                "Decoded DMS lease response"
138            );
139        }
140
141        Ok(Some(lease))
142    }
143
144    /// Complete task: POST /tasks/{id}/complete
145    pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
146        let url = self
147            .join_segments(&["tasks", &task_id.to_string(), "complete"])
148            .context("join /complete")?;
149        let mut headers = self.auth_headers().await?;
150        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
151        if let Some(preview) = json_debug_preview(body) {
152            tracing::debug!(
153                endpoint = %url,
154                task_id = %task_id,
155                body = %preview,
156                "Sending DMS complete request"
157            );
158        }
159        // First attempt
160        let mut res = self
161            .http
162            .post(url.clone())
163            .headers(headers.clone())
164            .json(body)
165            .send()
166            .await
167            .context("send POST /complete")?;
168        let mut status = res.status();
169        let mut body_text = res
170            .text()
171            .await
172            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
173        if status == StatusCode::UNAUTHORIZED {
174            let preview = truncate_preview(&body_text);
175            tracing::warn!(
176                status = %status,
177                body = %preview,
178                task_id = %task_id,
179                "DMS complete unauthorized; refreshing token and retrying"
180            );
181            self.auth.on_unauthorized().await;
182            headers = self.auth_headers().await?;
183            res = self
184                .http
185                .post(url)
186                .headers(headers)
187                .json(body)
188                .send()
189                .await
190                .context("retry POST /complete")?;
191            status = res.status();
192            body_text = res
193                .text()
194                .await
195                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
196        }
197        let preview = truncate_preview(&body_text);
198        if tracing::enabled!(Level::DEBUG) {
199            tracing::debug!(
200                status = %status,
201                body = %preview,
202                task_id = %task_id,
203                "DMS complete response"
204            );
205        }
206        if !status.is_success() {
207            tracing::error!(
208                status = %status,
209                body = %preview,
210                task_id = %task_id,
211                "DMS complete endpoint returned non-success status"
212            );
213            return Err(anyhow!(
214                "POST /tasks/{task_id}/complete status {status}; body: {preview}"
215            ));
216        }
217        Ok(())
218    }
219
220    /// Fail task: POST /tasks/{id}/fail
221    pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
222        let url = self
223            .join_segments(&["tasks", &task_id.to_string(), "fail"])
224            .context("join /fail")?;
225        let mut headers = self.auth_headers().await?;
226        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
227        if let Some(preview) = json_debug_preview(body) {
228            tracing::debug!(
229                endpoint = %url,
230                task_id = %task_id,
231                body = %preview,
232                "Sending DMS fail request"
233            );
234        }
235        // First attempt
236        let mut res = self
237            .http
238            .post(url.clone())
239            .headers(headers.clone())
240            .json(body)
241            .send()
242            .await
243            .context("send POST /fail")?;
244        let mut status = res.status();
245        let mut body_text = res
246            .text()
247            .await
248            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
249        if status == StatusCode::UNAUTHORIZED {
250            let preview = truncate_preview(&body_text);
251            tracing::warn!(
252                status = %status,
253                body = %preview,
254                task_id = %task_id,
255                "DMS fail unauthorized; refreshing token and retrying"
256            );
257            self.auth.on_unauthorized().await;
258            headers = self.auth_headers().await?;
259            res = self
260                .http
261                .post(url)
262                .headers(headers)
263                .json(body)
264                .send()
265                .await
266                .context("retry POST /fail")?;
267            status = res.status();
268            body_text = res
269                .text()
270                .await
271                .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
272        }
273        let preview = truncate_preview(&body_text);
274        if tracing::enabled!(Level::DEBUG) {
275            tracing::debug!(
276                status = %status,
277                body = %preview,
278                task_id = %task_id,
279                "DMS fail response"
280            );
281        }
282        if !status.is_success() {
283            tracing::error!(
284                status = %status,
285                body = %preview,
286                task_id = %task_id,
287                "DMS fail endpoint returned non-success status"
288            );
289            return Err(anyhow!(
290                "POST /tasks/{task_id}/fail status {status}; body: {preview}"
291            ));
292        }
293        Ok(())
294    }
295
296    /// Heartbeat: POST /tasks/{id}/heartbeat with progress payload.
297    /// Returns potential new access token for storage.
298    pub async fn heartbeat(
299        &self,
300        task_id: Uuid,
301        body: &HeartbeatRequest,
302    ) -> Result<HeartbeatResponse> {
303        let url = self
304            .join_segments(&["tasks", &task_id.to_string(), "heartbeat"])
305            .context("join /heartbeat")?;
306        let mut headers = self.auth_headers().await?;
307        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
308        if let Some(preview) = json_debug_preview(body) {
309            tracing::debug!(
310                endpoint = %url,
311                task_id = %task_id,
312                body = %preview,
313                "Sending DMS heartbeat request"
314            );
315        }
316        // First attempt
317        let mut res = self
318            .http
319            .post(url.clone())
320            .headers(headers.clone())
321            .json(body)
322            .send()
323            .await
324            .context("send POST /heartbeat")?;
325        let mut status = res.status();
326        let mut bytes = res.bytes().await.context("read heartbeat response body")?;
327        if status == StatusCode::UNAUTHORIZED {
328            let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
329            tracing::warn!(
330                status = %status,
331                body = %preview,
332                task_id = %task_id,
333                "DMS heartbeat unauthorized; refreshing token and retrying"
334            );
335            self.auth.on_unauthorized().await;
336            headers = self.auth_headers().await?;
337            res = self
338                .http
339                .post(url)
340                .headers(headers)
341                .json(body)
342                .send()
343                .await
344                .context("retry POST /heartbeat")?;
345            status = res.status();
346            bytes = res
347                .bytes()
348                .await
349                .context("read heartbeat response body (retry)")?;
350        }
351        let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
352        if tracing::enabled!(Level::DEBUG) {
353            tracing::debug!(
354                status = %status,
355                body = %preview,
356                task_id = %task_id,
357                "DMS heartbeat response"
358            );
359        }
360        if !status.is_success() {
361            return Err(anyhow!(
362                "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
363            ));
364        }
365        let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
366            .context("decode heartbeat response")?;
367        Ok(hb)
368    }
369}
370
371fn truncate_preview(body: &str) -> String {
372    const MAX: usize = 512;
373    if body.len() <= MAX {
374        body.to_string()
375    } else {
376        let mut preview: String = body.chars().take(MAX).collect();
377        preview.push_str("… (truncated)");
378        preview
379    }
380}
381
382fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
383    if !tracing::enabled!(Level::DEBUG) {
384        return None;
385    }
386    serde_json::to_string(value)
387        .map(|s| truncate_preview(&s))
388        .ok()
389}