posemesh_compute_node/dms/
client.rs1use crate::dms::types::{
2 CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
3};
4use anyhow::{anyhow, Context, Result};
5use reqwest::Client;
6use reqwest::{
7 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8 StatusCode,
9};
10use serde::Serialize;
11use std::time::Duration;
12use tracing::Level;
13use url::Url;
14use uuid::Uuid;
15
16#[derive(Clone)]
18pub struct DmsClient {
19 base: Url,
20 http: Client,
21 bearer: Option<String>,
22}
23impl DmsClient {
24 pub fn new(base: Url, timeout: Duration, bearer: Option<String>) -> Result<Self> {
26 let http = Client::builder()
27 .use_rustls_tls()
28 .timeout(timeout)
29 .build()
30 .context("build dms reqwest client")?;
31 Ok(Self { base, http, bearer })
32 }
33
34 fn auth_headers(&self) -> HeaderMap {
35 let mut h = HeaderMap::new();
36 if let Some(b) = &self.bearer {
37 let token = format!("Bearer {}", b);
38 let mut v = HeaderValue::from_str(&token)
39 .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
40 v.set_sensitive(true);
41 h.insert(AUTHORIZATION, v);
42 }
43 h
44 }
45
46 pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
50 let url = self.base.join("tasks").context("join /tasks")?;
51 if tracing::enabled!(Level::DEBUG) {
52 tracing::debug!(
53 endpoint = %url,
54 "Sending DMS lease request"
55 );
56 }
57 let res = self
58 .http
59 .get(url)
60 .headers(self.auth_headers())
61 .send()
62 .await
63 .context("send GET /tasks")?;
64 let status = res.status();
65 let bytes = res.bytes().await.context("read lease body")?;
66 if status == StatusCode::NO_CONTENT {
67 tracing::debug!("DMS lease returned 204 (no work available)");
68 return Ok(None);
69 }
70 let body_preview = String::from_utf8_lossy(&bytes);
71 if !status.is_success() {
72 tracing::warn!(
73 status = %status,
74 body = %body_preview,
75 "DMS lease request returned non-success status"
76 );
77 return Err(anyhow!("/tasks status: {}", status));
78 }
79 let lease: LeaseResponse = serde_json::from_slice(&bytes)
80 .map_err(|err| {
81 tracing::error!(
82 error = %err,
83 body = %body_preview,
84 "Failed to decode DMS lease response"
85 );
86 err
87 })
88 .context("decode lease")?;
89
90 if tracing::enabled!(Level::DEBUG) {
91 tracing::debug!(
92 status = %status,
93 body = %body_preview,
94 "Decoded DMS lease response"
95 );
96 }
97
98 Ok(Some(lease))
99 }
100
101 pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
103 let url = self
104 .base
105 .join(&format!("tasks/{}/complete", task_id))
106 .context("join /complete")?;
107 let mut headers = self.auth_headers();
108 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
109 if let Some(preview) = json_debug_preview(body) {
110 tracing::debug!(
111 endpoint = %url,
112 task_id = %task_id,
113 body = %preview,
114 "Sending DMS complete request"
115 );
116 }
117 let res = self
118 .http
119 .post(url)
120 .headers(headers)
121 .json(body)
122 .send()
123 .await
124 .context("send POST /complete")?;
125 let status = res.status();
126 let body_text = res
127 .text()
128 .await
129 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
130 let preview = truncate_preview(&body_text);
131 if tracing::enabled!(Level::DEBUG) {
132 tracing::debug!(
133 status = %status,
134 body = %preview,
135 task_id = %task_id,
136 "DMS complete response"
137 );
138 }
139 if !status.is_success() {
140 tracing::error!(
141 status = %status,
142 body = %preview,
143 task_id = %task_id,
144 "DMS complete endpoint returned non-success status"
145 );
146 return Err(anyhow!(
147 "POST /tasks/{task_id}/complete status {status}; body: {preview}"
148 ));
149 }
150 Ok(())
151 }
152
153 pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
155 let url = self
156 .base
157 .join(&format!("tasks/{}/fail", task_id))
158 .context("join /fail")?;
159 let mut headers = self.auth_headers();
160 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
161 if let Some(preview) = json_debug_preview(body) {
162 tracing::debug!(
163 endpoint = %url,
164 task_id = %task_id,
165 body = %preview,
166 "Sending DMS fail request"
167 );
168 }
169 let res = self
170 .http
171 .post(url)
172 .headers(headers)
173 .json(body)
174 .send()
175 .await
176 .context("send POST /fail")?;
177 let status = res.status();
178 let body_text = res
179 .text()
180 .await
181 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
182 let preview = truncate_preview(&body_text);
183 if tracing::enabled!(Level::DEBUG) {
184 tracing::debug!(
185 status = %status,
186 body = %preview,
187 task_id = %task_id,
188 "DMS fail response"
189 );
190 }
191 if !status.is_success() {
192 tracing::error!(
193 status = %status,
194 body = %preview,
195 task_id = %task_id,
196 runner_error = %body.reason,
197 "DMS fail endpoint returned non-success status"
198 );
199 return Err(anyhow!(
200 "POST /tasks/{task_id}/fail status {status}; body: {preview}"
201 ));
202 }
203 Ok(())
204 }
205
206 pub async fn heartbeat(
209 &self,
210 task_id: Uuid,
211 body: &HeartbeatRequest,
212 ) -> Result<HeartbeatResponse> {
213 let url = self
214 .base
215 .join(&format!("tasks/{}/heartbeat", task_id))
216 .context("join /heartbeat")?;
217 let mut headers = self.auth_headers();
218 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219 if let Some(preview) = json_debug_preview(body) {
220 tracing::debug!(
221 endpoint = %url,
222 task_id = %task_id,
223 body = %preview,
224 "Sending DMS heartbeat request"
225 );
226 }
227 let res = self
228 .http
229 .post(url)
230 .headers(headers)
231 .json(body)
232 .send()
233 .await
234 .context("send POST /heartbeat")?;
235 let status = res.status();
236 let bytes = res.bytes().await.context("read heartbeat response body")?;
237 let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
238 if tracing::enabled!(Level::DEBUG) {
239 tracing::debug!(
240 status = %status,
241 body = %preview,
242 task_id = %task_id,
243 "DMS heartbeat response"
244 );
245 }
246 if !status.is_success() {
247 return Err(anyhow!(
248 "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
249 ));
250 }
251 let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
252 .context("decode heartbeat response")?;
253 Ok(hb)
254 }
255}
256
257fn truncate_preview(body: &str) -> String {
258 const MAX: usize = 512;
259 if body.len() <= MAX {
260 body.to_string()
261 } else {
262 let mut preview: String = body.chars().take(MAX).collect();
263 preview.push_str("… (truncated)");
264 preview
265 }
266}
267
268fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
269 if !tracing::enabled!(Level::DEBUG) {
270 return None;
271 }
272 serde_json::to_string(value)
273 .map(|s| truncate_preview(&s))
274 .ok()
275}