#![allow(dead_code)]
use axum::extract::DefaultBodyLimit;
use ferrokinesis::store::StoreOptions;
use reqwest::Client;
use reqwest::header::{HeaderMap, HeaderValue};
use serde_json::{Value, json};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::net::TcpListener;
static PROP_COUNTER: AtomicU64 = AtomicU64::new(0);
pub fn unique_stream_name(prefix: &str) -> String {
format!(
"{}-{}",
prefix,
PROP_COUNTER.fetch_add(1, Ordering::Relaxed)
)
}
pub const AMZ_JSON: &str = "application/x-amz-json-1.1";
pub const AMZ_CBOR: &str = "application/x-amz-cbor-1.1";
pub const VERSION: &str = "Kinesis_20131202";
pub async fn decode_body(res: reqwest::Response) -> (u16, Value) {
let status = res.status().as_u16();
let ct = res
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let bytes = res.bytes().await.unwrap();
if bytes.is_empty() {
return (status, Value::Null);
}
if ct.contains("cbor") {
let cbor_val: ciborium::Value =
ciborium::from_reader(&bytes[..]).unwrap_or(ciborium::Value::Null);
(status, ferrokinesis::server::cbor_to_json(&cbor_val))
} else {
let val: Value = serde_json::from_slice(&bytes).unwrap_or(Value::Null);
(status, val)
}
}
pub struct TestServer {
pub addr: SocketAddr,
pub client: Client,
pub store: ferrokinesis::store::Store,
}
impl TestServer {
pub async fn new() -> Self {
Self::with_options(StoreOptions {
create_stream_ms: 0,
delete_stream_ms: 0,
update_stream_ms: 0,
shard_limit: 50,
..Default::default()
})
.await
}
pub async fn with_options(options: StoreOptions) -> Self {
let (app, store) = ferrokinesis::create_app(options);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
ferrokinesis::serve_plain_http(listener, app, std::future::pending())
.await
.unwrap();
});
TestServer {
addr,
client: Client::new(),
store,
}
}
pub async fn with_capture(
options: StoreOptions,
capture: ferrokinesis::capture::CaptureWriter,
) -> Self {
let (app, store) = ferrokinesis::create_app_with_capture(options, Some(capture));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
ferrokinesis::serve_plain_http(listener, app, std::future::pending())
.await
.unwrap();
});
TestServer {
addr,
client: Client::new(),
store,
}
}
pub async fn with_body_limit(options: StoreOptions, max_body_bytes: usize) -> Self {
let (app, store) = ferrokinesis::create_app(options);
let app = app.layer(DefaultBodyLimit::max(max_body_bytes));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
ferrokinesis::serve_plain_http(listener, app, std::future::pending())
.await
.unwrap();
});
TestServer {
addr,
client: Client::new(),
store,
}
}
pub fn url(&self) -> String {
format!("http://{}", self.addr)
}
pub async fn request(&self, target: &str, data: &Value) -> reqwest::Response {
self.signed_request_to(self.url(), target, data).await
}
async fn signed_request_to(
&self,
url: String,
target: &str,
data: &Value,
) -> reqwest::Response {
self.client
.post(url)
.header("Content-Type", AMZ_JSON)
.header("X-Amz-Target", format!("{VERSION}.{target}"))
.header(
"Authorization",
"AWS4-HMAC-SHA256 Credential=AKID/20150101/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=abcd1234",
)
.header("X-Amz-Date", "20150101T000000Z")
.body(serde_json::to_vec(data).unwrap())
.send()
.await
.unwrap()
}
pub async fn cbor_request(&self, target: &str, data: &Value) -> reqwest::Response {
let mut buf = Vec::new();
ciborium::into_writer(data, &mut buf).unwrap();
self.client
.post(self.url())
.header("Content-Type", AMZ_CBOR)
.header("X-Amz-Target", format!("{VERSION}.{target}"))
.header(
"Authorization",
"AWS4-HMAC-SHA256 Credential=AKID/20150101/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=abcd1234",
)
.header("X-Amz-Date", "20150101T000000Z")
.body(buf)
.send()
.await
.unwrap()
}
pub async fn request_both(&self, target: &str, data: &Value) -> ((u16, Value), (u16, Value)) {
let json_resp = decode_body(self.request(target, data).await).await;
let cbor_resp = decode_body(self.cbor_request(target, data).await).await;
(json_resp, cbor_resp)
}
pub async fn cbor_request_raw_data(
&self,
target: &str,
fields: &Value,
data_field_path: &str,
raw_data: &[u8],
) -> reqwest::Response {
let cbor_val = json_to_cbor_with_bytes(fields, data_field_path, raw_data);
let mut buf = Vec::new();
ciborium::into_writer(&cbor_val, &mut buf).unwrap();
self.client
.post(self.url())
.header("Content-Type", AMZ_CBOR)
.header("X-Amz-Target", format!("{VERSION}.{target}"))
.header(
"Authorization",
"AWS4-HMAC-SHA256 Credential=AKID/20150101/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=abcd1234",
)
.header("X-Amz-Date", "20150101T000000Z")
.body(buf)
.send()
.await
.unwrap()
}
pub async fn cbor_request_raw_data_many(
&self,
target: &str,
fields: &Value,
data_field_path: &str,
raw_data_many: &[Vec<u8>],
) -> reqwest::Response {
let cbor_val = json_to_cbor_with_bytes_many(fields, data_field_path, raw_data_many);
let mut buf = Vec::new();
ciborium::into_writer(&cbor_val, &mut buf).unwrap();
self.client
.post(self.url())
.header("Content-Type", AMZ_CBOR)
.header("X-Amz-Target", format!("{VERSION}.{target}"))
.header(
"Authorization",
"AWS4-HMAC-SHA256 Credential=AKID/20150101/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=abcd1234",
)
.header("X-Amz-Date", "20150101T000000Z")
.body(buf)
.send()
.await
.unwrap()
}
pub async fn raw_request(
&self,
method: reqwest::Method,
path: &str,
headers: HeaderMap,
body: Vec<u8>,
) -> reqwest::Response {
self.client
.request(method, format!("http://{}{}", self.addr, path))
.headers(headers)
.body(body)
.send()
.await
.unwrap()
}
pub async fn wait_for_stream_active(&self, name: &str) {
for _ in 0..20 {
let desc = self.describe_stream(name).await;
if desc["StreamDescription"]["StreamStatus"].as_str() == Some("ACTIVE") {
return;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
panic!("stream {name} did not become ACTIVE within timeout");
}
pub async fn create_stream(&self, name: &str, shard_count: u32) {
let res = self
.request(
"CreateStream",
&json!({"StreamName": name, "ShardCount": shard_count}),
)
.await;
assert_eq!(res.status(), 200, "Failed to create stream {name}");
self.wait_for_stream_active(name).await;
}
pub async fn describe_stream(&self, name: &str) -> Value {
let res = self
.request("DescribeStream", &json!({"StreamName": name}))
.await;
assert_eq!(res.status(), 200);
res.json().await.unwrap()
}
pub async fn wait_for_stream_status(
&self,
name: &str,
target: &str,
interval: tokio::time::Duration,
max_attempts: u32,
) -> Result<Value, String> {
for _ in 0..max_attempts {
let res = self
.request("DescribeStream", &json!({"StreamName": name}))
.await;
if res.status() == 200 {
let body: Value = res.json().await.unwrap();
if body["StreamDescription"]["StreamStatus"].as_str() == Some(target) {
return Ok(body);
}
}
tokio::time::sleep(interval).await;
}
Err(format!(
"stream {name:?} did not reach status {target:?} after {max_attempts} attempts"
))
}
pub async fn wait_for_stream_deleted(
&self,
name: &str,
interval: tokio::time::Duration,
max_attempts: u32,
) -> Result<(), String> {
for _ in 0..max_attempts {
let res = self
.request("DescribeStream", &json!({"StreamName": name}))
.await;
if res.status() != 200 {
return Ok(());
}
tokio::time::sleep(interval).await;
}
Err(format!(
"stream {name:?} was not fully deleted after {max_attempts} attempts"
))
}
pub async fn put_record(&self, stream: &str, data: &str, partition_key: &str) -> Value {
let res = self
.request(
"PutRecord",
&json!({
"StreamName": stream,
"Data": data,
"PartitionKey": partition_key,
}),
)
.await;
assert_eq!(res.status(), 200);
res.json().await.unwrap()
}
pub async fn get_shard_iterator(
&self,
stream: &str,
shard_id: &str,
iterator_type: &str,
) -> String {
let res = self
.request(
"GetShardIterator",
&json!({
"StreamName": stream,
"ShardId": shard_id,
"ShardIteratorType": iterator_type,
}),
)
.await;
assert_eq!(res.status(), 200);
let body: Value = res.json().await.unwrap();
body["ShardIterator"].as_str().unwrap().to_string()
}
pub async fn get_stream_arn(&self, name: &str) -> String {
let desc = self.describe_stream(name).await;
desc["StreamDescription"]["StreamARN"]
.as_str()
.unwrap()
.to_string()
}
pub async fn get_records(&self, iterator: &str) -> Value {
let res = self
.request("GetRecords", &json!({"ShardIterator": iterator}))
.await;
assert_eq!(res.status(), 200);
res.json().await.unwrap()
}
}
#[cfg(feature = "tls")]
impl TestServer {
pub async fn new_tls() -> Self {
let options = StoreOptions {
create_stream_ms: 0,
delete_stream_ms: 0,
update_stream_ms: 0,
shard_limit: 50,
..Default::default()
};
let (app, store) = ferrokinesis::create_app(options);
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into(), "127.0.0.1".into()])
.expect("failed to generate self-signed cert");
let cert_pem = cert.cert.pem();
let key_pem = cert.signing_key.serialize_pem();
let tls_config = axum_server::tls_rustls::RustlsConfig::from_pem(
cert_pem.into_bytes(),
key_pem.into_bytes(),
)
.await
.expect("failed to build RustlsConfig");
let handle = axum_server::Handle::new();
let handle_clone = handle.clone();
tokio::spawn(async move {
axum_server::bind_rustls("127.0.0.1:0".parse().unwrap(), tls_config)
.handle(handle_clone)
.serve(app.into_make_service())
.await
.unwrap();
});
let addr = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle.listening())
.await
.expect("timed out waiting for TLS server to start — server may have panicked")
.unwrap();
let client = reqwest::Client::builder()
.danger_accept_invalid_certs(true)
.build()
.unwrap();
TestServer {
addr,
client,
store,
}
}
pub fn tls_url(&self) -> String {
format!("https://{}", self.addr)
}
pub async fn tls_request(&self, target: &str, data: &Value) -> reqwest::Response {
self.signed_request_to(self.tls_url(), target, data).await
}
}
pub fn signed_headers() -> HeaderMap {
let mut h = HeaderMap::new();
h.insert("Content-Type", HeaderValue::from_static(AMZ_JSON));
h.insert(
"X-Amz-Target",
HeaderValue::from_static("Kinesis_20131202.ListStreams"),
);
h.insert(
"Authorization",
HeaderValue::from_static(
"AWS4-HMAC-SHA256 Credential=AKID/20150101/us-east-1/kinesis/aws4_request, SignedHeaders=content-type;host;x-amz-date;x-amz-target, Signature=abcd1234",
),
);
h.insert("X-Amz-Date", HeaderValue::from_static("20150101T000000Z"));
h
}
pub fn json_to_cbor_with_bytes(
val: &Value,
data_field_path: &str,
raw_data: &[u8],
) -> ciborium::Value {
json_to_cbor_inner(val, data_field_path, raw_data)
}
pub fn json_to_cbor_with_bytes_many(
val: &Value,
data_field_path: &str,
raw_data_many: &[Vec<u8>],
) -> ciborium::Value {
let mut next_idx = 0;
let cbor = json_to_cbor_inner_many(
val,
data_field_path,
data_field_path,
raw_data_many,
&mut next_idx,
);
assert_eq!(
next_idx,
raw_data_many.len(),
"path {:?} matched {} field(s), but {} raw payload(s) were provided",
data_field_path,
next_idx,
raw_data_many.len()
);
cbor
}
fn json_to_cbor_inner(val: &Value, path: &str, raw_data: &[u8]) -> ciborium::Value {
match val {
Value::Null => ciborium::Value::Null,
Value::Bool(b) => ciborium::Value::Bool(*b),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
ciborium::Value::Integer(i.into())
} else if let Some(f) = n.as_f64() {
ciborium::Value::Float(f)
} else {
ciborium::Value::Null
}
}
Value::String(s) => ciborium::Value::Text(s.clone()),
Value::Array(arr) => {
let (first, rest) = path.split_once('.').unwrap_or((path, ""));
if first == "*" {
ciborium::Value::Array(
arr.iter()
.map(|item| json_to_cbor_inner(item, rest, raw_data))
.collect(),
)
} else {
ciborium::Value::Array(
arr.iter()
.map(|item| json_to_cbor_inner(item, "", &[]))
.collect(),
)
}
}
Value::Object(map) => {
let (first, rest) = path.split_once('.').unwrap_or((path, ""));
ciborium::Value::Map(
map.iter()
.map(|(k, v)| {
let cbor_key = ciborium::Value::Text(k.clone());
let cbor_val = if k == first && rest.is_empty() {
ciborium::Value::Bytes(raw_data.to_vec())
} else if k == first {
json_to_cbor_inner(v, rest, raw_data)
} else {
json_to_cbor_inner(v, "", &[])
};
(cbor_key, cbor_val)
})
.collect(),
)
}
}
}
fn json_to_cbor_inner_many(
val: &Value,
path: &str,
full_path: &str,
raw_data_many: &[Vec<u8>],
next_idx: &mut usize,
) -> ciborium::Value {
if path.is_empty() {
return json_to_cbor_inner(val, "", &[]);
}
match val {
Value::Null => ciborium::Value::Null,
Value::Bool(b) => ciborium::Value::Bool(*b),
Value::Number(n) => {
if let Some(i) = n.as_i64() {
ciborium::Value::Integer(i.into())
} else if let Some(f) = n.as_f64() {
ciborium::Value::Float(f)
} else {
ciborium::Value::Null
}
}
Value::String(s) => ciborium::Value::Text(s.clone()),
Value::Array(arr) => {
let (first, rest) = path.split_once('.').unwrap_or((path, ""));
if first == "*" {
ciborium::Value::Array(
arr.iter()
.map(|item| {
json_to_cbor_inner_many(item, rest, full_path, raw_data_many, next_idx)
})
.collect(),
)
} else {
json_to_cbor_inner(val, "", &[])
}
}
Value::Object(map) => {
let (first, rest) = path.split_once('.').unwrap_or((path, ""));
ciborium::Value::Map(
map.iter()
.map(|(k, v)| {
let cbor_key = ciborium::Value::Text(k.clone());
let cbor_val = if k == first && rest.is_empty() {
let idx = *next_idx;
let raw = raw_data_many.get(idx).unwrap_or_else(|| {
panic!(
"path {:?} matched more fields than raw payloads provided: needed payload index {}, provided {}",
full_path,
idx,
raw_data_many.len()
)
});
*next_idx += 1;
ciborium::Value::Bytes(raw.clone())
} else if k == first {
json_to_cbor_inner_many(v, rest, full_path, raw_data_many, next_idx)
} else {
json_to_cbor_inner(v, "", &[])
};
(cbor_key, cbor_val)
})
.collect(),
)
}
}
}
pub fn prop_runner(cases: u32) -> proptest::test_runner::TestRunner {
proptest::test_runner::TestRunner::new(proptest::test_runner::Config {
cases,
..Default::default()
})
}
pub fn strip_keys(val: &mut Value, keys: &[&str]) {
match val {
Value::Object(map) => {
for key in keys {
map.remove(*key);
}
for v in map.values_mut() {
strip_keys(v, keys);
}
}
Value::Array(arr) => {
for item in arr {
strip_keys(item, keys);
}
}
_ => {}
}
}
pub fn assert_values_equivalent(a: &Value, b: &Value, ignore_keys: &[&str]) {
let mut a = a.clone();
let mut b = b.clone();
strip_keys(&mut a, ignore_keys);
strip_keys(&mut b, ignore_keys);
assert_eq!(
a,
b,
"Values not equivalent after stripping {:?}:\n left: {}\n right: {}",
ignore_keys,
serde_json::to_string_pretty(&a).unwrap(),
serde_json::to_string_pretty(&b).unwrap(),
);
}