Skip to main content

spikard_codegen/sql/
openapi.rs

1#![allow(
2    clippy::missing_errors_doc,
3    clippy::missing_panics_doc,
4    clippy::must_use_candidate,
5    clippy::doc_markdown,
6    clippy::too_long_first_doc_paragraph,
7    clippy::module_name_repetitions
8)]
9//! Emit an OpenAPI 3.1 document from a slice of [`SqlRoute`].
10//!
11//! The spec is built as a raw `serde_json::Value` rather than reusing
12//! `crate::openapi::OpenApiSpec` because the existing struct is a subset that
13//! lacks several 3.1 idioms we need (array-typed `type`, `oneOf` for
14//! nullability, `enum`). Emitting as `Value` keeps this module decoupled and
15//! the output round-trips through any OpenAPI 3.1 consumer.
16
17use indexmap::IndexMap;
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value, json};
20
21use super::annotations::{ApiKeyLocation, AuthRequirement, HttpMethod, HttpParamBinding};
22use super::route::SqlRoute;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct OpenApiInfo {
26    pub title: String,
27    pub version: String,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub description: Option<String>,
30}
31
32impl OpenApiInfo {
33    pub fn new(title: impl Into<String>, version: impl Into<String>) -> Self {
34        Self {
35            title: title.into(),
36            version: version.into(),
37            description: None,
38        }
39    }
40}
41
42/// Build an OpenAPI 3.1 document from a list of SQL-derived routes. The
43/// returned `Value` is ready to be `serde_json::to_writer_pretty`-ed to disk.
44pub fn openapi_from_routes(routes: &[SqlRoute], info: &OpenApiInfo) -> Value {
45    // Collect security schemes (one per distinct auth requirement) so we can
46    // reference them by name on per-operation `security` lists.
47    let (security_schemes, scheme_names) = collect_security_schemes(routes);
48
49    // Group operations by path so multiple methods on the same path share a
50    // `PathItem`. Using IndexMap keeps insertion order stable for snapshot
51    // testing.
52    let mut paths: IndexMap<String, Map<String, Value>> = IndexMap::new();
53    for route in routes {
54        let entry = paths.entry(route.http.path.clone()).or_default();
55        let operation = build_operation(route, &scheme_names);
56        entry.insert(method_key(route.http.method).to_string(), operation);
57    }
58
59    let mut paths_obj = Map::new();
60    for (path, methods) in paths {
61        paths_obj.insert(path, Value::Object(methods));
62    }
63
64    let mut spec = Map::new();
65    spec.insert("openapi".into(), json!("3.1.0"));
66    spec.insert("info".into(), serde_json::to_value(info).expect("info serializes"));
67    spec.insert("paths".into(), Value::Object(paths_obj));
68
69    let mut components = Map::new();
70    if !security_schemes.is_empty() {
71        components.insert("securitySchemes".into(), Value::Object(security_schemes));
72    }
73    if !components.is_empty() {
74        spec.insert("components".into(), Value::Object(components));
75    }
76
77    Value::Object(spec)
78}
79
80fn build_operation(route: &SqlRoute, scheme_names: &std::collections::BTreeMap<AuthRequirement, String>) -> Value {
81    let mut op = Map::new();
82    op.insert("operationId".into(), json!(&route.operation_id));
83
84    if let Some(s) = &route.http.summary {
85        op.insert("summary".into(), json!(s));
86    }
87    if let Some(d) = &route.http.description {
88        op.insert("description".into(), json!(d));
89    }
90    if !route.http.tags.is_empty() {
91        op.insert("tags".into(), json!(&route.http.tags));
92    }
93
94    let parameters = build_parameters(route);
95    if !parameters.is_empty() {
96        op.insert("parameters".into(), Value::Array(parameters));
97    }
98
99    if let Some(request_body) = build_request_body(route) {
100        op.insert("requestBody".into(), request_body);
101    }
102
103    op.insert("responses".into(), build_responses(route));
104
105    if let Some(auth) = &route.http.auth
106        && !matches!(auth, AuthRequirement::None)
107        && let Some(name) = scheme_names.get(auth)
108    {
109        op.insert("security".into(), json!([{ name.as_str(): [] }]));
110    }
111
112    Value::Object(op)
113}
114
115fn build_parameters(route: &SqlRoute) -> Vec<Value> {
116    let mut out = Vec::new();
117    let parameter_schema = &route.metadata["parameter_schema"];
118    let properties = parameter_schema.get("properties").and_then(Value::as_object);
119    let Some(properties) = properties else {
120        return out;
121    };
122    let required: std::collections::HashSet<&str> = parameter_schema
123        .get("required")
124        .and_then(Value::as_array)
125        .map(|arr| arr.iter().filter_map(Value::as_str).collect())
126        .unwrap_or_default();
127
128    for (name, schema) in properties {
129        let location = match route.param_locations.get(name) {
130            Some(HttpParamBinding::Path) => "path",
131            Some(HttpParamBinding::Query) => "query",
132            Some(HttpParamBinding::Header) => "header",
133            _ => continue,
134        };
135        let is_required = location == "path" || required.contains(name.as_str());
136        let mut p = Map::new();
137        p.insert("name".into(), json!(name));
138        p.insert("in".into(), json!(location));
139        p.insert("required".into(), json!(is_required));
140        p.insert("schema".into(), schema.clone());
141        out.push(Value::Object(p));
142    }
143    out
144}
145
146fn build_request_body(route: &SqlRoute) -> Option<Value> {
147    let request_schema = route.metadata.get("request_schema")?;
148    if request_schema.is_null() {
149        return None;
150    }
151    Some(json!({
152        "required": true,
153        "content": {
154            "application/json": { "schema": request_schema }
155        }
156    }))
157}
158
159fn build_responses(route: &SqlRoute) -> Value {
160    let mut responses = Map::new();
161    let response_schema = route.metadata.get("response_schema").cloned().unwrap_or(Value::Null);
162    let codes: Vec<u16> = if route.http.status_codes.is_empty() {
163        vec![route.default_status]
164    } else {
165        route.http.status_codes.clone()
166    };
167    for (idx, code) in codes.iter().enumerate() {
168        let is_primary = idx == 0;
169        let mut body = Map::new();
170        body.insert("description".into(), json!(describe_status(*code)));
171        if is_primary && !response_schema.is_null() && *code != 204 {
172            body.insert(
173                "content".into(),
174                json!({ "application/json": { "schema": response_schema.clone() } }),
175            );
176        }
177        responses.insert(code.to_string(), Value::Object(body));
178    }
179    Value::Object(responses)
180}
181
182const fn describe_status(code: u16) -> &'static str {
183    match code {
184        200 => "OK",
185        201 => "Created",
186        202 => "Accepted",
187        204 => "No Content",
188        400 => "Bad Request",
189        401 => "Unauthorized",
190        403 => "Forbidden",
191        404 => "Not Found",
192        409 => "Conflict",
193        422 => "Unprocessable Entity",
194        500 => "Internal Server Error",
195        _ => "Response",
196    }
197}
198
199fn collect_security_schemes(
200    routes: &[SqlRoute],
201) -> (Map<String, Value>, std::collections::BTreeMap<AuthRequirement, String>) {
202    let mut schemes = Map::new();
203    let mut name_for = std::collections::BTreeMap::new();
204    for route in routes {
205        let Some(auth) = &route.http.auth else { continue };
206        if matches!(auth, AuthRequirement::None) {
207            continue;
208        }
209        if name_for.contains_key(auth) {
210            continue;
211        }
212        let name = match auth {
213            AuthRequirement::None => unreachable!(),
214            AuthRequirement::Bearer { format: None } => "bearerAuth".to_string(),
215            AuthRequirement::Bearer { format: Some(f) } => format!("bearer{}", f.to_uppercase()),
216            AuthRequirement::ApiKey { location, name } => {
217                format!("apiKey_{}_{}", location_short(*location), name.replace('-', "_"))
218            }
219        };
220        let scheme_value = match auth {
221            AuthRequirement::None => unreachable!(),
222            AuthRequirement::Bearer { format } => {
223                let mut s = Map::new();
224                s.insert("type".into(), json!("http"));
225                s.insert("scheme".into(), json!("bearer"));
226                if let Some(f) = format {
227                    s.insert("bearerFormat".into(), json!(f));
228                }
229                Value::Object(s)
230            }
231            AuthRequirement::ApiKey { location, name } => json!({
232                "type": "apiKey",
233                "in": location_str(*location),
234                "name": name,
235            }),
236        };
237        schemes.insert(name.clone(), scheme_value);
238        name_for.insert(auth.clone(), name);
239    }
240    (schemes, name_for)
241}
242
243const fn location_short(loc: ApiKeyLocation) -> &'static str {
244    match loc {
245        ApiKeyLocation::Header => "h",
246        ApiKeyLocation::Query => "q",
247        ApiKeyLocation::Cookie => "c",
248    }
249}
250
251const fn location_str(loc: ApiKeyLocation) -> &'static str {
252    match loc {
253        ApiKeyLocation::Header => "header",
254        ApiKeyLocation::Query => "query",
255        ApiKeyLocation::Cookie => "cookie",
256    }
257}
258
259const fn method_key(m: HttpMethod) -> &'static str {
260    match m {
261        HttpMethod::Get => "get",
262        HttpMethod::Post => "post",
263        HttpMethod::Put => "put",
264        HttpMethod::Patch => "patch",
265        HttpMethod::Delete => "delete",
266        HttpMethod::Head => "head",
267        HttpMethod::Options => "options",
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274    use crate::sql::neutral_to_json_schema::BuildOptions;
275    use crate::sql::route::route_from_query;
276    use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
277    use scythe_core::catalog::Catalog;
278    use scythe_core::parser::{CustomAnnotation, QueryCommand};
279
280    fn empty_catalog() -> Catalog {
281        Catalog::from_ddl(&[]).unwrap()
282    }
283
284    fn get_user_query() -> AnalyzedQuery {
285        AnalyzedQuery {
286            name: "GetUser".to_string(),
287            command: QueryCommand::One,
288            sql: "SELECT id, email FROM users WHERE id = $1".to_string(),
289            columns: vec![
290                AnalyzedColumn {
291                    name: "id".into(),
292                    neutral_type: "int64".into(),
293                    nullable: false,
294                },
295                AnalyzedColumn {
296                    name: "email".into(),
297                    neutral_type: "string".into(),
298                    nullable: false,
299                },
300            ],
301            params: vec![AnalyzedParam {
302                name: "id".into(),
303                neutral_type: "int64".into(),
304                nullable: false,
305                position: 1,
306            }],
307            deprecated: None,
308            source_table: Some("users".into()),
309            composites: vec![],
310            enums: vec![],
311            optional_params: vec![],
312            group_by: None,
313            custom: vec![
314                CustomAnnotation {
315                    name: "http".into(),
316                    value: "GET /users/{id}".into(),
317                    line: 1,
318                },
319                CustomAnnotation {
320                    name: "http_auth".into(),
321                    value: "bearer:jwt".into(),
322                    line: 2,
323                },
324                CustomAnnotation {
325                    name: "http_status".into(),
326                    value: "200,404".into(),
327                    line: 3,
328                },
329                CustomAnnotation {
330                    name: "http_tags".into(),
331                    value: "users".into(),
332                    line: 4,
333                },
334                CustomAnnotation {
335                    name: "http_summary".into(),
336                    value: "Fetch a user".into(),
337                    line: 5,
338                },
339            ],
340        }
341    }
342
343    fn create_user_query() -> AnalyzedQuery {
344        AnalyzedQuery {
345            name: "CreateUser".to_string(),
346            command: QueryCommand::ExecRows,
347            sql: "INSERT INTO users (email) VALUES ($1)".to_string(),
348            columns: vec![],
349            params: vec![AnalyzedParam {
350                name: "email".into(),
351                neutral_type: "string".into(),
352                nullable: false,
353                position: 1,
354            }],
355            deprecated: None,
356            source_table: None,
357            composites: vec![],
358            enums: vec![],
359            optional_params: vec![],
360            group_by: None,
361            custom: vec![
362                CustomAnnotation {
363                    name: "http".into(),
364                    value: "POST /users".into(),
365                    line: 1,
366                },
367                CustomAnnotation {
368                    name: "http_auth".into(),
369                    value: "bearer:jwt".into(),
370                    line: 2,
371                },
372                CustomAnnotation {
373                    name: "http_status".into(),
374                    value: "201".into(),
375                    line: 3,
376                },
377            ],
378        }
379    }
380
381    fn build_two_routes() -> Vec<SqlRoute> {
382        let opts = BuildOptions::default();
383        let r1 = route_from_query(&get_user_query(), &empty_catalog(), &opts)
384            .unwrap()
385            .unwrap();
386        let r2 = route_from_query(&create_user_query(), &empty_catalog(), &opts)
387            .unwrap()
388            .unwrap();
389        vec![r1, r2]
390    }
391
392    #[test]
393    fn emits_openapi_3_1_header() {
394        let routes = build_two_routes();
395        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("test", "0.1.0"));
396        assert_eq!(spec["openapi"], "3.1.0");
397        assert_eq!(spec["info"]["title"], "test");
398        assert_eq!(spec["info"]["version"], "0.1.0");
399    }
400
401    #[test]
402    fn groups_methods_under_shared_path() {
403        let routes = build_two_routes();
404        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
405        // /users has POST; /users/{id} has GET.
406        assert!(spec["paths"]["/users"]["post"].is_object());
407        assert!(spec["paths"]["/users/{id}"]["get"].is_object());
408    }
409
410    #[test]
411    fn operation_carries_operation_id_summary_tags() {
412        let routes = build_two_routes();
413        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
414        let op = &spec["paths"]["/users/{id}"]["get"];
415        assert_eq!(op["operationId"], "GetUser");
416        assert_eq!(op["summary"], "Fetch a user");
417        assert_eq!(op["tags"], json!(["users"]));
418    }
419
420    #[test]
421    fn path_parameter_emitted() {
422        let routes = build_two_routes();
423        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
424        let params = spec["paths"]["/users/{id}"]["get"]["parameters"].as_array().unwrap();
425        assert_eq!(params.len(), 1);
426        assert_eq!(params[0]["name"], "id");
427        assert_eq!(params[0]["in"], "path");
428        assert_eq!(params[0]["required"], true);
429    }
430
431    #[test]
432    fn post_carries_request_body() {
433        let routes = build_two_routes();
434        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
435        let body = &spec["paths"]["/users"]["post"]["requestBody"];
436        assert_eq!(body["required"], true);
437        assert!(body["content"]["application/json"]["schema"]["properties"]["email"].is_object());
438    }
439
440    #[test]
441    fn responses_keyed_by_status_codes() {
442        let routes = build_two_routes();
443        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
444        let resp = &spec["paths"]["/users/{id}"]["get"]["responses"];
445        assert!(resp["200"].is_object());
446        assert!(resp["404"].is_object());
447    }
448
449    #[test]
450    fn primary_response_includes_schema() {
451        let routes = build_two_routes();
452        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
453        let primary = &spec["paths"]["/users/{id}"]["get"]["responses"]["200"];
454        assert!(primary["content"]["application/json"]["schema"]["properties"]["id"].is_object());
455    }
456
457    #[test]
458    fn registers_bearer_security_scheme_once() {
459        let routes = build_two_routes();
460        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
461        let schemes = &spec["components"]["securitySchemes"];
462        // Both routes share `bearer:jwt`, so exactly one scheme is registered.
463        assert_eq!(schemes.as_object().unwrap().len(), 1);
464        let (_name, scheme) = schemes.as_object().unwrap().iter().next().unwrap();
465        assert_eq!(scheme["type"], "http");
466        assert_eq!(scheme["scheme"], "bearer");
467        assert_eq!(scheme["bearerFormat"], "jwt");
468    }
469
470    #[test]
471    fn operations_reference_security_scheme() {
472        let routes = build_two_routes();
473        let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
474        let op = &spec["paths"]["/users/{id}"]["get"];
475        let sec = op["security"].as_array().unwrap();
476        assert_eq!(sec.len(), 1);
477        let scheme_name = sec[0].as_object().unwrap().keys().next().unwrap();
478        // The name must exist in components.securitySchemes.
479        assert!(spec["components"]["securitySchemes"][scheme_name].is_object());
480    }
481
482    #[test]
483    fn no_204_response_carries_body() {
484        let mut q = create_user_query();
485        // Use :exec instead of :exec_rows to get the 204-default path.
486        q.command = QueryCommand::Exec;
487        // adjust the @http_status to omit explicit codes
488        q.custom.retain(|a| a.name != "http_status");
489        let route = route_from_query(&q, &empty_catalog(), &BuildOptions::default())
490            .unwrap()
491            .unwrap();
492        let spec = openapi_from_routes(&[route], &OpenApiInfo::new("t", "1"));
493        let resp = &spec["paths"]["/users"]["post"]["responses"]["204"];
494        assert!(resp["content"].is_null());
495    }
496}