dnslib/vendors/pihole/
client.rs1use reqwest::{Client, Response};
2use serde_json::Value;
3
4use crate::core::error::{Error, Result};
5use crate::core::secret::ApiToken;
6
7#[derive(Clone, Debug)]
14pub struct PiholeClient {
15 pub http: Client,
16 pub base_url: String,
17 password: ApiToken,
18}
19
20impl PiholeClient {
21 pub fn new(base_url: String, password: ApiToken) -> Result<Self> {
22 let base_url = base_url.trim_end_matches('/').to_string();
23 let http = Client::builder()
24 .timeout(std::time::Duration::from_secs(30))
25 .build()
26 .map_err(Error::Network)?;
27 Ok(Self {
28 http,
29 base_url,
30 password,
31 })
32 }
33
34 async fn session_id(&self) -> Result<String> {
36 let url = format!("{}/api/auth", self.base_url);
37 let body = serde_json::json!({ "password": self.password.expose_for_auth() });
38 let resp = self.http.post(&url).json(&body).send().await.map_err(|e| {
39 tracing::warn!(error = %e, "Pi-hole authentication request failed");
40 Error::Network(e)
41 })?;
42 let status = resp.status();
43 let data: Value = resp.json().await.map_err(|e| {
44 if e.is_decode() {
45 Error::InvalidJson(e)
46 } else {
47 Error::Network(e)
48 }
49 })?;
50 if !status.is_success() {
51 let message = data
52 .get("error")
53 .and_then(|e| e.get("message"))
54 .and_then(|m| m.as_str())
55 .unwrap_or("authentication failed")
56 .to_string();
57 return if status.as_u16() == 401 || status.as_u16() == 403 {
58 Err(Error::forbidden(message))
59 } else {
60 Err(Error::Api { message })
61 };
62 }
63 data.get("session")
64 .and_then(|s| s.get("sid"))
65 .and_then(|s| s.as_str())
66 .map(ToOwned::to_owned)
67 .ok_or_else(|| Error::parse("Pi-hole auth response missing session SID"))
68 }
69
70 pub async fn get(&self, path: &str, params: &[(&str, String)]) -> Result<Value> {
71 let sid = self.session_id().await?;
72 let url = format!("{}{}", self.base_url, path);
73 let span = tracing::debug_span!("http.get", path, http.status = tracing::field::Empty);
74 let _enter = span.enter();
75 tracing::debug!("sending GET");
76 let resp = self
77 .http
78 .get(&url)
79 .bearer_auth(&sid)
80 .query(params)
81 .send()
82 .await
83 .map_err(|e| {
84 tracing::warn!(error = %e, "GET failed");
85 Error::Network(e)
86 })?;
87 span.record("http.status", resp.status().as_u16());
88 tracing::debug!("received response");
89 parse_response(resp).await
90 }
91
92 pub async fn post(&self, path: &str, body: &Value) -> Result<Value> {
93 let sid = self.session_id().await?;
94 let url = format!("{}{}", self.base_url, path);
95 let span = tracing::debug_span!("http.post", path, http.status = tracing::field::Empty);
96 let _enter = span.enter();
97 tracing::debug!("sending POST");
98 let resp = self
99 .http
100 .post(&url)
101 .bearer_auth(&sid)
102 .json(body)
103 .send()
104 .await
105 .map_err(|e| {
106 tracing::warn!(error = %e, "POST failed");
107 Error::Network(e)
108 })?;
109 span.record("http.status", resp.status().as_u16());
110 tracing::debug!("received response");
111 parse_response(resp).await
112 }
113
114 pub async fn delete(&self, path: &str) -> Result<Value> {
115 let sid = self.session_id().await?;
116 let url = format!("{}{}", self.base_url, path);
117 let span = tracing::debug_span!("http.delete", path, http.status = tracing::field::Empty);
118 let _enter = span.enter();
119 tracing::debug!("sending DELETE");
120 let resp = self
121 .http
122 .delete(&url)
123 .bearer_auth(&sid)
124 .send()
125 .await
126 .map_err(|e| {
127 tracing::warn!(error = %e, "DELETE failed");
128 Error::Network(e)
129 })?;
130 span.record("http.status", resp.status().as_u16());
131 tracing::debug!("received response");
132 parse_response(resp).await
133 }
134
135 pub async fn delete_with_body(&self, path: &str, body: &Value) -> Result<Value> {
136 let sid = self.session_id().await?;
137 let url = format!("{}{}", self.base_url, path);
138 let span = tracing::debug_span!("http.delete", path, http.status = tracing::field::Empty);
139 let _enter = span.enter();
140 tracing::debug!("sending DELETE");
141 let resp = self
142 .http
143 .delete(&url)
144 .bearer_auth(&sid)
145 .json(body)
146 .send()
147 .await
148 .map_err(|e| {
149 tracing::warn!(error = %e, "DELETE failed");
150 Error::Network(e)
151 })?;
152 span.record("http.status", resp.status().as_u16());
153 tracing::debug!("received response");
154 parse_response(resp).await
155 }
156}
157
158async fn parse_response(resp: Response) -> Result<Value> {
159 let status = resp.status();
160
161 if status == reqwest::StatusCode::NO_CONTENT {
163 return Ok(serde_json::json!({}));
164 }
165
166 let body: Value = resp.json().await.map_err(|e| {
167 if e.is_decode() {
168 Error::InvalidJson(e)
169 } else {
170 Error::Network(e)
171 }
172 })?;
173
174 if status.is_success() {
175 return Ok(body);
176 }
177
178 let message = body
179 .get("error")
180 .and_then(|e| e.get("message"))
181 .and_then(|m| m.as_str())
182 .unwrap_or("unknown error")
183 .to_string();
184
185 if status.as_u16() == 401 || status.as_u16() == 403 {
186 return Err(Error::forbidden(message));
187 }
188
189 Err(Error::Api { message })
190}
191
192#[cfg(test)]
195mod tests {
196 use super::*;
197 use serde_json::json;
198
199 fn make_resp(status: u16, body: Value) -> reqwest::Response {
200 http::Response::builder()
201 .status(status)
202 .header("content-type", "application/json")
203 .body(body.to_string())
204 .map(reqwest::Response::from)
205 .unwrap()
206 }
207
208 fn make_client() -> PiholeClient {
209 PiholeClient::new(
210 "http://pi.hole".to_string(),
211 crate::core::secret::ApiToken::new("test-password"),
212 )
213 .unwrap()
214 }
215
216 #[test]
217 fn client_builds_successfully() {
218 let client = make_client();
219 assert_eq!(client.base_url, "http://pi.hole");
220 }
221
222 #[test]
223 fn trailing_slash_stripped_from_base_url() {
224 let client = PiholeClient::new(
225 "http://pi.hole/".to_string(),
226 crate::core::secret::ApiToken::new("pass"),
227 )
228 .unwrap();
229 assert_eq!(client.base_url, "http://pi.hole");
230 }
231
232 #[tokio::test]
233 async fn no_content_response_returns_empty_object() {
234 let resp = http::Response::builder()
235 .status(204)
236 .body("".to_string())
237 .map(reqwest::Response::from)
238 .unwrap();
239 let val = parse_response(resp).await.unwrap();
240 assert!(val.is_object());
241 }
242
243 #[tokio::test]
244 async fn success_response_returns_body() {
245 let resp = make_resp(
246 200,
247 json!({"dns": [{"ip": "1.2.3.4", "host": "myhost.local"}]}),
248 );
249 let val = parse_response(resp).await.unwrap();
250 assert_eq!(val["dns"][0]["ip"], "1.2.3.4");
251 }
252
253 #[tokio::test]
254 async fn forbidden_response_returns_forbidden_error() {
255 let resp = make_resp(
256 403,
257 json!({"error": {"key": "unauthorized", "message": "Unauthorized", "hint": null}}),
258 );
259 let err = parse_response(resp).await.unwrap_err();
260 assert!(matches!(err, Error::Forbidden { ref message } if message == "Unauthorized"));
261 }
262
263 #[tokio::test]
264 async fn unauthorized_response_returns_forbidden_error() {
265 let resp = make_resp(
266 401,
267 json!({"error": {"key": "unauthorized", "message": "Invalid password", "hint": null}}),
268 );
269 let err = parse_response(resp).await.unwrap_err();
270 assert!(matches!(err, Error::Forbidden { ref message } if message == "Invalid password"));
271 }
272
273 #[tokio::test]
274 async fn api_error_returns_message() {
275 let resp = make_resp(
276 400,
277 json!({"error": {"key": "bad_request", "message": "Invalid domain", "hint": null}}),
278 );
279 let err = parse_response(resp).await.unwrap_err();
280 assert!(matches!(err, Error::Api { ref message } if message == "Invalid domain"));
281 }
282
283 #[tokio::test]
284 async fn missing_error_key_uses_unknown_error() {
285 let resp = make_resp(500, json!({}));
286 let err = parse_response(resp).await.unwrap_err();
287 assert!(matches!(err, Error::Api { ref message } if message == "unknown error"));
288 }
289}