posemesh_compute_node/dms/
client.rs1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3 CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9 StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18#[derive(Clone)]
20pub struct DmsClient {
21 base: Url,
22 http: Client,
23 auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26 pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28 let http = Client::builder()
29 .use_rustls_tls()
30 .timeout(timeout)
31 .build()
32 .context("build dms reqwest client")?;
33 Ok(Self { base, http, auth })
34 }
35
36 async fn auth_headers(&self) -> Result<HeaderMap> {
37 let mut h = HeaderMap::new();
38 let b = self
39 .auth
40 .bearer()
41 .await
42 .map_err(|e| anyhow!("token provider: {e}"))?;
43 let token = format!("Bearer {}", b);
44 let mut v = HeaderValue::from_str(&token)
45 .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46 v.set_sensitive(true);
47 h.insert(AUTHORIZATION, v);
48 Ok(h)
49 }
50
51 fn join_segments(&self, segments: &[&str]) -> Result<Url> {
52 let mut url = self.base.clone();
53 url.path_segments_mut()
54 .map_err(|_| anyhow!("invalid DMS base URL; cannot be a base"))?
55 .extend(segments.iter().copied());
56 Ok(url)
57 }
58
59 pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
63 let url = self.join_segments(&["tasks"]).context("join /tasks")?;
64 if tracing::enabled!(Level::DEBUG) {
65 tracing::debug!(
66 endpoint = %url,
67 "Sending DMS lease request"
68 );
69 }
70 let mut headers = self.auth_headers().await?;
72 let mut res = self
73 .http
74 .get(url.clone())
75 .headers(headers.clone())
76 .send()
77 .await
78 .context("send GET /tasks")?;
79 let mut status = res.status();
80 let mut bytes = res.bytes().await.context("read lease body")?;
81 if status == StatusCode::UNAUTHORIZED {
83 let body_preview = String::from_utf8_lossy(&bytes);
84 tracing::warn!(
85 status = %status,
86 body = %body_preview,
87 "DMS lease unauthorized; refreshing token and retrying"
88 );
89 self.auth.on_unauthorized().await;
90 headers = self.auth_headers().await?;
91 res = self
92 .http
93 .get(url)
94 .headers(headers)
95 .send()
96 .await
97 .context("retry GET /tasks")?;
98 status = res.status();
99 bytes = res.bytes().await.context("read lease body (retry)")?;
100 }
101 if status == StatusCode::NO_CONTENT {
102 tracing::debug!("DMS lease returned 204 (no work available)");
103 return Ok(None);
104 }
105 let body_preview = String::from_utf8_lossy(&bytes);
106 if !status.is_success() {
107 tracing::warn!(
108 status = %status,
109 body = %body_preview,
110 "DMS lease request returned non-success status"
111 );
112 return Err(anyhow!("/tasks status: {}", status));
113 }
114 let lease: LeaseResponse = serde_json::from_slice(&bytes)
115 .map_err(|err| {
116 tracing::error!(
117 error = %err,
118 body = %body_preview,
119 "Failed to decode DMS lease response"
120 );
121 err
122 })
123 .context("decode lease")?;
124
125 if tracing::enabled!(Level::DEBUG) {
126 tracing::debug!(
127 status = %status,
128 body = %body_preview,
129 "Decoded DMS lease response"
130 );
131 }
132
133 Ok(Some(lease))
134 }
135
136 pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
138 let url = self
139 .join_segments(&["tasks", &task_id.to_string(), "complete"])
140 .context("join /complete")?;
141 let mut headers = self.auth_headers().await?;
142 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
143 if let Some(preview) = json_debug_preview(body) {
144 tracing::debug!(
145 endpoint = %url,
146 task_id = %task_id,
147 body = %preview,
148 "Sending DMS complete request"
149 );
150 }
151 let mut res = self
153 .http
154 .post(url.clone())
155 .headers(headers.clone())
156 .json(body)
157 .send()
158 .await
159 .context("send POST /complete")?;
160 let mut status = res.status();
161 let mut body_text = res
162 .text()
163 .await
164 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
165 if status == StatusCode::UNAUTHORIZED {
166 let preview = truncate_preview(&body_text);
167 tracing::warn!(
168 status = %status,
169 body = %preview,
170 task_id = %task_id,
171 "DMS complete unauthorized; refreshing token and retrying"
172 );
173 self.auth.on_unauthorized().await;
174 headers = self.auth_headers().await?;
175 res = self
176 .http
177 .post(url)
178 .headers(headers)
179 .json(body)
180 .send()
181 .await
182 .context("retry POST /complete")?;
183 status = res.status();
184 body_text = res
185 .text()
186 .await
187 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
188 }
189 let preview = truncate_preview(&body_text);
190 if tracing::enabled!(Level::DEBUG) {
191 tracing::debug!(
192 status = %status,
193 body = %preview,
194 task_id = %task_id,
195 "DMS complete response"
196 );
197 }
198 if !status.is_success() {
199 tracing::error!(
200 status = %status,
201 body = %preview,
202 task_id = %task_id,
203 "DMS complete endpoint returned non-success status"
204 );
205 return Err(anyhow!(
206 "POST /tasks/{task_id}/complete status {status}; body: {preview}"
207 ));
208 }
209 Ok(())
210 }
211
212 pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
214 let url = self
215 .join_segments(&["tasks", &task_id.to_string(), "fail"])
216 .context("join /fail")?;
217 let mut headers = self.auth_headers().await?;
218 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219 if let Some(preview) = json_debug_preview(body) {
220 tracing::debug!(
221 endpoint = %url,
222 task_id = %task_id,
223 body = %preview,
224 "Sending DMS fail request"
225 );
226 }
227 let mut res = self
229 .http
230 .post(url.clone())
231 .headers(headers.clone())
232 .json(body)
233 .send()
234 .await
235 .context("send POST /fail")?;
236 let mut status = res.status();
237 let mut body_text = res
238 .text()
239 .await
240 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
241 if status == StatusCode::UNAUTHORIZED {
242 let preview = truncate_preview(&body_text);
243 tracing::warn!(
244 status = %status,
245 body = %preview,
246 task_id = %task_id,
247 "DMS fail unauthorized; refreshing token and retrying"
248 );
249 self.auth.on_unauthorized().await;
250 headers = self.auth_headers().await?;
251 res = self
252 .http
253 .post(url)
254 .headers(headers)
255 .json(body)
256 .send()
257 .await
258 .context("retry POST /fail")?;
259 status = res.status();
260 body_text = res
261 .text()
262 .await
263 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
264 }
265 let preview = truncate_preview(&body_text);
266 if tracing::enabled!(Level::DEBUG) {
267 tracing::debug!(
268 status = %status,
269 body = %preview,
270 task_id = %task_id,
271 "DMS fail response"
272 );
273 }
274 if !status.is_success() {
275 tracing::error!(
276 status = %status,
277 body = %preview,
278 task_id = %task_id,
279 "DMS fail endpoint returned non-success status"
280 );
281 return Err(anyhow!(
282 "POST /tasks/{task_id}/fail status {status}; body: {preview}"
283 ));
284 }
285 Ok(())
286 }
287
288 pub async fn heartbeat(
291 &self,
292 task_id: Uuid,
293 body: &HeartbeatRequest,
294 ) -> Result<HeartbeatResponse> {
295 let url = self
296 .join_segments(&["tasks", &task_id.to_string(), "heartbeat"])
297 .context("join /heartbeat")?;
298 let mut headers = self.auth_headers().await?;
299 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
300 if let Some(preview) = json_debug_preview(body) {
301 tracing::debug!(
302 endpoint = %url,
303 task_id = %task_id,
304 body = %preview,
305 "Sending DMS heartbeat request"
306 );
307 }
308 let mut res = self
310 .http
311 .post(url.clone())
312 .headers(headers.clone())
313 .json(body)
314 .send()
315 .await
316 .context("send POST /heartbeat")?;
317 let mut status = res.status();
318 let mut bytes = res.bytes().await.context("read heartbeat response body")?;
319 if status == StatusCode::UNAUTHORIZED {
320 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
321 tracing::warn!(
322 status = %status,
323 body = %preview,
324 task_id = %task_id,
325 "DMS heartbeat unauthorized; refreshing token and retrying"
326 );
327 self.auth.on_unauthorized().await;
328 headers = self.auth_headers().await?;
329 res = self
330 .http
331 .post(url)
332 .headers(headers)
333 .json(body)
334 .send()
335 .await
336 .context("retry POST /heartbeat")?;
337 status = res.status();
338 bytes = res
339 .bytes()
340 .await
341 .context("read heartbeat response body (retry)")?;
342 }
343 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
344 if tracing::enabled!(Level::DEBUG) {
345 tracing::debug!(
346 status = %status,
347 body = %preview,
348 task_id = %task_id,
349 "DMS heartbeat response"
350 );
351 }
352 if !status.is_success() {
353 return Err(anyhow!(
354 "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
355 ));
356 }
357 let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
358 .context("decode heartbeat response")?;
359 Ok(hb)
360 }
361}
362
363fn truncate_preview(body: &str) -> String {
364 const MAX: usize = 512;
365 if body.len() <= MAX {
366 body.to_string()
367 } else {
368 let mut preview: String = body.chars().take(MAX).collect();
369 preview.push_str("… (truncated)");
370 preview
371 }
372}
373
374fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
375 if !tracing::enabled!(Level::DEBUG) {
376 return None;
377 }
378 serde_json::to_string(value)
379 .map(|s| truncate_preview(&s))
380 .ok()
381}