posemesh_compute_node/dms/
client.rs

1use crate::dms::types::{
2    CompleteTaskRequest, FailTaskRequest, HeartbeatRequest, HeartbeatResponse, LeaseResponse,
3};
4use anyhow::{anyhow, Context, Result};
5use reqwest::Client;
6use reqwest::{
7    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8    StatusCode,
9};
10use serde::Serialize;
11use std::time::Duration;
12use tracing::Level;
13use url::Url;
14use uuid::Uuid;
15
16/// Minimal DMS HTTP client using rustls with sensitive Authorization header.
17#[derive(Clone)]
18pub struct DmsClient {
19    base: Url,
20    http: Client,
21    bearer: Option<String>,
22}
23impl DmsClient {
24    /// Create client with base URL, timeout, and optional bearer token.
25    pub fn new(base: Url, timeout: Duration, bearer: Option<String>) -> Result<Self> {
26        let http = Client::builder()
27            .use_rustls_tls()
28            .timeout(timeout)
29            .build()
30            .context("build dms reqwest client")?;
31        Ok(Self { base, http, bearer })
32    }
33
34    fn auth_headers(&self) -> HeaderMap {
35        let mut h = HeaderMap::new();
36        if let Some(b) = &self.bearer {
37            let token = format!("Bearer {}", b);
38            let mut v = HeaderValue::from_str(&token)
39                .unwrap_or_else(|_| HeaderValue::from_static("Bearer INVALID"));
40            v.set_sensitive(true);
41            h.insert(AUTHORIZATION, v);
42        }
43        h
44    }
45
46    /// Lease a task: GET /tasks
47    ///
48    /// `capability` is accepted for optional filter but not implemented yet.
49    pub async fn lease_by_capability(&self, _capability: &str) -> Result<Option<LeaseResponse>> {
50        let url = self.base.join("tasks").context("join /tasks")?;
51        if tracing::enabled!(Level::DEBUG) {
52            tracing::debug!(
53                endpoint = %url,
54                "Sending DMS lease request"
55            );
56        }
57        let res = self
58            .http
59            .get(url)
60            .headers(self.auth_headers())
61            .send()
62            .await
63            .context("send GET /tasks")?;
64        let status = res.status();
65        let bytes = res.bytes().await.context("read lease body")?;
66        if status == StatusCode::NO_CONTENT {
67            tracing::debug!("DMS lease returned 204 (no work available)");
68            return Ok(None);
69        }
70        let body_preview = String::from_utf8_lossy(&bytes);
71        if !status.is_success() {
72            tracing::warn!(
73                status = %status,
74                body = %body_preview,
75                "DMS lease request returned non-success status"
76            );
77            return Err(anyhow!("/tasks status: {}", status));
78        }
79        let lease: LeaseResponse = serde_json::from_slice(&bytes)
80            .map_err(|err| {
81                tracing::error!(
82                    error = %err,
83                    body = %body_preview,
84                    "Failed to decode DMS lease response"
85                );
86                err
87            })
88            .context("decode lease")?;
89
90        if tracing::enabled!(Level::DEBUG) {
91            tracing::debug!(
92                status = %status,
93                body = %body_preview,
94                "Decoded DMS lease response"
95            );
96        }
97
98        Ok(Some(lease))
99    }
100
101    /// Complete task: POST /tasks/{id}/complete
102    pub async fn complete(&self, task_id: Uuid, body: &CompleteTaskRequest) -> Result<()> {
103        let url = self
104            .base
105            .join(&format!("tasks/{}/complete", task_id))
106            .context("join /complete")?;
107        let mut headers = self.auth_headers();
108        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
109        if let Some(preview) = json_debug_preview(body) {
110            tracing::debug!(
111                endpoint = %url,
112                task_id = %task_id,
113                body = %preview,
114                "Sending DMS complete request"
115            );
116        }
117        let res = self
118            .http
119            .post(url)
120            .headers(headers)
121            .json(body)
122            .send()
123            .await
124            .context("send POST /complete")?;
125        let status = res.status();
126        let body_text = res
127            .text()
128            .await
129            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
130        let preview = truncate_preview(&body_text);
131        if tracing::enabled!(Level::DEBUG) {
132            tracing::debug!(
133                status = %status,
134                body = %preview,
135                task_id = %task_id,
136                "DMS complete response"
137            );
138        }
139        if !status.is_success() {
140            tracing::error!(
141                status = %status,
142                body = %preview,
143                task_id = %task_id,
144                "DMS complete endpoint returned non-success status"
145            );
146            return Err(anyhow!(
147                "POST /tasks/{task_id}/complete status {status}; body: {preview}"
148            ));
149        }
150        Ok(())
151    }
152
153    /// Fail task: POST /tasks/{id}/fail
154    pub async fn fail(&self, task_id: Uuid, body: &FailTaskRequest) -> Result<()> {
155        let url = self
156            .base
157            .join(&format!("tasks/{}/fail", task_id))
158            .context("join /fail")?;
159        let mut headers = self.auth_headers();
160        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
161        if let Some(preview) = json_debug_preview(body) {
162            tracing::debug!(
163                endpoint = %url,
164                task_id = %task_id,
165                body = %preview,
166                "Sending DMS fail request"
167            );
168        }
169        let res = self
170            .http
171            .post(url)
172            .headers(headers)
173            .json(body)
174            .send()
175            .await
176            .context("send POST /fail")?;
177        let status = res.status();
178        let body_text = res
179            .text()
180            .await
181            .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
182        let preview = truncate_preview(&body_text);
183        if tracing::enabled!(Level::DEBUG) {
184            tracing::debug!(
185                status = %status,
186                body = %preview,
187                task_id = %task_id,
188                "DMS fail response"
189            );
190        }
191        if !status.is_success() {
192            tracing::error!(
193                status = %status,
194                body = %preview,
195                task_id = %task_id,
196                runner_error = %body.reason,
197                "DMS fail endpoint returned non-success status"
198            );
199            return Err(anyhow!(
200                "POST /tasks/{task_id}/fail status {status}; body: {preview}"
201            ));
202        }
203        Ok(())
204    }
205
206    /// Heartbeat: POST /tasks/{id}/heartbeat with progress payload.
207    /// Returns potential new access token for storage.
208    pub async fn heartbeat(
209        &self,
210        task_id: Uuid,
211        body: &HeartbeatRequest,
212    ) -> Result<HeartbeatResponse> {
213        let url = self
214            .base
215            .join(&format!("tasks/{}/heartbeat", task_id))
216            .context("join /heartbeat")?;
217        let mut headers = self.auth_headers();
218        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
219        if let Some(preview) = json_debug_preview(body) {
220            tracing::debug!(
221                endpoint = %url,
222                task_id = %task_id,
223                body = %preview,
224                "Sending DMS heartbeat request"
225            );
226        }
227        let res = self
228            .http
229            .post(url)
230            .headers(headers)
231            .json(body)
232            .send()
233            .await
234            .context("send POST /heartbeat")?;
235        let status = res.status();
236        let bytes = res.bytes().await.context("read heartbeat response body")?;
237        let preview = truncate_preview(&String::from_utf8_lossy(&bytes));
238        if tracing::enabled!(Level::DEBUG) {
239            tracing::debug!(
240                status = %status,
241                body = %preview,
242                task_id = %task_id,
243                "DMS heartbeat response"
244            );
245        }
246        if !status.is_success() {
247            return Err(anyhow!(
248                "POST /tasks/{task_id}/heartbeat status {status}; body: {preview}"
249            ));
250        }
251        let hb = serde_json::from_slice::<HeartbeatResponse>(&bytes)
252            .context("decode heartbeat response")?;
253        Ok(hb)
254    }
255}
256
257fn truncate_preview(body: &str) -> String {
258    const MAX: usize = 512;
259    if body.len() <= MAX {
260        body.to_string()
261    } else {
262        let mut preview: String = body.chars().take(MAX).collect();
263        preview.push_str("… (truncated)");
264        preview
265    }
266}
267
268fn json_debug_preview<T: Serialize>(value: &T) -> Option<String> {
269    if !tracing::enabled!(Level::DEBUG) {
270        return None;
271    }
272    serde_json::to_string(value)
273        .map(|s| truncate_preview(&s))
274        .ok()
275}