use std::{
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
time::Duration,
};
use chrono::{DateTime, Utc};
use hyper::{
body::Buf,
service::{make_service_fn, service_fn},
Body, Request, Response, Server, StatusCode,
};
use lazy_static::lazy_static;
use matches::assert_matches;
use parking_lot::Mutex;
use serde_json::json;
use tokio::sync::{
mpsc::{self, Receiver},
oneshot,
};
use crate::{timeout, TelemetryClient, TelemetryConfig};
lazy_static! {
static ref SERIAL_TEST_MUTEX: Mutex<()> = Mutex::new(());
}
macro_rules! manual_timeout_test {
(async fn $name: ident() $body: block) => {
#[test]
fn $name() {
let _guard = SERIAL_TEST_MUTEX.lock();
let rt = tokio::runtime::Runtime::new().expect("runtime");
rt.block_on(async {
timeout::init();
$body;
timeout::reset();
});
}
};
}
manual_timeout_test! {
async fn it_sends_one_telemetry_item() {
let mut server = server().status(StatusCode::OK).create();
let client = create_client(server.url());
client.track_event("--event--");
timeout::expire();
assert_matches!(server.next_request_timeout().await, Ok(_));
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_does_not_resend_submitted_telemetry_items() {
let mut server = server().status(StatusCode::OK).create();
let client = create_client(server.url());
client.track_event("--event--");
timeout::expire();
assert_matches!(server.next_request_timeout().await, Ok(_));
timeout::expire();
assert_matches!(
server.next_request_timeout().await,
Err(RecvTimeoutError::Timeout)
);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_sends_telemetry_items_in_several_batches() {
let mut server = server().status(StatusCode::OK).status(StatusCode::OK).create();
let client = create_client(server.url());
for i in 0..10 {
client.track_event(format!("--event {}--", i));
}
timeout::expire();
for i in 10..15 {
client.track_event(format!("--event {}--", i));
}
timeout::expire();
let requests = server.wait_for_requests(2).await;
let content = requests.into_iter().fold(String::new(), |mut content, body| {
content.push_str(&body);
content
});
let items_count = (0..15)
.filter(|i| content.contains(&format!("--event {}--", i)))
.count();
assert_eq!(items_count, 15);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_flushes_all_pending_telemetry_items() {
let mut server = server().status(StatusCode::OK).status(StatusCode::OK).create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
client.flush_channel();
let requests = server.wait_for_requests(1).await;
assert_eq!(requests.len(), 1);
let content = requests.into_iter().fold(String::new(), |mut content, body| {
content.push_str(&body);
content
});
let items_count = (0..15)
.filter(|i| content.contains(&format!("--event {}--", i)))
.count();
assert_eq!(items_count, 15);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_does_not_send_any_pending_telemetry_items_when_drop_client() {
let mut server = server().status(StatusCode::OK).status(StatusCode::OK).create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
drop(client);
assert_matches!(
server.next_request_timeout().await,
Err(RecvTimeoutError::Timeout)
);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_tries_to_send_pending_telemetry_items_when_close_channel_requested() {
let mut server = server().status(StatusCode::OK).status(StatusCode::OK).create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
client.close_channel().await;
let requests = server.wait_for_requests(1).await;
assert_eq!(requests.len(), 1);
let content = requests.into_iter().fold(String::new(), |mut content, body| {
content.push_str(&body);
content
});
let items_count = (0..15)
.filter(|i| content.contains(&format!("--event {}--", i)))
.count();
assert_eq!(items_count, 15);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_does_not_try_to_send_pending_telemetry_items_when_client_terminated() {
let mut server = server().status(StatusCode::OK).status(StatusCode::OK).create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
client.terminate().await;
let requests = server.wait_for_requests(1).await;
assert!(requests.is_empty());
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_retries_when_previous_submission_failed() {
let mut server = server()
.response(StatusCode::INTERNAL_SERVER_ERROR, json!({}), None)
.response(
StatusCode::OK,
json!(
{
"itemsAccepted": 15,
"itemsReceived": 15,
"errors": [],
}),
None,
)
.create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
timeout::expire();
timeout::expire();
let requests = server.wait_for_requests(2).await;
assert_eq!(requests.len(), 2);
assert_eq!(requests[0], requests[1]);
server.terminate().await;
}
}
manual_timeout_test! {
async fn it_retries_when_partial_content() {
let mut server = server()
.response(
StatusCode::PARTIAL_CONTENT,
json!(
{
"itemsAccepted": 12,
"itemsReceived": 15,
"errors": [
{
"index": 4,
"statusCode": StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
"message": "Internal Server Error"
},
{
"index": 9,
"statusCode": StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
"message": "Internal Server Error"
},
{
"index": 14,
"statusCode": StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
"message": "Internal Server Error"
}
],
}),
None,
)
.response(
StatusCode::OK,
json!(
{
"itemsAccepted": 3,
"itemsReceived": 3,
"errors": [],
}),
None,
)
.create();
let client = create_client(server.url());
for i in 0..15 {
client.track_event(format!("--event {}--", i));
}
timeout::expire();
timeout::expire();
let requests = server.wait_for_requests(1).await;
assert_eq!(requests.len(), 1);
let content = requests.into_iter().fold(String::new(), |mut content, body| {
content.push_str(&body);
content
});
let items_count = (0..15)
.filter(|i| content.contains(&format!("--event {}--", i)))
.count();
assert_eq!(items_count, 15);
let requests = server.wait_for_requests(1).await;
assert_eq!(requests.len(), 1);
let content = requests.into_iter().fold(String::new(), |mut content, body| {
content.push_str(&body);
content
});
let items_count = [4, 9, 14]
.iter()
.filter(|i| content.contains(&format!("--event {}--", i)))
.count();
assert_eq!(items_count, 3);
server.terminate().await;
}
}
fn create_client(endpoint: &str) -> TelemetryClient {
let config = TelemetryConfig::builder()
.i_key("instrumentation key")
.endpoint(endpoint)
.interval(Duration::from_millis(300))
.build();
TelemetryClient::from_config(config)
}
fn server() -> Builder {
Builder { responses: Vec::new() }
}
struct HyperTestServer {
url: String,
request_recv: Receiver<String>,
shutdown_send: Option<oneshot::Sender<()>>,
}
impl HyperTestServer {
fn url(&self) -> &str {
&self.url
}
async fn next_request_timeout(&mut self) -> Result<String, RecvTimeoutError> {
match tokio::time::timeout(Duration::from_millis(100), self.request_recv.recv()).await {
Ok(Some(x)) => Ok(x),
Ok(None) => Err(RecvTimeoutError::Disconnected),
Err(_) => Err(RecvTimeoutError::Timeout),
}
}
async fn wait_for_requests(&mut self, count: usize) -> Vec<String> {
let mut requests = Vec::new();
for _ in 0..count {
match self.next_request_timeout().await {
Result::Ok(request) => requests.push(request),
Result::Err(err) => {
log::error!("{:?}", err);
}
}
}
requests
}
async fn terminate(mut self) {
if let Some(shutdown) = self.shutdown_send.take() {
shutdown.send(()).unwrap();
}
}
}
#[derive(Debug)]
enum RecvTimeoutError {
Disconnected,
Timeout,
}
struct Builder {
responses: Vec<Response<String>>,
}
impl Builder {
fn response(mut self, status: StatusCode, body: impl ToString, retry_after: Option<DateTime<Utc>>) -> Self {
let mut builder = Response::builder().status(status);
if let Some(retry_after) = retry_after {
let retry_after = retry_after.to_rfc2822();
builder = builder.header("Retry-After", retry_after);
}
let response = builder.body(body.to_string()).unwrap();
self.responses.push(response);
self
}
fn status(self, status: StatusCode) -> Self {
self.response(
status,
json!(
{
"itemsAccepted": 1,
"itemsReceived": 1,
"errors": [],
}),
None,
)
}
fn create(self) -> HyperTestServer {
let (shutdown_send, shutdown_recv) = oneshot::channel();
let (request_sender, request_receiver) = mpsc::channel(100);
let responses = Arc::new(self.responses);
let counter = Arc::new(AtomicUsize::new(0));
let make_service = make_service_fn(move |_| {
let request_send = request_sender.clone();
let counter = counter.clone();
let responses = responses.clone();
async move {
Ok::<_, hyper::Error>(service_fn(move |req: Request<Body>| {
let request_send = request_send.clone();
let counter = counter.clone();
let responses = responses.clone();
async move {
let body = hyper::body::aggregate(req).await?;
use std::io::Read;
let mut content = String::default();
body.reader().read_to_string(&mut content).unwrap();
request_send.send(content).await.expect("send request");
let count = counter.fetch_add(1, Ordering::AcqRel);
let response = if let Some(response) = responses.get(count) {
Response::builder()
.status(response.status())
.body(Body::from(response.body().clone()))
.unwrap()
} else {
Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Body::empty())
.unwrap()
};
Ok::<_, hyper::Error>(response)
}
}))
}
});
let server = Server::bind(&([0, 0, 0, 0], 0).into()).serve(make_service);
let url = format!("http://{}", server.local_addr());
let graceful = server.with_graceful_shutdown(async {
shutdown_recv.await.ok();
});
tokio::spawn(async move {
if let Err(e) = graceful.await {
log::error!("server error: {}", e);
}
});
HyperTestServer {
url,
request_recv: request_receiver,
shutdown_send: Some(shutdown_send),
}
}
}