posemesh_compute_node/dms/
client.rs1use crate::auth::token_manager::TokenProvider;
2use crate::dms::types::{
3 CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
4};
5use anyhow::{anyhow, Context, Result};
6use reqwest::Client;
7use reqwest::{
8 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
9 StatusCode,
10};
11use serde::Serialize;
12use std::sync::Arc;
13use std::time::Duration;
14use tracing::Level;
15use url::Url;
16use uuid::Uuid;
17
18#[derive(Clone)]
20pub struct DmsClient {
21 base: Url,
22 http: Client,
23 auth: Arc<dyn TokenProvider>,
24}
25impl DmsClient {
26 pub fn new(base: Url, timeout: Duration, auth: Arc<dyn TokenProvider>) -> Result<Self> {
28 let http = Client::builder()
29 .use_rustls_tls()
30 .timeout(timeout)
31 .build()
32 .context("build dms reqwest client")?;
33 Ok(Self { base, http, auth })
34 }
35
36 async fn auth_headers(&self) -> Result<HeaderMap> {
37 let mut h = HeaderMap::new();
38 let b = self
39 .auth
40 .bearer()
41 .await
42 .map_err(|e| anyhow!("token provider: {e}"))?;
43 let token = format!("Bearer {}", b);
44 let mut v = HeaderValue::from_str(&token)
45 .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
46 v.set_sensitive(true);
47 h.insert(AUTHORIZATION, v);
48 Ok(h)
49 }
50
51 pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
55 let url = self.base.join("tasks").context("join /tasks")?;
56 if tracing::enabled!(Level::DEBUG) {
57 tracing::debug!(
58 endpoint = %url,
59 "Sending DMS lease request"
60 );
61 }
62 let mut headers = self.auth_headers().await?;
64 let mut res = self
65 .http
66 .get(url.clone())
67 .headers(headers.clone())
68 .send()
69 .await
70 .context("send GET /tasks")?;
71 let mut status = res.status();
72 let mut bytes = res.bytes().await.context("read lease body")?;
73 if status == StatusCode::UNAUTHORIZED {
75 let body_preview = String::from_utf8_lossy(&bytes);
76 tracing::warn!(
77 status = %status,
78 body = %body_preview,
79 "DMS lease unauthorized; refreshing token and retrying"
80 );
81 self.auth.on_unauthorized().await;
82 headers = self.auth_headers().await?;
83 res = self
84 .http
85 .get(url)
86 .headers(headers)
87 .send()
88 .await
89 .context("retry GET /tasks")?;
90 status = res.status();
91 bytes = res.bytes().await.context("read lease body (retry)")?;
92 }
93 if status == StatusCode::NO_CONTENT {
94 tracing::debug!("DMS lease returned 204 (no work available)");
95 return Ok(None);
96 }
97 let body_preview = String::from_utf8_lossy(&bytes);
98 if !status.is_success() {
99 tracing::warn!(
100 status = %status,
101 body = %body_preview,
102 "DMS lease request returned non-success status"
103 );
104 return Err(anyhow!("/tasks status: {}", status));
105 }
106 let lease: LeaseResponse = serde_json::from_slice(&bytes)
107 .map_err(|err| {
108 tracing::error!(
109 error = %err,
110 body = %body_preview,
111 "Failed to decode DMS lease response"
112 );
113 err
114 })
115 .context("decode lease")?;
116
117 if tracing::enabled!(Level::DEBUG) {
118 tracing::debug!(
119 status = %status,
120 body = %body_preview,
121 "Decoded DMS lease response"
122 );
123 }
124
125 Ok(Some(lease))
126 }
127
128 pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
130 let url = self
131 .base
132 .join(&format!("tasks/{}/complete", task_id))
133 .context("join /complete")?;
134 let mut headers = self.auth_headers().await?;
135 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
136 if let Some(preview) = json_debug_preview(body) {
137 tracing::debug!(
138 endpoint = %url,
139 task_id = %task_id,
140 body = %preview,
141 "Sending DMS complete request"
142 );
143 }
144 let mut res = self
146 .http
147 .post(url.clone())
148 .headers(headers.clone())
149 .json(body)
150 .send()
151 .await
152 .context("send POST /complete")?;
153 let mut status = res.status();
154 let mut body_text = res
155 .text()
156 .await
157 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
158 if status == StatusCode::UNAUTHORIZED {
159 let preview = truncate_preview(&body_text);
160 tracing::warn!(
161 status = %status,
162 body = %preview,
163 task_id = %task_id,
164 "DMS complete unauthorized; refreshing token and retrying"
165 );
166 self.auth.on_unauthorized().await;
167 headers = self.auth_headers().await?;
168 res = self
169 .http
170 .post(url)
171 .headers(headers)
172 .json(body)
173 .send()
174 .await
175 .context("retry POST /complete")?;
176 status = res.status();
177 body_text = res
178 .text()
179 .await
180 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
181 }
182 let preview = truncate_preview(&body_text);
183 if tracing::enabled!(Level::DEBUG) {
184 tracing::debug!(
185 status = %status,
186 body = %preview,
187 task_id = %task_id,
188 "DMS complete response"
189 );
190 }
191 if !status.is_success() {
192 tracing::error!(
193 status = %status,
194 body = %preview,
195 task_id = %task_id,
196 "DMS complete endpoint returned non-success status"
197 );
198 return Err(anyhow!(
199 "POST /tasks/{task_id}/complete status {status}; body: {preview}"
200 ));
201 }
202 Ok(())
203 }
204
205 pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
207 let url = self
208 .base
209 .join(&format!("tasks/{}/fail", task_id))
210 .context("join /fail")?;
211 let mut headers = self.auth_headers().await?;
212 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
213 if let Some(preview) = json_debug_preview(body) {
214 tracing::debug!(
215 endpoint = %url,
216 task_id = %task_id,
217 body = %preview,
218 "Sending DMS fail request"
219 );
220 }
221 let mut res = self
223 .http
224 .post(url.clone())
225 .headers(headers.clone())
226 .json(body)
227 .send()
228 .await
229 .context("send POST /fail")?;
230 let mut status = res.status();
231 let mut body_text = res
232 .text()
233 .await
234 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
235 if status == StatusCode::UNAUTHORIZED {
236 let preview = truncate_preview(&body_text);
237 tracing::warn!(
238 status = %status,
239 body = %preview,
240 task_id = %task_id,
241 "DMS fail unauthorized; refreshing token and retrying"
242 );
243 self.auth.on_unauthorized().await;
244 headers = self.auth_headers().await?;
245 res = self
246 .http
247 .post(url)
248 .headers(headers)
249 .json(body)
250 .send()
251 .await
252 .context("retry POST /fail")?;
253 status = res.status();
254 body_text = res
255 .text()
256 .await
257 .unwrap_or_else(|e| format!("<failed to read body (retry): {e}>"));
258 }
259 let preview = truncate_preview(&body_text);
260 if tracing::enabled!(Level::DEBUG) {
261 tracing::debug!(
262 status = %status,
263 body = %preview,
264 task_id = %task_id,
265 "DMS fail response"
266 );
267 }
268 if !status.is_success() {
269 tracing::error!(
270 status = %status,
271 body = %preview,
272 task_id = %task_id,
273 "DMS fail endpoint returned non-success status"
274 );
275 return Err(anyhow!(
276 "POST /tasks/{task_id}/fail status {status}; body: {preview}"
277 ));
278 }
279 Ok(())
280 }
281
282 pub async fn heartbeat(
285 &self,
286 task_id: Uuid,
287 body: &HeartbeatRequest,
288 ) -> Result<HeartbeatResponse> {
289 let url = self
290 .base
291 .join(&format!("tasks/{}/heartbeat", task_id))
292 .context("join /heartbeat")?;
293 let mut headers = self.auth_headers().await?;
294 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
295 if let Some(preview) = json_debug_preview(body) {
296 tracing::debug!(
297 endpoint = %url,
298 task_id = %task_id,
299 body = %preview,
300 "Sending DMS heartbeat request"
301 );
302 }
303 let mut res = self
305 .http
306 .post(url.clone())
307 .headers(headers.clone())
308 .json(body)
309 .send()
310 .await
311 .context("send POST /heartbeat")?;
312 let mut status = res.status();
313 let mut bytes = res.bytes().await.context("read heartbeat response body")?;
314 if status == StatusCode::UNAUTHORIZED {
315 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
316 tracing::warn!(
317 status = %status,
318 body = %preview,
319 task_id = %task_id,
320 "DMS heartbeat unauthorized; refreshing token and retrying"
321 );
322 self.auth.on_unauthorized().await;
323 headers = self.auth_headers().await?;
324 res = self
325 .http
326 .post(url)
327 .headers(headers)
328 .json(body)
329 .send()
330 .await
331 .context("retry POST /heartbeat")?;
332 status = res.status();
333 bytes = res
334 .bytes()
335 .await
336 .context("read heartbeat response body (retry)")?;
337 }
338 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
339 if tracing::enabled!(Level::DEBUG) {
340 tracing::debug!(
341 status = %status,
342 body = %preview,
343 task_id = %task_id,
344 "DMS heartbeat response"
345 );
346 }
347 if !status.is_success() {
348 return Err(anyhow!(
349 "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
350 ));
351 }
352 let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
353 .context("decode heartbeat response")?;
354 Ok(hb)
355 }
356}
357
358fn truncate_preview(body: &str) -> String {
359 const MAX: usize = 512;
360 if body.len() <= MAX {
361 body.to_string()
362 } else {
363 let mut preview: String = body.chars().take(MAX).collect();
364 preview.push_str("… (truncated)");
365 preview
366 }
367}
368
369fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
370 if !tracing::enabled!(Level::DEBUG) {
371 return None;
372 }
373 serde_json::to_string(value)
374 .map(|s| truncate_preview(&s))
375 .ok()
376}