use std::net::SocketAddr;
use axum::{extract::FromRequestParts, response::IntoResponse, routing::get, Router};
use http::request::Parts;
use serde::{Deserialize, Serialize};
use time::Duration;
use tower_sessions::{Expiry, MemoryStore, Session, SessionManagerLayer};
const COUNTER_KEY: &str = "counter";
#[derive(Default, Deserialize, Serialize)]
struct Counter(usize);
impl<S> FromRequestParts<S> for Counter
where
S: Send + Sync,
{
type Rejection = (http::StatusCode, &'static str);
async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let session = Session::from_request_parts(req, state).await?;
let counter: Counter = session.get(COUNTER_KEY).await.unwrap().unwrap_or_default();
session.insert(COUNTER_KEY, counter.0 + 1).await.unwrap();
Ok(counter)
}
}
#[tokio::main]
async fn main() {
let session_store = MemoryStore::default();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::seconds(10)));
let app = Router::new().route("/", get(handler)).layer(session_layer);
let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app.into_make_service())
.await
.unwrap();
}
async fn handler(counter: Counter) -> impl IntoResponse {
format!("Current count: {}", counter.0)
}