Skip to main content

aperion_shield/orgmode/
client.rs

1//! Thin reqwest wrapper around the Smartflow REST API.
2//!
3//! Owns its own `reqwest::Client` with conservative defaults (HTTPS-only
4//! in the typical deployment, 10 s timeouts, no fancy connection pool
5//! settings -- we make few requests). Errors surface as
6//! [`OrgApiError`]; callers decide whether to retry, fall back to local,
7//! or propagate.
8
9use std::time::Duration;
10
11use chrono::{DateTime, Utc};
12use reqwest::{Client, StatusCode};
13use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16use super::state::OrgState;
17
18/// Connect timeout per request -- matches what the enterprise
19/// device API expects.
20const REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
21
22#[derive(Debug, Error)]
23pub enum OrgApiError {
24    #[error("network: {0}")]
25    Network(String),
26    #[error("http {status}: {body}")]
27    Http { status: u16, body: String },
28    #[error("decode: {0}")]
29    Decode(String),
30    #[error("unauthorized -- vkey rejected; the device may have been revoked")]
31    Unauthorized,
32}
33
34impl From<reqwest::Error> for OrgApiError {
35    fn from(e: reqwest::Error) -> Self {
36        OrgApiError::Network(e.to_string())
37    }
38}
39
40pub struct OrgApi {
41    http: Client,
42    smartflow_url: String,
43    vkey: String,
44}
45
46impl OrgApi {
47    pub fn new(smartflow_url: impl Into<String>, vkey: impl Into<String>) -> Self {
48        let http = Client::builder()
49            .timeout(REQUEST_TIMEOUT)
50            .user_agent(format!("aperion-shield/{}", env!("CARGO_PKG_VERSION")))
51            .build()
52            .expect("reqwest client build");
53        Self {
54            http,
55            smartflow_url: smartflow_url.into(),
56            vkey: vkey.into(),
57        }
58    }
59
60    pub fn from_state(state: &OrgState) -> Self {
61        Self::new(&state.smartflow_url, &state.vkey)
62    }
63
64    fn url(&self, path: &str) -> String {
65        format!("{}{}", self.smartflow_url.trim_end_matches('/'), path)
66    }
67
68    async fn unwrap_response<T: for<'de> Deserialize<'de>>(
69        &self,
70        resp: reqwest::Response,
71    ) -> Result<T, OrgApiError> {
72        let status = resp.status();
73        if status == StatusCode::UNAUTHORIZED {
74            return Err(OrgApiError::Unauthorized);
75        }
76        if !status.is_success() {
77            let body = resp.text().await.unwrap_or_default();
78            return Err(OrgApiError::Http {
79                status: status.as_u16(),
80                body,
81            });
82        }
83        resp.json::<T>()
84            .await
85            .map_err(|e| OrgApiError::Decode(e.to_string()))
86    }
87
88    // ── Enrollment ────────────────────────────────────────────────
89
90    /// Exchange a one-time enrollment token for a virtual key + device
91    /// id. Uses the existing `enterprise_device_api::token_enroll`
92    /// endpoint -- no new server code required for enrollment itself.
93    pub async fn token_enroll(
94        smartflow_url: &str,
95        enrollment_token: &str,
96        device_fingerprint: &str,
97        device_name: &str,
98        platform: &str,
99        user_email: Option<&str>,
100    ) -> Result<TokenEnrollResponse, OrgApiError> {
101        let http = Client::builder()
102            .timeout(REQUEST_TIMEOUT)
103            .user_agent(format!("aperion-shield/{}", env!("CARGO_PKG_VERSION")))
104            .build()?;
105        let body = serde_json::json!({
106            "enrollment_token": enrollment_token,
107            "device_fingerprint": device_fingerprint,
108            "device_name": device_name,
109            "platform": platform,
110            "user_email": user_email,
111        });
112        let resp = http
113            .post(format!(
114                "{}/api/enterprise/devices/token-enroll",
115                smartflow_url.trim_end_matches('/')
116            ))
117            .json(&body)
118            .send()
119            .await?;
120        let status = resp.status();
121        if !status.is_success() {
122            let body = resp.text().await.unwrap_or_default();
123            return Err(OrgApiError::Http {
124                status: status.as_u16(),
125                body,
126            });
127        }
128        resp.json::<TokenEnrollResponse>()
129            .await
130            .map_err(|e| OrgApiError::Decode(e.to_string()))
131    }
132
133    // ── Heartbeat ─────────────────────────────────────────────────
134
135    pub async fn heartbeat(&self, device_id: &str) -> Result<(), OrgApiError> {
136        let resp = self
137            .http
138            .post(self.url(&format!(
139                "/api/enterprise/devices/{}/heartbeat",
140                device_id
141            )))
142            .bearer_auth(&self.vkey)
143            .send()
144            .await?;
145        let status = resp.status();
146        if !status.is_success() {
147            let body = resp.text().await.unwrap_or_default();
148            return Err(OrgApiError::Http {
149                status: status.as_u16(),
150                body,
151            });
152        }
153        Ok(())
154    }
155
156    // ── Shieldset (policy) ────────────────────────────────────────
157
158    /// Fetch the current shieldset YAML for a group, returning
159    /// `(yaml, version)`. Version comes from the `X-Shield-Policy-Version`
160    /// header so the caller can decide whether to hot-reload.
161    pub async fn get_shieldset(&self, group: &str) -> Result<(String, u64), OrgApiError> {
162        let resp = self
163            .http
164            .get(self.url(&format!("/api/enterprise/shield/shieldset/{}", group)))
165            .bearer_auth(&self.vkey)
166            .send()
167            .await?;
168        let status = resp.status();
169        if status == StatusCode::UNAUTHORIZED {
170            return Err(OrgApiError::Unauthorized);
171        }
172        if !status.is_success() {
173            let body = resp.text().await.unwrap_or_default();
174            return Err(OrgApiError::Http {
175                status: status.as_u16(),
176                body,
177            });
178        }
179        let version: u64 = resp
180            .headers()
181            .get("X-Shield-Policy-Version")
182            .and_then(|v| v.to_str().ok())
183            .and_then(|s| s.parse().ok())
184            .unwrap_or(0);
185        let yaml = resp.text().await?;
186        Ok((yaml, version))
187    }
188
189    /// Cheap version probe -- no payload. Used by the policy pull loop
190    /// to decide whether to fetch a full shieldset.
191    pub async fn get_shieldset_version(&self, group: &str) -> Result<VersionInfo, OrgApiError> {
192        let resp = self
193            .http
194            .get(self.url(&format!(
195                "/api/enterprise/shield/shieldset/{}/version",
196                group
197            )))
198            .bearer_auth(&self.vkey)
199            .send()
200            .await?;
201        self.unwrap_response(resp).await
202    }
203
204    // ── Events (audit sink) ───────────────────────────────────────
205
206    pub async fn post_events(
207        &self,
208        events: &[serde_json::Value],
209    ) -> Result<EventsAck, OrgApiError> {
210        let resp = self
211            .http
212            .post(self.url("/api/enterprise/shield/events"))
213            .bearer_auth(&self.vkey)
214            .json(&serde_json::json!({ "events": events }))
215            .send()
216            .await?;
217        self.unwrap_response(resp).await
218    }
219
220    // ── Identity (M3) ─────────────────────────────────────────────
221
222    pub async fn identity_check(
223        &self,
224        req: &IdentityCheckRequest,
225    ) -> Result<IdentityCheckResponse, OrgApiError> {
226        let resp = self
227            .http
228            .post(self.url("/api/enterprise/shield/identity/check"))
229            .bearer_auth(&self.vkey)
230            .json(req)
231            .send()
232            .await?;
233        self.unwrap_response(resp).await
234    }
235
236    pub async fn identity_begin(
237        &self,
238        req: &IdentityCheckRequest,
239    ) -> Result<IdentityCheckResponse, OrgApiError> {
240        let resp = self
241            .http
242            .post(self.url("/api/enterprise/shield/identity/begin"))
243            .bearer_auth(&self.vkey)
244            .json(req)
245            .send()
246            .await?;
247        self.unwrap_response(resp).await
248    }
249
250    pub async fn identity_result(
251        &self,
252        challenge_id: &str,
253    ) -> Result<IdentityCheckResponse, OrgApiError> {
254        let resp = self
255            .http
256            .get(self.url(&format!(
257                "/api/enterprise/shield/identity/result/{}",
258                challenge_id
259            )))
260            .bearer_auth(&self.vkey)
261            .send()
262            .await?;
263        self.unwrap_response(resp).await
264    }
265
266    // ── Info / killswitch ─────────────────────────────────────────
267
268    pub async fn info(&self) -> Result<InfoResponse, OrgApiError> {
269        let resp = self
270            .http
271            .get(self.url("/api/enterprise/shield/info"))
272            .bearer_auth(&self.vkey)
273            .send()
274            .await?;
275        self.unwrap_response(resp).await
276    }
277}
278
279// ─────────────────────────────────────────────────────────────────────
280// DTOs
281// ─────────────────────────────────────────────────────────────────────
282
283#[derive(Debug, Deserialize)]
284pub struct TokenEnrollResponse {
285    pub device_id: String,
286    pub vkey: String,
287    pub proxy_url: String,
288    pub policy_group: String,
289    #[serde(default)]
290    pub policy_version: String,
291    #[serde(default)]
292    pub policy_ws_url: String,
293}
294
295#[derive(Debug, Deserialize)]
296pub struct VersionInfo {
297    pub group: String,
298    pub version: u64,
299    pub killswitch: KillswitchState,
300    pub server_time: DateTime<Utc>,
301}
302
303#[derive(Debug, Clone, Deserialize, Default)]
304pub struct KillswitchState {
305    pub on: bool,
306    pub reason: Option<String>,
307}
308
309#[derive(Debug, Deserialize)]
310pub struct EventsAck {
311    pub ok: bool,
312    pub received: usize,
313}
314
315#[derive(Debug, Serialize, Clone)]
316pub struct IdentityCheckRequest {
317    pub provider: String,
318    pub scope: String,
319    #[serde(skip_serializing_if = "Vec::is_empty")]
320    pub allowed_subjects: Vec<String>,
321    #[serde(skip_serializing_if = "Option::is_none")]
322    pub min_loa: Option<u8>,
323    pub max_age_seconds: u64,
324}
325
326#[derive(Debug, Clone, Deserialize)]
327pub struct IdentityCheckResponse {
328    pub verified: bool,
329    #[serde(default)]
330    pub subject: Option<String>,
331    #[serde(default)]
332    pub loa: Option<u8>,
333    #[serde(default)]
334    pub expires_at: Option<DateTime<Utc>>,
335    #[serde(default)]
336    pub signature: Option<String>,
337    #[serde(default)]
338    pub verify_url: Option<String>,
339    #[serde(default)]
340    pub challenge_id: Option<String>,
341    #[serde(default)]
342    pub provider: String,
343}
344
345#[derive(Debug, Deserialize)]
346pub struct InfoResponse {
347    pub device_id: String,
348    pub policy_group: String,
349    pub owner_email: String,
350    pub policy_version: u64,
351    pub killswitch: KillswitchState,
352    pub server_time: DateTime<Utc>,
353    pub identity_providers: Vec<IdentityProviderInfo>,
354}
355
356#[derive(Debug, Deserialize)]
357pub struct IdentityProviderInfo {
358    pub id: String,
359    pub display_name: String,
360    pub kind: String,
361    pub ready: bool,
362}