use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::{Expr, Ident, LitStr, Token};
use std::path::PathBuf;
use crate::spec::{self, SpecIndex};
use crate::types;
fn resolve_spec_path(relative: &str) -> PathBuf {
let p = PathBuf::from(relative);
if p.is_absolute() && p.exists() {
return p;
}
if let Ok(env_path) = std::env::var("OPENAPI_SPEC_PATH") {
let ep = PathBuf::from(&env_path);
if ep.is_absolute() {
return ep;
}
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into());
return PathBuf::from(manifest).join(ep);
}
let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into());
let candidate = PathBuf::from(&manifest).join(relative);
if candidate.exists() {
return candidate;
}
let mut dir = PathBuf::from(&manifest);
while let Some(parent) = dir.parent() {
let candidate = parent.join(relative);
if candidate.exists() {
return candidate;
}
dir = parent.to_path_buf();
}
PathBuf::from(manifest).join(relative)
}
struct GenerateTypesInput {
spec_path: LitStr,
}
impl Parse for GenerateTypesInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let spec_path: LitStr = input.parse()?;
Ok(Self { spec_path })
}
}
pub fn generate_types_impl(input: TokenStream) -> syn::Result<TokenStream> {
let parsed = syn::parse2::<GenerateTypesInput>(input)?;
let spec_path_str = parsed.spec_path.value();
let full_path = resolve_spec_path(&spec_path_str);
let spec = spec::load_spec(&full_path, &spec_path_str).map_err(|e| {
syn::Error::new_spanned(&parsed.spec_path, format!("failed to load spec: {e}"))
})?;
let type_tokens = types::generate_all_types(&spec.schemas);
Ok(type_tokens)
}
struct ApiInput {
method: Ident,
path: LitStr,
params: Vec<(Ident, Expr)>,
query_params: Vec<(Ident, Expr)>,
body: Option<Expr>,
}
impl Parse for ApiInput {
fn parse(input: ParseStream) -> syn::Result<Self> {
let method: Ident = input.parse()?;
let path: LitStr = input.parse()?;
let mut params = Vec::new();
let mut query_params = Vec::new();
let mut body = None;
while input.peek(Token![,]) {
let _: Token![,] = input.parse()?;
if input.is_empty() {
break;
}
if input.peek(Ident) {
let key: Ident = input.parse()?;
let key_str = key.to_string();
if key_str == "query" {
let _: Token![=] = input.parse()?;
let content;
syn::braced!(content in input);
while !content.is_empty() {
let qk: Ident = content.parse()?;
let _: Token![:] = content.parse()?;
let qv: Expr = content.parse()?;
query_params.push((qk, qv));
if content.peek(Token![,]) {
let _: Token![,] = content.parse()?;
}
}
} else if key_str == "body" {
let _: Token![=] = input.parse()?;
let expr: Expr = input.parse()?;
body = Some(expr);
} else {
let _: Token![=] = input.parse()?;
let expr: Expr = input.parse()?;
params.push((key, expr));
}
}
}
Ok(Self {
method,
path,
params,
query_params,
body,
})
}
}
pub fn api_impl(input: TokenStream) -> syn::Result<TokenStream> {
let parsed = syn::parse2::<ApiInput>(input)?;
let method_str = parsed.method.to_string().to_uppercase();
let path_str = parsed.path.value();
let spec_path = resolve_spec_path("openapi-spec.json");
let spec = spec::load_spec(&spec_path, "openapi-spec.json")
.map_err(|e| syn::Error::new_spanned(&parsed.path, format!("failed to load spec: {e}")))?;
validate_endpoint(&parsed, &method_str, &path_str, &spec)?;
let endpoint = &spec.endpoints[&(method_str.clone(), path_str.clone())];
let method_ident = format_ident!("{}", method_str);
let path_format = build_path_format(&path_str, &parsed.params);
let response_type = if endpoint.is_sse {
quote! { () }
} else {
match &endpoint.response_schema {
Some(schema) => types::schema_to_rust_type(schema, &spec.schemas),
None => quote! { () },
}
};
let query_code = if parsed.query_params.is_empty() {
quote! {}
} else {
let pairs: Vec<TokenStream> = parsed
.query_params
.iter()
.map(|(k, v)| {
let ks = k.to_string();
quote! { (#ks, &#v as &dyn ToString) }
})
.collect();
quote! { .query_raw(::openapi_contract::build_query_string(&[#(#pairs),*])) }
};
let body_code = match &parsed.body {
Some(expr) => quote! { .body_json(#expr) },
None => quote! {},
};
Ok(quote! {
{
let __path = #path_format;
::openapi_contract::ApiRequest::<#response_type>::new(
::openapi_contract::Method::#method_ident,
__path,
)
#query_code
#body_code
}
})
}
fn validate_endpoint(
parsed: &ApiInput,
method: &str,
path: &str,
spec: &SpecIndex,
) -> syn::Result<()> {
let key = (method.to_string(), path.to_string());
if !spec.endpoints.contains_key(&key) {
let has_path = spec.endpoints.keys().any(|(_, p)| p == path);
if has_path {
let hint = spec::available_methods_hint(spec, path);
return Err(syn::Error::new_spanned(
&parsed.method,
format!("{method} is not defined for \"{path}\". Available methods: {hint}"),
));
}
let prefix_end = path.char_indices().nth(8).map_or(path.len(), |(i, _)| i);
let prefix = &path[..prefix_end];
let hint = spec::available_paths_hint(spec, prefix);
return Err(syn::Error::new_spanned(
&parsed.path,
format!("unknown API path \"{path}\". Similar paths: {hint}"),
));
}
let endpoint = &spec.endpoints[&key];
let provided_params: Vec<String> = parsed.params.iter().map(|(k, _)| k.to_string()).collect();
for expected in &endpoint.path_params {
if !provided_params.contains(expected) {
return Err(syn::Error::new_spanned(
&parsed.path,
format!(
"missing path parameter `{expected}` for {method} \"{path}\". Required: {}",
endpoint.path_params.join(", ")
),
));
}
}
for (k, _) in &parsed.params {
let ks = k.to_string();
if !endpoint.path_params.contains(&ks) {
return Err(syn::Error::new_spanned(
k,
format!(
"unexpected parameter `{ks}` for {method} \"{path}\". Expected: {}",
endpoint.path_params.join(", ")
),
));
}
}
Ok(())
}
fn build_path_format(path: &str, params: &[(Ident, Expr)]) -> TokenStream {
if params.is_empty() {
return quote! { #path.to_string() };
}
let param_map: std::collections::HashMap<String, &Expr> = params
.iter()
.map(|(name, expr)| (name.to_string(), expr))
.collect();
let mut fmt_str = String::new();
let mut args = Vec::new();
let mut chars = path.chars().peekable();
while let Some(c) = chars.next() {
if c == '{' {
let mut name = String::new();
for c2 in chars.by_ref() {
if c2 == '}' {
break;
}
name.push(c2);
}
fmt_str.push_str("{}");
if let Some(expr) = param_map.get(&name) {
args.push(quote! { #expr });
}
} else {
fmt_str.push(c);
}
}
quote! { format!(#fmt_str, #(#args),*) }
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Mutex, OnceLock};
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(())).lock().unwrap()
}
fn write_test_spec(
spec: &serde_json::Value,
) -> (std::path::PathBuf, std::sync::MutexGuard<'static, ()>) {
let guard = env_lock();
let tmp = std::env::temp_dir().join("openapi-contract-codegen-test-spec.json");
std::fs::write(&tmp, spec.to_string()).unwrap();
unsafe { std::env::set_var("OPENAPI_SPEC_PATH", tmp.to_str().unwrap()) };
(tmp, guard)
}
fn cleanup_test_spec(tmp: &std::path::Path) {
unsafe { std::env::remove_var("OPENAPI_SPEC_PATH") };
let _ = std::fs::remove_file(tmp);
}
fn full_spec() -> serde_json::Value {
serde_json::json!({
"openapi": "3.0.3",
"info": {"title": "T", "version": "1"},
"paths": {
"/items": {
"get": {
"responses": { "200": { "description": "ok", "content": {
"application/json": { "schema": {"type": "array", "items": {"type": "string"}} }
}}}
},
"post": {
"responses": { "200": { "description": "ok", "content": {
"application/json": { "schema": {"type": "string"} }
}}}
}
},
"/items/{id}": {
"get": {
"parameters": [{"name": "id", "in": "path", "required": true, "schema": {"type": "string"}}],
"responses": { "200": { "description": "ok", "content": {
"application/json": { "schema": {"type": "string"} }
}}}
}
},
"/stream": {
"get": {
"responses": { "200": { "description": "ok", "content": {
"text/event-stream": { "schema": {"type": "string"} }
}}}
}
}
}
})
}
#[test]
fn parse_api_input_variants() {
let p: ApiInput = syn::parse2(quote! { GET "/api/test" }).unwrap();
assert_eq!(p.method.to_string(), "GET");
assert_eq!(p.path.value(), "/api/test");
assert!(p.params.is_empty() && p.body.is_none());
let p: ApiInput = syn::parse2(quote! { GET "/api/teams/{id}", id = &team_id }).unwrap();
assert_eq!(p.params.len(), 1);
assert_eq!(p.params[0].0.to_string(), "id");
let p: ApiInput = syn::parse2(quote! { POST "/api/teams", body = &data }).unwrap();
assert!(p.body.is_some());
let p: ApiInput =
syn::parse2(quote! { GET "/api/users", query = { limit: 10, offset: 0 } }).unwrap();
assert_eq!(p.query_params.len(), 2);
let p: ApiInput = syn::parse2(quote! { POST "/api/teams/{id}/invite", id = &tid, query = { draft: true }, body = &invite }).unwrap();
assert_eq!(p.params.len(), 1);
assert_eq!(p.query_params.len(), 1);
assert!(p.body.is_some());
let p: ApiInput = syn::parse2(quote! { GET "/api/test", }).unwrap();
assert_eq!(p.method.to_string(), "GET");
assert!(syn::parse2::<ApiInput>(quote! { GET "/api/test", 123 }).is_err());
let g: GenerateTypesInput = syn::parse2(quote! { "spec.json" }).unwrap();
assert_eq!(g.spec_path.value(), "spec.json");
}
#[test]
fn build_path_format_variants() {
assert!(
build_path_format("/api/test", &[])
.to_string()
.contains("to_string")
);
let ts = build_path_format(
"/api/teams/{id}",
&[(format_ident!("id"), syn::parse_quote! { &team_id })],
);
assert!(ts.to_string().contains("format"));
let ts = build_path_format(
"/orgs/{org}/teams/{team}",
&[
(format_ident!("org"), syn::parse_quote! { &org_id }),
(format_ident!("team"), syn::parse_quote! { &team_id }),
],
);
assert!(ts.to_string().contains("format"));
}
#[test]
fn resolve_spec_path_all_strategies() {
let tmp = std::env::temp_dir().join("openapi-contract-abs-spec.json");
std::fs::write(&tmp, "{}").unwrap();
assert_eq!(resolve_spec_path(tmp.to_str().unwrap()), tmp);
let _ = std::fs::remove_file(&tmp);
let _guard = env_lock();
unsafe { std::env::set_var("OPENAPI_SPEC_PATH", "/absolute/path/spec.json") };
assert!(
resolve_spec_path("ignored.json")
.to_string_lossy()
.contains("spec.json")
);
unsafe { std::env::set_var("OPENAPI_SPEC_PATH", "relative/spec.json") };
assert!(
resolve_spec_path("ignored.json")
.to_string_lossy()
.contains("spec.json")
);
unsafe { std::env::remove_var("OPENAPI_SPEC_PATH") };
let dir = std::env::temp_dir().join("openapi-contract-manifest-dir-test");
std::fs::create_dir_all(&dir).unwrap();
let spec = dir.join("my-spec.json");
std::fs::write(&spec, "{}").unwrap();
unsafe { std::env::set_var("CARGO_MANIFEST_DIR", dir.to_str().unwrap()) };
assert_eq!(resolve_spec_path("my-spec.json"), spec);
unsafe { std::env::remove_var("CARGO_MANIFEST_DIR") };
let _ = std::fs::remove_dir_all(&dir);
let root = std::env::temp_dir().join("openapi-contract-parent-search");
let nested = root.join("a").join("b");
let spec = root.join("openapi-spec.json");
std::fs::create_dir_all(&nested).unwrap();
std::fs::write(&spec, "{}").unwrap();
unsafe { std::env::set_var("CARGO_MANIFEST_DIR", nested.to_str().unwrap()) };
assert_eq!(resolve_spec_path("openapi-spec.json"), spec);
unsafe { std::env::remove_var("CARGO_MANIFEST_DIR") };
let _ = std::fs::remove_dir_all(&root);
unsafe { std::env::remove_var("OPENAPI_SPEC_PATH") };
assert!(
resolve_spec_path("test.json")
.to_string_lossy()
.contains("test.json")
);
}
#[test]
fn generate_types_impl_success_and_error() {
let _guard = env_lock();
unsafe { std::env::remove_var("OPENAPI_SPEC_PATH") };
let err = generate_types_impl(quote! { "nonexistent-spec-file.json" }).unwrap_err();
assert!(err.to_string().contains("failed to load spec"));
drop(_guard);
let (tmp, _guard) = write_test_spec(&serde_json::json!({
"openapi": "3.0.3",
"info": {"title": "T", "version": "1"},
"paths": {},
"components": { "schemas": { "User": {
"type": "object", "required": ["id"],
"properties": { "id": {"type": "string"} }
}}}
}));
let code = generate_types_impl(quote! { "ignored.json" })
.unwrap()
.to_string();
assert!(code.contains("User"));
cleanup_test_spec(&tmp);
}
#[test]
fn api_impl_valid_spec_variants() {
let (tmp, _guard) = write_test_spec(&full_spec());
let code = api_impl(quote! { GET "/items" }).unwrap().to_string();
assert!(code.contains("Vec"));
let code = api_impl(quote! { GET "/items/{id}", id = &my_id })
.unwrap()
.to_string();
assert!(code.contains("format"));
let code = api_impl(quote! { GET "/items", query = { limit: 10 } })
.unwrap()
.to_string();
assert!(code.contains("query_raw"));
let code = api_impl(quote! { POST "/items", body = &data })
.unwrap()
.to_string();
assert!(code.contains("body_json"));
assert!(api_impl(quote! { GET "/stream" }).is_ok());
cleanup_test_spec(&tmp);
}
#[test]
fn api_impl_validation_errors() {
let (tmp, _guard) = write_test_spec(&full_spec());
assert!(
api_impl(quote! { DELETE "/items" })
.unwrap_err()
.to_string()
.contains("is not defined for")
);
assert!(
api_impl(quote! { GET "/nonexistent" })
.unwrap_err()
.to_string()
.contains("unknown API path")
);
assert!(
api_impl(quote! { GET "/items/{id}" })
.unwrap_err()
.to_string()
.contains("missing path parameter")
);
assert!(
api_impl(quote! { GET "/items", id = &x })
.unwrap_err()
.to_string()
.contains("unexpected parameter")
);
cleanup_test_spec(&tmp);
}
#[test]
fn api_impl_missing_spec() {
let _guard = env_lock();
unsafe { std::env::remove_var("OPENAPI_SPEC_PATH") };
unsafe { std::env::set_var("CARGO_MANIFEST_DIR", "/tmp/nonexistent") };
assert!(api_impl(quote! { GET "/api/test" }).is_err());
unsafe { std::env::remove_var("CARGO_MANIFEST_DIR") };
}
}