use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
sync::{
Arc, RwLock,
atomic::{AtomicU64, Ordering},
},
};
use trillium::{Conn, Handler, Status};
use trillium_api::{ApiConnExt, Body, FromConn, Json, State, TryFromConn, Value, api};
use trillium_router::{RouterConnExt, router};
#[derive(Debug, Serialize, Deserialize, Clone)]
struct Todo {
id: u64,
title: String,
completed: bool,
owner: String,
}
#[derive(Debug, Deserialize)]
struct NewTodo {
title: String,
completed: Option<bool>,
}
#[derive(Debug, Deserialize)]
struct UpdateTodo {
title: Option<String>,
completed: Option<bool>,
}
#[derive(Debug, Clone, Default)]
struct Db {
todos: Arc<RwLock<HashMap<u64, Todo>>>,
next_id: Arc<AtomicU64>,
}
impl Handler for Db {
async fn run(&self, conn: Conn) -> Conn {
conn.with_state(self.clone())
}
}
impl FromConn for Db {
async fn from_conn(conn: &mut Conn) -> Option<Self> {
conn.state().cloned()
}
}
#[derive(Debug, Clone)]
struct User(String);
impl FromConn for User {
async fn from_conn(conn: &mut Conn) -> Option<Self> {
conn.request_headers()
.get_str("x-user")
.map(|s| User(s.to_owned()))
}
}
impl TryFromConn for Todo {
type Error = Status;
async fn try_from_conn(conn: &mut Conn) -> Result<Self, Status> {
let db = Db::from_conn(conn)
.await
.ok_or(Status::InternalServerError)?;
let id: u64 = conn
.param("todo_id")
.and_then(|p| p.parse().ok())
.ok_or(Status::BadRequest)?;
let todos = db.todos.read().unwrap();
todos.get(&id).cloned().ok_or(Status::NotFound)
}
}
async fn require_user(
_conn: &mut Conn,
user: Option<User>,
) -> Option<(Status, trillium_api::Halt)> {
if user.is_none() {
Some((Status::Forbidden, trillium_api::Halt))
} else {
None
}
}
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "error")]
enum AppError {
#[serde(rename = "not_found")]
NotFound { message: String },
#[serde(rename = "bad_request")]
BadRequest { message: String },
}
impl Handler for AppError {
async fn run(&self, conn: Conn) -> Conn {
let status = match self {
AppError::NotFound { .. } => Status::NotFound,
AppError::BadRequest { .. } => Status::BadRequest,
};
conn.with_json(self).with_status(status).halt()
}
}
async fn health(_conn: &mut Conn, _: ()) -> &'static str {
"ok"
}
async fn list_todos(_conn: &mut Conn, db: Db) -> Json<Vec<Todo>> {
let todos = db.todos.read().unwrap();
Json(todos.values().cloned().collect())
}
async fn create_todo(
_conn: &mut Conn,
(User(owner), Body(new_todo), db): (User, Body<NewTodo>, Db),
) -> (Status, Json<Todo>) {
let id = db.next_id.fetch_add(1, Ordering::Relaxed);
let todo = Todo {
id,
title: new_todo.title,
completed: new_todo.completed.unwrap_or(false),
owner,
};
db.todos.write().unwrap().insert(id, todo.clone());
(Status::Created, Json(todo))
}
async fn show_todo(_conn: &mut Conn, todo: Todo) -> Json<Todo> {
Json(todo)
}
async fn update_todo(
_conn: &mut Conn,
(mut todo, Body(update), db): (Todo, Body<UpdateTodo>, Db),
) -> Result<Json<Todo>, AppError> {
if let Some(title) = update.title {
if title.is_empty() {
return Err(AppError::BadRequest {
message: "title cannot be empty".into(),
});
}
todo.title = title;
}
if let Some(completed) = update.completed {
todo.completed = completed;
}
db.todos.write().unwrap().insert(todo.id, todo.clone());
Ok(Json(todo))
}
async fn delete_todo(_conn: &mut Conn, (todo, db): (Todo, Db)) -> Status {
db.todos.write().unwrap().remove(&todo.id);
Status::NoContent
}
async fn search_todos(conn: &mut Conn, db: Db) -> Result<Json<Vec<Todo>>, AppError> {
let query = conn
.querystring()
.split('&')
.find_map(|pair| pair.strip_prefix("q="))
.ok_or_else(|| AppError::BadRequest {
message: "missing query parameter `q`".into(),
})?;
let todos = db.todos.read().unwrap();
let matches: Vec<Todo> = todos
.values()
.filter(|t| t.title.contains(query))
.cloned()
.collect();
if matches.is_empty() {
Err(AppError::NotFound {
message: format!("no todos matching \"{query}\""),
})
} else {
Ok(Json(matches))
}
}
async fn me(_conn: &mut Conn, (User(name), State(app_name)): (User, State<String>)) -> Json<Value> {
Json(trillium_api::json!({
"user": name,
"app": app_name,
}))
}
#[derive(Copy, Clone, Debug)]
struct CustomErrorHandler;
impl Handler for CustomErrorHandler {
async fn run(&self, conn: Conn) -> Conn {
conn
}
async fn before_send(&self, mut conn: Conn) -> Conn {
if let Some(error) = conn.take_state::<AppError>() {
let status = match &error {
AppError::NotFound { .. } => Status::NotFound,
AppError::BadRequest { .. } => Status::BadRequest,
};
conn.with_json(&error).with_status(status)
} else {
conn
}
}
}
fn app() -> impl Handler {
let db = Db::default();
(
db,
trillium::State::new("Todo App".to_string()),
CustomErrorHandler,
api(require_user),
router()
.get("/", api(health))
.get("/me", api(me))
.get("/todos", api(list_todos))
.get("/todos/search", api(search_todos))
.post("/todos", api(create_todo))
.get("/todos/:todo_id", api(show_todo))
.patch("/todos/:todo_id", api(update_todo))
.delete("/todos/:todo_id", api(delete_todo)),
)
}
fn main() {
env_logger::init();
trillium_smol::run(app());
}
#[cfg(test)]
mod tests {
use super::*;
use trillium_testing::prelude::*;
#[test]
fn test_list_empty() {
assert_ok!(
get("/todos")
.with_request_header("x-user", "alice")
.on(&app()),
"[]"
);
}
#[test]
fn test_create_and_show() {
let app = app();
let mut response = post("/todos")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"title": "buy milk"}"#)
.on(&app);
assert_status!(&response, Status::Created);
let body = response.take_response_body_string().unwrap();
assert!(body.contains("buy milk"));
assert!(body.contains("alice"));
let mut show = get("/todos/0")
.with_request_header("x-user", "alice")
.on(&app);
assert_status!(&show, Status::Ok);
let show_body = show.take_response_body_string().unwrap();
assert!(show_body.contains("buy milk"));
}
#[test]
fn test_update() {
let app = app();
post("/todos")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"title": "buy milk"}"#)
.on(&app);
let mut response = patch("/todos/0")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"completed": true}"#)
.on(&app);
assert_status!(&response, Status::Ok);
let body = response.take_response_body_string().unwrap();
assert!(body.contains("true"));
}
#[test]
fn test_update_empty_title_returns_error() {
let app = app();
post("/todos")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"title": "buy milk"}"#)
.on(&app);
assert_status!(
patch("/todos/0")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"title": ""}"#)
.on(&app),
Status::BadRequest
);
}
#[test]
fn test_delete() {
let app = app();
post("/todos")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body(r#"{"title": "buy milk"}"#)
.on(&app);
assert_status!(
delete("/todos/0")
.with_request_header("x-user", "alice")
.on(&app),
Status::NoContent
);
assert_status!(
get("/todos/0")
.with_request_header("x-user", "alice")
.on(&app),
Status::NotFound
);
}
#[test]
fn test_not_found() {
assert_status!(
get("/todos/999")
.with_request_header("x-user", "alice")
.on(&app()),
Status::NotFound
);
}
#[test]
fn test_missing_auth_returns_forbidden() {
assert_status!(get("/todos").on(&app()), Status::Forbidden);
}
#[test]
fn test_bad_json() {
let mut response = post("/todos")
.with_request_header("x-user", "alice")
.with_request_header("content-type", "application/json")
.with_request_body("not json")
.on(&app());
assert_status!(&response, Status::UnprocessableEntity);
let body = response.take_response_body_string().unwrap();
assert!(body.contains("parse_error"), "got: {body}");
}
#[test]
fn test_me() {
let mut response = get("/me").with_request_header("x-user", "alice").on(&app());
assert_status!(&response, Status::Ok);
let body = response.take_response_body_string().unwrap();
assert!(body.contains("alice"));
assert!(body.contains("Todo App"));
}
}