#[cfg(feature = "openapi")]
mod openapi;
use std::collections::HashMap;
use std::sync::Arc;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode};
use axum::response::Response;
use axum::routing::get;
use axum::Router;
use serde_json::{json, Value};
use crate::core::{
Assignment, CountQuery, DeleteQuery, FieldType, Filter, InsertQuery, ModelSchema, Op,
OrderClause, SearchClause, SelectQuery, SqlValue, UpdateQuery, WhereExpr,
};
use crate::forms::{collect_values, parse_form_value, parse_pk_string, FormError};
use crate::sql::sqlx::{PgPool, Row as _};
#[derive(Clone, Default)]
pub struct ViewSetPerms {
pub list: Vec<String>,
pub retrieve: Vec<String>,
pub create: Vec<String>,
pub update: Vec<String>,
pub destroy: Vec<String>,
}
#[derive(Clone, Debug)]
pub enum PaginationStyle {
PageNumber,
Cursor {
field: &'static str,
desc: bool,
},
}
impl PaginationStyle {
#[must_use]
pub const fn page_number() -> Self {
Self::PageNumber
}
#[must_use]
pub const fn cursor(field: &'static str) -> Self {
Self::Cursor { field, desc: false }
}
#[must_use]
pub const fn cursor_desc(field: &'static str) -> Self {
Self::Cursor { field, desc: true }
}
}
type RowRender = std::sync::Arc<dyn Fn(&crate::sql::sqlx::postgres::PgRow) -> Value + Send + Sync>;
#[derive(Clone)]
pub struct ViewSet {
schema: &'static ModelSchema,
fields: Option<Vec<String>>,
filter_fields: Vec<String>,
search_fields: Vec<String>,
default_page_size: usize,
default_ordering: Vec<(String, bool)>,
perms: ViewSetPerms,
read_only: bool,
pagination: PaginationStyle,
row_render: Option<RowRender>,
}
impl ViewSet {
pub fn for_model(schema: &'static ModelSchema) -> Self {
Self {
schema,
fields: None,
filter_fields: Vec::new(),
search_fields: Vec::new(),
default_page_size: 20,
default_ordering: Vec::new(),
perms: ViewSetPerms::default(),
read_only: false,
pagination: PaginationStyle::PageNumber,
row_render: None,
}
}
#[must_use]
pub fn serializer<S>(mut self) -> Self
where
S: crate::serializer::ModelSerializer + 'static,
S::Model:
for<'r> crate::sql::sqlx::FromRow<'r, crate::sql::sqlx::postgres::PgRow> + Send + Unpin,
{
let render: RowRender = std::sync::Arc::new(|row| {
match <S::Model as crate::sql::sqlx::FromRow<_>>::from_row(row) {
Ok(model) => {
let s = S::from_model(&model);
serde_json::to_value(&s).unwrap_or(Value::Null)
}
Err(_) => Value::Null,
}
});
self.row_render = Some(render);
self
}
#[must_use]
pub fn cursor_pagination(mut self, field: &'static str) -> Self {
self.pagination = PaginationStyle::Cursor { field, desc: false };
self
}
#[must_use]
pub fn cursor_pagination_desc(mut self, field: &'static str) -> Self {
self.pagination = PaginationStyle::Cursor { field, desc: true };
self
}
#[must_use]
pub fn pagination(mut self, style: PaginationStyle) -> Self {
self.pagination = style;
self
}
pub fn fields(mut self, fields: &[&str]) -> Self {
self.fields = Some(fields.iter().map(|&s| s.to_owned()).collect());
self
}
pub fn filter_fields(mut self, fields: &[&str]) -> Self {
self.filter_fields = fields.iter().map(|&s| s.to_owned()).collect();
self
}
pub fn search_fields(mut self, fields: &[&str]) -> Self {
self.search_fields = fields.iter().map(|&s| s.to_owned()).collect();
self
}
pub fn page_size(mut self, n: usize) -> Self {
self.default_page_size = n.min(1000);
self
}
pub fn ordering(mut self, ordering: &[(&str, bool)]) -> Self {
self.default_ordering = ordering.iter().map(|&(f, d)| (f.to_owned(), d)).collect();
self
}
pub fn permissions(mut self, perms: ViewSetPerms) -> Self {
self.perms = perms;
self
}
#[cfg(feature = "tenancy")]
pub fn permissions_for_model<T: crate::core::Model>(mut self) -> Self {
let cn = |action: &str| crate::permissions::codename_for::<T>(action);
self.perms = ViewSetPerms {
list: vec![cn("view")],
retrieve: vec![cn("view")],
create: vec![cn("add")],
update: vec![cn("change")],
destroy: vec![cn("delete")],
};
self
}
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
pub fn router(self, prefix: &str, pool: PgPool) -> Router {
Self::router_with_source(self, prefix, PoolSource::Static(pool))
}
#[cfg(feature = "tenancy")]
#[must_use]
pub fn tenant_router(self, prefix: &str) -> Router {
Self::router_with_source(self, prefix, PoolSource::Tenant)
}
fn router_with_source(self, prefix: &str, pool_source: PoolSource) -> Router {
let state = Arc::new(ViewSetState {
pool_source,
vs: self.clone(),
});
let prefix = prefix.trim_end_matches('/').to_owned();
let collection = prefix.clone();
let item = format!("{prefix}/{{pk}}");
let collection_route = if self.read_only {
get(handle_list)
} else {
get(handle_list).post(handle_create)
};
let item_route = if self.read_only {
axum::routing::MethodRouter::new().get(handle_retrieve)
} else {
axum::routing::MethodRouter::new()
.get(handle_retrieve)
.put(handle_update)
.patch(handle_partial_update)
.delete(handle_destroy)
};
Router::new()
.route(&collection, collection_route)
.route(&item, item_route)
.with_state(state)
}
}
#[derive(Clone)]
enum PoolSource {
Static(PgPool),
#[cfg(feature = "tenancy")]
Tenant,
}
#[derive(Clone)]
struct ViewSetState {
pool_source: PoolSource,
vs: ViewSet,
}
enum AcquiredConn {
Static(PgPool),
#[cfg(feature = "tenancy")]
Tenant(Box<crate::extractors::Tenant>),
}
impl AcquiredConn {
async fn select_rows(
&mut self,
q: &SelectQuery,
) -> Result<Vec<crate::sql::sqlx::postgres::PgRow>, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::select_rows_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::select_rows_on(t.conn(), q).await,
}
}
async fn count_rows(&mut self, q: &CountQuery) -> Result<i64, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::count_rows_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::count_rows_on(t.conn(), q).await,
}
}
async fn select_one_row(
&mut self,
q: &SelectQuery,
) -> Result<Option<crate::sql::sqlx::postgres::PgRow>, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::select_one_row_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::select_one_row_on(t.conn(), q).await,
}
}
async fn insert_returning(
&mut self,
q: &InsertQuery,
) -> Result<crate::sql::sqlx::postgres::PgRow, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::insert_returning_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::insert_returning_on(t.conn(), q).await,
}
}
async fn update(&mut self, q: &UpdateQuery) -> Result<u64, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::update_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::update_on(t.conn(), q).await,
}
}
async fn delete(&mut self, q: &DeleteQuery) -> Result<u64, crate::sql::ExecError> {
match self {
Self::Static(pool) => crate::sql::delete_on(&*pool, q).await,
#[cfg(feature = "tenancy")]
Self::Tenant(t) => crate::sql::delete_on(t.conn(), q).await,
}
}
#[cfg(feature = "tenancy")]
async fn has_perm(
&mut self,
uid: i64,
codename: &str,
) -> Result<bool, crate::sql::sqlx::Error> {
match self {
Self::Static(pool) => {
crate::tenancy::permissions::has_perm_on(uid, codename, &*pool).await
}
Self::Tenant(t) => {
crate::tenancy::permissions::has_perm_on(uid, codename, t.conn()).await
}
}
}
}
impl ViewSetState {
fn effective_fields(&self) -> Vec<&'static crate::core::FieldSchema> {
let schema = self.vs.schema;
match &self.vs.fields {
Some(names) => names.iter().filter_map(|n| schema.field(n)).collect(),
None => schema.scalar_fields().collect(),
}
}
async fn acquire(
&self,
parts: &mut axum::http::request::Parts,
) -> Result<AcquiredConn, Response> {
match &self.pool_source {
PoolSource::Static(pool) => Ok(AcquiredConn::Static(pool.clone())),
#[cfg(feature = "tenancy")]
PoolSource::Tenant => {
use axum::extract::FromRequestParts as _;
use axum::response::IntoResponse as _;
crate::extractors::Tenant::from_request_parts(parts, &())
.await
.map(|t| AcquiredConn::Tenant(Box::new(t)))
.map_err(|e| e.into_response())
}
}
}
async fn check_perm(
&self,
codenames: &[String],
parts: &axum::http::request::Parts,
conn: &mut AcquiredConn,
) -> bool {
if codenames.is_empty() {
return true;
}
#[cfg(feature = "tenancy")]
{
let Some(auth) = parts
.extensions
.get::<crate::tenancy::middleware::AuthenticatedUser>()
else {
return false;
};
if auth.is_superuser {
return true;
}
for cn in codenames {
if let Ok(true) = conn.has_perm(auth.id, cn).await {
return true;
}
}
false
}
#[cfg(not(feature = "tenancy"))]
{
let _ = (parts, conn);
false
}
}
fn pk_field(&self) -> Option<&'static crate::core::FieldSchema> {
self.vs.schema.primary_key()
}
}
pub(crate) use crate::sql::row_to_json;
fn json_response(body: Value) -> Response {
Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn json_error(status: StatusCode, msg: &str) -> Response {
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(json!({"error": msg}).to_string()))
.unwrap()
}
fn json_created(body: Value) -> Response {
Response::builder()
.status(StatusCode::CREATED)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn no_content() -> Response {
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap()
}
fn build_lookup_filter(
field: &'static crate::core::FieldSchema,
lookup: Option<&str>,
raw: &str,
) -> Option<WhereExpr> {
let column = field.column;
let predicate =
|op: Op, value: SqlValue| Some(WhereExpr::Predicate(Filter { column, op, value }));
match lookup.unwrap_or("exact") {
"exact" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Eq, v)),
"ne" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Ne, v)),
"gt" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Gt, v)),
"gte" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Gte, v)),
"lt" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Lt, v)),
"lte" => parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(Op::Lte, v)),
"in" | "not_in" => {
let parts: Vec<SqlValue> = raw
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(|s| parse_form_value(field, Some(s)).ok())
.collect();
if parts.is_empty() {
return None;
}
let op = if lookup == Some("not_in") {
Op::NotIn
} else {
Op::In
};
predicate(op, SqlValue::List(parts))
}
"contains" => predicate(Op::Like, SqlValue::String(format!("%{raw}%"))),
"icontains" => predicate(Op::ILike, SqlValue::String(format!("%{raw}%"))),
"startswith" => predicate(Op::Like, SqlValue::String(format!("{raw}%"))),
"istartswith" => predicate(Op::ILike, SqlValue::String(format!("{raw}%"))),
"endswith" => predicate(Op::Like, SqlValue::String(format!("%{raw}"))),
"iendswith" => predicate(Op::ILike, SqlValue::String(format!("%{raw}"))),
"isnull" => {
let is_null = matches!(raw.to_ascii_lowercase().as_str(), "true" | "1" | "yes");
predicate(Op::IsNull, SqlValue::Bool(is_null))
}
_ => None, }
}
async fn handle_list(
State(state): State<Arc<ViewSetState>>,
Query(params): Query<HashMap<String, String>>,
req: axum::extract::Request,
) -> Response {
let (mut parts, _) = req.into_parts();
let mut acq = match state.acquire(&mut parts).await {
Ok(a) => a,
Err(resp) => return resp,
};
if !state
.check_perm(&state.vs.perms.list, &parts, &mut acq)
.await
{
return json_error(StatusCode::FORBIDDEN, "permission denied");
}
let page_size: i64 = params
.get("page_size")
.and_then(|p| p.parse().ok())
.unwrap_or(state.vs.default_page_size as i64)
.min(1000)
.max(1);
let mut filters: Vec<WhereExpr> = Vec::new();
for (param_key, raw_val) in ¶ms {
if matches!(
param_key.as_str(),
"page" | "page_size" | "ordering" | "search" | "cursor"
) {
continue;
}
let (field_name, lookup) = match param_key.split_once("__") {
Some((name, lk)) => (name, Some(lk)),
None => (param_key.as_str(), None),
};
if !state.vs.filter_fields.iter().any(|f| f == field_name) {
continue;
}
let Some(field) = state.vs.schema.field(field_name) else {
continue;
};
if let Some(predicate) = build_lookup_filter(field, lookup, raw_val) {
filters.push(predicate);
}
}
let where_clause = if filters.len() == 1 {
filters.remove(0)
} else if filters.is_empty() {
WhereExpr::And(vec![])
} else {
WhereExpr::And(filters)
};
let search = params.get("search").filter(|s| !s.is_empty()).cloned();
let search_clause = search.map(|q| SearchClause {
query: q,
columns: state
.vs
.search_fields
.iter()
.filter_map(|n| state.vs.schema.field(n).map(|f| f.column))
.collect(),
});
let ordering_param = params.get("ordering").cloned();
let order_by: Vec<OrderClause> = ordering_param
.as_deref()
.map(|raw| {
raw.split(',')
.filter(|s| !s.is_empty())
.filter_map(|part| {
let (field_name, desc) = if let Some(name) = part.strip_prefix('-') {
(name, true)
} else {
(part, false)
};
state.vs.schema.field(field_name).map(|f| OrderClause {
column: f.column,
desc,
})
})
.collect()
})
.unwrap_or_else(|| {
state
.vs
.default_ordering
.iter()
.filter_map(|(name, desc)| {
state.vs.schema.field(name).map(|f| OrderClause {
column: f.column,
desc: *desc,
})
})
.collect()
});
let fields = state.effective_fields();
match &state.vs.pagination {
PaginationStyle::PageNumber => {
let page: i64 = params
.get("page")
.and_then(|p| p.parse().ok())
.unwrap_or(1)
.max(1);
let offset = (page - 1) * page_size;
let select_q = SelectQuery {
model: state.vs.schema,
where_clause: where_clause.clone(),
search: search_clause.clone(),
joins: vec![],
order_by: order_by.clone(),
limit: Some(page_size),
offset: Some(offset),
};
let count_q = CountQuery {
model: state.vs.schema,
where_clause,
search: search_clause.clone(),
};
let rows = match acq.select_rows(&select_q).await {
Ok(r) => r,
Err(e) => return json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
};
let count = match acq.count_rows(&count_q).await {
Ok(c) => c,
Err(e) => return json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
};
let results: Vec<Value> = match &state.vs.row_render {
Some(render) => rows.iter().map(|r| (render)(r)).collect(),
None => rows.iter().map(|row| row_to_json(row, &fields)).collect(),
};
let last_page = ((count - 1).max(0) / page_size) + 1;
json_response(json!({
"count": count,
"page": page,
"page_size": page_size,
"last_page": last_page,
"results": results,
}))
}
PaginationStyle::Cursor {
field: cursor_field,
desc,
} => {
handle_list_cursor(
state.as_ref(),
&mut acq,
params,
where_clause,
search_clause,
fields,
page_size,
cursor_field,
*desc,
)
.await
}
}
}
async fn handle_list_cursor(
state: &ViewSetState,
acq: &mut AcquiredConn,
params: HashMap<String, String>,
where_clause: WhereExpr,
search_clause: Option<SearchClause>,
fields: Vec<&'static crate::core::FieldSchema>,
page_size: i64,
cursor_field: &str,
desc: bool,
) -> Response {
let Some(cursor_schema) = state.vs.schema.field(cursor_field) else {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("cursor field `{cursor_field}` not found on model"),
);
};
if !matches!(
cursor_schema.ty,
FieldType::I16 | FieldType::I32 | FieldType::I64
) {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"cursor pagination requires an integer field (i16/i32/i64)",
);
}
let cursor_val: Option<i64> = match params.get("cursor") {
Some(c) if !c.is_empty() => match decode_cursor(c) {
Some(v) => Some(v),
None => return json_error(StatusCode::BAD_REQUEST, "invalid cursor"),
},
_ => None,
};
let final_where = match cursor_val {
Some(v) => {
let op = if desc { Op::Lt } else { Op::Gt };
let cursor_pred = WhereExpr::Predicate(Filter {
column: cursor_schema.column,
op,
value: SqlValue::I64(v),
});
match where_clause {
WhereExpr::And(v) if v.is_empty() => cursor_pred,
WhereExpr::And(mut v) => {
v.push(cursor_pred);
WhereExpr::And(v)
}
other => WhereExpr::And(vec![other, cursor_pred]),
}
}
None => where_clause,
};
let order_by = vec![OrderClause {
column: cursor_schema.column,
desc,
}];
let select_q = SelectQuery {
model: state.vs.schema,
where_clause: final_where,
search: search_clause,
joins: vec![],
order_by,
limit: Some(page_size + 1),
offset: None,
};
let rows = match acq.select_rows(&select_q).await {
Ok(r) => r,
Err(e) => return json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
};
let has_more = rows.len() as i64 > page_size;
let page_rows = if has_more {
&rows[..page_size as usize]
} else {
&rows[..]
};
let next_cursor = if has_more {
let last = page_rows.last().expect("non-empty page");
let val: i64 = match cursor_schema.ty {
FieldType::I16 => last
.try_get::<i16, _>(cursor_schema.column)
.map(i64::from)
.unwrap_or(0),
FieldType::I32 => last
.try_get::<i32, _>(cursor_schema.column)
.map(i64::from)
.unwrap_or(0),
FieldType::I64 => last.try_get::<i64, _>(cursor_schema.column).unwrap_or(0),
_ => 0,
};
Some(encode_cursor(val))
} else {
None
};
let results: Vec<Value> = match &state.vs.row_render {
Some(render) => page_rows.iter().map(|r| (render)(r)).collect(),
None => page_rows
.iter()
.map(|row| row_to_json(row, &fields))
.collect(),
};
json_response(json!({
"page_size": page_size,
"next": next_cursor,
"results": results,
}))
}
fn encode_cursor(value: i64) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(value.to_string().as_bytes())
}
fn decode_cursor(token: &str) -> Option<i64> {
use base64::Engine;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(token.as_bytes())
.ok()?;
let s = std::str::from_utf8(&bytes).ok()?;
s.parse::<i64>().ok()
}
async fn handle_retrieve(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
let (mut parts, _) = req.into_parts();
let mut acq = match state.acquire(&mut parts).await {
Ok(a) => a,
Err(resp) => return resp,
};
if !state
.check_perm(&state.vs.perms.retrieve, &parts, &mut acq)
.await
{
return json_error(StatusCode::FORBIDDEN, "permission denied");
}
let Some(pk_field) = state.pk_field() else {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"model has no primary key",
);
};
let pk_val = match parse_pk_string(pk_field, &pk_raw) {
Ok(v) => v,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
};
let select_q = SelectQuery {
model: state.vs.schema,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val,
}),
search: None,
joins: vec![],
order_by: vec![],
limit: Some(1),
offset: None,
};
let fields = state.effective_fields();
match acq.select_one_row(&select_q).await {
Ok(Some(row)) => match &state.vs.row_render {
Some(render) => json_response((render)(&row)),
None => json_response(row_to_json(&row, &fields)),
},
Ok(None) => json_error(StatusCode::NOT_FOUND, "not found"),
Err(e) => json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
}
}
async fn handle_create(
State(state): State<Arc<ViewSetState>>,
req: axum::extract::Request,
) -> Response {
let (mut parts, body) = req.into_parts();
let mut acq = match state.acquire(&mut parts).await {
Ok(a) => a,
Err(resp) => return resp,
};
if !state
.check_perm(&state.vs.perms.create, &parts, &mut acq)
.await
{
return json_error(StatusCode::FORBIDDEN, "permission denied");
}
let form = match extract_form_body(parts, body).await {
Ok(f) => f,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
};
let skip: Vec<&str> = state
.vs
.schema
.scalar_fields()
.filter(|f| f.primary_key || f.auto)
.map(|f| f.name)
.collect();
let collected = match collect_values(state.vs.schema, &form, &skip) {
Ok(v) => v,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
};
let (columns, values): (Vec<_>, Vec<_>) = collected.into_iter().unzip();
let pk_field = match state.pk_field() {
Some(f) => f,
None => {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"model has no primary key",
)
}
};
let query = InsertQuery {
model: state.vs.schema,
columns,
values,
returning: vec![pk_field.column],
on_conflict: None,
};
let row = match acq.insert_returning(&query).await {
Ok(r) => r,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
};
let pk_val = match pk_field.ty {
FieldType::I64 => SqlValue::I64(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I32 => SqlValue::I32(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I16 => SqlValue::I16(row.try_get(pk_field.column).unwrap_or(0)),
_ => return json_error(StatusCode::INTERNAL_SERVER_ERROR, "unsupported PK type"),
};
let fields = state.effective_fields();
match fetch_by_pk(&state, &mut acq, pk_field, pk_val, &fields).await {
Some(obj) => json_created(obj),
None => json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"created but could not retrieve",
),
}
}
async fn handle_update(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
update_inner(state, pk_raw, req, false).await
}
async fn handle_partial_update(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
update_inner(state, pk_raw, req, true).await
}
async fn update_inner(
state: Arc<ViewSetState>,
pk_raw: String,
req: axum::extract::Request,
partial: bool,
) -> Response {
let (mut parts, body) = req.into_parts();
let mut acq = match state.acquire(&mut parts).await {
Ok(a) => a,
Err(resp) => return resp,
};
if !state
.check_perm(&state.vs.perms.update, &parts, &mut acq)
.await
{
return json_error(StatusCode::FORBIDDEN, "permission denied");
}
let Some(pk_field) = state.pk_field() else {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"model has no primary key",
);
};
let pk_val = match parse_pk_string(pk_field, &pk_raw) {
Ok(v) => v,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
};
let form = match extract_form_body(parts, body).await {
Ok(f) => f,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e),
};
let mut assignments: Vec<Assignment> = Vec::new();
for field in state.vs.schema.scalar_fields() {
if field.primary_key || field.auto {
continue;
}
if partial && !form.contains_key(field.name) {
continue;
}
let raw = form.get(field.name).map(String::as_str);
match parse_form_value(field, raw) {
Ok(v) => assignments.push(Assignment {
column: field.column,
value: v,
}),
Err(FormError::Missing { .. }) if partial => continue,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
}
}
if assignments.is_empty() {
return json_error(StatusCode::BAD_REQUEST, "no fields to update");
}
let query = UpdateQuery {
model: state.vs.schema,
set: assignments,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val.clone(),
}),
};
if let Err(e) = acq.update(&query).await {
return json_error(StatusCode::BAD_REQUEST, &e.to_string());
}
let fields = state.effective_fields();
match fetch_by_pk(&state, &mut acq, pk_field, pk_val, &fields).await {
Some(obj) => json_response(obj),
None => json_error(StatusCode::NOT_FOUND, "not found after update"),
}
}
async fn handle_destroy(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
let (mut parts, _) = req.into_parts();
let mut acq = match state.acquire(&mut parts).await {
Ok(a) => a,
Err(resp) => return resp,
};
if !state
.check_perm(&state.vs.perms.destroy, &parts, &mut acq)
.await
{
return json_error(StatusCode::FORBIDDEN, "permission denied");
}
let Some(pk_field) = state.pk_field() else {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"model has no primary key",
);
};
let pk_val = match parse_pk_string(pk_field, &pk_raw) {
Ok(v) => v,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
};
let query = DeleteQuery {
model: state.vs.schema,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val,
}),
};
match acq.delete(&query).await {
Ok(0) => json_error(StatusCode::NOT_FOUND, "not found"),
Ok(_) => no_content(),
Err(e) => json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
}
}
async fn fetch_by_pk(
state: &ViewSetState,
acq: &mut AcquiredConn,
pk_field: &'static crate::core::FieldSchema,
pk_val: SqlValue,
fields: &[&'static crate::core::FieldSchema],
) -> Option<Value> {
let select_q = SelectQuery {
model: state.vs.schema,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val,
}),
search: None,
joins: vec![],
order_by: vec![],
limit: Some(1),
offset: None,
};
acq.select_one_row(&select_q)
.await
.ok()
.flatten()
.map(|row| match &state.vs.row_render {
Some(render) => (render)(&row),
None => row_to_json(&row, fields),
})
}
async fn extract_form_body(
parts: axum::http::request::Parts,
body: Body,
) -> Result<HashMap<String, String>, String> {
use axum::body::to_bytes;
let content_type = parts
.headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let bytes = to_bytes(body, 4 * 1024 * 1024)
.await
.map_err(|e| e.to_string())?;
if content_type.contains("application/json") {
let value: serde_json::Value = serde_json::from_slice(&bytes).map_err(|e| e.to_string())?;
let obj = value.as_object().ok_or("expected a JSON object")?;
let mut form = HashMap::new();
for (k, v) in obj {
let s = match v {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => String::new(),
other => other.to_string(),
};
form.insert(k.clone(), s);
}
Ok(form)
} else {
serde_urlencoded::from_bytes::<HashMap<String, String>>(&bytes).map_err(|e| e.to_string())
}
}
#[cfg(test)]
mod cursor_tests {
use super::{decode_cursor, encode_cursor};
#[test]
fn cursor_roundtrip_positive() {
let token = encode_cursor(12345);
assert_eq!(decode_cursor(&token), Some(12345));
}
#[test]
fn cursor_roundtrip_zero() {
let token = encode_cursor(0);
assert_eq!(decode_cursor(&token), Some(0));
}
#[test]
fn cursor_roundtrip_max() {
let token = encode_cursor(i64::MAX);
assert_eq!(decode_cursor(&token), Some(i64::MAX));
}
#[test]
fn cursor_decode_invalid_base64_returns_none() {
assert!(decode_cursor("not!valid!base64@@").is_none());
}
#[test]
fn cursor_decode_non_numeric_payload_returns_none() {
use base64::Engine;
let token = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("not_a_number");
assert!(decode_cursor(&token).is_none());
}
}
#[cfg(all(test, feature = "tenancy"))]
mod tenant_router_tests {
use super::*;
#[test]
fn tenant_router_builds_for_a_basic_model() {
use crate::core::Model as _;
let _r = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA)
.read_only()
.tenant_router("/api/users");
}
#[test]
fn tenant_router_carries_over_full_builder_chain() {
use crate::core::Model as _;
let _r = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA)
.filter_fields(&["username"])
.search_fields(&["username"])
.ordering(&[("id", true)])
.page_size(50)
.tenant_router("/api/users");
}
#[test]
fn router_and_tenant_router_set_distinct_pool_sources() {
use crate::core::Model as _;
let static_state = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA);
let vs = static_state.read_only();
let _r = vs.clone().tenant_router("/api/users");
}
}
#[cfg(test)]
mod lookup_tests {
use super::*;
use crate::core::{FieldSchema, FieldType};
fn int_field() -> &'static FieldSchema {
&FieldSchema {
name: "author_id",
column: "author_id",
ty: FieldType::I64,
nullable: false,
primary_key: false,
relation: None,
max_length: None,
min: None,
max: None,
default: None,
auto: false,
unique: false,
generated_as: None,
}
}
fn string_field() -> &'static FieldSchema {
&FieldSchema {
name: "title",
column: "title",
ty: FieldType::String,
nullable: true,
primary_key: false,
relation: None,
max_length: None,
min: None,
max: None,
default: None,
auto: false,
unique: false,
generated_as: None,
}
}
fn extract_pred(expr: WhereExpr) -> Filter {
match expr {
WhereExpr::Predicate(f) => f,
_ => panic!("expected Predicate"),
}
}
#[test]
fn no_lookup_means_eq() {
let f = extract_pred(build_lookup_filter(int_field(), None, "42").unwrap());
assert_eq!(f.op, Op::Eq);
assert!(matches!(f.value, SqlValue::I64(42)));
}
#[test]
fn explicit_exact_means_eq() {
let f = extract_pred(build_lookup_filter(int_field(), Some("exact"), "42").unwrap());
assert_eq!(f.op, Op::Eq);
}
#[test]
fn comparison_lookups() {
for (lk, expected) in [
("gt", Op::Gt),
("gte", Op::Gte),
("lt", Op::Lt),
("lte", Op::Lte),
("ne", Op::Ne),
] {
let f = extract_pred(build_lookup_filter(int_field(), Some(lk), "10").unwrap());
assert_eq!(f.op, expected, "lookup {lk}");
}
}
#[test]
fn in_lookup_parses_csv() {
let f = extract_pred(build_lookup_filter(int_field(), Some("in"), "1,2,3").unwrap());
assert_eq!(f.op, Op::In);
match f.value {
SqlValue::List(v) => assert_eq!(v.len(), 3),
_ => panic!("expected List"),
}
}
#[test]
fn not_in_lookup_parses_csv() {
let f = extract_pred(build_lookup_filter(int_field(), Some("not_in"), "1,2").unwrap());
assert_eq!(f.op, Op::NotIn);
}
#[test]
fn in_lookup_drops_empty_entries() {
let f = extract_pred(build_lookup_filter(int_field(), Some("in"), "1,,2,").unwrap());
match f.value {
SqlValue::List(v) => assert_eq!(v.len(), 2),
_ => panic!("expected List"),
}
}
#[test]
fn contains_wraps_with_percents_and_uses_like() {
let f =
extract_pred(build_lookup_filter(string_field(), Some("contains"), "hello").unwrap());
assert_eq!(f.op, Op::Like);
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%hello%"));
}
#[test]
fn icontains_uses_ilike() {
let f = extract_pred(build_lookup_filter(string_field(), Some("icontains"), "hi").unwrap());
assert_eq!(f.op, Op::ILike);
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%hi%"));
}
#[test]
fn startswith_only_trailing_percent() {
let f =
extract_pred(build_lookup_filter(string_field(), Some("startswith"), "pre").unwrap());
assert!(matches!(f.value, SqlValue::String(ref s) if s == "pre%"));
}
#[test]
fn endswith_only_leading_percent() {
let f = extract_pred(build_lookup_filter(string_field(), Some("endswith"), "fix").unwrap());
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%fix"));
}
#[test]
fn isnull_true() {
let f = extract_pred(build_lookup_filter(string_field(), Some("isnull"), "true").unwrap());
assert_eq!(f.op, Op::IsNull);
assert!(matches!(f.value, SqlValue::Bool(true)));
}
#[test]
fn isnull_false() {
let f = extract_pred(build_lookup_filter(string_field(), Some("isnull"), "false").unwrap());
assert!(matches!(f.value, SqlValue::Bool(false)));
}
#[test]
fn unknown_lookup_returns_none() {
let r = build_lookup_filter(int_field(), Some("frobulate"), "x");
assert!(r.is_none());
}
#[test]
fn parse_failure_returns_none() {
let r = build_lookup_filter(int_field(), Some("gt"), "not-a-number");
assert!(r.is_none());
}
}
#[cfg(all(test, feature = "tenancy"))]
mod typed_perms_tests {
use super::*;
use crate::sql::Auto;
#[derive(crate::Model)]
#[rustango(table = "vs_typed_perm_post")]
#[allow(dead_code)]
pub struct PermPost {
#[rustango(primary_key)]
pub id: Auto<i64>,
#[rustango(max_length = 200)]
pub title: String,
}
#[test]
fn permissions_for_model_fills_all_four_crud_codenames() {
use crate::core::Model;
let vs =
ViewSet::for_model(<PermPost as Model>::SCHEMA).permissions_for_model::<PermPost>();
assert_eq!(vs.perms.list, vec!["vs_typed_perm_post.view"]);
assert_eq!(vs.perms.retrieve, vec!["vs_typed_perm_post.view"]);
assert_eq!(vs.perms.create, vec!["vs_typed_perm_post.add"]);
assert_eq!(vs.perms.update, vec!["vs_typed_perm_post.change"]);
assert_eq!(vs.perms.destroy, vec!["vs_typed_perm_post.delete"]);
}
}