use reqwest::{self};
use serde::{de::DeserializeOwned, Deserialize};
use simple_error;
use std::collections::HashMap;
use std::error::Error;
use std::net::IpAddr;
#[derive(Deserialize, Debug)]
pub struct SummaryRaw {
pub domains_being_blocked: u64,
pub dns_queries_today: u64,
pub ads_blocked_today: u64,
pub ads_percentage_today: f64,
pub unique_domains: u64,
pub queries_forwarded: u64,
pub queries_cached: u64,
pub clients_ever_seen: u64,
pub unique_clients: u64,
pub dns_queries_all_types: u64,
#[serde(rename = "reply_NODATA")]
pub reply_nodata: u64,
#[serde(rename = "reply_NXDOMAIN")]
pub reply_nxdomain: u64,
#[serde(rename = "reply_CNAME")]
pub reply_cname: u64,
#[serde(rename = "reply_IP")]
pub reply_ip: u64,
pub privacy_level: u64,
pub status: String,
}
#[derive(Deserialize, Debug)]
pub struct Summary {
pub domains_being_blocked: String,
pub dns_queries_today: String,
pub ads_blocked_today: String,
pub ads_percentage_today: String,
pub unique_domains: String,
pub queries_forwarded: String,
pub queries_cached: String,
pub clients_ever_seen: String,
pub unique_clients: String,
pub dns_queries_all_types: String,
#[serde(rename = "reply_NODATA")]
pub reply_nodata: String,
#[serde(rename = "reply_NXDOMAIN")]
pub reply_nxdomain: String,
#[serde(rename = "reply_CNAME")]
pub reply_cname: String,
#[serde(rename = "reply_IP")]
pub reply_ip: String,
pub privacy_level: String,
pub status: String,
}
#[derive(Deserialize, Debug)]
pub struct OverTimeData {
pub domains_over_time: HashMap<i64, u64>,
pub ads_over_time: HashMap<i64, u64>,
}
#[derive(Deserialize, Debug)]
pub struct TopItems {
pub top_queries: HashMap<String, u64>,
pub top_ads: HashMap<String, u64>,
}
#[derive(Deserialize, Debug)]
pub struct TopClients {
pub top_sources: HashMap<String, u64>,
}
#[derive(Deserialize, Debug)]
pub struct TopClientsBlocked {
pub top_sources_blocked: HashMap<String, u64>,
}
#[derive(Deserialize, Debug)]
pub struct ForwardDestinations {
pub forward_destinations: HashMap<String, f64>,
}
#[derive(Deserialize, Debug)]
pub struct QueryTypes {
pub querytypes: HashMap<String, f64>,
}
#[derive(Deserialize, Debug)]
pub struct Query {
pub timestring: String,
pub query_type: String,
pub domain: String,
pub client: String,
pub answer_type: String,
}
#[derive(Deserialize, Debug)]
pub struct AllQueries {
data: Vec<Query>,
}
#[derive(Deserialize, Debug)]
pub struct Status {
pub status: String,
}
#[derive(Deserialize, Debug)]
pub struct Version {
pub version: u32,
}
#[derive(Deserialize, Debug)]
pub struct CacheInfo {
#[serde(rename = "cache-size")]
pub cache_size: u64,
#[serde(rename = "cache-live-freed")]
pub cache_live_freed: u64,
#[serde(rename = "cache-inserted")]
pub cache_inserted: u64,
}
#[derive(Deserialize, Debug)]
pub struct ClientName {
pub name: String,
pub ip: IpAddr,
}
#[derive(Deserialize, Debug)]
pub struct NetworkClient {
pub id: u64,
pub ip: Vec<IpAddr>,
pub hwaddr: String,
pub interface: String,
pub name: String,
#[serde(rename = "firstSeen")]
pub first_seen: u64,
#[serde(rename = "lastQuery")]
pub last_query: u64,
#[serde(rename = "numQueries")]
pub num_queries: u64,
#[serde(rename = "macVendor")]
pub mac_vendor: String,
}
#[derive(Deserialize, Debug)]
pub struct Network {
pub network: Vec<NetworkClient>,
}
pub struct PiHoleAPI {
host: String,
api_key: Option<String>,
}
impl PiHoleAPI {
pub fn new(host: String, api_key: Option<String>) -> Self {
Self { host, api_key }
}
pub fn set_api_key(&mut self, api_key: &String) {
self.api_key = Some(api_key.into());
}
async fn simple_json_request<T>(&self, path_query: String) -> Result<T, Box<dyn Error>>
where
T: DeserializeOwned,
{
let response = reqwest::get(&format!("{}{}", self.host, path_query)).await?;
Ok(response.json().await?)
}
async fn authenticated_json_request<T>(&self, path_query: String) -> Result<T, Box<dyn Error>>
where
T: DeserializeOwned,
{
if self.api_key == None {
simple_error::bail!("API key is required for authenticated requests");
}
let auth_path_query;
match path_query.contains("?") {
true => {
auth_path_query = format!(
"{}{}&auth={}",
self.host,
path_query,
self.api_key.as_ref().unwrap()
)
}
false => {
auth_path_query = format!(
"{}{}?auth={}",
self.host,
path_query,
self.api_key.as_ref().unwrap()
)
}
}
let response = reqwest::get(&auth_path_query).await?;
Ok(response.json().await?)
}
pub async fn get_summary_raw(&self) -> Result<SummaryRaw, Box<dyn Error>> {
self.simple_json_request("/admin/api.php?summaryRaw".to_string())
.await
}
pub async fn get_summary(&self) -> Result<Summary, Box<dyn Error>> {
self.simple_json_request("/admin/api.php?summary".to_string())
.await
}
pub async fn get_over_time_data_10_mins(&self) -> Result<OverTimeData, Box<dyn Error>> {
self.simple_json_request("/admin/api.php?overTimeData10mins".to_string())
.await
}
pub async fn get_top_items(&self, count: Option<u32>) -> Result<TopItems, Box<dyn Error>> {
self.authenticated_json_request(format!("/admin/api.php?topItems={}", count.unwrap_or(10)))
.await
}
pub async fn get_top_clients(&self, count: Option<u32>) -> Result<TopClients, Box<dyn Error>> {
self.authenticated_json_request(format!(
"/admin/api.php?topClients={}",
count.unwrap_or(10)
))
.await
}
pub async fn get_top_clients_blocked(
&self,
count: Option<u32>,
) -> Result<TopClientsBlocked, Box<dyn Error>> {
self.authenticated_json_request(format!(
"/admin/api.php?topClientsBlocked={}",
count.unwrap_or(10)
))
.await
}
pub async fn get_forward_destinations(&self) -> Result<ForwardDestinations, Box<dyn Error>> {
self.authenticated_json_request("/admin/api.php?getForwardDestinations".to_string())
.await
}
pub async fn get_query_types(&self) -> Result<QueryTypes, Box<dyn Error>> {
self.authenticated_json_request("/admin/api.php?getQueryTypes".to_string())
.await
}
pub async fn get_all_queries(&self, count: u32) -> Result<AllQueries, Box<dyn Error>> {
let mut raw_data: HashMap<String, Vec<Vec<String>>> = self
.authenticated_json_request(format!("/admin/api.php?getAllQueries={}", count))
.await?;
let data = AllQueries {
data: raw_data
.remove("data")
.unwrap()
.iter()
.map(|raw_query| Query {
timestring: raw_query[0].clone(),
query_type: raw_query[1].clone(),
domain: raw_query[2].clone(),
client: raw_query[3].clone(),
answer_type: raw_query[4].clone(),
})
.collect(),
};
Ok(data)
}
pub async fn enable(&self) -> Result<Status, Box<dyn Error>> {
self.authenticated_json_request("/admin/api.php?enable".to_string())
.await
}
pub async fn disable(&self, seconds: u64) -> Result<Status, Box<dyn Error>> {
self.authenticated_json_request(format!("/admin/api.php?disable={}", seconds))
.await
}
pub async fn get_version(&self) -> Result<Version, Box<dyn Error>> {
self.simple_json_request("/admin/api.php?version".to_string())
.await
}
pub async fn get_cache_info(&self) -> Result<CacheInfo, Box<dyn Error>> {
let mut raw_data: HashMap<String, CacheInfo> = self
.authenticated_json_request("/admin/api.php?getCacheInfo".to_string())
.await?;
Ok(raw_data.remove("cacheinfo").expect("Missing cache info"))
}
pub async fn get_client_names(&self) -> Result<Vec<ClientName>, Box<dyn Error>> {
let mut raw_data: HashMap<String, Vec<ClientName>> = self
.authenticated_json_request("/admin/api.php?getClientNames".to_string())
.await?;
Ok(raw_data
.remove("clients")
.expect("Missing clients attribute"))
}
pub async fn get_over_time_data_clients(
&self,
) -> Result<HashMap<u64, Vec<u64>>, Box<dyn Error>> {
let mut raw_data: HashMap<String, HashMap<u64, Vec<u64>>> = self
.authenticated_json_request("/admin/api.php?overTimeDataClients".to_string())
.await?;
Ok(raw_data
.remove("over_time")
.expect("Missing over_time attribute"))
}
pub async fn get_network(&self) -> Result<Network, Box<dyn Error>> {
self.authenticated_json_request("/admin/api_db.php?network".to_string())
.await
}
pub async fn get_queries_count(&self) -> Result<u64, Box<dyn Error>> {
let mut raw_data: HashMap<String, u64> = self
.authenticated_json_request("/admin/api_db.php?getQueriesCount".to_string())
.await?;
Ok(raw_data.remove("count").expect("Missing count attribute"))
}
pub async fn add(&self, domains: Vec<String>, list: String) -> Result<(), Box<dyn Error>> {
let url = format!(
"{}/admin/api.php?add={}&list={}&auth={}",
self.host,
domains.join(" "),
list,
self.api_key.as_ref().unwrap_or(&"".to_string())
);
let body = reqwest::get(&url).await?.text().await?;
match body.contains("Success") {
true => Ok(()),
false => simple_error::bail!("Pi-Hole API error: ".to_string() + &body),
}
}
}