use axum::{
extract::State,
routing::{get, post},
Router,
};
use sqlx::prelude::*;
type PgPool = sqlx::postgres::PgPool;
async fn connection_pool(conn_uri: &str) -> PgPool {
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(conn_uri)
.await
.expect("failed to create connection pool");
sqlx::migrate!("examples/sqlx-migrations")
.run(&pool)
.await
.expect("failed to run migrations");
pool
}
async fn list_completed_tasks(pool: State<PgPool>) -> String {
let rows = sqlx::query("SELECT task FROM tasks WHERE completed = true")
.fetch_all(&*pool)
.await
.expect("failed to execute select query");
let mut output = String::new();
for row in rows {
let task: &str = row.get(0);
output = output + task + "\n";
}
output
}
async fn create_task(pool: State<PgPool>, body: String) -> &'static str {
sqlx::query("INSERT INTO tasks (task) VALUES ($1)")
.bind(body)
.execute(&*pool)
.await
.expect("failed to execute insert query");
"ok"
}
fn axum_router(pool: PgPool) -> Router {
Router::new()
.route("/list_completed_tasks", get(list_completed_tasks))
.route("/create_task", post(create_task))
.with_state(pool)
}
async fn run_listener(conn_uri: &str) {
let pool = connection_pool(conn_uri).await;
let mut pglistener = sqlx::postgres::PgListener::connect_with(&pool)
.await
.expect("failed to create listener");
pglistener
.listen("insert_tasks")
.await
.expect("failed to start listening to insert_task events");
println!("listener is ready");
loop {
let notif = pglistener.recv().await.expect("listener recv failed.");
let row_id = notif.payload();
println!("new task, row id {row_id}");
let mut tx = pool.begin().await.expect("failed to start transaction");
let task_row = sqlx::query(
"SELECT id, task FROM tasks WHERE completed = false FOR UPDATE SKIP LOCKED LIMIT 1",
)
.fetch_one(&mut *tx)
.await
.expect("failed to select open task");
let id: i32 = task_row.get(0);
let task: &str = task_row.get(1);
println!("executing task `{task}`");
sqlx::query("UPDATE tasks SET completed = true WHERE id = ($1)")
.bind(id)
.execute(&mut *tx)
.await
.expect("failed to execute update on task");
tx.commit()
.await
.expect("failed to commit task execution transaction");
sqlx::query("NOTIFY task_executed")
.execute(&pool)
.await
.expect("failed to notify task executed");
}
}
#[tokio::test]
async fn test_sqlx_queue_example() {
run_sqlx_queue_example().await;
}
#[tokio::main]
async fn main() {
run_sqlx_queue_example().await;
}
async fn run_sqlx_queue_example() {
let db = pgtemp::PgTempDB::new();
let conn_uri = db.connection_uri().clone();
let pool = connection_pool(&conn_uri).await;
std::thread::spawn(move || {
let rt = tokio::runtime::Runtime::new().expect("failed to start runtime");
rt.block_on(async { run_listener(&conn_uri).await });
});
std::thread::sleep(std::time::Duration::from_millis(200));
let mut pglistener = sqlx::postgres::PgListener::connect_with(&pool)
.await
.expect("failed to create listener");
pglistener
.listen("task_executed")
.await
.expect("failed to start listening to task_executed events");
let router = axum_router(pool.clone());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to start listener");
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, router)
.await
.expect("failed to run axum server");
});
let base_url = format!("http://{addr}");
let client = reqwest::Client::new();
let resp = client
.post(base_url.clone() + "/create_task")
.body("hello")
.send()
.await
.expect("failed to create task 1");
assert!(resp.status().is_success());
let resp = client
.post(base_url.clone() + "/create_task")
.body("task 2")
.send()
.await
.expect("failed to create task 2");
assert!(resp.status().is_success());
let _notif1 = pglistener.recv().await.expect("listener recv failed.");
let _notif2 = pglistener.recv().await.expect("listener recv failed.");
let resp = client
.get(base_url + "/list_completed_tasks")
.send()
.await
.expect("failed to list tasks");
assert!(resp.status().is_success());
let body = resp.text().await.expect("failed to parse body");
assert!(body.contains("hello"));
assert!(body.contains("task 2"));
}