1use std::time::Duration;
12
13use reqwest::{Client, RequestBuilder, Response};
14use serde_json::Value;
15use tokio::sync::OnceCell;
16use tracing::Instrument;
17
18use crate::core::error::{Error, Result};
19use crate::core::secret::ApiToken;
20
21use super::responses::{
22 UnifiDnsPolicy, UnifiDnsPolicyPage, UnifiSite, match_site, parse_page, parse_site_page,
23};
24
25pub const MAX_PAGE_LIMIT: u32 = 200;
27
28pub const DEFAULT_PAGE_LIMIT: u32 = 200;
30
31#[derive(Clone, Debug)]
38pub struct UnifiClient {
39 http: Client,
40 base_url: String,
41 token: ApiToken,
42 site: String,
43 resolved_site_id: std::sync::Arc<OnceCell<String>>,
44}
45
46impl UnifiClient {
47 pub fn new(base_url: String, token: ApiToken, site: String) -> Result<Self> {
50 let http = Client::builder()
51 .timeout(Duration::from_secs(30))
52 .build()
53 .map_err(Error::Network)?;
54 Ok(Self {
55 http,
56 base_url,
57 token,
58 site,
59 resolved_site_id: std::sync::Arc::new(OnceCell::new()),
60 })
61 }
62
63 pub fn base_url(&self) -> &str {
64 &self.base_url
65 }
66
67 pub fn site(&self) -> &str {
70 &self.site
71 }
72
73 #[cfg(test)]
76 pub fn token_for_test(&self) -> &str {
77 self.token.expose_for_auth()
78 }
79
80 fn url(&self, path: &str) -> String {
81 format!("{}{}", self.base_url, path)
82 }
83
84 fn policies_path(&self, site_id: &str) -> String {
85 format!("/sites/{site_id}/dns/policies")
86 }
87
88 fn policy_path(&self, site_id: &str, policy_id: &str) -> String {
89 format!("/sites/{site_id}/dns/policies/{policy_id}")
90 }
91
92 pub async fn resolve_site_id(&self) -> Result<&str> {
105 let cached = self
106 .resolved_site_id
107 .get_or_try_init(|| async {
108 let sites = self.list_all_sites().await?;
109 match match_site(&sites, &self.site) {
110 Some(site) => Ok(site.id.clone()),
111 None => {
112 let available = if sites.is_empty() {
113 "<no sites visible to this API key>".to_string()
114 } else {
115 sites
116 .iter()
117 .map(|s| s.display_name())
118 .collect::<Vec<_>>()
119 .join(", ")
120 };
121 Err(Error::Api {
122 message: format!(
123 "UniFi site '{}' not found on this controller; available sites: [{}]",
124 self.site, available
125 ),
126 })
127 }
128 }
129 })
130 .await?;
131 Ok(cached.as_str())
132 }
133
134 fn auth(&self, req: RequestBuilder) -> RequestBuilder {
135 req.header("X-API-KEY", self.token.expose_for_auth())
136 .header("Accept", "application/json")
137 }
138
139 async fn send(
140 &self,
141 method: &'static str,
142 path: &str,
143 builder: RequestBuilder,
144 ) -> Result<Response> {
145 let span = tracing::debug_span!(
146 "http.request",
147 method,
148 path,
149 http.status = tracing::field::Empty
150 );
151 async {
152 tracing::debug!("sending request");
153 let resp = self.auth(builder).send().await.map_err(|e| {
154 tracing::warn!(error = %e, "request failed");
155 Error::Network(e)
156 })?;
157 tracing::Span::current().record("http.status", resp.status().as_u16());
158 tracing::debug!("received response");
159 Ok(resp)
160 }
161 .instrument(span)
162 .await
163 }
164
165 async fn get(&self, path: &str, params: &[(&str, String)]) -> Result<Value> {
168 let req = self.http.get(self.url(path)).query(params);
169 let resp = self.send("GET", path, req).await?;
170 parse_json_response(resp).await
171 }
172
173 async fn post(&self, path: &str, body: &Value) -> Result<Value> {
174 let req = self.http.post(self.url(path)).json(body);
175 let resp = self.send("POST", path, req).await?;
176 parse_json_response(resp).await
177 }
178
179 async fn put(&self, path: &str, body: &Value) -> Result<Value> {
180 let req = self.http.put(self.url(path)).json(body);
181 let resp = self.send("PUT", path, req).await?;
182 parse_json_response(resp).await
183 }
184
185 async fn delete(&self, path: &str) -> Result<Value> {
186 let req = self.http.delete(self.url(path));
187 let resp = self.send("DELETE", path, req).await?;
188 parse_optional_json_response(resp).await
189 }
190
191 pub async fn list_sites_page(
195 &self,
196 offset: u32,
197 limit: u32,
198 ) -> Result<super::responses::UnifiSitePage> {
199 let limit = limit.min(MAX_PAGE_LIMIT);
200 let params: Vec<(&str, String)> =
201 vec![("offset", offset.to_string()), ("limit", limit.to_string())];
202 let value = self.get("/sites", ¶ms).await?;
203 parse_site_page(value).map_err(|e| Error::parse(format!("decoding UniFi site page: {e}")))
204 }
205
206 pub async fn list_all_sites(&self) -> Result<Vec<UnifiSite>> {
211 let mut out: Vec<UnifiSite> = Vec::new();
212 let mut offset = 0u32;
213 let mut pages = 0u32;
214 loop {
215 let page = self.list_sites_page(offset, DEFAULT_PAGE_LIMIT).await?;
216 let returned = page.data.len() as u32;
217 let total = page.total();
218 out.extend(page.data);
219 offset += returned.max(1);
220 pages += 1;
221 if returned == 0 {
222 break;
223 }
224 if let Some(total) = total {
225 if out.len() as u32 >= total {
226 break;
227 }
228 } else if returned < DEFAULT_PAGE_LIMIT {
229 break;
230 }
231 if pages >= 1000 {
232 return Err(Error::parse(
233 "UniFi site pagination exceeded 1000 pages without terminating",
234 ));
235 }
236 }
237 Ok(out)
238 }
239
240 pub async fn list_dns_policies_page(
247 &self,
248 offset: u32,
249 limit: u32,
250 filter: Option<&str>,
251 ) -> Result<UnifiDnsPolicyPage> {
252 let limit = limit.min(MAX_PAGE_LIMIT);
253 let mut params: Vec<(&str, String)> =
254 vec![("offset", offset.to_string()), ("limit", limit.to_string())];
255 if let Some(f) = filter {
256 params.push(("filter", f.to_string()));
257 }
258 let site_id = self.resolve_site_id().await?.to_string();
259 let path = self.policies_path(&site_id);
260 let value = self.get(&path, ¶ms).await?;
261 parse_page(value).map_err(|e| Error::parse(format!("decoding UniFi DNS policy page: {e}")))
262 }
263
264 pub async fn list_all_dns_policies(&self, filter: Option<&str>) -> Result<Vec<UnifiDnsPolicy>> {
270 let mut out: Vec<UnifiDnsPolicy> = Vec::new();
271 let mut offset = 0u32;
272 let mut pages = 0u32;
273 loop {
274 let page = self
275 .list_dns_policies_page(offset, DEFAULT_PAGE_LIMIT, filter)
276 .await?;
277 let returned = page.data.len() as u32;
278 let total = page.total();
279 out.extend(page.data);
280 offset += returned.max(1); pages += 1;
282
283 if returned == 0 {
285 break;
286 }
287 if let Some(total) = total {
288 if out.len() as u32 >= total {
289 break;
290 }
291 } else if returned < DEFAULT_PAGE_LIMIT {
292 break;
294 }
295 if pages >= 1000 {
296 return Err(Error::parse(
297 "UniFi DNS policy pagination exceeded 1000 pages without terminating",
298 ));
299 }
300 }
301 Ok(out)
302 }
303
304 pub async fn create_dns_policy(&self, body: &Value) -> Result<UnifiDnsPolicy> {
306 let site_id = self.resolve_site_id().await?.to_string();
307 let path = self.policies_path(&site_id);
308 let value = self.post(&path, body).await?;
309 serde_json::from_value(value)
310 .map_err(|e| Error::parse(format!("decoding UniFi create DNS policy response: {e}")))
311 }
312
313 pub async fn get_dns_policy(&self, policy_id: &str) -> Result<UnifiDnsPolicy> {
315 let site_id = self.resolve_site_id().await?.to_string();
316 let path = self.policy_path(&site_id, policy_id);
317 let value = self.get(&path, &[]).await?;
318 serde_json::from_value(value)
319 .map_err(|e| Error::parse(format!("decoding UniFi get DNS policy response: {e}")))
320 }
321
322 pub async fn update_dns_policy(&self, policy_id: &str, body: &Value) -> Result<UnifiDnsPolicy> {
327 let site_id = self.resolve_site_id().await?.to_string();
328 let path = self.policy_path(&site_id, policy_id);
329 let value = self.put(&path, body).await?;
330 serde_json::from_value(value)
331 .map_err(|e| Error::parse(format!("decoding UniFi update DNS policy response: {e}")))
332 }
333
334 pub async fn delete_dns_policy(&self, policy_id: &str) -> Result<()> {
336 let site_id = self.resolve_site_id().await?.to_string();
337 let path = self.policy_path(&site_id, policy_id);
338 self.delete(&path).await?;
339 Ok(())
340 }
341}
342
343async fn parse_json_response(resp: Response) -> Result<Value> {
349 let status = resp.status();
350 let bytes = resp.bytes().await.map_err(Error::Network)?;
351
352 if bytes.is_empty() {
353 if status.is_success() {
354 return Ok(Value::Null);
355 }
356 return Err(Error::Http {
357 status: status.as_u16(),
358 body: String::new(),
359 });
360 }
361
362 let value: Value = serde_json::from_slice(&bytes).map_err(|e| {
363 let _ = e;
366 Error::Parse {
367 context: format!(
368 "UniFi response body is not valid JSON (status {}): {}",
369 status.as_u16(),
370 String::from_utf8_lossy(&bytes)
371 .chars()
372 .take(200)
373 .collect::<String>(),
374 ),
375 }
376 })?;
377
378 if status.is_success() {
379 return Ok(value);
380 }
381
382 let message = unifi_error_message(&value).unwrap_or_else(|| value.to_string());
383
384 if status.as_u16() == 403 {
385 return Err(Error::forbidden(message));
386 }
387 if status.is_client_error() || status.is_server_error() {
388 return Err(Error::Api { message });
390 }
391
392 Err(Error::Http {
393 status: status.as_u16(),
394 body: value.to_string(),
395 })
396}
397
398async fn parse_optional_json_response(resp: Response) -> Result<Value> {
401 let status = resp.status();
402 if status.is_success() {
403 let bytes = resp.bytes().await.map_err(Error::Network)?;
404 if bytes.is_empty() {
405 return Ok(Value::Null);
406 }
407 return serde_json::from_slice::<Value>(&bytes).map_err(|_| Error::Parse {
408 context: format!(
409 "UniFi DELETE response was not valid JSON (status {})",
410 status.as_u16()
411 ),
412 });
413 }
414 parse_json_response(resp).await
415}
416
417fn unifi_error_message(value: &Value) -> Option<String> {
419 if let Some(msg) = value.get("message").and_then(|m| m.as_str()) {
420 return Some(msg.to_string());
421 }
422 if let Some(msg) = value.get("statusName").and_then(|m| m.as_str()) {
423 return Some(msg.to_string());
424 }
425 None
426}
427
428#[cfg(test)]
431mod tests {
432 use super::*;
433 use serde_json::json;
434
435 fn make_resp(status: u16, body: Value) -> reqwest::Response {
436 http::Response::builder()
437 .status(status)
438 .header("content-type", "application/json")
439 .body(body.to_string())
440 .map(reqwest::Response::from)
441 .unwrap()
442 }
443
444 fn make_empty_resp(status: u16) -> reqwest::Response {
445 http::Response::builder()
446 .status(status)
447 .body(String::new())
448 .map(reqwest::Response::from)
449 .unwrap()
450 }
451
452 #[tokio::test]
453 async fn success_returns_body() {
454 let resp = make_resp(200, json!({ "id": "abc" }));
455 let v = parse_json_response(resp).await.unwrap();
456 assert_eq!(v["id"], "abc");
457 }
458
459 #[tokio::test]
460 async fn forbidden_maps_to_forbidden_error() {
461 let resp = make_resp(
462 403,
463 json!({
464 "statusCode": 403,
465 "statusName": "Forbidden",
466 "message": "Invalid API key"
467 }),
468 );
469 let err = parse_json_response(resp).await.unwrap_err();
470 assert!(matches!(err, Error::Forbidden { ref message } if message == "Invalid API key"));
471 }
472
473 #[tokio::test]
474 async fn client_error_maps_to_api_error() {
475 let resp = make_resp(
476 400,
477 json!({
478 "statusCode": 400,
479 "statusName": "BadRequest",
480 "message": "domain is required"
481 }),
482 );
483 let err = parse_json_response(resp).await.unwrap_err();
484 assert!(matches!(err, Error::Api { ref message } if message == "domain is required"));
485 }
486
487 #[tokio::test]
488 async fn empty_success_returns_null() {
489 let resp = make_empty_resp(200);
490 let v = parse_json_response(resp).await.unwrap();
491 assert!(v.is_null());
492 }
493
494 #[tokio::test]
495 async fn empty_failure_returns_http_error() {
496 let resp = make_empty_resp(502);
497 let err = parse_json_response(resp).await.unwrap_err();
498 assert!(matches!(err, Error::Http { status: 502, .. }));
499 }
500
501 #[tokio::test]
502 async fn delete_empty_success_returns_ok_null() {
503 let resp = make_empty_resp(200);
504 let v = parse_optional_json_response(resp).await.unwrap();
505 assert!(v.is_null());
506 }
507
508 #[test]
509 fn unifi_error_message_prefers_message_over_status_name() {
510 let v = json!({"message": "boom", "statusName": "Ouch"});
511 assert_eq!(unifi_error_message(&v).as_deref(), Some("boom"));
512 }
513
514 #[test]
515 fn unifi_error_message_falls_back_to_status_name() {
516 let v = json!({"statusName": "Ouch"});
517 assert_eq!(unifi_error_message(&v).as_deref(), Some("Ouch"));
518 }
519}