use rama::{
Context, Layer,
error::OpaqueError,
http::{
BodyExtractExt,
client::EasyHttpWebClient,
server::HttpServer,
service::{client::HttpClientExt, web::WebService},
},
layer::{
LimitLayer,
limit::{Policy, PolicyOutput, policy::PolicyResult},
},
net::client::pool::FiFoReuseLruDropPool,
rt::Executor,
tcp::server::TcpListener,
};
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use tokio::{sync::oneshot::Sender, sync::oneshot::channel};
use tokio_test::assert_err;
use tracing::level_filters::LevelFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::{EnvFilter, fmt};
const ADDRESS: &str = "127.0.0.1:62024";
#[tokio::main]
async fn main() {
setup_tracing();
let (ready_tx, ready_rx) = channel();
tokio::spawn(run_server(ADDRESS, ready_tx));
ready_rx.await.unwrap();
let pool = FiFoReuseLruDropPool::new(5, 10).unwrap();
let client = EasyHttpWebClient::default().with_connection_pool(pool);
let resp = client
.get(format!("http://{ADDRESS}/"))
.send(Context::default())
.await
.unwrap();
let body = resp.try_into_string().await.unwrap();
tracing::info!("body: {:?}", body);
assert_eq!(body, "Hello, World!");
let _resp = client
.get(format!("http://{ADDRESS}/"))
.send(Context::default())
.await
.unwrap();
let client = client.without_connection_pool();
let result = client
.get(format!("http://{ADDRESS}/"))
.send(Context::default())
.await;
assert_err!(result);
}
fn setup_tracing() {
tracing_subscriber::registry()
.with(fmt::layer())
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::DEBUG.into())
.from_env_lossy(),
)
.init();
}
async fn run_server(addr: &str, ready: Sender<()>) {
tracing::info!("running service at: {addr}");
let exec = Executor::default();
let http_service =
HttpServer::auto(exec).service(WebService::default().get("/", "Hello, World!"));
let serve = TcpListener::build()
.bind(addr)
.await
.expect("bind TCP Listener")
.serve((LimitLayer::new(FirstConnOnly::new())).layer(http_service));
ready.send(()).unwrap();
serve.await;
}
#[derive(Clone)]
struct FirstConnOnly(Arc<AtomicBool>);
impl FirstConnOnly {
fn new() -> Self {
Self(Arc::new(AtomicBool::new(false)))
}
}
impl<State, Request> Policy<State, Request> for FirstConnOnly
where
State: Send + Sync + 'static,
Request: Send + 'static,
{
type Guard = ();
type Error = OpaqueError;
async fn check(
&self,
ctx: Context<State>,
request: Request,
) -> PolicyResult<State, Request, Self::Guard, Self::Error> {
let output = match !self.0.swap(true, Ordering::AcqRel) {
true => PolicyOutput::Ready(()),
false => PolicyOutput::Abort(OpaqueError::from_display(
"Only first connection is allowed",
)),
};
PolicyResult {
ctx,
request,
output,
}
}
}