use reqwest::{
header::{HeaderMap, HeaderValue},
Client, Error, Method, Response,
};
#[derive(Clone, Debug)]
pub struct Builder {
method: Method,
url: String,
schema: Option<String>,
pub(crate) queries: Vec<(String, String)>,
headers: HeaderMap,
body: Option<String>,
is_rpc: bool,
client: Client,
}
impl Builder {
pub fn new<T>(url: T, schema: Option<String>, headers: HeaderMap, client: Client) -> Self
where
T: Into<String>,
{
let url = url.into().trim_end_matches('/').to_string();
let mut builder = Builder {
method: Method::GET,
url,
schema,
queries: Vec::new(),
headers,
body: None,
is_rpc: false,
client,
};
builder
.headers
.insert("Accept", HeaderValue::from_static("application/json"));
builder
}
pub fn auth<T>(mut self, token: T) -> Self
where
T: AsRef<str>,
{
self.headers.insert(
"Authorization",
HeaderValue::from_str(&format!("Bearer {}", token.as_ref())).unwrap(),
);
self
}
pub fn select<T>(mut self, columns: T) -> Self
where
T: Into<String>,
{
self.queries.push(("select".to_string(), columns.into()));
self
}
pub fn order<T>(mut self, columns: T) -> Self
where
T: Into<String>,
{
self.queries.push(("order".to_string(), columns.into()));
self
}
pub fn order_with_options<T, U>(
mut self,
columns: T,
foreign_table: Option<U>,
ascending: bool,
nulls_first: bool,
) -> Self
where
T: Into<String>,
U: Into<String>,
{
let mut key = "order".to_string();
if let Some(foreign_table) = foreign_table {
let foreign_table = foreign_table.into();
if !foreign_table.is_empty() {
key = format!("{}.order", foreign_table);
}
}
let mut ascending_string = "desc";
if ascending {
ascending_string = "asc";
}
let mut nulls_first_string = "nullslast";
if nulls_first {
nulls_first_string = "nullsfirst";
}
let existing_order = self.queries.iter().find(|(k, _)| k == &key);
match existing_order {
Some((_, v)) => {
let new_order = format!(
"{},{}.{}.{}",
v,
columns.into(),
ascending_string,
nulls_first_string
);
self.queries.push((key, new_order));
}
None => {
self.queries.push((
key,
format!(
"{}.{}.{}",
columns.into(),
ascending_string,
nulls_first_string
),
));
}
}
self
}
pub fn limit(mut self, count: usize) -> Self {
self.headers
.insert("Range-Unit", HeaderValue::from_static("items"));
self.headers.insert(
"Range",
HeaderValue::from_str(&format!("0-{}", count - 1)).unwrap(),
);
self
}
pub fn foreign_table_limit<T>(mut self, count: usize, foreign_table: T) -> Self
where
T: Into<String>,
{
self.queries
.push((format!("{}.limit", foreign_table.into()), count.to_string()));
self
}
pub fn range(mut self, low: usize, high: usize) -> Self {
self.headers
.insert("Range-Unit", HeaderValue::from_static("items"));
self.headers.insert(
"Range",
HeaderValue::from_str(&format!("{}-{}", low, high)).unwrap(),
);
self
}
fn count(mut self, method: &str) -> Self {
self.headers
.insert("Range-Unit", HeaderValue::from_static("items"));
self.headers
.insert("Range", HeaderValue::from_static("0-0"));
self.headers.insert(
"Prefer",
HeaderValue::from_str(&format!("count={}", method)).unwrap(),
);
self
}
pub fn exact_count(self) -> Self {
self.count("exact")
}
pub fn planned_count(self) -> Self {
self.count("planned")
}
pub fn estimated_count(self) -> Self {
self.count("estimated")
}
pub fn single(mut self) -> Self {
self.headers.insert(
"Accept",
HeaderValue::from_static("application/vnd.pgrst.object+json"),
);
self
}
pub fn insert<T>(self, body: &T) -> serde_json::Result<Self>
where
T: serde::Serialize,
{
Ok(self.insert_impl(serde_json::to_string(body)?))
}
fn insert_impl(mut self, body: String) -> Self {
self.method = Method::POST;
self.headers
.insert("Prefer", HeaderValue::from_static("return=representation"));
self.body = Some(body);
self
}
pub fn upsert<T>(self, body: &T) -> serde_json::Result<Self>
where
T: serde::Serialize,
{
Ok(self.upsert_impl(serde_json::to_string(body)?))
}
fn upsert_impl(mut self, body: String) -> Self {
self.method = Method::POST;
self.headers.insert(
"Prefer",
HeaderValue::from_static("return=representation,resolution=merge-duplicates"),
);
self.body = Some(body);
self
}
pub fn on_conflict<T>(mut self, columns: T) -> Self
where
T: Into<String>,
{
self.queries
.push(("on_conflict".to_string(), columns.into()));
self
}
pub fn update<T>(self, body: &T) -> serde_json::Result<Self>
where
T: serde::Serialize,
{
Ok(self.update_impl(serde_json::to_string(body)?))
}
fn update_impl(mut self, body: String) -> Self {
self.method = Method::PATCH;
self.headers
.insert("Prefer", HeaderValue::from_static("return=representation"));
self.body = Some(body);
self
}
pub fn delete(mut self) -> Self {
self.method = Method::DELETE;
self.headers
.insert("Prefer", HeaderValue::from_static("return=representation"));
self
}
pub fn rpc<T>(mut self, params: T) -> Self
where
T: Into<String>,
{
self.method = Method::POST;
self.body = Some(params.into());
self.is_rpc = true;
self
}
pub fn build(mut self) -> reqwest::RequestBuilder {
if let Some(schema) = self.schema {
let key = match self.method {
Method::GET | Method::HEAD => "Accept-Profile",
_ => "Content-Profile",
};
self.headers
.insert(key, HeaderValue::from_str(&schema).unwrap());
}
match self.method {
Method::GET | Method::HEAD => {}
_ => {
self.headers
.insert("Content-Type", HeaderValue::from_static("application/json"));
}
};
self.client
.request(self.method, self.url)
.headers(self.headers)
.query(&self.queries)
.body(self.body.unwrap_or_default())
}
pub async fn execute(self) -> Result<Response, Error> {
self.build().send().await
}
}
#[cfg(test)]
mod tests {
use super::*;
const TABLE_URL: &str = "http://localhost:3000/table";
const RPC_URL: &str = "http://localhost:3000/rpc";
#[test]
fn only_accept_json() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client);
assert_eq!(
builder.headers.get("Accept").unwrap(),
HeaderValue::from_static("application/json")
);
}
#[test]
fn auth_with_token() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).auth("$Up3rS3crET");
assert_eq!(
builder.headers.get("Authorization").unwrap(),
HeaderValue::from_static("Bearer $Up3rS3crET")
);
}
#[test]
fn select_assert_query() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).select("some_table");
assert_eq!(builder.method, Method::GET);
assert_eq!(
builder
.queries
.contains(&("select".to_string(), "some_table".to_string())),
true
);
}
#[test]
fn order_assert_query() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).order("id");
assert_eq!(
builder
.queries
.contains(&("order".to_string(), "id".to_string())),
true
);
}
#[test]
fn order_with_options_assert_query() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).order_with_options(
"name",
Some("cities"),
true,
false,
);
assert_eq!(
builder
.queries
.contains(&("cities.order".to_string(), "name.asc.nullslast".to_string())),
true
);
}
#[test]
fn limit_assert_range_header() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).limit(20);
assert_eq!(
builder.headers.get("Range").unwrap(),
HeaderValue::from_static("0-19")
);
}
#[test]
fn foreign_table_limit_assert_query() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client)
.foreign_table_limit(20, "some_table");
assert_eq!(
builder
.queries
.contains(&("some_table.limit".to_string(), "20".to_string())),
true
);
}
#[test]
fn range_assert_range_header() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).range(10, 20);
assert_eq!(
builder.headers.get("Range").unwrap(),
HeaderValue::from_static("10-20")
);
}
#[test]
fn single_assert_accept_header() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).single();
assert_eq!(
builder.headers.get("Accept").unwrap(),
HeaderValue::from_static("application/vnd.pgrst.object+json")
);
}
#[test]
fn upsert_assert_prefer_header_serde() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client)
.upsert(&())
.unwrap();
assert_eq!(
builder.headers.get("Prefer").unwrap(),
HeaderValue::from_static("return=representation,resolution=merge-duplicates")
);
}
#[test]
fn not_rpc_should_not_have_flag() {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client).select("ignored");
assert_eq!(builder.is_rpc, false);
}
#[test]
fn rpc_should_have_body_and_flag() {
let client = Client::new();
let builder =
Builder::new(RPC_URL, None, HeaderMap::new(), client).rpc("{\"a\": 1, \"b\": 2}");
assert_eq!(builder.body.unwrap(), "{\"a\": 1, \"b\": 2}");
assert_eq!(builder.is_rpc, true);
}
#[test]
fn chain_filters() -> Result<(), Box<dyn std::error::Error>> {
let client = Client::new();
let builder = Builder::new(TABLE_URL, None, HeaderMap::new(), client)
.eq("username", "supabot")
.neq("message", "hello world")
.gte("channel_id", "1")
.select("*");
let queries = builder.queries;
assert_eq!(queries.len(), 4);
assert!(queries.contains(&("username".into(), "eq.supabot".into())));
assert!(queries.contains(&("message".into(), "neq.hello world".into())));
assert!(queries.contains(&("channel_id".into(), "gte.1".into())));
Ok(())
}
}