#![allow(clippy::needless_doctest_main)]
pub use exonum_api::ApiAccess;
use actix_web::{
test::{self, TestServer},
web, App,
};
use exonum::{
blockchain::ApiSender,
messages::{AnyTx, Verified},
};
use exonum_api::{self as api, ApiAggregator};
use exonum_proto::ProtobufConvert;
use reqwest::{
redirect::Policy as RedirectPolicy, Client, ClientBuilder, RequestBuilder as ReqwestBuilder,
Response, StatusCode,
};
use serde::{de::DeserializeOwned, Serialize};
use std::{
collections::HashMap,
fmt::{self, Display},
};
use crate::TestKit;
#[derive(Debug, Clone, Copy)]
#[non_exhaustive]
pub enum ApiKind {
System,
Explorer,
RustRuntime,
Service(&'static str),
}
impl fmt::Display for ApiKind {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::System => write!(formatter, "api/system"),
Self::Explorer => write!(formatter, "api/explorer"),
Self::RustRuntime => write!(formatter, "api/runtimes/rust"),
Self::Service(name) => write!(formatter, "api/services/{}", name),
}
}
}
pub struct TestKitApi {
_test_server_handle: TestServer,
test_client: TestKitApiClient,
api_sender: ApiSender,
}
impl fmt::Debug for TestKitApi {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
formatter.debug_struct("TestKitApi").finish()
}
}
impl TestKitApi {
pub fn new(testkit: &mut TestKit) -> Self {
Self::from_raw_parts(testkit.update_aggregator(), testkit.api_sender.clone())
}
pub(crate) fn from_raw_parts(aggregator: ApiAggregator, api_sender: ApiSender) -> Self {
let inner = ClientBuilder::new()
.redirect(RedirectPolicy::none())
.build()
.unwrap();
let test_server = create_test_server(aggregator);
let test_client = TestKitApiClient {
test_server_url: test_server.url(""),
inner,
};
Self {
_test_server_handle: test_server,
test_client,
api_sender,
}
}
pub async fn send<T>(&self, transaction: T)
where
T: Into<Verified<AnyTx>>,
{
self.api_sender
.broadcast_transaction(transaction.into())
.await
.expect("Cannot broadcast transaction");
}
pub fn public_url(&self, url: &str) -> String {
self.test_client.public_url(url)
}
pub fn private_url(&self, url: &str) -> String {
self.test_client.private_url(url)
}
pub fn public(&self, kind: impl Display) -> RequestBuilder<'_, '_> {
self.test_client.public(kind)
}
pub fn private(&self, kind: impl Display) -> RequestBuilder<'_, '_> {
self.test_client.private(kind)
}
pub fn client(&self) -> &TestKitApiClient {
&self.test_client
}
}
#[derive(Debug, Clone)]
pub struct TestKitApiClient {
test_server_url: String,
inner: Client,
}
impl TestKitApiClient {
pub fn public_url(&self, url: &str) -> String {
[&self.test_server_url, "public/", url].concat()
}
pub fn private_url(&self, url: &str) -> String {
[&self.test_server_url, "private/", url].concat()
}
pub fn public(&self, kind: impl Display) -> RequestBuilder<'_, '_> {
RequestBuilder::new(
&self.test_server_url,
&self.inner,
ApiAccess::Public,
kind.to_string(),
)
}
pub fn private(&self, kind: impl Display) -> RequestBuilder<'_, '_> {
RequestBuilder::new(
&self.test_server_url,
&self.inner,
ApiAccess::Private,
kind.to_string(),
)
}
pub fn inner(&self) -> &Client {
&self.inner
}
}
type ReqwestModifier<'b> = Box<dyn FnOnce(ReqwestBuilder) -> ReqwestBuilder + Send + 'b>;
pub struct RequestBuilder<'a, 'b, Q = ()> {
test_server_url: &'a str,
test_client: &'a Client,
access: ApiAccess,
prefix: String,
query: Option<&'b Q>,
modifier: Option<ReqwestModifier<'b>>,
expected_headers: HashMap<String, String>,
}
impl<'a, 'b, Q> fmt::Debug for RequestBuilder<'a, 'b, Q>
where
Q: 'b + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("RequestBuilder")
.field("access", &self.access)
.field("prefix", &self.prefix)
.field("query", &self.query)
.finish()
}
}
impl<'a, 'b, Q> RequestBuilder<'a, 'b, Q>
where
Q: 'b,
{
fn new(
test_server_url: &'a str,
test_client: &'a Client,
access: ApiAccess,
prefix: String,
) -> Self {
RequestBuilder {
test_server_url,
test_client,
access,
prefix,
query: None,
modifier: None,
expected_headers: HashMap::new(),
}
}
pub fn query<T>(self, query: &'b T) -> RequestBuilder<'a, 'b, T> {
RequestBuilder {
test_server_url: self.test_server_url,
test_client: self.test_client,
access: self.access,
prefix: self.prefix,
query: Some(query),
modifier: self.modifier,
expected_headers: self.expected_headers,
}
}
pub fn with<F>(self, f: F) -> Self
where
F: FnOnce(ReqwestBuilder) -> ReqwestBuilder + Send + 'b,
{
Self {
modifier: Some(Box::new(f)),
..self
}
}
pub fn expect_header(self, header: &str, value: &str) -> Self {
let mut expected_headers = self.expected_headers;
expected_headers.insert(header.into(), value.into());
Self {
expected_headers,
..self
}
}
fn verify_headers(expected_headers: &HashMap<String, String>, response: &Response) {
let headers = response.headers();
for (header, expected_value) in expected_headers {
let header_value = headers.get(header).unwrap_or_else(|| {
panic!(
"Response {:?} was expected to have header {}, but it isn't present",
response, header
);
});
assert_eq!(
header_value, expected_value,
"Unexpected value of response header {}",
header
);
}
}
async fn response_to_api_result<R>(response: Response) -> api::Result<R>
where
R: DeserializeOwned + 'static,
{
let code = response.status();
let body = response.text().await.expect("Unable to get response text");
log::trace!("Body: {}", body);
if code == StatusCode::OK {
let value = serde_json::from_str(&body).expect("Unable to deserialize body");
Ok(value)
} else {
let error = api::Error::parse(code, &body).expect("Unable to deserialize API error");
Err(error)
}
}
}
impl<Q> RequestBuilder<'_, '_, Q>
where
Q: Serialize,
{
pub async fn get<R>(self, endpoint: &str) -> api::Result<R>
where
R: DeserializeOwned + 'static,
{
let params = self
.query
.as_ref()
.map(|query| {
format!(
"?{}",
serde_urlencoded::to_string(query).expect("Unable to serialize query.")
)
})
.unwrap_or_default();
let url = format!(
"{url}{access}/{prefix}/{endpoint}{query}",
url = self.test_server_url,
access = self.access,
prefix = self.prefix,
endpoint = endpoint,
query = params
);
log::trace!("GET {}", url);
let mut builder = self.test_client.get(&url);
if let Some(modifier) = self.modifier {
builder = modifier(builder);
}
let response = builder.send().await.expect("Unable to send request");
Self::verify_headers(&self.expected_headers, &response);
Self::response_to_api_result(response).await
}
pub async fn post<R>(self, endpoint: &str) -> api::Result<R>
where
R: DeserializeOwned + 'static,
{
let url = format!(
"{url}{access}/{prefix}/{endpoint}",
url = self.test_server_url,
access = self.access,
prefix = self.prefix,
endpoint = endpoint
);
log::trace!("POST {}", url);
let builder = self.test_client.post(&url);
let mut builder = if let Some(query) = self.query.as_ref() {
builder.json(query)
} else {
builder.json(&serde_json::Value::Null)
};
if let Some(modifier) = self.modifier {
builder = modifier(builder);
}
let response = builder.send().await.expect("Unable to send request");
Self::verify_headers(&self.expected_headers, &response);
Self::response_to_api_result(response).await
}
}
impl<Q> RequestBuilder<'_, '_, Q>
where
Q: ProtobufConvert,
Q::ProtoStruct: protobuf::Message,
{
pub async fn post_pb<R>(self, endpoint: &str) -> api::Result<R>
where
R: DeserializeOwned + 'static,
{
let url = format!(
"{url}{access}/{prefix}/{endpoint}",
url = self.test_server_url,
access = self.access,
prefix = self.prefix,
endpoint = endpoint
);
log::trace!("POST Protobuf {}", url);
let body = self
.query
.map(|query| {
let message = query.to_pb();
protobuf::Message::write_to_bytes(&message)
.expect("Cannot write Protobuf message to `Vec<u8>`")
})
.unwrap_or_default();
let mut builder = self
.test_client
.post(&url)
.header("Content-Type", "application/octet-stream")
.body(body);
if let Some(modifier) = self.modifier {
builder = modifier(builder);
}
let response = builder.send().await.expect("Unable to send request");
Self::verify_headers(&self.expected_headers, &response);
Self::response_to_api_result(response).await
}
}
fn create_test_server(aggregator: ApiAggregator) -> TestServer {
let server = test::start(move || {
let public_apis = aggregator.extend_backend(ApiAccess::Public, web::scope("public/api"));
let private_apis = aggregator.extend_backend(ApiAccess::Private, web::scope("private/api"));
App::new().service(public_apis).service(private_apis)
});
log::info!("Test server created on {}", server.addr());
server
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestKitBuilder;
fn assert_send<T: Send>(_object: &T) {}
#[test]
fn assert_send_for_testkit_api() {
let mut testkit = TestKitBuilder::validator().build();
let api = testkit.api();
assert_send(&api.public(ApiKind::Explorer).get::<()>("v1/transactions"));
assert_send(&api.public(ApiKind::Explorer).post::<()>("v1/transactions"));
}
#[test]
fn assert_send_for_testkit_client() {
let api = TestKitBuilder::validator().build().api();
let client = api.client().clone();
assert_send(&client.public(ApiKind::Explorer).get::<()>("ping"));
}
}