#[cfg(feature = "openapi")]
mod openapi;
use std::collections::HashMap;
use std::future::Future;
#[cfg(feature = "serializer")]
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use axum::body::Body;
use axum::extract::{Path, Query, State};
use axum::http::{header, StatusCode};
use axum::response::Response;
use axum::routing::get;
use axum::Router;
use serde_json::{json, Value};
use crate::core::{
Assignment, CountQuery, DeleteQuery, FieldType, Filter, InsertQuery, ModelSchema, Op,
SearchClause, SelectQuery, SqlValue, UpdateQuery, WhereExpr,
};
use crate::forms::{collect_values, parse_form_value, parse_pk_string, FormError};
use crate::sql::Pool;
#[derive(Clone, Default)]
pub struct ViewSetPerms {
pub list: Vec<String>,
pub retrieve: Vec<String>,
pub create: Vec<String>,
pub update: Vec<String>,
pub destroy: Vec<String>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ThrottleRule {
pub max: u32,
pub window_secs: u64,
}
impl ThrottleRule {
#[must_use]
pub const fn new(max: u32, window_secs: u64) -> Self {
Self { max, window_secs }
}
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct ViewSetThrottle {
pub list: Option<ThrottleRule>,
pub retrieve: Option<ThrottleRule>,
pub create: Option<ThrottleRule>,
pub update: Option<ThrottleRule>,
pub destroy: Option<ThrottleRule>,
}
impl ViewSetThrottle {
#[must_use]
pub const fn all(max: u32, window_secs: u64) -> Self {
let r = Some(ThrottleRule::new(max, window_secs));
Self {
list: r,
retrieve: r,
create: r,
update: r,
destroy: r,
}
}
#[must_use]
pub fn for_action(&self, action: &str) -> Option<ThrottleRule> {
match action {
"list" => self.list,
"retrieve" => self.retrieve,
"create" => self.create,
"update" => self.update,
"destroy" => self.destroy,
_ => None,
}
}
}
#[derive(Clone, Debug)]
pub enum PaginationStyle {
PageNumber,
Cursor {
field: &'static str,
desc: bool,
},
LimitOffset,
}
impl PaginationStyle {
#[must_use]
pub const fn page_number() -> Self {
Self::PageNumber
}
#[must_use]
pub const fn cursor(field: &'static str) -> Self {
Self::Cursor { field, desc: false }
}
#[must_use]
pub const fn cursor_desc(field: &'static str) -> Self {
Self::Cursor { field, desc: true }
}
#[must_use]
pub const fn limit_offset() -> Self {
Self::LimitOffset
}
}
trait SerializerBridge: Send + Sync {
fn render_rows<'a>(
&'a self,
acq: &'a mut AcquiredConn,
q: &'a SelectQuery,
) -> Pin<Box<dyn Future<Output = Result<Vec<Value>, crate::sql::ExecError>> + Send + 'a>>;
fn render_one<'a>(
&'a self,
acq: &'a mut AcquiredConn,
q: &'a SelectQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<Value>, crate::sql::ExecError>> + Send + 'a>>;
fn validate_body(&self, body: &Value) -> Result<(), crate::forms::FormErrors>;
fn writable_model_fields(&self) -> &'static [&'static str];
fn writable_field_names(&self) -> &'static [&'static str];
}
#[cfg(feature = "serializer")]
struct Bridge<S>(PhantomData<S>);
#[cfg(feature = "serializer")]
impl<S> SerializerBridge for Bridge<S>
where
S: crate::serializer::ModelSerializer + Send + Sync + 'static,
S::Model: crate::sql::MaybePgFromRow
+ crate::sql::MaybeMyFromRow
+ crate::sql::MaybeSqliteFromRow
+ crate::sql::LoadRelated
+ crate::sql::MaybeMyLoadRelated
+ crate::sql::MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
fn render_rows<'a>(
&'a self,
acq: &'a mut AcquiredConn,
q: &'a SelectQuery,
) -> Pin<Box<dyn Future<Output = Result<Vec<Value>, crate::sql::ExecError>> + Send + 'a>> {
Box::pin(async move {
let models = acq.select_rows_typed::<S::Model>(q).await?;
Ok(models.iter().map(|m| S::from_model(m).to_value()).collect())
})
}
fn render_one<'a>(
&'a self,
acq: &'a mut AcquiredConn,
q: &'a SelectQuery,
) -> Pin<Box<dyn Future<Output = Result<Option<Value>, crate::sql::ExecError>> + Send + 'a>>
{
Box::pin(async move {
let models = acq.select_rows_typed::<S::Model>(q).await?;
Ok(models.first().map(|m| S::from_model(m).to_value()))
})
}
fn validate_body(&self, body: &Value) -> Result<(), crate::forms::FormErrors> {
let s = S::from_writable_json(body)?;
s.validate()
}
fn writable_model_fields(&self) -> &'static [&'static str] {
S::writable_source_fields()
}
fn writable_field_names(&self) -> &'static [&'static str] {
S::writable_fields()
}
}
pub trait ViewSetFilter: Send + Sync + 'static {
fn filter(
&self,
params: &HashMap<String, String>,
schema: &'static ModelSchema,
) -> Vec<WhereExpr>;
}
impl<F> ViewSetFilter for F
where
F: Fn(&HashMap<String, String>, &'static ModelSchema) -> Vec<WhereExpr> + Send + Sync + 'static,
{
fn filter(
&self,
params: &HashMap<String, String>,
schema: &'static ModelSchema,
) -> Vec<WhereExpr> {
self(params, schema)
}
}
#[derive(Clone)]
pub struct ViewSet {
schema: &'static ModelSchema,
fields: Option<Vec<String>>,
filter_fields: Vec<String>,
search_fields: Vec<String>,
ordering_fields: Vec<String>,
default_page_size: usize,
default_ordering: Vec<(String, bool)>,
perms: ViewSetPerms,
read_only: bool,
pagination: PaginationStyle,
filter_backends: Vec<std::sync::Arc<dyn ViewSetFilter>>,
throttle: ViewSetThrottle,
serializer: Option<Arc<dyn SerializerBridge>>,
}
impl ViewSet {
pub fn for_model(schema: &'static ModelSchema) -> Self {
Self {
schema,
fields: None,
filter_fields: Vec::new(),
search_fields: Vec::new(),
ordering_fields: Vec::new(),
default_page_size: 20,
default_ordering: Vec::new(),
perms: ViewSetPerms::default(),
read_only: false,
pagination: PaginationStyle::PageNumber,
filter_backends: Vec::new(),
throttle: ViewSetThrottle::default(),
serializer: None,
}
}
#[cfg(feature = "serializer")]
#[must_use]
pub fn serializer<S>(mut self) -> Self
where
S: crate::serializer::ModelSerializer + Send + Sync + 'static,
S::Model: crate::sql::MaybePgFromRow
+ crate::sql::MaybeMyFromRow
+ crate::sql::MaybeSqliteFromRow
+ crate::sql::LoadRelated
+ crate::sql::MaybeMyLoadRelated
+ crate::sql::MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
self.serializer = Some(Arc::new(Bridge::<S>(PhantomData)));
self
}
#[must_use]
pub fn cursor_pagination(mut self, field: &'static str) -> Self {
self.pagination = PaginationStyle::Cursor { field, desc: false };
self
}
#[must_use]
pub fn cursor_pagination_desc(mut self, field: &'static str) -> Self {
self.pagination = PaginationStyle::Cursor { field, desc: true };
self
}
#[must_use]
pub fn limit_offset_pagination(mut self) -> Self {
self.pagination = PaginationStyle::LimitOffset;
self
}
#[must_use]
pub fn pagination(mut self, style: PaginationStyle) -> Self {
self.pagination = style;
self
}
#[must_use]
pub fn filter_backend(mut self, backend: impl ViewSetFilter) -> Self {
self.filter_backends.push(std::sync::Arc::new(backend));
self
}
#[must_use]
pub fn throttle(mut self, throttle: ViewSetThrottle) -> Self {
self.throttle = throttle;
self
}
#[must_use]
pub fn throttle_all(mut self, max: u32, window_secs: u64) -> Self {
self.throttle = ViewSetThrottle::all(max, window_secs);
self
}
pub fn fields(mut self, fields: &[&str]) -> Self {
self.fields = Some(fields.iter().map(|&s| s.to_owned()).collect());
self
}
pub fn filter_fields(mut self, fields: &[&str]) -> Self {
self.filter_fields = fields.iter().map(|&s| s.to_owned()).collect();
self
}
pub fn search_fields(mut self, fields: &[&str]) -> Self {
self.search_fields = fields.iter().map(|&s| s.to_owned()).collect();
self
}
pub fn ordering_fields(mut self, fields: &[&str]) -> Self {
self.ordering_fields = fields.iter().map(|&s| s.to_owned()).collect();
self
}
pub fn page_size(mut self, n: usize) -> Self {
self.default_page_size = n.min(1000);
self
}
pub fn ordering(mut self, ordering: &[(&str, bool)]) -> Self {
self.default_ordering = ordering.iter().map(|&(f, d)| (f.to_owned(), d)).collect();
self
}
pub fn permissions(mut self, perms: ViewSetPerms) -> Self {
self.perms = perms;
self
}
#[cfg(feature = "tenancy")]
pub fn permissions_for_model<T: crate::core::Model>(mut self) -> Self {
let cn = |action: &str| crate::permissions::codename_for::<T>(action);
self.perms = ViewSetPerms {
list: vec![cn("view")],
retrieve: vec![cn("view")],
create: vec![cn("add")],
update: vec![cn("change")],
destroy: vec![cn("delete")],
};
self
}
pub fn read_only(mut self) -> Self {
self.read_only = true;
self
}
#[cfg(feature = "postgres")]
pub fn router(self, prefix: &str, pool: crate::sql::sqlx::PgPool) -> Router {
Self::router_with_source(self, prefix, PoolSource::Static(pool))
}
pub fn router_pool(self, prefix: &str, pool: crate::sql::Pool) -> Router {
Self::router_with_source(self, prefix, PoolSource::StaticPool(pool))
}
#[cfg(feature = "tenancy")]
#[must_use]
pub fn tenant_router(self, prefix: &str) -> Router {
Self::router_with_source(self, prefix, PoolSource::Tenant)
}
fn router_with_source(self, prefix: &str, pool_source: PoolSource) -> Router {
let state = Arc::new(ViewSetState {
pool_source,
vs: self.clone(),
throttle_store: Arc::new(Mutex::new(HashMap::new())),
});
let prefix = prefix.trim_end_matches('/').to_owned();
let collection = prefix.clone();
let item = format!("{prefix}/{{pk}}");
let collection_route = if self.read_only {
get(handle_list)
} else {
get(handle_list).post(handle_create)
};
let item_route = if self.read_only {
axum::routing::MethodRouter::new().get(handle_retrieve)
} else {
axum::routing::MethodRouter::new()
.get(handle_retrieve)
.put(handle_update)
.patch(handle_partial_update)
.delete(handle_destroy)
};
Router::new()
.route(&collection, collection_route)
.route(&item, item_route)
.with_state(state)
}
}
#[derive(Clone)]
enum PoolSource {
#[cfg(feature = "postgres")]
Static(crate::sql::sqlx::PgPool),
StaticPool(crate::sql::Pool),
#[cfg(feature = "tenancy")]
Tenant,
}
#[derive(Clone)]
struct ViewSetState {
pool_source: PoolSource,
vs: ViewSet,
throttle_store: Arc<Mutex<HashMap<String, (u32, Instant)>>>,
}
struct AcquiredConn {
pool: Pool,
#[cfg(feature = "tenancy")]
#[allow(dead_code)]
_tenant: Option<Box<crate::extractors::Tenant>>,
}
impl AcquiredConn {
async fn select_rows_as_json(
&mut self,
q: &SelectQuery,
fields: &[&'static crate::core::FieldSchema],
) -> Result<Vec<Value>, crate::sql::ExecError> {
crate::sql::select_rows_as_json(&self.pool, q, fields).await
}
async fn count_rows(&mut self, q: &CountQuery) -> Result<i64, crate::sql::ExecError> {
crate::sql::count_rows_pool(&self.pool, q).await
}
async fn select_one_as_json(
&mut self,
q: &SelectQuery,
fields: &[&'static crate::core::FieldSchema],
) -> Result<Option<Value>, crate::sql::ExecError> {
let mut rows = crate::sql::select_rows_as_json(&self.pool, q, fields).await?;
Ok(rows.pop())
}
#[cfg(feature = "serializer")]
async fn select_rows_typed<T>(
&mut self,
q: &SelectQuery,
) -> Result<Vec<T>, crate::sql::ExecError>
where
T: crate::sql::MaybePgFromRow
+ crate::sql::MaybeMyFromRow
+ crate::sql::MaybeSqliteFromRow
+ crate::sql::LoadRelated
+ crate::sql::MaybeMyLoadRelated
+ crate::sql::MaybeSqliteLoadRelated
+ Send
+ Unpin,
{
crate::sql::select_rows_pool_with_related::<T>(&self.pool, q).await
}
async fn insert_returning_pk(
&mut self,
q: &InsertQuery,
pk_field: &crate::core::FieldSchema,
) -> Result<SqlValue, crate::sql::ExecError> {
let returning = crate::sql::insert_returning_pool(&self.pool, q).await?;
let pk = match returning {
#[cfg(feature = "postgres")]
crate::sql::InsertReturningPool::PgRow(row) => {
use crate::sql::sqlx::Row as _;
match pk_field.ty {
FieldType::I64 => SqlValue::I64(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I32 => SqlValue::I32(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I16 => SqlValue::I16(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::String => {
SqlValue::String(row.try_get(pk_field.column).unwrap_or_default())
}
_ => SqlValue::Null,
}
}
#[cfg(feature = "mysql")]
crate::sql::InsertReturningPool::MySqlAutoId(id) => match pk_field.ty {
FieldType::I64 => SqlValue::I64(id),
FieldType::I32 => SqlValue::I32(id as i32),
FieldType::I16 => SqlValue::I16(id as i16),
_ => SqlValue::I64(id),
},
#[cfg(feature = "sqlite")]
crate::sql::InsertReturningPool::SqliteRow(row) => {
use crate::sql::sqlx::Row as _;
match pk_field.ty {
FieldType::I64 => SqlValue::I64(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I32 => SqlValue::I32(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::I16 => SqlValue::I16(row.try_get(pk_field.column).unwrap_or(0)),
FieldType::String => {
SqlValue::String(row.try_get(pk_field.column).unwrap_or_default())
}
_ => SqlValue::Null,
}
}
};
Ok(pk)
}
async fn update(&mut self, q: &UpdateQuery) -> Result<u64, crate::sql::ExecError> {
crate::sql::update_pool(&self.pool, q).await
}
async fn delete(&mut self, q: &DeleteQuery) -> Result<u64, crate::sql::ExecError> {
crate::sql::delete_pool(&self.pool, q).await
}
#[cfg(feature = "tenancy")]
async fn has_perm(&mut self, uid: i64, codename: &str) -> bool {
crate::tenancy::permissions::has_perm_pool(uid, codename, &self.pool)
.await
.unwrap_or(false)
}
}
impl ViewSetState {
fn effective_fields(&self) -> Vec<&'static crate::core::FieldSchema> {
let schema = self.vs.schema;
match &self.vs.fields {
Some(names) => names.iter().filter_map(|n| schema.field(n)).collect(),
None => schema.scalar_fields().collect(),
}
}
async fn acquire(
&self,
parts: &mut axum::http::request::Parts,
) -> Result<AcquiredConn, Response> {
match &self.pool_source {
#[cfg(feature = "postgres")]
PoolSource::Static(pool) => Ok(AcquiredConn {
pool: Pool::from(pool.clone()),
#[cfg(feature = "tenancy")]
_tenant: None,
}),
PoolSource::StaticPool(pool) => Ok(AcquiredConn {
pool: pool.clone(),
#[cfg(feature = "tenancy")]
_tenant: None,
}),
#[cfg(feature = "tenancy")]
PoolSource::Tenant => {
use axum::response::IntoResponse as _;
let t = <crate::extractors::Tenant<
crate::tenancy::DefaultTenantDb,
> as axum::extract::FromRequestParts<()>>::from_request_parts(parts, &())
.await
.map_err(|e| e.into_response())?;
let pool = t.pool().clone();
Ok(AcquiredConn {
pool,
_tenant: Some(Box::new(t)),
})
}
}
}
async fn check_perm(
&self,
codenames: &[String],
parts: &axum::http::request::Parts,
conn: &mut AcquiredConn,
) -> bool {
if codenames.is_empty() {
return true;
}
#[cfg(feature = "tenancy")]
{
let Some(auth) = parts
.extensions
.get::<crate::tenancy::middleware::AuthenticatedUser>()
else {
return false;
};
if auth.is_superuser {
return true;
}
for cn in codenames {
if conn.has_perm(auth.id, cn).await {
return true;
}
}
false
}
#[cfg(not(feature = "tenancy"))]
{
let _ = (parts, conn);
false
}
}
fn pk_field(&self) -> Option<&'static crate::core::FieldSchema> {
self.vs.schema.primary_key()
}
}
fn json_with_status(status: StatusCode, body: Value) -> Response {
Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body.to_string()))
.unwrap()
}
fn json_response(body: Value) -> Response {
json_with_status(StatusCode::OK, body)
}
fn json_error(status: StatusCode, msg: &str) -> Response {
json_with_status(status, json!({ "error": msg }))
}
fn json_form_errors(errs: &crate::forms::FormErrors) -> Response {
let mut map = serde_json::Map::new();
for (field, msgs) in errs.fields() {
map.insert(field.clone(), json!(msgs));
}
if !errs.non_field().is_empty() {
map.insert("non_field_errors".to_owned(), json!(errs.non_field()));
}
json_with_status(StatusCode::BAD_REQUEST, Value::Object(map))
}
fn serializer_input_renamed_form(
state: &ViewSetState,
form: &HashMap<String, String>,
) -> Option<HashMap<String, String>> {
let bridge = state.vs.serializer.as_ref()?;
let renames: Vec<(&'static str, &'static str)> = bridge
.writable_field_names()
.iter()
.zip(bridge.writable_model_fields().iter())
.filter(|(name, col)| name != col && form.contains_key(**name))
.map(|(name, col)| (*name, *col))
.collect();
if renames.is_empty() {
return None;
}
let mut out = form.clone();
for (name, col) in renames {
if let Some(v) = out.remove(name) {
out.entry(col.to_owned()).or_insert(v);
}
}
Some(out)
}
fn serializer_write_prep(
state: &ViewSetState,
json: Option<&Value>,
) -> Result<Vec<&'static str>, Response> {
match &state.vs.serializer {
Some(bridge) => {
if let Some(body) = json {
if let Err(errs) = bridge.validate_body(body) {
return Err(json_form_errors(&errs));
}
}
let writable = bridge.writable_model_fields();
let extra_skip = state
.vs
.schema
.scalar_fields()
.map(|f| f.name)
.filter(|n| !writable.contains(n))
.collect();
Ok(extra_skip)
}
None => Ok(Vec::new()),
}
}
macro_rules! or_500 {
($expr:expr) => {
match $expr {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(e) => {
return json_error(
::axum::http::StatusCode::INTERNAL_SERVER_ERROR,
&::std::string::ToString::to_string(&e),
);
}
}
};
}
macro_rules! or_400 {
($expr:expr) => {
match $expr {
::core::result::Result::Ok(v) => v,
::core::result::Result::Err(e) => {
return json_error(
::axum::http::StatusCode::BAD_REQUEST,
&::std::string::ToString::to_string(&e),
);
}
}
};
}
async fn enter(
state: &Arc<ViewSetState>,
req: axum::extract::Request,
codenames: &[String],
action: &'static str,
) -> Result<(axum::http::request::Parts, Body, AcquiredConn), Response> {
let (mut parts, body) = req.into_parts();
if let Some(resp) = check_throttle(state, action, &parts) {
return Err(resp);
}
let mut acq = state.acquire(&mut parts).await?;
if !state.check_perm(codenames, &parts, &mut acq).await {
return Err(json_error(StatusCode::FORBIDDEN, "permission denied"));
}
Ok((parts, body, acq))
}
fn check_throttle(
state: &ViewSetState,
action: &str,
parts: &axum::http::request::Parts,
) -> Option<Response> {
let rule = state.vs.throttle.for_action(action)?;
let client = client_key(parts);
let key = format!("{}:{}:{}", state.vs.schema.table, action, client);
let now = Instant::now();
let window = std::time::Duration::from_secs(rule.window_secs);
let mut store = match state.throttle_store.lock() {
Ok(g) => g,
Err(_) => return None,
};
let entry = store.entry(key).or_insert((0, now));
if now.duration_since(entry.1) >= window {
*entry = (0, now); }
entry.0 += 1;
if entry.0 > rule.max {
let retry = window
.checked_sub(now.duration_since(entry.1))
.map_or(1, |d| d.as_secs().max(1));
return Some(throttled_response(retry));
}
None
}
fn client_key(parts: &axum::http::request::Parts) -> String {
if let Some(ci) = parts
.extensions
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
{
return ci.0.ip().to_string();
}
for h in ["x-forwarded-for", "x-real-ip"] {
if let Some(first) = parts
.headers
.get(h)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.split(',').next())
.map(str::trim)
.filter(|s| !s.is_empty())
{
return first.to_owned();
}
}
"global".to_owned()
}
fn throttled_response(retry_after_secs: u64) -> Response {
Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(header::CONTENT_TYPE, "application/json")
.header(header::RETRY_AFTER, retry_after_secs.to_string())
.body(Body::from(
json!({ "error": "request throttled" }).to_string(),
))
.unwrap()
}
fn parse_pk_or_400(
field: &'static crate::core::FieldSchema,
raw: &str,
) -> Result<crate::core::SqlValue, Response> {
parse_pk_string(field, raw).map_err(|e| json_error(StatusCode::BAD_REQUEST, &e.to_string()))
}
fn pk_field_or_500(state: &ViewSetState) -> Result<&'static crate::core::FieldSchema, Response> {
state.pk_field().ok_or_else(|| {
json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"model has no primary key",
)
})
}
fn json_created(body: Value) -> Response {
json_with_status(StatusCode::CREATED, body)
}
fn no_content() -> Response {
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.unwrap()
}
fn build_lookup_filter(
field: &'static crate::core::FieldSchema,
lookup: Option<&str>,
raw: &str,
) -> Option<WhereExpr> {
let column = field.column;
let predicate =
|op: Op, value: SqlValue| Some(WhereExpr::Predicate(Filter { column, op, value }));
let binary_op = match lookup.unwrap_or("exact") {
"exact" => Some(Op::Eq),
"ne" => Some(Op::Ne),
"gt" => Some(Op::Gt),
"gte" => Some(Op::Gte),
"lt" => Some(Op::Lt),
"lte" => Some(Op::Lte),
_ => None,
};
if let Some(op) = binary_op {
return parse_form_value(field, Some(raw))
.ok()
.and_then(|v| predicate(op, v));
}
match lookup.unwrap_or("exact") {
"in" | "not_in" => {
let parts: Vec<SqlValue> = raw
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.filter_map(|s| parse_form_value(field, Some(s)).ok())
.collect();
if parts.is_empty() {
return None;
}
let op = if lookup == Some("not_in") {
Op::NotIn
} else {
Op::In
};
predicate(op, SqlValue::List(parts))
}
"contains" => predicate(Op::Like, SqlValue::String(format!("%{raw}%"))),
"icontains" => predicate(Op::ILike, SqlValue::String(format!("%{raw}%"))),
"startswith" => predicate(Op::Like, SqlValue::String(format!("{raw}%"))),
"istartswith" => predicate(Op::ILike, SqlValue::String(format!("{raw}%"))),
"endswith" => predicate(Op::Like, SqlValue::String(format!("%{raw}"))),
"iendswith" => predicate(Op::ILike, SqlValue::String(format!("%{raw}"))),
"isnull" => {
let is_null = matches!(raw.to_ascii_lowercase().as_str(), "true" | "1" | "yes");
predicate(Op::IsNull, SqlValue::Bool(is_null))
}
_ => None, }
}
async fn handle_list(
State(state): State<Arc<ViewSetState>>,
Query(params): Query<HashMap<String, String>>,
req: axum::extract::Request,
) -> Response {
let (_parts, _body, mut acq) = match enter(&state, req, &state.vs.perms.list, "list").await {
Ok(x) => x,
Err(resp) => return resp,
};
let page_size: i64 = params
.get("page_size")
.and_then(|p| p.parse().ok())
.unwrap_or(state.vs.default_page_size as i64)
.min(1000)
.max(1);
let mut filters: Vec<WhereExpr> = Vec::new();
for (param_key, raw_val) in ¶ms {
if crate::list_params::is_reserved_list_key(param_key) {
continue;
}
let (field_name, lookup) = match param_key.split_once("__") {
Some((name, lk)) => (name, Some(lk)),
None => (param_key.as_str(), None),
};
if !state.vs.filter_fields.iter().any(|f| f == field_name) {
continue;
}
let Some(field) = state.vs.schema.field(field_name) else {
continue;
};
if let Some(predicate) = build_lookup_filter(field, lookup, raw_val) {
filters.push(predicate);
}
}
for backend in &state.vs.filter_backends {
filters.extend(backend.filter(¶ms, state.vs.schema));
}
let where_clause = if filters.len() == 1 {
filters.remove(0)
} else if filters.is_empty() {
WhereExpr::And(vec![])
} else {
WhereExpr::And(filters)
};
let search = params.get("search").filter(|s| !s.is_empty()).cloned();
let search_clause = search.map(|q| SearchClause {
query: q,
columns: state
.vs
.search_fields
.iter()
.filter_map(|n| state.vs.schema.field(n).map(|f| f.column))
.collect(),
});
let order_by: Vec<crate::core::OrderItem> = params
.get("ordering")
.map(|raw| {
crate::list_params::parse_ordering(raw, &state.vs.ordering_fields, state.vs.schema)
})
.filter(|v| !v.is_empty())
.unwrap_or_else(|| {
state
.vs
.default_ordering
.iter()
.filter_map(|(name, desc)| {
state
.vs
.schema
.field(name)
.map(|f| crate::core::OrderItem::column(f.column, *desc))
})
.collect()
});
let fields = state.effective_fields();
match &state.vs.pagination {
PaginationStyle::PageNumber => {
let page: i64 = params
.get("page")
.and_then(|p| p.parse().ok())
.unwrap_or(1)
.max(1);
let offset = (page - 1) * page_size;
let select_q = SelectQuery {
where_clause: where_clause.clone(),
search: search_clause.clone(),
order_by: order_by.clone(),
limit: Some(page_size),
offset: Some(offset),
..SelectQuery::new(state.vs.schema)
};
let count_q = CountQuery {
model: state.vs.schema,
where_clause,
search: search_clause.clone(),
};
let results = or_500!(render_list(&state, &mut acq, &select_q, &fields).await);
let count = or_500!(acq.count_rows(&count_q).await);
let last_page = ((count - 1).max(0) / page_size) + 1;
json_response(json!({
"count": count,
"page": page,
"page_size": page_size,
"last_page": last_page,
"results": results,
}))
}
PaginationStyle::Cursor {
field: cursor_field,
desc,
} => {
handle_list_cursor(
state.as_ref(),
&mut acq,
params,
where_clause,
search_clause,
fields,
page_size,
cursor_field,
*desc,
)
.await
}
PaginationStyle::LimitOffset => {
let limit: i64 = params
.get("limit")
.and_then(|p| p.parse().ok())
.unwrap_or(page_size)
.min(1000)
.max(1);
let offset: i64 = params
.get("offset")
.and_then(|p| p.parse().ok())
.unwrap_or(0)
.max(0);
let select_q = SelectQuery {
where_clause: where_clause.clone(),
search: search_clause.clone(),
order_by,
limit: Some(limit),
offset: Some(offset),
..SelectQuery::new(state.vs.schema)
};
let count_q = CountQuery {
model: state.vs.schema,
where_clause,
search: search_clause,
};
let results = or_500!(render_list(&state, &mut acq, &select_q, &fields).await);
let count = or_500!(acq.count_rows(&count_q).await);
json_response(json!({
"count": count,
"limit": limit,
"offset": offset,
"results": results,
}))
}
}
}
async fn handle_list_cursor(
state: &ViewSetState,
acq: &mut AcquiredConn,
params: HashMap<String, String>,
where_clause: WhereExpr,
search_clause: Option<SearchClause>,
fields: Vec<&'static crate::core::FieldSchema>,
page_size: i64,
cursor_field: &str,
desc: bool,
) -> Response {
let Some(cursor_schema) = state.vs.schema.field(cursor_field) else {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
&format!("cursor field `{cursor_field}` not found on model"),
);
};
if !matches!(
cursor_schema.ty,
FieldType::I16 | FieldType::I32 | FieldType::I64
) {
return json_error(
StatusCode::INTERNAL_SERVER_ERROR,
"cursor pagination requires an integer field (i16/i32/i64)",
);
}
let cursor_val: Option<i64> = match params.get("cursor") {
Some(c) if !c.is_empty() => match decode_cursor(c) {
Some(v) => Some(v),
None => return json_error(StatusCode::BAD_REQUEST, "invalid cursor"),
},
_ => None,
};
let final_where = match cursor_val {
Some(v) => {
let op = if desc { Op::Lt } else { Op::Gt };
let cursor_pred = WhereExpr::Predicate(Filter {
column: cursor_schema.column,
op,
value: SqlValue::I64(v),
});
match where_clause {
WhereExpr::And(v) if v.is_empty() => cursor_pred,
WhereExpr::And(mut v) => {
v.push(cursor_pred);
WhereExpr::And(v)
}
other => WhereExpr::And(vec![other, cursor_pred]),
}
}
None => where_clause,
};
let order_by = vec![crate::core::OrderItem::column(cursor_schema.column, desc)];
let select_q = SelectQuery {
where_clause: final_where,
search: search_clause,
order_by,
limit: Some(page_size + 1),
..SelectQuery::new(state.vs.schema)
};
let rows = or_500!(render_list(state, acq, &select_q, &fields).await);
let has_more = rows.len() as i64 > page_size;
let page_rows: &[Value] = if has_more {
&rows[..page_size as usize]
} else {
&rows[..]
};
let next_cursor = if has_more {
let last = page_rows.last().expect("non-empty page");
let val: i64 = last
.get(cursor_schema.name)
.and_then(|v| v.as_i64())
.unwrap_or(0);
Some(encode_cursor(val))
} else {
None
};
let results: Vec<Value> = page_rows.to_vec();
json_response(json!({
"page_size": page_size,
"next": next_cursor,
"results": results,
}))
}
fn encode_cursor(value: i64) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(value.to_string().as_bytes())
}
fn decode_cursor(token: &str) -> Option<i64> {
use base64::Engine;
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(token.as_bytes())
.ok()?;
let s = std::str::from_utf8(&bytes).ok()?;
s.parse::<i64>().ok()
}
async fn handle_retrieve(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
let (_parts, _body, mut acq) =
match enter(&state, req, &state.vs.perms.retrieve, "retrieve").await {
Ok(x) => x,
Err(resp) => return resp,
};
let pk_field = match pk_field_or_500(&state) {
Ok(f) => f,
Err(resp) => return resp,
};
let pk_val = match parse_pk_or_400(pk_field, &pk_raw) {
Ok(v) => v,
Err(resp) => return resp,
};
let select_q = SelectQuery::by_pk(state.vs.schema, pk_field.column, pk_val);
let fields = state.effective_fields();
match render_single(&state, &mut acq, &select_q, &fields).await {
Ok(Some(row)) => json_response(row),
Ok(None) => json_error(StatusCode::NOT_FOUND, "not found"),
Err(e) => json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
}
}
async fn handle_create(
State(state): State<Arc<ViewSetState>>,
req: axum::extract::Request,
) -> Response {
let (parts, body, mut acq) = match enter(&state, req, &state.vs.perms.create, "create").await {
Ok(x) => x,
Err(resp) => return resp,
};
let create_body = or_400!(extract_create_body(parts, body).await);
let skip: Vec<&str> = state
.vs
.schema
.scalar_fields()
.filter(|f| f.primary_key || f.auto)
.map(|f| f.name)
.collect();
let pk_field = match pk_field_or_500(&state) {
Ok(f) => f,
Err(resp) => return resp,
};
match create_body {
CreateBody::Single(form, json) => {
create_one(&state, &mut acq, &form, json.as_ref(), &skip, pk_field).await
}
CreateBody::Bulk(rows) => create_many(&state, &mut acq, &rows, &skip, pk_field).await,
}
}
async fn insert_and_fetch_one(
state: &Arc<ViewSetState>,
acq: &mut AcquiredConn,
columns: Vec<&'static str>,
values: Vec<SqlValue>,
pk_field: &'static crate::core::FieldSchema,
fields: &[&'static crate::core::FieldSchema],
) -> Result<Value, (StatusCode, String)> {
let query = InsertQuery {
model: state.vs.schema,
columns,
values,
returning: vec![pk_field.column],
on_conflict: None,
};
let pk_val = acq
.insert_returning_pk(&query, pk_field)
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
fetch_by_pk(state, acq, pk_field, pk_val, fields)
.await
.ok_or_else(|| {
(
StatusCode::INTERNAL_SERVER_ERROR,
"created but could not retrieve".to_owned(),
)
})
}
async fn create_one(
state: &Arc<ViewSetState>,
acq: &mut AcquiredConn,
form: &HashMap<String, String>,
json: Option<&Value>,
skip: &[&str],
pk_field: &'static crate::core::FieldSchema,
) -> Response {
let extra_skip = match serializer_write_prep(state, json) {
Ok(s) => s,
Err(resp) => return resp,
};
let mut all_skip: Vec<&str> = skip.to_vec();
all_skip.extend(extra_skip);
let renamed = serializer_input_renamed_form(state, form);
let form = renamed.as_ref().unwrap_or(form);
let collected = or_400!(collect_values(state.vs.schema, form, &all_skip));
let (columns, values): (Vec<_>, Vec<_>) = collected.into_iter().unzip();
let fields = state.effective_fields();
match insert_and_fetch_one(state, acq, columns, values, pk_field, &fields).await {
Ok(obj) => json_created(obj),
Err((code, msg)) => json_error(code, &msg),
}
}
async fn create_many(
state: &Arc<ViewSetState>,
acq: &mut AcquiredConn,
rows: &[(HashMap<String, String>, Option<Value>)],
skip: &[&str],
pk_field: &'static crate::core::FieldSchema,
) -> Response {
if rows.is_empty() {
return json_created(Value::Array(Vec::new()));
}
let mut prepared: Vec<(Vec<&'static str>, Vec<SqlValue>)> = Vec::with_capacity(rows.len());
for (i, (row, json)) in rows.iter().enumerate() {
let extra_skip = match serializer_write_prep(state, json.as_ref()) {
Ok(s) => s,
Err(resp) => return resp,
};
let mut all_skip: Vec<&str> = skip.to_vec();
all_skip.extend(extra_skip);
let renamed = serializer_input_renamed_form(state, row);
let row = renamed.as_ref().unwrap_or(row);
let collected = match collect_values(state.vs.schema, row, &all_skip) {
Ok(v) => v,
Err(e) => {
return json_error(StatusCode::BAD_REQUEST, &format!("bulk entry {i}: {e}"));
}
};
prepared.push(collected.into_iter().unzip());
}
let fields = state.effective_fields();
let mut created: Vec<Value> = Vec::with_capacity(prepared.len());
for (i, (columns, values)) in prepared.into_iter().enumerate() {
match insert_and_fetch_one(state, acq, columns, values, pk_field, &fields).await {
Ok(obj) => created.push(obj),
Err((code, msg)) => {
return json_error(code, &format!("bulk entry {i}: {msg}"));
}
}
}
json_created(Value::Array(created))
}
async fn handle_update(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
update_inner(state, pk_raw, req, false).await
}
async fn handle_partial_update(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
update_inner(state, pk_raw, req, true).await
}
async fn update_inner(
state: Arc<ViewSetState>,
pk_raw: String,
req: axum::extract::Request,
partial: bool,
) -> Response {
let (parts, body, mut acq) = match enter(&state, req, &state.vs.perms.update, "update").await {
Ok(x) => x,
Err(resp) => return resp,
};
let pk_field = match pk_field_or_500(&state) {
Ok(f) => f,
Err(resp) => return resp,
};
let pk_val = match parse_pk_or_400(pk_field, &pk_raw) {
Ok(v) => v,
Err(resp) => return resp,
};
let (form, json) = or_400!(extract_form_body(parts, body).await);
let non_writable = match serializer_write_prep(&state, json.as_ref()) {
Ok(s) => s,
Err(resp) => return resp,
};
let renamed = serializer_input_renamed_form(&state, &form);
let form = renamed.as_ref().unwrap_or(&form);
let mut assignments: Vec<Assignment> = Vec::new();
for field in state.vs.schema.scalar_fields() {
if field.primary_key || field.auto {
continue;
}
if non_writable.contains(&field.name) {
continue;
}
if partial && !form.contains_key(field.name) {
continue;
}
let raw = form.get(field.name).map(String::as_str);
match parse_form_value(field, raw) {
Ok(v) => assignments.push(Assignment {
column: field.column,
value: v.into(),
}),
Err(FormError::Missing { .. }) if partial => continue,
Err(e) => return json_error(StatusCode::BAD_REQUEST, &e.to_string()),
}
}
if assignments.is_empty() {
return json_error(StatusCode::BAD_REQUEST, "no fields to update");
}
let query = UpdateQuery {
model: state.vs.schema,
set: assignments,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val.clone(),
}),
};
if let Err(e) = acq.update(&query).await {
return json_error(StatusCode::BAD_REQUEST, &e.to_string());
}
let fields = state.effective_fields();
match fetch_by_pk(&state, &mut acq, pk_field, pk_val, &fields).await {
Some(obj) => json_response(obj),
None => json_error(StatusCode::NOT_FOUND, "not found after update"),
}
}
async fn handle_destroy(
State(state): State<Arc<ViewSetState>>,
Path(pk_raw): Path<String>,
req: axum::extract::Request,
) -> Response {
let (_parts, _body, mut acq) =
match enter(&state, req, &state.vs.perms.destroy, "destroy").await {
Ok(x) => x,
Err(resp) => return resp,
};
let pk_field = match pk_field_or_500(&state) {
Ok(f) => f,
Err(resp) => return resp,
};
let pk_val = match parse_pk_or_400(pk_field, &pk_raw) {
Ok(v) => v,
Err(resp) => return resp,
};
let query = DeleteQuery {
model: state.vs.schema,
where_clause: WhereExpr::Predicate(Filter {
column: pk_field.column,
op: Op::Eq,
value: pk_val,
}),
};
match acq.delete(&query).await {
Ok(0) => json_error(StatusCode::NOT_FOUND, "not found"),
Ok(_) => no_content(),
Err(e) => json_error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()),
}
}
async fn fetch_by_pk(
state: &ViewSetState,
acq: &mut AcquiredConn,
pk_field: &'static crate::core::FieldSchema,
pk_val: SqlValue,
fields: &[&'static crate::core::FieldSchema],
) -> Option<Value> {
let select_q = SelectQuery::by_pk(state.vs.schema, pk_field.column, pk_val);
render_single(state, acq, &select_q, fields)
.await
.ok()
.flatten()
}
async fn render_list(
state: &ViewSetState,
acq: &mut AcquiredConn,
select_q: &SelectQuery,
fields: &[&'static crate::core::FieldSchema],
) -> Result<Vec<Value>, crate::sql::ExecError> {
match &state.vs.serializer {
Some(bridge) => bridge.render_rows(acq, select_q).await,
None => acq.select_rows_as_json(select_q, fields).await,
}
}
async fn render_single(
state: &ViewSetState,
acq: &mut AcquiredConn,
select_q: &SelectQuery,
fields: &[&'static crate::core::FieldSchema],
) -> Result<Option<Value>, crate::sql::ExecError> {
match &state.vs.serializer {
Some(bridge) => bridge.render_one(acq, select_q).await,
None => acq.select_one_as_json(select_q, fields).await,
}
}
async fn extract_form_body(
parts: axum::http::request::Parts,
body: Body,
) -> Result<(HashMap<String, String>, Option<Value>), String> {
match extract_create_body(parts, body).await? {
CreateBody::Single(form, json) => Ok((form, json)),
CreateBody::Bulk(_) => Err("expected a JSON object; got an array".into()),
}
}
pub(crate) enum CreateBody {
Single(HashMap<String, String>, Option<Value>),
Bulk(Vec<(HashMap<String, String>, Option<Value>)>),
}
pub(crate) async fn extract_create_body(
parts: axum::http::request::Parts,
body: Body,
) -> Result<CreateBody, String> {
use axum::body::to_bytes;
let content_type = parts
.headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let bytes = to_bytes(body, 4 * 1024 * 1024)
.await
.map_err(|e| e.to_string())?;
if content_type.contains("application/json") {
let value: serde_json::Value = serde_json::from_slice(&bytes).map_err(|e| e.to_string())?;
if let Some(array) = value.as_array() {
let mut bulk: Vec<(HashMap<String, String>, Option<Value>)> =
Vec::with_capacity(array.len());
for (i, entry) in array.iter().enumerate() {
let obj = entry
.as_object()
.ok_or_else(|| format!("bulk entry {i} is not a JSON object"))?;
bulk.push((json_object_to_form(obj), Some(entry.clone())));
}
return Ok(CreateBody::Bulk(bulk));
}
let obj = value
.as_object()
.ok_or("expected a JSON object or array of objects")?;
Ok(CreateBody::Single(json_object_to_form(obj), Some(value)))
} else {
let form = serde_urlencoded::from_bytes::<HashMap<String, String>>(&bytes)
.map_err(|e| e.to_string())?;
Ok(CreateBody::Single(form, None))
}
}
fn json_object_to_form(obj: &serde_json::Map<String, Value>) -> HashMap<String, String> {
let mut form = HashMap::with_capacity(obj.len());
for (k, v) in obj {
let s = match v {
Value::String(s) => s.clone(),
Value::Number(n) => n.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => String::new(),
other => other.to_string(),
};
form.insert(k.clone(), s);
}
form
}
#[cfg(test)]
mod cursor_tests {
use super::{decode_cursor, encode_cursor};
#[test]
fn cursor_roundtrip_positive() {
let token = encode_cursor(12345);
assert_eq!(decode_cursor(&token), Some(12345));
}
#[test]
fn cursor_roundtrip_zero() {
let token = encode_cursor(0);
assert_eq!(decode_cursor(&token), Some(0));
}
#[test]
fn cursor_roundtrip_max() {
let token = encode_cursor(i64::MAX);
assert_eq!(decode_cursor(&token), Some(i64::MAX));
}
#[test]
fn cursor_decode_invalid_base64_returns_none() {
assert!(decode_cursor("not!valid!base64@@").is_none());
}
#[test]
fn cursor_decode_non_numeric_payload_returns_none() {
use base64::Engine;
let token = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("not_a_number");
assert!(decode_cursor(&token).is_none());
}
}
#[cfg(all(test, feature = "tenancy"))]
mod tenant_router_tests {
use super::*;
#[test]
fn tenant_router_builds_for_a_basic_model() {
use crate::core::Model as _;
let _r = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA)
.read_only()
.tenant_router("/api/users");
}
#[test]
fn tenant_router_carries_over_full_builder_chain() {
use crate::core::Model as _;
let _r = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA)
.filter_fields(&["username"])
.search_fields(&["username"])
.ordering(&[("id", true)])
.page_size(50)
.tenant_router("/api/users");
}
#[test]
fn router_and_tenant_router_set_distinct_pool_sources() {
use crate::core::Model as _;
let static_state = ViewSet::for_model(crate::tenancy::auth::User::SCHEMA);
let vs = static_state.read_only();
let _r = vs.clone().tenant_router("/api/users");
}
}
#[cfg(test)]
mod lookup_tests {
use super::*;
use crate::core::{FieldSchema, FieldType};
fn int_field() -> &'static FieldSchema {
&FieldSchema {
name: "author_id",
column: "author_id",
ty: FieldType::I64,
nullable: false,
primary_key: false,
relation: None,
max_length: None,
min: None,
max: None,
default: None,
auto: false,
unique: false,
generated_as: None,
help_text: None,
choices: None,
db_comment: None,
verbose_name: None,
editable: true,
blank: false,
case_insensitive: false,
fk_on_delete: None,
validators: &[],
}
}
fn string_field() -> &'static FieldSchema {
&FieldSchema {
name: "title",
column: "title",
ty: FieldType::String,
nullable: true,
primary_key: false,
relation: None,
max_length: None,
min: None,
max: None,
default: None,
auto: false,
unique: false,
generated_as: None,
help_text: None,
choices: None,
db_comment: None,
verbose_name: None,
editable: true,
blank: false,
case_insensitive: false,
fk_on_delete: None,
validators: &[],
}
}
fn extract_pred(expr: WhereExpr) -> Filter {
match expr {
WhereExpr::Predicate(f) => f,
_ => panic!("expected Predicate"),
}
}
#[test]
fn no_lookup_means_eq() {
let f = extract_pred(build_lookup_filter(int_field(), None, "42").unwrap());
assert_eq!(f.op, Op::Eq);
assert!(matches!(f.value, SqlValue::I64(42)));
}
#[test]
fn explicit_exact_means_eq() {
let f = extract_pred(build_lookup_filter(int_field(), Some("exact"), "42").unwrap());
assert_eq!(f.op, Op::Eq);
}
#[test]
fn comparison_lookups() {
for (lk, expected) in [
("gt", Op::Gt),
("gte", Op::Gte),
("lt", Op::Lt),
("lte", Op::Lte),
("ne", Op::Ne),
] {
let f = extract_pred(build_lookup_filter(int_field(), Some(lk), "10").unwrap());
assert_eq!(f.op, expected, "lookup {lk}");
}
}
#[test]
fn in_lookup_parses_csv() {
let f = extract_pred(build_lookup_filter(int_field(), Some("in"), "1,2,3").unwrap());
assert_eq!(f.op, Op::In);
match f.value {
SqlValue::List(v) => assert_eq!(v.len(), 3),
_ => panic!("expected List"),
}
}
#[test]
fn not_in_lookup_parses_csv() {
let f = extract_pred(build_lookup_filter(int_field(), Some("not_in"), "1,2").unwrap());
assert_eq!(f.op, Op::NotIn);
}
#[test]
fn in_lookup_drops_empty_entries() {
let f = extract_pred(build_lookup_filter(int_field(), Some("in"), "1,,2,").unwrap());
match f.value {
SqlValue::List(v) => assert_eq!(v.len(), 2),
_ => panic!("expected List"),
}
}
#[test]
fn contains_wraps_with_percents_and_uses_like() {
let f =
extract_pred(build_lookup_filter(string_field(), Some("contains"), "hello").unwrap());
assert_eq!(f.op, Op::Like);
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%hello%"));
}
#[test]
fn icontains_uses_ilike() {
let f = extract_pred(build_lookup_filter(string_field(), Some("icontains"), "hi").unwrap());
assert_eq!(f.op, Op::ILike);
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%hi%"));
}
#[test]
fn startswith_only_trailing_percent() {
let f =
extract_pred(build_lookup_filter(string_field(), Some("startswith"), "pre").unwrap());
assert!(matches!(f.value, SqlValue::String(ref s) if s == "pre%"));
}
#[test]
fn endswith_only_leading_percent() {
let f = extract_pred(build_lookup_filter(string_field(), Some("endswith"), "fix").unwrap());
assert!(matches!(f.value, SqlValue::String(ref s) if s == "%fix"));
}
#[test]
fn isnull_true() {
let f = extract_pred(build_lookup_filter(string_field(), Some("isnull"), "true").unwrap());
assert_eq!(f.op, Op::IsNull);
assert!(matches!(f.value, SqlValue::Bool(true)));
}
#[test]
fn isnull_false() {
let f = extract_pred(build_lookup_filter(string_field(), Some("isnull"), "false").unwrap());
assert!(matches!(f.value, SqlValue::Bool(false)));
}
#[test]
fn unknown_lookup_returns_none() {
let r = build_lookup_filter(int_field(), Some("frobulate"), "x");
assert!(r.is_none());
}
#[test]
fn parse_failure_returns_none() {
let r = build_lookup_filter(int_field(), Some("gt"), "not-a-number");
assert!(r.is_none());
}
}
#[cfg(all(test, feature = "tenancy"))]
mod typed_perms_tests {
use super::*;
use crate::sql::Auto;
#[derive(crate::Model)]
#[rustango(table = "vs_typed_perm_post")]
#[allow(dead_code)]
pub struct PermPost {
#[rustango(primary_key)]
pub id: Auto<i64>,
#[rustango(max_length = 200)]
pub title: String,
}
#[test]
fn permissions_for_model_fills_all_four_crud_codenames() {
use crate::core::Model;
let vs =
ViewSet::for_model(<PermPost as Model>::SCHEMA).permissions_for_model::<PermPost>();
assert_eq!(vs.perms.list, vec!["vs_typed_perm_post.view"]);
assert_eq!(vs.perms.retrieve, vec!["vs_typed_perm_post.view"]);
assert_eq!(vs.perms.create, vec!["vs_typed_perm_post.add"]);
assert_eq!(vs.perms.update, vec!["vs_typed_perm_post.change"]);
assert_eq!(vs.perms.destroy, vec!["vs_typed_perm_post.delete"]);
}
}