use audb::model::project::{AuthRequirement, Endpoint};
use audb::model::query::Query;
use audb::schema::Schema;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use std::collections::HashMap;
pub struct EndpointGenerator {
pub with_auth: bool,
schemas: HashMap<String, Schema>,
queries: HashMap<String, Query>,
}
impl EndpointGenerator {
pub fn new() -> Self {
Self {
with_auth: true,
schemas: HashMap::new(),
queries: HashMap::new(),
}
}
pub fn with_schemas(schemas: HashMap<String, Schema>) -> Self {
Self {
with_auth: true,
schemas,
queries: HashMap::new(),
}
}
pub fn with_schemas_and_queries(
schemas: HashMap<String, Schema>,
queries: HashMap<String, Query>,
) -> Self {
Self {
with_auth: true,
schemas,
queries,
}
}
pub fn generate_prelude(&self) -> TokenStream {
quote! {
#![allow(dead_code)]
use crate::generated::queries;
use crate::generated::schemas::*;
use audb_runtime::Database;
use std::sync::Arc;
use axum::response::IntoResponse;
#[derive(Clone)]
pub struct AppState {
pub db: Arc<Database>,
}
}
}
pub fn generate(&self, endpoint: &Endpoint) -> TokenStream {
let handler_name = self.generate_handler_name(endpoint);
let method = endpoint.method.to_uppercase();
let doc_comment = if let Some(ref doc) = endpoint.doc_comment {
let doc_lines: Vec<_> = doc.lines().map(|line| quote! { #[doc = #line] }).collect();
quote! { #(#doc_lines)* }
} else {
quote! {}
};
let handler_fn = self.generate_handler_function(endpoint);
let path = &endpoint.path;
let _route = match method.as_str() {
"GET" => quote! { .route(#path, axum::routing::get(#handler_name)) },
"POST" => quote! { .route(#path, axum::routing::post(#handler_name)) },
"PUT" => quote! { .route(#path, axum::routing::put(#handler_name)) },
"PATCH" => quote! { .route(#path, axum::routing::patch(#handler_name)) },
"DELETE" => quote! { .route(#path, axum::routing::delete(#handler_name)) },
_ => quote! { .route(#path, axum::routing::get(#handler_name)) },
};
quote! {
#doc_comment
#handler_fn
}
}
fn generate_handler_name(&self, endpoint: &Endpoint) -> syn::Ident {
let path_clean = endpoint
.path
.replace('/', "_")
.replace(':', "")
.trim_start_matches('_')
.to_lowercase();
let method_lower = endpoint.method.to_lowercase();
let name = format!("handle_{}_{}", method_lower, path_clean);
format_ident!("{}", name)
}
fn generate_handler_function(&self, endpoint: &Endpoint) -> TokenStream {
let handler_name = self.generate_handler_name(endpoint);
if let Some(ref schema_name) = endpoint.crud_schema {
return self.generate_crud_handler(endpoint, schema_name, &handler_name);
}
let path_params = self.extract_path_params(&endpoint.path);
let query_params = if let Some(ref query_name) = endpoint.query {
if let Some(query) = self.queries.get(query_name) {
query
.params
.iter()
.filter(|p| !path_params.contains(&p.name))
.cloned()
.collect::<Vec<_>>()
} else {
Vec::new()
}
} else {
Vec::new()
};
let (query_params_struct, query_params_type) = if !query_params.is_empty() {
let struct_name = format_ident!(
"{}QueryParams",
handler_name
.to_string()
.split('_')
.map(|s| {
let mut c = s.chars();
match c.next() {
None => String::new(),
Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
}
})
.collect::<String>()
);
let fields = query_params.iter().map(|p| {
let field_name = format_ident!("{}", p.name);
let field_type = self.convert_type(&p.param_type);
quote! { pub #field_name: #field_type }
});
let struct_def = quote! {
#[derive(Debug, serde::Deserialize)]
struct #struct_name {
#(#fields),*
}
};
(Some(struct_def), Some(struct_name))
} else {
(None, None)
};
let extractors =
self.generate_extractors(endpoint, &path_params, query_params_type.as_ref());
let auth_check = if self.with_auth {
match endpoint.auth {
AuthRequirement::Required => {
quote! {
let user = auth.require_user()?;
}
}
AuthRequirement::Optional => {
quote! {
let user = auth.optional_user();
}
}
AuthRequirement::None => quote! {},
}
} else {
quote! {}
};
let query_call = if let Some(ref query_name) = endpoint.query {
let query_fn = format_ident!("{}", query_name);
let mut path_param_tokens: Vec<TokenStream> = Vec::new();
let mut query_param_tokens: Vec<TokenStream> = Vec::new();
for p in &path_params {
let param_name = format_ident!("{}", p);
path_param_tokens.push(quote! { #param_name });
}
for p in &query_params {
let field_name = format_ident!("{}", p.name);
query_param_tokens.push(quote! { query_params.#field_name });
}
let all_params = path_param_tokens.into_iter().chain(query_param_tokens);
if path_params.is_empty() && query_params.is_empty() {
quote! {
let result = queries::#query_fn(&state.db).await?;
}
} else {
quote! {
let result = queries::#query_fn(&state.db, #(#all_params),*).await?;
}
}
} else {
quote! {
let result = String::from("Not implemented");
}
};
let response = quote! {
Ok(axum::Json(result))
};
quote! {
#query_params_struct
async fn #handler_name(
#extractors
) -> impl axum::response::IntoResponse {
#auth_check
let result: Result<_, audb_runtime::QueryError> = async {
#query_call
#response
}.await;
match result {
Ok(json) => json.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
fn convert_type(&self, typ: &audb::schema::Type) -> TokenStream {
use audb::schema::Type;
match typ {
Type::String => quote! { String },
Type::Integer => quote! { i64 },
Type::Float => quote! { f64 },
Type::Boolean => quote! { bool },
Type::EntityId => quote! { uuid::Uuid },
Type::Timestamp => quote! { chrono::DateTime<chrono::Utc> },
Type::Vector => quote! { Vec<f32> },
Type::JsonValue => quote! { serde_json::Value },
Type::Unit => quote! { () },
Type::Named(name) | Type::Custom(name) => {
let ident = format_ident!("{}", name);
quote! { #ident }
}
Type::Option(inner) => {
let inner_type = self.convert_type(inner);
quote! { Option<#inner_type> }
}
Type::Vec(inner) => {
let inner_type = self.convert_type(inner);
quote! { Vec<#inner_type> }
}
Type::Enum(variants) => {
let _ = variants;
quote! { String }
}
}
}
fn extract_path_params(&self, path: &str) -> Vec<String> {
path.split('/')
.filter(|s| s.starts_with(':'))
.map(|s| s.trim_start_matches(':').to_string())
.collect()
}
fn generate_extractors(
&self,
endpoint: &Endpoint,
path_params: &[String],
query_params_type: Option<&syn::Ident>,
) -> TokenStream {
let mut extractors = vec![];
extractors.push(quote! {
axum::extract::State(state): axum::extract::State<AppState>
});
if self.with_auth && endpoint.auth != AuthRequirement::None {
extractors.push(quote! {
auth: Auth
});
}
if !path_params.is_empty() {
if path_params.len() == 1 {
let param_name = format_ident!("{}", &path_params[0]);
if path_params[0] == "id" || path_params[0].ends_with("_id") {
extractors.push(quote! {
axum::extract::Path(#param_name): axum::extract::Path<uuid::Uuid>
});
} else {
extractors.push(quote! {
axum::extract::Path(#param_name): axum::extract::Path<String>
});
}
} else {
let param_names: Vec<_> =
path_params.iter().map(|p| format_ident!("{}", p)).collect();
let param_types: Vec<_> = path_params
.iter()
.map(|p| {
if p == "id" || p.ends_with("_id") {
quote! { uuid::Uuid }
} else {
quote! { String }
}
})
.collect();
extractors.push(quote! {
axum::extract::Path((#(#param_names),*)): axum::extract::Path<(#(#param_types),*)>
});
}
}
if let Some(query_params_type_name) = query_params_type {
extractors.push(quote! {
axum::extract::Query(query_params): axum::extract::Query<#query_params_type_name>
});
}
if matches!(
endpoint.method.to_uppercase().as_str(),
"POST" | "PUT" | "PATCH"
) {
extractors.push(quote! {
axum::Json(body): axum::Json<RequestBody>
});
}
quote! {
#(#extractors),*
}
}
pub fn generate_router(&self, endpoints: &[&Endpoint]) -> TokenStream {
let routes: Vec<_> = endpoints
.iter()
.map(|endpoint| {
let handler_name = self.generate_handler_name(endpoint);
let path = &endpoint.path;
let method = endpoint.method.to_uppercase();
match method.as_str() {
"GET" => quote! { .route(#path, axum::routing::get(#handler_name)) },
"POST" => quote! { .route(#path, axum::routing::post(#handler_name)) },
"PUT" => quote! { .route(#path, axum::routing::put(#handler_name)) },
"PATCH" => quote! { .route(#path, axum::routing::patch(#handler_name)) },
"DELETE" => quote! { .route(#path, axum::routing::delete(#handler_name)) },
_ => quote! { .route(#path, axum::routing::get(#handler_name)) },
}
})
.collect();
quote! {
pub fn create_router(state: AppState) -> axum::Router {
axum::Router::new()
#(#routes)*
.with_state(state)
}
}
}
pub fn generate_all(&self, endpoints: &[&Endpoint]) -> TokenStream {
let prelude = self.generate_prelude();
let endpoint_code: Vec<_> = endpoints.iter().map(|e| self.generate(e)).collect();
let router_code = self.generate_router(endpoints);
quote! {
#prelude
#(#endpoint_code)*
#router_code
}
}
fn generate_field_extractions(&self, schema: &Schema) -> TokenStream {
use audb::schema::Type;
let extractions: Vec<_> = schema
.fields
.iter()
.filter(|f| f.name != "id" && f.embedding_config.is_none())
.map(|field| {
let field_name = format_ident!("{}", field.name);
let field_str = &field.name;
match &field.field_type {
Type::String => quote! {
let #field_name = body[#field_str]
.as_str()
.ok_or_else(|| audb_runtime::QueryError::serialization(
format!("Missing or invalid field: {}", #field_str)
))?
.to_string();
},
Type::Integer => quote! {
let #field_name = body[#field_str]
.as_i64()
.ok_or_else(|| audb_runtime::QueryError::serialization(
format!("Missing or invalid field: {}", #field_str)
))?;
},
Type::Boolean => quote! {
let #field_name = body[#field_str]
.as_bool()
.ok_or_else(|| audb_runtime::QueryError::serialization(
format!("Missing or invalid field: {}", #field_str)
))?;
},
Type::Timestamp => quote! {
let #field_name = if let Some(ts_str) = body[#field_str].as_str() {
chrono::DateTime::parse_from_rfc3339(ts_str)
.map_err(|e| audb_runtime::QueryError::serialization(
format!("Invalid timestamp for {}: {}", #field_str, e)
))?
.with_timezone(&chrono::Utc)
} else {
chrono::Utc::now()
};
},
Type::Float => quote! {
let #field_name = body[#field_str]
.as_f64()
.ok_or_else(|| audb_runtime::QueryError::serialization(
format!("Missing or invalid field: {}", #field_str)
))?;
},
_ => quote! {
let #field_name = serde_json::from_value(body[#field_str].clone())
.map_err(|e| audb_runtime::QueryError::serialization(
format!("Invalid field {}: {}", #field_str, e)
))?;
},
}
})
.collect();
quote! {
#(#extractions)*
}
}
fn generate_crud_handler(
&self,
endpoint: &Endpoint,
schema_name: &str,
handler_name: &syn::Ident,
) -> TokenStream {
let schema_type = format_ident!("{}", schema_name);
let method = endpoint.method.to_uppercase();
match method.as_str() {
"POST" if endpoint.path.ends_with("/search") => {
quote! {
#[derive(serde::Deserialize)]
struct SearchRequest {
query: String,
#[serde(default = "default_search_limit")]
limit: usize,
}
fn default_search_limit() -> usize { 10 }
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::Json(req): axum::Json<SearchRequest>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
let results = #schema_type::find_similar(&state.db, &req.query, req.limit)?;
Ok(axum::Json(results))
}.await;
match result {
Ok(json) => json.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
"POST" => {
let has_embeddings = self
.schemas
.get(schema_name)
.map(|schema| schema.fields.iter().any(|f| f.embedding_config.is_some()))
.unwrap_or(false);
if has_embeddings {
let schema = self.schemas.get(schema_name).expect("Schema must exist");
let field_extractions = self.generate_field_extractions(schema);
let field_names: Vec<_> = schema
.fields
.iter()
.filter(|f| f.name != "id" && f.embedding_config.is_none())
.map(|f| format_ident!("{}", f.name))
.collect();
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::Json(body): axum::Json<serde_json::Value>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
#field_extractions
let created = #schema_type::create_with_embedding(&state.db, #(#field_names),*)?;
Ok(axum::Json(created))
}.await;
match result {
Ok(json) => (axum::http::StatusCode::CREATED, json).into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
} else {
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::Json(body): axum::Json<serde_json::Value>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
let created = #schema_type::create_from_json(&state.db, body)?;
Ok(axum::Json(created))
}.await;
match result {
Ok(json) => (axum::http::StatusCode::CREATED, json).into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
}
"GET" if endpoint.path.contains(":id") => {
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::extract::Path(id): axum::extract::Path<uuid::Uuid>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
let entity = #schema_type::get(&state.db, id)?;
Ok(axum::Json(entity))
}.await;
match result {
Ok(json) => json.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
"GET" => {
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
let entities = #schema_type::list_all(&state.db)?;
Ok(axum::Json(entities))
}.await;
match result {
Ok(json) => json.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
"PUT" => {
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::extract::Path(_id): axum::extract::Path<uuid::Uuid>,
axum::Json(body): axum::Json<#schema_type>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
body.update(&state.db)?;
Ok(axum::http::StatusCode::NO_CONTENT)
}.await;
match result {
Ok(status) => status.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
"DELETE" => {
quote! {
async fn #handler_name(
axum::extract::State(state): axum::extract::State<AppState>,
axum::extract::Path(id): axum::extract::Path<uuid::Uuid>,
) -> impl axum::response::IntoResponse {
let result: Result<_, audb_runtime::QueryError> = async {
#schema_type::delete(&state.db, id)?;
Ok(axum::http::StatusCode::NO_CONTENT)
}.await;
match result {
Ok(status) => status.into_response(),
Err(e) => {
let status = axum::http::StatusCode::INTERNAL_SERVER_ERROR;
let body = format!("Database error: {}", e);
(status, body).into_response()
}
}
}
}
}
_ => {
quote! {
async fn #handler_name() -> impl axum::response::IntoResponse {
(axum::http::StatusCode::NOT_IMPLEMENTED, "Not implemented")
}
}
}
}
}
}
impl Default for EndpointGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_generate_simple_endpoint() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/users".to_string(),
query: Some("list_users".to_string()),
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("async fn handle_get_api_users"));
assert!(code_str.contains("list_users"));
}
#[test]
fn test_generate_endpoint_with_path_params() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/users/:id".to_string(),
query: Some("get_user".to_string()),
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("handle_get_api_users_id"));
assert!(code_str.contains("Path"));
assert!(code_str.contains("id"));
}
#[test]
fn test_generate_post_endpoint() {
let endpoint = Endpoint {
method: "POST".to_string(),
path: "/api/users".to_string(),
query: Some("create_user".to_string()),
auth: AuthRequirement::Required,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("handle_post_api_users"));
assert!(code_str.contains("Json"));
}
#[test]
fn test_auth_required() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/profile".to_string(),
query: Some("get_profile".to_string()),
auth: AuthRequirement::Required,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("auth : Auth"));
assert!(code_str.contains("require_user"));
}
#[test]
fn test_auth_optional() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/content".to_string(),
query: Some("get_content".to_string()),
auth: AuthRequirement::Optional,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("optional_user"));
}
#[test]
fn test_generate_handler_name() {
let generator = EndpointGenerator::new();
let endpoint1 = Endpoint {
method: "GET".to_string(),
path: "/api/users/:id".to_string(),
query: None,
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let name1 = generator.generate_handler_name(&endpoint1);
assert_eq!(name1.to_string(), "handle_get_api_users_id");
let endpoint2 = Endpoint {
method: "POST".to_string(),
path: "/api/auth/login".to_string(),
query: None,
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let name2 = generator.generate_handler_name(&endpoint2);
assert_eq!(name2.to_string(), "handle_post_api_auth_login");
}
#[test]
fn test_extract_path_params() {
let generator = EndpointGenerator::new();
let params1 = generator.extract_path_params("/api/users/:id");
assert_eq!(params1, vec!["id"]);
let params2 = generator.extract_path_params("/api/users/:user_id/posts/:post_id");
assert_eq!(params2, vec!["user_id", "post_id"]);
let params3 = generator.extract_path_params("/api/users");
assert!(params3.is_empty());
}
#[test]
fn test_generate_router() {
let endpoint1 = Endpoint {
method: "GET".to_string(),
path: "/api/users".to_string(),
query: Some("list_users".to_string()),
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let endpoint2 = Endpoint {
method: "POST".to_string(),
path: "/api/users".to_string(),
query: Some("create_user".to_string()),
auth: AuthRequirement::Required,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let generator = EndpointGenerator::new();
let code = generator.generate_router(&[&endpoint1, &endpoint2]);
let code_str = code.to_string();
assert!(code_str.contains("create_router"));
assert!(code_str.contains("Router :: new"));
assert!(code_str.contains("with_state"));
}
#[test]
fn test_with_doc_comment() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/users".to_string(),
query: Some("list_users".to_string()),
auth: AuthRequirement::None,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: Some("List all users".to_string()),
};
let generator = EndpointGenerator::new();
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(code_str.contains("List all users"));
}
#[test]
fn test_without_auth() {
let endpoint = Endpoint {
method: "GET".to_string(),
path: "/api/public".to_string(),
query: Some("get_public".to_string()),
auth: AuthRequirement::Required,
roles: Vec::new(),
rate_limit: None,
validation: HashMap::new(),
doc_comment: None,
};
let mut generator = EndpointGenerator::new();
generator.with_auth = false;
let code = generator.generate(&endpoint);
let code_str = code.to_string();
assert!(!code_str.contains("auth : Auth"));
}
}