#![allow(clippy::unwrap_used)]
use std::convert::TryInto;
use std::env;
use std::fmt::Debug;
use std::sync::Arc;
use cfg_if::cfg_if;
use surf::{Client, Config, StatusCode, Url};
use tide::{http, Server};
use crate::builtins::monitor::setup_monitor;
use crate::logging::{log_format_json, log_format_pretty};
use crate::middleware::json_error::JsonError;
use crate::middleware::{JsonErrorMiddleware, LogMiddleware, RequestIdMiddleware};
use crate::VariadicRoutes;
#[cfg(feature = "honeycomb")]
use tracing_subscriber::Registry;
cfg_if! {
if #[cfg(feature = "postgres")] {
use async_std::sync::RwLock;
use sqlx::postgres::{PgConnectOptions, PgPoolOptions, Postgres};
use sqlx::ConnectOptions;
use tide::{Middleware, Next, Request};
use crate::middleware::postgres::{ConnectionWrap, ConnectionWrapInner};
}
}
pub type TestResult<T> = surf::Result<T>;
pub async fn create_client<State>(
state: State,
setup_routes_fns: impl Into<VariadicRoutes<State>>,
) -> TestResult<Client>
where
State: Send + Sync + 'static,
{
let server = create_server(state, setup_routes_fns)?;
let client: Client = Config::new()
.set_http_client(server)
.set_base_url(Url::parse("http://localhost:8080")?) .try_into()?;
Ok(client)
}
#[cfg(feature = "postgres")]
#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
pub async fn create_client_and_postgres<State>(
state: State,
setup_routes_fns: impl Into<VariadicRoutes<State>>,
) -> TestResult<(Client, Arc<RwLock<ConnectionWrapInner<Postgres>>>)>
where
State: Send + Sync + 'static,
{
let mut server = create_server(state, setup_routes_fns)?;
let mut connect_opts = PgConnectOptions::new()
.host(
env::var("TEST_DATABASE_HOST")
.as_deref()
.unwrap_or("localhost"),
)
.port(
env::var("TEST_DATABASE_PORT")
.ok()
.map(|v| v.parse())
.transpose()?
.unwrap_or(5432),
)
.database(
env::var("TEST_DATABASE_NAME")
.or_else(|_| env::var("CARGO_PKG_NAME").map(|v| format!("{}-test", v)))
.as_deref()
.unwrap_or("database_test"),
);
connect_opts.log_statements(log::LevelFilter::Debug);
let pg_pool = PgPoolOptions::new()
.max_connections(5)
.connect_with(connect_opts)
.await?;
let conn_wrap = Arc::new(RwLock::new(ConnectionWrapInner::Transacting(
pg_pool.begin().await?,
)));
server.with(PostgresTestMiddleware(conn_wrap.clone()));
let client: Client = Config::new()
.set_http_client(server)
.set_base_url(Url::parse("http://localhost:8080")?) .try_into()?;
Ok((client, conn_wrap))
}
#[allow(clippy::unnecessary_wraps)]
pub(crate) fn create_server<State>(
state: State,
setup_routes_fns: impl Into<VariadicRoutes<State>>,
) -> TestResult<Server<Arc<State>>>
where
State: Send + Sync + 'static,
{
dotenv::dotenv().ok();
let log_level: log::LevelFilter = env::var("LOGLEVEL")
.map(|v| v.parse().expect("LOGLEVEL must be a valid log level."))
.unwrap_or(log::LevelFilter::Off);
let environment = env::var("ENVIRONMENT").unwrap_or_else(|_| "development".to_string());
if environment.starts_with("prod") {
env_logger::builder()
.format(log_format_json)
.filter_level(log_level)
.write_style(env_logger::WriteStyle::Never)
.try_init()
.ok();
} else {
env_logger::builder()
.format(log_format_pretty)
.filter_level(log_level)
.try_init()
.ok();
}
#[cfg(feature = "honeycomb")]
{
let subscriber = Registry::default();
tracing::subscriber::set_global_default(subscriber).ok();
}
let mut server = tide::with_state(Arc::new(state));
server.with(RequestIdMiddleware::new());
server.with(LogMiddleware::new());
server.with(JsonErrorMiddleware::new());
setup_monitor("preroll_test_utils", &mut server);
let mut version = 1;
for routes_fn in setup_routes_fns.into().routes {
routes_fn(server.at(&format!("/api/v{}", version)));
version += 1;
}
Ok(server)
}
#[cfg(feature = "postgres")]
#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
#[derive(Debug, Clone)]
struct PostgresTestMiddleware(ConnectionWrap<Postgres>);
#[cfg(feature = "postgres")]
#[tide::utils::async_trait]
impl<State: Clone + Send + Sync + 'static> Middleware<State> for PostgresTestMiddleware {
async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> tide::Result {
req.set_ext(self.0.clone());
Ok(next.run(req).await)
}
}
pub fn mock_client<MocksFn>(base_url: impl AsRef<str>, setup_mocks_fn: MocksFn) -> Client
where
MocksFn: Fn(&mut Server<()>),
{
let mut mocks_server = tide::new();
setup_mocks_fn(&mut mocks_server);
let mock_client: Client = Config::new()
.set_http_client(mocks_server)
.set_base_url(Url::parse(base_url.as_ref()).unwrap())
.try_into()
.expect("async-h1 client from config is infallible");
mock_client
}
#[allow(dead_code)] #[track_caller]
pub async fn assert_json_error<Status>(
mut res: impl AsMut<http::Response>,
status: Status,
err_msg: &str,
) where
Status: TryInto<StatusCode>,
Status::Error: Debug,
{
let res = res.as_mut();
let status: StatusCode = status
.try_into()
.expect("test must specify valid status code");
let str_response = res.body_string().await.unwrap();
let error: JsonError = serde_json::from_str(&str_response).map_err(|e| {
surf::Error::from_str(
res.status(),
format!("Error, could not parse Response into JsonError! json err: \"{}\", response body: \"{}\"", e, str_response)
)
}).unwrap();
assert_eq!(res.status(), status);
assert_eq!(&error.title, status.canonical_reason());
assert_eq!(error.message, err_msg);
assert_eq!(error.status, status as u16);
assert_eq!(
error.request_id.as_str(),
res["X-Request-Id"].last().as_str()
);
if res.status().is_server_error() {
assert_eq!(
error
.correlation_id
.expect("Internal server errors must have correlation ids.")
.as_str(),
res["X-Correlation-Id"].last().as_str()
);
} else {
assert_eq!(error.correlation_id, None);
assert!(res.header("X-Correlation-Id").is_none());
}
}
#[track_caller]
pub async fn assert_status_json<StructType, Status>(
mut res: impl AsMut<http::Response>,
status: Status,
) -> StructType
where
StructType: serde::de::DeserializeOwned,
Status: TryInto<StatusCode>,
Status::Error: Debug,
{
let res = res.as_mut();
let status: StatusCode = status
.try_into()
.expect("test must specify valid status code");
let body = res.body_string().await.unwrap();
assert_eq!(res.status(), status, "Response body: {}", body);
serde_json::from_str(&body).unwrap_or_else(|err| {
panic!(
"Error: \"{}\" Body was not parseable into a {}, body was: \"{}\"",
err,
std::any::type_name::<StructType>(),
body
)
})
}
#[track_caller]
pub async fn assert_status<Status>(mut res: impl AsMut<http::Response>, status: Status) -> String
where
Status: TryInto<StatusCode>,
Status::Error: Debug,
{
let res = res.as_mut();
let status: StatusCode = status
.try_into()
.expect("test must specify valid status code");
let body = res.body_string().await.unwrap();
assert_eq!(res.status(), status, "Response body: {}", body);
body
}