posemesh_compute_node/dms/
client.rs1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3 CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9 StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18#[derive(Clone)]
20pub struct DmsClient {
21 base: Url,
22 http: Client,
23 auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26 pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28 let http = Client::builder()
29 .use_rustls_tls()
30 .timeout(timeout)
31 .build()
32 .context("build dms reqwest client")?;
33 Ok(Self { base, http, auth })
34 }
35
36 async fn auth_headers(&self) -> Result<HeaderMap> {
37 let mut h = HeaderMap::new();
38 let b = self
39 .auth
40 .bearer()
41 .await
42 .map_err(|e| anyhow!("token provider: {e}"))?;
43 let token = format!("Bearer {}", b);
44 let mut v = HeaderValue::from_str(&token)
45 .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46 v.set_sensitive(true);
47 h.insert(AUTHORIZATION, v);
48 Ok(h)
49 }
50
51 fn join_segments(&self, segments: &[&str]) -> Result<Url> {
52 let mut url = self.base.clone();
53 url.path_segments_mut()
54 .map_err(|_| anyhow!("invalid DMS base URL; cannot be a base"))?
55 .extend(segments.iter().copied());
56 Ok(url)
57 }
58
59 pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
63 let url = self.join_segments(&["tasks"]).context("join /tasks")?;
64 if tracing::enabled!(Level::DEBUG) {
65 tracing::debug!(
66 endpoint = %url,
67 "Sending DMS lease request"
68 );
69 }
70 let mut headers = self.auth_headers().await?;
72 let mut res = self
73 .http
74 .get(url.clone())
75 .headers(headers.clone())
76 .send()
77 .await
78 .context("send GET /tasks")?;
79 let mut status = res.status();
80 let mut bytes = res.bytes().await.context("read lease body")?;
81 if status == StatusCode::UNAUTHORIZED {
83 let body_preview = String::from_utf8_lossy(&bytes);
84 tracing::warn!(
85 status = %status,
86 body = %body_preview,
87 "DMS lease unauthorized; refreshing token and retrying"
88 );
89 self.auth.on_unauthorized().await;
90 headers = self.auth_headers().await?;
91 res = self
92 .http
93 .get(url)
94 .headers(headers)
95 .send()
96 .await
97 .context("retry GET /tasks")?;
98 status = res.status();
99 bytes = res.bytes().await.context("read lease body (retry)")?;
100 }
101 if status == StatusCode::NO_CONTENT {
102 tracing::debug!("DMS lease returned 204 (no work available)");
103 return Ok(None);
104 }
105 let body_preview = String::from_utf8_lossy(&bytes);
106 if status == StatusCode::CONFLICT {
107 tracing::debug!(
108 status = %status,
109 body = %body_preview,
110 "DMS lease returned conflict (busy); treating as no work"
111 );
112 return Ok(None);
113 }
114 if !status.is_success() {
115 tracing::warn!(
116 status = %status,
117 body = %body_preview,
118 "DMS lease request returned non-success status"
119 );
120 return Err(anyhow!("/tasks status: {}", status));
121 }
122 let lease: LeaseResponse = serde_json::from_slice(&bytes)
123 .map_err(|err| {
124 tracing::error!(
125 error = %err,
126 body = %body_preview,
127 "Failed to decode DMS lease response"
128 );
129 err
130 })
131 .context("decode lease")?;
132
133 if tracing::enabled!(Level::DEBUG) {
134 tracing::debug!(
135 status = %status,
136 body = %body_preview,
137 "Decoded DMS lease response"
138 );
139 }
140
141 Ok(Some(lease))
142 }
143
144 pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
146 let url = self
147 .join_segments(&["tasks", &task_id.to_string(), "complete"])
148 .context("join /complete")?;
149 let mut headers = self.auth_headers().await?;
150 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
151 if let Some(preview) = json_debug_preview(body) {
152 tracing::debug!(
153 endpoint = %url,
154 task_id = %task_id,
155 body = %preview,
156 "Sending DMS complete request"
157 );
158 }
159 let mut res = self
161 .http
162 .post(url.clone())
163 .headers(headers.clone())
164 .json(body)
165 .send()
166 .await
167 .context("send POST /complete")?;
168 let mut status = res.status();
169 let mut body_text = res
170 .text()
171 .await
172 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
173 if status == StatusCode::UNAUTHORIZED {
174 let preview = truncate_preview(&body_text);
175 tracing::warn!(
176 status = %status,
177 body = %preview,
178 task_id = %task_id,
179 "DMS complete unauthorized; refreshing token and retrying"
180 );
181 self.auth.on_unauthorized().await;
182 headers = self.auth_headers().await?;
183 res = self
184 .http
185 .post(url)
186 .headers(headers)
187 .json(body)
188 .send()
189 .await
190 .context("retry POST /complete")?;
191 status = res.status();
192 body_text = res
193 .text()
194 .await
195 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
196 }
197 let preview = truncate_preview(&body_text);
198 if tracing::enabled!(Level::DEBUG) {
199 tracing::debug!(
200 status = %status,
201 body = %preview,
202 task_id = %task_id,
203 "DMS complete response"
204 );
205 }
206 if !status.is_success() {
207 tracing::error!(
208 status = %status,
209 body = %preview,
210 task_id = %task_id,
211 "DMS complete endpoint returned non-success status"
212 );
213 return Err(anyhow!(
214 "POST /tasks/{task_id}/complete status {status}; body: {preview}"
215 ));
216 }
217 Ok(())
218 }
219
220 pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
222 let url = self
223 .join_segments(&["tasks", &task_id.to_string(), "fail"])
224 .context("join /fail")?;
225 let mut headers = self.auth_headers().await?;
226 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
227 if let Some(preview) = json_debug_preview(body) {
228 tracing::debug!(
229 endpoint = %url,
230 task_id = %task_id,
231 body = %preview,
232 "Sending DMS fail request"
233 );
234 }
235 let mut res = self
237 .http
238 .post(url.clone())
239 .headers(headers.clone())
240 .json(body)
241 .send()
242 .await
243 .context("send POST /fail")?;
244 let mut status = res.status();
245 let mut body_text = res
246 .text()
247 .await
248 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
249 if status == StatusCode::UNAUTHORIZED {
250 let preview = truncate_preview(&body_text);
251 tracing::warn!(
252 status = %status,
253 body = %preview,
254 task_id = %task_id,
255 "DMS fail unauthorized; refreshing token and retrying"
256 );
257 self.auth.on_unauthorized().await;
258 headers = self.auth_headers().await?;
259 res = self
260 .http
261 .post(url)
262 .headers(headers)
263 .json(body)
264 .send()
265 .await
266 .context("retry POST /fail")?;
267 status = res.status();
268 body_text = res
269 .text()
270 .await
271 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
272 }
273 let preview = truncate_preview(&body_text);
274 if tracing::enabled!(Level::DEBUG) {
275 tracing::debug!(
276 status = %status,
277 body = %preview,
278 task_id = %task_id,
279 "DMS fail response"
280 );
281 }
282 if !status.is_success() {
283 tracing::error!(
284 status = %status,
285 body = %preview,
286 task_id = %task_id,
287 "DMS fail endpoint returned non-success status"
288 );
289 return Err(anyhow!(
290 "POST /tasks/{task_id}/fail status {status}; body: {preview}"
291 ));
292 }
293 Ok(())
294 }
295
296 pub async fn heartbeat(
299 &self,
300 task_id: Uuid,
301 body: &HeartbeatRequest,
302 ) -> Result<HeartbeatResponse> {
303 let url = self
304 .join_segments(&["tasks", &task_id.to_string(), "heartbeat"])
305 .context("join /heartbeat")?;
306 let mut headers = self.auth_headers().await?;
307 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
308 if let Some(preview) = json_debug_preview(body) {
309 tracing::debug!(
310 endpoint = %url,
311 task_id = %task_id,
312 body = %preview,
313 "Sending DMS heartbeat request"
314 );
315 }
316 let mut res = self
318 .http
319 .post(url.clone())
320 .headers(headers.clone())
321 .json(body)
322 .send()
323 .await
324 .context("send POST /heartbeat")?;
325 let mut status = res.status();
326 let mut bytes = res.bytes().await.context("read heartbeat response body")?;
327 if status == StatusCode::UNAUTHORIZED {
328 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
329 tracing::warn!(
330 status = %status,
331 body = %preview,
332 task_id = %task_id,
333 "DMS heartbeat unauthorized; refreshing token and retrying"
334 );
335 self.auth.on_unauthorized().await;
336 headers = self.auth_headers().await?;
337 res = self
338 .http
339 .post(url)
340 .headers(headers)
341 .json(body)
342 .send()
343 .await
344 .context("retry POST /heartbeat")?;
345 status = res.status();
346 bytes = res
347 .bytes()
348 .await
349 .context("read heartbeat response body (retry)")?;
350 }
351 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
352 if tracing::enabled!(Level::DEBUG) {
353 tracing::debug!(
354 status = %status,
355 body = %preview,
356 task_id = %task_id,
357 "DMS heartbeat response"
358 );
359 }
360 if !status.is_success() {
361 return Err(anyhow!(
362 "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
363 ));
364 }
365 let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
366 .context("decode heartbeat response")?;
367 Ok(hb)
368 }
369}
370
371fn truncate_preview(body: &str) -> String {
372 const MAX: usize = 512;
373 if body.len() <= MAX {
374 body.to_string()
375 } else {
376 let mut preview: String = body.chars().take(MAX).collect();
377 preview.push_str("… (truncated)");
378 preview
379 }
380}
381
382fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
383 if !tracing::enabled!(Level::DEBUG) {
384 return None;
385 }
386 serde_json::to_string(value)
387 .map(|s| truncate_preview(&s))
388 .ok()
389}