Skip to main content

dnslib/vendors/pihole/
client.rs

1use reqwest::{Client, Response};
2use serde_json::Value;
3
4use crate::core::error::{Error, Result};
5use crate::core::secret::ApiToken;
6
7/// Pi-hole v6 REST API client.
8///
9/// Pi-hole uses session-based authentication: the password is exchanged for a
10/// session SID via `POST /api/auth`, and that SID is sent as a Bearer token on
11/// every subsequent request.  Because sessions expire (default 1800 s), each
12/// public HTTP method obtains a fresh SID so callers don't need to manage state.
13#[derive(Clone, Debug)]
14pub struct PiholeClient {
15    pub http: Client,
16    pub base_url: String,
17    password: ApiToken,
18}
19
20impl PiholeClient {
21    pub fn new(base_url: String, password: ApiToken) -> Result<Self> {
22        let base_url = base_url.trim_end_matches('/').to_string();
23        let http = Client::builder()
24            .timeout(std::time::Duration::from_secs(30))
25            .build()
26            .map_err(Error::Network)?;
27        Ok(Self {
28            http,
29            base_url,
30            password,
31        })
32    }
33
34    /// Authenticate and return the session SID.
35    async fn session_id(&self) -> Result<String> {
36        let url = format!("{}/api/auth", self.base_url);
37        let body = serde_json::json!({ "password": self.password.expose_for_auth() });
38        let resp = self.http.post(&url).json(&body).send().await.map_err(|e| {
39            tracing::warn!(error = %e, "Pi-hole authentication request failed");
40            Error::Network(e)
41        })?;
42        let status = resp.status();
43        let data: Value = resp.json().await.map_err(|e| {
44            if e.is_decode() {
45                Error::InvalidJson(e)
46            } else {
47                Error::Network(e)
48            }
49        })?;
50        if !status.is_success() {
51            let message = data
52                .get("error")
53                .and_then(|e| e.get("message"))
54                .and_then(|m| m.as_str())
55                .unwrap_or("authentication failed")
56                .to_string();
57            return if status.as_u16() == 401 || status.as_u16() == 403 {
58                Err(Error::forbidden(message))
59            } else {
60                Err(Error::Api { message })
61            };
62        }
63        data.get("session")
64            .and_then(|s| s.get("sid"))
65            .and_then(|s| s.as_str())
66            .map(ToOwned::to_owned)
67            .ok_or_else(|| Error::parse("Pi-hole auth response missing session SID"))
68    }
69
70    pub async fn get(&self, path: &str, params: &[(&str, String)]) -> Result<Value> {
71        let sid = self.session_id().await?;
72        let url = format!("{}{}", self.base_url, path);
73        let span = tracing::debug_span!("http.get", path, http.status = tracing::field::Empty);
74        let _enter = span.enter();
75        tracing::debug!("sending GET");
76        let resp = self
77            .http
78            .get(&url)
79            .bearer_auth(&sid)
80            .query(params)
81            .send()
82            .await
83            .map_err(|e| {
84                tracing::warn!(error = %e, "GET failed");
85                Error::Network(e)
86            })?;
87        span.record("http.status", resp.status().as_u16());
88        tracing::debug!("received response");
89        parse_response(resp).await
90    }
91
92    pub async fn post(&self, path: &str, body: &Value) -> Result<Value> {
93        let sid = self.session_id().await?;
94        let url = format!("{}{}", self.base_url, path);
95        let span = tracing::debug_span!("http.post", path, http.status = tracing::field::Empty);
96        let _enter = span.enter();
97        tracing::debug!("sending POST");
98        let resp = self
99            .http
100            .post(&url)
101            .bearer_auth(&sid)
102            .json(body)
103            .send()
104            .await
105            .map_err(|e| {
106                tracing::warn!(error = %e, "POST failed");
107                Error::Network(e)
108            })?;
109        span.record("http.status", resp.status().as_u16());
110        tracing::debug!("received response");
111        parse_response(resp).await
112    }
113
114    pub async fn delete(&self, path: &str) -> Result<Value> {
115        let sid = self.session_id().await?;
116        let url = format!("{}{}", self.base_url, path);
117        let span = tracing::debug_span!("http.delete", path, http.status = tracing::field::Empty);
118        let _enter = span.enter();
119        tracing::debug!("sending DELETE");
120        let resp = self
121            .http
122            .delete(&url)
123            .bearer_auth(&sid)
124            .send()
125            .await
126            .map_err(|e| {
127                tracing::warn!(error = %e, "DELETE failed");
128                Error::Network(e)
129            })?;
130        span.record("http.status", resp.status().as_u16());
131        tracing::debug!("received response");
132        parse_response(resp).await
133    }
134
135    pub async fn delete_with_body(&self, path: &str, body: &Value) -> Result<Value> {
136        let sid = self.session_id().await?;
137        let url = format!("{}{}", self.base_url, path);
138        let span = tracing::debug_span!("http.delete", path, http.status = tracing::field::Empty);
139        let _enter = span.enter();
140        tracing::debug!("sending DELETE");
141        let resp = self
142            .http
143            .delete(&url)
144            .bearer_auth(&sid)
145            .json(body)
146            .send()
147            .await
148            .map_err(|e| {
149                tracing::warn!(error = %e, "DELETE failed");
150                Error::Network(e)
151            })?;
152        span.record("http.status", resp.status().as_u16());
153        tracing::debug!("received response");
154        parse_response(resp).await
155    }
156}
157
158async fn parse_response(resp: Response) -> Result<Value> {
159    let status = resp.status();
160
161    // 204 No Content (successful DELETE operations return no body)
162    if status == reqwest::StatusCode::NO_CONTENT {
163        return Ok(serde_json::json!({}));
164    }
165
166    let body: Value = resp.json().await.map_err(|e| {
167        if e.is_decode() {
168            Error::InvalidJson(e)
169        } else {
170            Error::Network(e)
171        }
172    })?;
173
174    if status.is_success() {
175        return Ok(body);
176    }
177
178    let message = body
179        .get("error")
180        .and_then(|e| e.get("message"))
181        .and_then(|m| m.as_str())
182        .unwrap_or("unknown error")
183        .to_string();
184
185    if status.as_u16() == 401 || status.as_u16() == 403 {
186        return Err(Error::forbidden(message));
187    }
188
189    Err(Error::Api { message })
190}
191
192// ─── Tests ────────────────────────────────────────────────────────────────────
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use serde_json::json;
198
199    fn make_resp(status: u16, body: Value) -> reqwest::Response {
200        http::Response::builder()
201            .status(status)
202            .header("content-type", "application/json")
203            .body(body.to_string())
204            .map(reqwest::Response::from)
205            .unwrap()
206    }
207
208    fn make_client() -> PiholeClient {
209        PiholeClient::new(
210            "http://pi.hole".to_string(),
211            crate::core::secret::ApiToken::new("test-password"),
212        )
213        .unwrap()
214    }
215
216    #[test]
217    fn client_builds_successfully() {
218        let client = make_client();
219        assert_eq!(client.base_url, "http://pi.hole");
220    }
221
222    #[test]
223    fn trailing_slash_stripped_from_base_url() {
224        let client = PiholeClient::new(
225            "http://pi.hole/".to_string(),
226            crate::core::secret::ApiToken::new("pass"),
227        )
228        .unwrap();
229        assert_eq!(client.base_url, "http://pi.hole");
230    }
231
232    #[tokio::test]
233    async fn no_content_response_returns_empty_object() {
234        let resp = http::Response::builder()
235            .status(204)
236            .body("".to_string())
237            .map(reqwest::Response::from)
238            .unwrap();
239        let val = parse_response(resp).await.unwrap();
240        assert!(val.is_object());
241    }
242
243    #[tokio::test]
244    async fn success_response_returns_body() {
245        let resp = make_resp(
246            200,
247            json!({"dns": [{"ip": "1.2.3.4", "host": "myhost.local"}]}),
248        );
249        let val = parse_response(resp).await.unwrap();
250        assert_eq!(val["dns"][0]["ip"], "1.2.3.4");
251    }
252
253    #[tokio::test]
254    async fn forbidden_response_returns_forbidden_error() {
255        let resp = make_resp(
256            403,
257            json!({"error": {"key": "unauthorized", "message": "Unauthorized", "hint": null}}),
258        );
259        let err = parse_response(resp).await.unwrap_err();
260        assert!(matches!(err, Error::Forbidden { ref message } if message == "Unauthorized"));
261    }
262
263    #[tokio::test]
264    async fn unauthorized_response_returns_forbidden_error() {
265        let resp = make_resp(
266            401,
267            json!({"error": {"key": "unauthorized", "message": "Invalid password", "hint": null}}),
268        );
269        let err = parse_response(resp).await.unwrap_err();
270        assert!(matches!(err, Error::Forbidden { ref message } if message == "Invalid password"));
271    }
272
273    #[tokio::test]
274    async fn api_error_returns_message() {
275        let resp = make_resp(
276            400,
277            json!({"error": {"key": "bad_request", "message": "Invalid domain", "hint": null}}),
278        );
279        let err = parse_response(resp).await.unwrap_err();
280        assert!(matches!(err, Error::Api { ref message } if message == "Invalid domain"));
281    }
282
283    #[tokio::test]
284    async fn missing_error_key_uses_unknown_error() {
285        let resp = make_resp(500, json!({}));
286        let err = parse_response(resp).await.unwrap_err();
287        assert!(matches!(err, Error::Api { ref message } if message == "unknown error"));
288    }
289}