Skip to main content

dnslib/vendors/unifi/
client.rs

1//! UniFi Network Integration API client (DNS policies only).
2//!
3//! UniFi authenticates with an `X-API-KEY` header rather than HTTP bearer
4//! auth, so this client builds its own `reqwest::Client` rather than reusing
5//! the shared `vendors::http::HttpClient`.
6//!
7//! All paths are appended to `base_url`. The expected effective URL is
8//! `<base_url>/sites/{siteId}/dns/policies[...]`, where `base_url` typically
9//! ends in `/proxy/network/integration/v1` on a local controller.
10
11use std::time::Duration;
12
13use reqwest::{Client, RequestBuilder, Response};
14use serde_json::Value;
15use tokio::sync::OnceCell;
16use tracing::Instrument;
17
18use crate::core::error::{Error, Result};
19use crate::core::secret::ApiToken;
20
21use super::responses::{
22    UnifiDnsPolicy, UnifiDnsPolicyPage, UnifiSite, match_site, parse_page, parse_site_page,
23};
24
25/// Maximum page size accepted by the UniFi DNS policy list endpoint.
26pub const MAX_PAGE_LIMIT: u32 = 200;
27
28/// Default page size when the caller does not specify one.
29pub const DEFAULT_PAGE_LIMIT: u32 = 200;
30
31/// UniFi DNS-policy client.
32///
33/// `site` holds the user-configured value — typically the human-readable site
34/// name (e.g. `"Default"`), but a UUID is also accepted. The first DNS call
35/// resolves that label to the controller's actual site UUID via
36/// `GET /v1/sites` and caches it for the lifetime of the client.
37#[derive(Clone, Debug)]
38pub struct UnifiClient {
39    http: Client,
40    base_url: String,
41    token: ApiToken,
42    site: String,
43    resolved_site_id: std::sync::Arc<OnceCell<String>>,
44}
45
46impl UnifiClient {
47    /// Build a new client. Uses a 30-second timeout to match the rest of the
48    /// vendor clients in this crate.
49    pub fn new(base_url: String, token: ApiToken, site: String) -> Result<Self> {
50        let http = Client::builder()
51            .timeout(Duration::from_secs(30))
52            .build()
53            .map_err(Error::Network)?;
54        Ok(Self {
55            http,
56            base_url,
57            token,
58            site,
59            resolved_site_id: std::sync::Arc::new(OnceCell::new()),
60        })
61    }
62
63    pub fn base_url(&self) -> &str {
64        &self.base_url
65    }
66
67    /// The configured site identifier as supplied via config or env vars.
68    /// This is what the user typed and may be either a name or a UUID.
69    pub fn site(&self) -> &str {
70        &self.site
71    }
72
73    /// Test-only helper for verifying credential plumbing without forcing the
74    /// production code to expose the token via `Debug`.
75    #[cfg(test)]
76    pub fn token_for_test(&self) -> &str {
77        self.token.expose_for_auth()
78    }
79
80    fn url(&self, path: &str) -> String {
81        format!("{}{}", self.base_url, path)
82    }
83
84    fn policies_path(&self, site_id: &str) -> String {
85        format!("/sites/{site_id}/dns/policies")
86    }
87
88    fn policy_path(&self, site_id: &str, policy_id: &str) -> String {
89        format!("/sites/{site_id}/dns/policies/{policy_id}")
90    }
91
92    /// Resolve the configured `site` (name or UUID) to the canonical UniFi
93    /// site UUID.
94    ///
95    /// On first call this performs `GET /v1/sites` and matches the
96    /// configured value against each site's `id`, `name`, and
97    /// `internalReference` (case-insensitively). The resolved UUID is cached
98    /// for the lifetime of the client so subsequent DNS calls don't pay the
99    /// site-list cost.
100    ///
101    /// If no site matches the configured value, returns `Error::Api` whose
102    /// message lists every valid human-readable site name so the user can
103    /// fix their config without leaving the CLI.
104    pub async fn resolve_site_id(&self) -> Result<&str> {
105        let cached = self
106            .resolved_site_id
107            .get_or_try_init(|| async {
108                let sites = self.list_all_sites().await?;
109                match match_site(&sites, &self.site) {
110                    Some(site) => Ok(site.id.clone()),
111                    None => {
112                        let available = if sites.is_empty() {
113                            "<no sites visible to this API key>".to_string()
114                        } else {
115                            sites
116                                .iter()
117                                .map(|s| s.display_name())
118                                .collect::<Vec<_>>()
119                                .join(", ")
120                        };
121                        Err(Error::Api {
122                            message: format!(
123                                "UniFi site '{}' not found on this controller; available sites: [{}]",
124                                self.site, available
125                            ),
126                        })
127                    }
128                }
129            })
130            .await?;
131        Ok(cached.as_str())
132    }
133
134    fn auth(&self, req: RequestBuilder) -> RequestBuilder {
135        req.header("X-API-KEY", self.token.expose_for_auth())
136            .header("Accept", "application/json")
137    }
138
139    async fn send(
140        &self,
141        method: &'static str,
142        path: &str,
143        builder: RequestBuilder,
144    ) -> Result<Response> {
145        let span = tracing::debug_span!(
146            "http.request",
147            method,
148            path,
149            http.status = tracing::field::Empty
150        );
151        async {
152            tracing::debug!("sending request");
153            let resp = self.auth(builder).send().await.map_err(|e| {
154                tracing::warn!(error = %e, "request failed");
155                Error::Network(e)
156            })?;
157            tracing::Span::current().record("http.status", resp.status().as_u16());
158            tracing::debug!("received response");
159            Ok(resp)
160        }
161        .instrument(span)
162        .await
163    }
164
165    // ── Raw HTTP verbs ──────────────────────────────────────────────────────
166
167    async fn get(&self, path: &str, params: &[(&str, String)]) -> Result<Value> {
168        let req = self.http.get(self.url(path)).query(params);
169        let resp = self.send("GET", path, req).await?;
170        parse_json_response(resp).await
171    }
172
173    async fn post(&self, path: &str, body: &Value) -> Result<Value> {
174        let req = self.http.post(self.url(path)).json(body);
175        let resp = self.send("POST", path, req).await?;
176        parse_json_response(resp).await
177    }
178
179    async fn put(&self, path: &str, body: &Value) -> Result<Value> {
180        let req = self.http.put(self.url(path)).json(body);
181        let resp = self.send("PUT", path, req).await?;
182        parse_json_response(resp).await
183    }
184
185    async fn delete(&self, path: &str) -> Result<Value> {
186        let req = self.http.delete(self.url(path));
187        let resp = self.send("DELETE", path, req).await?;
188        parse_optional_json_response(resp).await
189    }
190
191    // ── Site discovery ──────────────────────────────────────────────────────
192
193    /// `GET /v1/sites` — single page of sites accessible to this API key.
194    pub async fn list_sites_page(
195        &self,
196        offset: u32,
197        limit: u32,
198    ) -> Result<super::responses::UnifiSitePage> {
199        let limit = limit.min(MAX_PAGE_LIMIT);
200        let params: Vec<(&str, String)> =
201            vec![("offset", offset.to_string()), ("limit", limit.to_string())];
202        let value = self.get("/sites", &params).await?;
203        parse_site_page(value).map_err(|e| Error::parse(format!("decoding UniFi site page: {e}")))
204    }
205
206    /// Fetch every site by paginating until exhausted.
207    ///
208    /// Same termination logic as `list_all_dns_policies`: stops on empty page,
209    /// known `totalCount`, or short page; capped at 1000 pages.
210    pub async fn list_all_sites(&self) -> Result<Vec<UnifiSite>> {
211        let mut out: Vec<UnifiSite> = Vec::new();
212        let mut offset = 0u32;
213        let mut pages = 0u32;
214        loop {
215            let page = self.list_sites_page(offset, DEFAULT_PAGE_LIMIT).await?;
216            let returned = page.data.len() as u32;
217            let total = page.total();
218            out.extend(page.data);
219            offset += returned.max(1);
220            pages += 1;
221            if returned == 0 {
222                break;
223            }
224            if let Some(total) = total {
225                if out.len() as u32 >= total {
226                    break;
227                }
228            } else if returned < DEFAULT_PAGE_LIMIT {
229                break;
230            }
231            if pages >= 1000 {
232                return Err(Error::parse(
233                    "UniFi site pagination exceeded 1000 pages without terminating",
234                ));
235            }
236        }
237        Ok(out)
238    }
239
240    // ── DNS-policy endpoints ────────────────────────────────────────────────
241
242    /// `GET /sites/{siteId}/dns/policies` — single page.
243    ///
244    /// Caller controls pagination through `offset` and `limit`. `limit` is
245    /// clamped to the documented maximum of 200.
246    pub async fn list_dns_policies_page(
247        &self,
248        offset: u32,
249        limit: u32,
250        filter: Option<&str>,
251    ) -> Result<UnifiDnsPolicyPage> {
252        let limit = limit.min(MAX_PAGE_LIMIT);
253        let mut params: Vec<(&str, String)> =
254            vec![("offset", offset.to_string()), ("limit", limit.to_string())];
255        if let Some(f) = filter {
256            params.push(("filter", f.to_string()));
257        }
258        let site_id = self.resolve_site_id().await?.to_string();
259        let path = self.policies_path(&site_id);
260        let value = self.get(&path, &params).await?;
261        parse_page(value).map_err(|e| Error::parse(format!("decoding UniFi DNS policy page: {e}")))
262    }
263
264    /// Fetch every DNS policy by paginating until exhausted.
265    ///
266    /// Termination: stops when an empty page is returned, or when `totalCount`
267    /// (if present) has been reached. Hard cap of 1000 pages guards against
268    /// pathological controller responses.
269    pub async fn list_all_dns_policies(&self, filter: Option<&str>) -> Result<Vec<UnifiDnsPolicy>> {
270        let mut out: Vec<UnifiDnsPolicy> = Vec::new();
271        let mut offset = 0u32;
272        let mut pages = 0u32;
273        loop {
274            let page = self
275                .list_dns_policies_page(offset, DEFAULT_PAGE_LIMIT, filter)
276                .await?;
277            let returned = page.data.len() as u32;
278            let total = page.total();
279            out.extend(page.data);
280            offset += returned.max(1); // ensure progress even if controller returns count=0
281            pages += 1;
282
283            // Stop conditions: empty page, reached known total, or page cap.
284            if returned == 0 {
285                break;
286            }
287            if let Some(total) = total {
288                if out.len() as u32 >= total {
289                    break;
290                }
291            } else if returned < DEFAULT_PAGE_LIMIT {
292                // No totalCount header — short page means we're done.
293                break;
294            }
295            if pages >= 1000 {
296                return Err(Error::parse(
297                    "UniFi DNS policy pagination exceeded 1000 pages without terminating",
298                ));
299            }
300        }
301        Ok(out)
302    }
303
304    /// `POST /sites/{siteId}/dns/policies`
305    pub async fn create_dns_policy(&self, body: &Value) -> Result<UnifiDnsPolicy> {
306        let site_id = self.resolve_site_id().await?.to_string();
307        let path = self.policies_path(&site_id);
308        let value = self.post(&path, body).await?;
309        serde_json::from_value(value)
310            .map_err(|e| Error::parse(format!("decoding UniFi create DNS policy response: {e}")))
311    }
312
313    /// `GET /sites/{siteId}/dns/policies/{dnsPolicyId}`
314    pub async fn get_dns_policy(&self, policy_id: &str) -> Result<UnifiDnsPolicy> {
315        let site_id = self.resolve_site_id().await?.to_string();
316        let path = self.policy_path(&site_id, policy_id);
317        let value = self.get(&path, &[]).await?;
318        serde_json::from_value(value)
319            .map_err(|e| Error::parse(format!("decoding UniFi get DNS policy response: {e}")))
320    }
321
322    /// `PUT /sites/{siteId}/dns/policies/{dnsPolicyId}`
323    ///
324    /// UniFi requires the full create/update payload — partial updates are
325    /// not supported. Caller is responsible for sending all fields.
326    pub async fn update_dns_policy(&self, policy_id: &str, body: &Value) -> Result<UnifiDnsPolicy> {
327        let site_id = self.resolve_site_id().await?.to_string();
328        let path = self.policy_path(&site_id, policy_id);
329        let value = self.put(&path, body).await?;
330        serde_json::from_value(value)
331            .map_err(|e| Error::parse(format!("decoding UniFi update DNS policy response: {e}")))
332    }
333
334    /// `DELETE /sites/{siteId}/dns/policies/{dnsPolicyId}`
335    pub async fn delete_dns_policy(&self, policy_id: &str) -> Result<()> {
336        let site_id = self.resolve_site_id().await?.to_string();
337        let path = self.policy_path(&site_id, policy_id);
338        self.delete(&path).await?;
339        Ok(())
340    }
341}
342
343/// Parse a UniFi JSON response into a `Value`.
344///
345/// UniFi error responses follow `{"statusCode": 4xx, "statusName": "...",
346/// "message": "..."}` and may also include a `details` array. Non-2xx
347/// responses are mapped to the standard dnsync error variants.
348async fn parse_json_response(resp: Response) -> Result<Value> {
349    let status = resp.status();
350    let bytes = resp.bytes().await.map_err(Error::Network)?;
351
352    if bytes.is_empty() {
353        if status.is_success() {
354            return Ok(Value::Null);
355        }
356        return Err(Error::Http {
357            status: status.as_u16(),
358            body: String::new(),
359        });
360    }
361
362    let value: Value = serde_json::from_slice(&bytes).map_err(|e| {
363        // Use a faux parse error rather than InvalidJson(reqwest::Error)
364        // because we already consumed the response bytes.
365        let _ = e;
366        Error::Parse {
367            context: format!(
368                "UniFi response body is not valid JSON (status {}): {}",
369                status.as_u16(),
370                String::from_utf8_lossy(&bytes)
371                    .chars()
372                    .take(200)
373                    .collect::<String>(),
374            ),
375        }
376    })?;
377
378    if status.is_success() {
379        return Ok(value);
380    }
381
382    let message = unifi_error_message(&value).unwrap_or_else(|| value.to_string());
383
384    if status.as_u16() == 403 {
385        return Err(Error::forbidden(message));
386    }
387    if status.is_client_error() || status.is_server_error() {
388        // 4xx/5xx with an error payload — surface as a vendor API error.
389        return Err(Error::Api { message });
390    }
391
392    Err(Error::Http {
393        status: status.as_u16(),
394        body: value.to_string(),
395    })
396}
397
398/// Like `parse_json_response`, but treats an empty body as success — DELETE
399/// often returns 200 OK with no payload.
400async fn parse_optional_json_response(resp: Response) -> Result<Value> {
401    let status = resp.status();
402    if status.is_success() {
403        let bytes = resp.bytes().await.map_err(Error::Network)?;
404        if bytes.is_empty() {
405            return Ok(Value::Null);
406        }
407        return serde_json::from_slice::<Value>(&bytes).map_err(|_| Error::Parse {
408            context: format!(
409                "UniFi DELETE response was not valid JSON (status {})",
410                status.as_u16()
411            ),
412        });
413    }
414    parse_json_response(resp).await
415}
416
417/// Pull a human-readable error string out of a UniFi error envelope.
418fn unifi_error_message(value: &Value) -> Option<String> {
419    if let Some(msg) = value.get("message").and_then(|m| m.as_str()) {
420        return Some(msg.to_string());
421    }
422    if let Some(msg) = value.get("statusName").and_then(|m| m.as_str()) {
423        return Some(msg.to_string());
424    }
425    None
426}
427
428// ─── Tests ────────────────────────────────────────────────────────────────────
429
430#[cfg(test)]
431mod tests {
432    use super::*;
433    use serde_json::json;
434
435    fn make_resp(status: u16, body: Value) -> reqwest::Response {
436        http::Response::builder()
437            .status(status)
438            .header("content-type", "application/json")
439            .body(body.to_string())
440            .map(reqwest::Response::from)
441            .unwrap()
442    }
443
444    fn make_empty_resp(status: u16) -> reqwest::Response {
445        http::Response::builder()
446            .status(status)
447            .body(String::new())
448            .map(reqwest::Response::from)
449            .unwrap()
450    }
451
452    #[tokio::test]
453    async fn success_returns_body() {
454        let resp = make_resp(200, json!({ "id": "abc" }));
455        let v = parse_json_response(resp).await.unwrap();
456        assert_eq!(v["id"], "abc");
457    }
458
459    #[tokio::test]
460    async fn forbidden_maps_to_forbidden_error() {
461        let resp = make_resp(
462            403,
463            json!({
464                "statusCode": 403,
465                "statusName": "Forbidden",
466                "message": "Invalid API key"
467            }),
468        );
469        let err = parse_json_response(resp).await.unwrap_err();
470        assert!(matches!(err, Error::Forbidden { ref message } if message == "Invalid API key"));
471    }
472
473    #[tokio::test]
474    async fn client_error_maps_to_api_error() {
475        let resp = make_resp(
476            400,
477            json!({
478                "statusCode": 400,
479                "statusName": "BadRequest",
480                "message": "domain is required"
481            }),
482        );
483        let err = parse_json_response(resp).await.unwrap_err();
484        assert!(matches!(err, Error::Api { ref message } if message == "domain is required"));
485    }
486
487    #[tokio::test]
488    async fn empty_success_returns_null() {
489        let resp = make_empty_resp(200);
490        let v = parse_json_response(resp).await.unwrap();
491        assert!(v.is_null());
492    }
493
494    #[tokio::test]
495    async fn empty_failure_returns_http_error() {
496        let resp = make_empty_resp(502);
497        let err = parse_json_response(resp).await.unwrap_err();
498        assert!(matches!(err, Error::Http { status: 502, .. }));
499    }
500
501    #[tokio::test]
502    async fn delete_empty_success_returns_ok_null() {
503        let resp = make_empty_resp(200);
504        let v = parse_optional_json_response(resp).await.unwrap();
505        assert!(v.is_null());
506    }
507
508    #[test]
509    fn unifi_error_message_prefers_message_over_status_name() {
510        let v = json!({"message": "boom", "statusName": "Ouch"});
511        assert_eq!(unifi_error_message(&v).as_deref(), Some("boom"));
512    }
513
514    #[test]
515    fn unifi_error_message_falls_back_to_status_name() {
516        let v = json!({"statusName": "Ouch"});
517        assert_eq!(unifi_error_message(&v).as_deref(), Some("Ouch"));
518    }
519}