1#![allow(
2 clippy::missing_errors_doc,
3 clippy::missing_panics_doc,
4 clippy::must_use_candidate,
5 clippy::doc_markdown,
6 clippy::too_long_first_doc_paragraph,
7 clippy::module_name_repetitions
8)]
9use indexmap::IndexMap;
18use serde::{Deserialize, Serialize};
19use serde_json::{Map, Value, json};
20
21use super::annotations::{ApiKeyLocation, AuthRequirement, HttpMethod, HttpParamBinding};
22use super::route::SqlRoute;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct OpenApiInfo {
26 pub title: String,
27 pub version: String,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub description: Option<String>,
30}
31
32impl OpenApiInfo {
33 pub fn new(title: impl Into<String>, version: impl Into<String>) -> Self {
34 Self {
35 title: title.into(),
36 version: version.into(),
37 description: None,
38 }
39 }
40}
41
42pub fn openapi_from_routes(routes: &[SqlRoute], info: &OpenApiInfo) -> Value {
45 let (security_schemes, scheme_names) = collect_security_schemes(routes);
48
49 let mut paths: IndexMap<String, Map<String, Value>> = IndexMap::new();
53 for route in routes {
54 let entry = paths.entry(route.http.path.clone()).or_default();
55 let operation = build_operation(route, &scheme_names);
56 entry.insert(method_key(route.http.method).to_string(), operation);
57 }
58
59 let mut paths_obj = Map::new();
60 for (path, methods) in paths {
61 paths_obj.insert(path, Value::Object(methods));
62 }
63
64 let mut spec = Map::new();
65 spec.insert("openapi".into(), json!("3.1.0"));
66 spec.insert("info".into(), serde_json::to_value(info).expect("info serializes"));
67 spec.insert("paths".into(), Value::Object(paths_obj));
68
69 let mut components = Map::new();
70 if !security_schemes.is_empty() {
71 components.insert("securitySchemes".into(), Value::Object(security_schemes));
72 }
73 if !components.is_empty() {
74 spec.insert("components".into(), Value::Object(components));
75 }
76
77 Value::Object(spec)
78}
79
80fn build_operation(route: &SqlRoute, scheme_names: &std::collections::BTreeMap<AuthRequirement, String>) -> Value {
81 let mut op = Map::new();
82 op.insert("operationId".into(), json!(&route.operation_id));
83
84 if let Some(s) = &route.http.summary {
85 op.insert("summary".into(), json!(s));
86 }
87 if let Some(d) = &route.http.description {
88 op.insert("description".into(), json!(d));
89 }
90 if !route.http.tags.is_empty() {
91 op.insert("tags".into(), json!(&route.http.tags));
92 }
93
94 let parameters = build_parameters(route);
95 if !parameters.is_empty() {
96 op.insert("parameters".into(), Value::Array(parameters));
97 }
98
99 if let Some(request_body) = build_request_body(route) {
100 op.insert("requestBody".into(), request_body);
101 }
102
103 op.insert("responses".into(), build_responses(route));
104
105 if let Some(auth) = &route.http.auth
106 && !matches!(auth, AuthRequirement::None)
107 && let Some(name) = scheme_names.get(auth)
108 {
109 op.insert("security".into(), json!([{ name.as_str(): [] }]));
110 }
111
112 Value::Object(op)
113}
114
115fn build_parameters(route: &SqlRoute) -> Vec<Value> {
116 let mut out = Vec::new();
117 let parameter_schema = &route.metadata["parameter_schema"];
118 let properties = parameter_schema.get("properties").and_then(Value::as_object);
119 let Some(properties) = properties else {
120 return out;
121 };
122 let required: std::collections::HashSet<&str> = parameter_schema
123 .get("required")
124 .and_then(Value::as_array)
125 .map(|arr| arr.iter().filter_map(Value::as_str).collect())
126 .unwrap_or_default();
127
128 for (name, schema) in properties {
129 let location = match route.param_locations.get(name) {
130 Some(HttpParamBinding::Path) => "path",
131 Some(HttpParamBinding::Query) => "query",
132 Some(HttpParamBinding::Header) => "header",
133 _ => continue,
134 };
135 let is_required = location == "path" || required.contains(name.as_str());
136 let mut p = Map::new();
137 p.insert("name".into(), json!(name));
138 p.insert("in".into(), json!(location));
139 p.insert("required".into(), json!(is_required));
140 p.insert("schema".into(), schema.clone());
141 out.push(Value::Object(p));
142 }
143 out
144}
145
146fn build_request_body(route: &SqlRoute) -> Option<Value> {
147 let request_schema = route.metadata.get("request_schema")?;
148 if request_schema.is_null() {
149 return None;
150 }
151 Some(json!({
152 "required": true,
153 "content": {
154 "application/json": { "schema": request_schema }
155 }
156 }))
157}
158
159fn build_responses(route: &SqlRoute) -> Value {
160 let mut responses = Map::new();
161 let response_schema = route.metadata.get("response_schema").cloned().unwrap_or(Value::Null);
162 let codes: Vec<u16> = if route.http.status_codes.is_empty() {
163 vec![route.default_status]
164 } else {
165 route.http.status_codes.clone()
166 };
167 for (idx, code) in codes.iter().enumerate() {
168 let is_primary = idx == 0;
169 let mut body = Map::new();
170 body.insert("description".into(), json!(describe_status(*code)));
171 if is_primary && !response_schema.is_null() && *code != 204 {
172 body.insert(
173 "content".into(),
174 json!({ "application/json": { "schema": response_schema.clone() } }),
175 );
176 }
177 responses.insert(code.to_string(), Value::Object(body));
178 }
179 Value::Object(responses)
180}
181
182const fn describe_status(code: u16) -> &'static str {
183 match code {
184 200 => "OK",
185 201 => "Created",
186 202 => "Accepted",
187 204 => "No Content",
188 400 => "Bad Request",
189 401 => "Unauthorized",
190 403 => "Forbidden",
191 404 => "Not Found",
192 409 => "Conflict",
193 422 => "Unprocessable Entity",
194 500 => "Internal Server Error",
195 _ => "Response",
196 }
197}
198
199fn collect_security_schemes(
200 routes: &[SqlRoute],
201) -> (Map<String, Value>, std::collections::BTreeMap<AuthRequirement, String>) {
202 let mut schemes = Map::new();
203 let mut name_for = std::collections::BTreeMap::new();
204 for route in routes {
205 let Some(auth) = &route.http.auth else { continue };
206 if matches!(auth, AuthRequirement::None) {
207 continue;
208 }
209 if name_for.contains_key(auth) {
210 continue;
211 }
212 let name = match auth {
213 AuthRequirement::None => unreachable!(),
214 AuthRequirement::Bearer { format: None } => "bearerAuth".to_string(),
215 AuthRequirement::Bearer { format: Some(f) } => format!("bearer{}", f.to_uppercase()),
216 AuthRequirement::ApiKey { location, name } => {
217 format!("apiKey_{}_{}", location_short(*location), name.replace('-', "_"))
218 }
219 };
220 let scheme_value = match auth {
221 AuthRequirement::None => unreachable!(),
222 AuthRequirement::Bearer { format } => {
223 let mut s = Map::new();
224 s.insert("type".into(), json!("http"));
225 s.insert("scheme".into(), json!("bearer"));
226 if let Some(f) = format {
227 s.insert("bearerFormat".into(), json!(f));
228 }
229 Value::Object(s)
230 }
231 AuthRequirement::ApiKey { location, name } => json!({
232 "type": "apiKey",
233 "in": location_str(*location),
234 "name": name,
235 }),
236 };
237 schemes.insert(name.clone(), scheme_value);
238 name_for.insert(auth.clone(), name);
239 }
240 (schemes, name_for)
241}
242
243const fn location_short(loc: ApiKeyLocation) -> &'static str {
244 match loc {
245 ApiKeyLocation::Header => "h",
246 ApiKeyLocation::Query => "q",
247 ApiKeyLocation::Cookie => "c",
248 }
249}
250
251const fn location_str(loc: ApiKeyLocation) -> &'static str {
252 match loc {
253 ApiKeyLocation::Header => "header",
254 ApiKeyLocation::Query => "query",
255 ApiKeyLocation::Cookie => "cookie",
256 }
257}
258
259const fn method_key(m: HttpMethod) -> &'static str {
260 match m {
261 HttpMethod::Get => "get",
262 HttpMethod::Post => "post",
263 HttpMethod::Put => "put",
264 HttpMethod::Patch => "patch",
265 HttpMethod::Delete => "delete",
266 HttpMethod::Head => "head",
267 HttpMethod::Options => "options",
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274 use crate::sql::neutral_to_json_schema::BuildOptions;
275 use crate::sql::route::route_from_query;
276 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
277 use scythe_core::catalog::Catalog;
278 use scythe_core::parser::{CustomAnnotation, QueryCommand};
279
280 fn empty_catalog() -> Catalog {
281 Catalog::from_ddl(&[]).unwrap()
282 }
283
284 fn get_user_query() -> AnalyzedQuery {
285 AnalyzedQuery {
286 name: "GetUser".to_string(),
287 command: QueryCommand::One,
288 sql: "SELECT id, email FROM users WHERE id = $1".to_string(),
289 columns: vec![
290 AnalyzedColumn {
291 name: "id".into(),
292 neutral_type: "int64".into(),
293 nullable: false,
294 },
295 AnalyzedColumn {
296 name: "email".into(),
297 neutral_type: "string".into(),
298 nullable: false,
299 },
300 ],
301 params: vec![AnalyzedParam {
302 name: "id".into(),
303 neutral_type: "int64".into(),
304 nullable: false,
305 position: 1,
306 }],
307 deprecated: None,
308 source_table: Some("users".into()),
309 composites: vec![],
310 enums: vec![],
311 optional_params: vec![],
312 group_by: None,
313 custom: vec![
314 CustomAnnotation {
315 name: "http".into(),
316 value: "GET /users/{id}".into(),
317 line: 1,
318 },
319 CustomAnnotation {
320 name: "http_auth".into(),
321 value: "bearer:jwt".into(),
322 line: 2,
323 },
324 CustomAnnotation {
325 name: "http_status".into(),
326 value: "200,404".into(),
327 line: 3,
328 },
329 CustomAnnotation {
330 name: "http_tags".into(),
331 value: "users".into(),
332 line: 4,
333 },
334 CustomAnnotation {
335 name: "http_summary".into(),
336 value: "Fetch a user".into(),
337 line: 5,
338 },
339 ],
340 }
341 }
342
343 fn create_user_query() -> AnalyzedQuery {
344 AnalyzedQuery {
345 name: "CreateUser".to_string(),
346 command: QueryCommand::ExecRows,
347 sql: "INSERT INTO users (email) VALUES ($1)".to_string(),
348 columns: vec![],
349 params: vec![AnalyzedParam {
350 name: "email".into(),
351 neutral_type: "string".into(),
352 nullable: false,
353 position: 1,
354 }],
355 deprecated: None,
356 source_table: None,
357 composites: vec![],
358 enums: vec![],
359 optional_params: vec![],
360 group_by: None,
361 custom: vec![
362 CustomAnnotation {
363 name: "http".into(),
364 value: "POST /users".into(),
365 line: 1,
366 },
367 CustomAnnotation {
368 name: "http_auth".into(),
369 value: "bearer:jwt".into(),
370 line: 2,
371 },
372 CustomAnnotation {
373 name: "http_status".into(),
374 value: "201".into(),
375 line: 3,
376 },
377 ],
378 }
379 }
380
381 fn build_two_routes() -> Vec<SqlRoute> {
382 let opts = BuildOptions::default();
383 let r1 = route_from_query(&get_user_query(), &empty_catalog(), &opts)
384 .unwrap()
385 .unwrap();
386 let r2 = route_from_query(&create_user_query(), &empty_catalog(), &opts)
387 .unwrap()
388 .unwrap();
389 vec![r1, r2]
390 }
391
392 #[test]
393 fn emits_openapi_3_1_header() {
394 let routes = build_two_routes();
395 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("test", "0.1.0"));
396 assert_eq!(spec["openapi"], "3.1.0");
397 assert_eq!(spec["info"]["title"], "test");
398 assert_eq!(spec["info"]["version"], "0.1.0");
399 }
400
401 #[test]
402 fn groups_methods_under_shared_path() {
403 let routes = build_two_routes();
404 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
405 assert!(spec["paths"]["/users"]["post"].is_object());
407 assert!(spec["paths"]["/users/{id}"]["get"].is_object());
408 }
409
410 #[test]
411 fn operation_carries_operation_id_summary_tags() {
412 let routes = build_two_routes();
413 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
414 let op = &spec["paths"]["/users/{id}"]["get"];
415 assert_eq!(op["operationId"], "GetUser");
416 assert_eq!(op["summary"], "Fetch a user");
417 assert_eq!(op["tags"], json!(["users"]));
418 }
419
420 #[test]
421 fn path_parameter_emitted() {
422 let routes = build_two_routes();
423 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
424 let params = spec["paths"]["/users/{id}"]["get"]["parameters"].as_array().unwrap();
425 assert_eq!(params.len(), 1);
426 assert_eq!(params[0]["name"], "id");
427 assert_eq!(params[0]["in"], "path");
428 assert_eq!(params[0]["required"], true);
429 }
430
431 #[test]
432 fn post_carries_request_body() {
433 let routes = build_two_routes();
434 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
435 let body = &spec["paths"]["/users"]["post"]["requestBody"];
436 assert_eq!(body["required"], true);
437 assert!(body["content"]["application/json"]["schema"]["properties"]["email"].is_object());
438 }
439
440 #[test]
441 fn responses_keyed_by_status_codes() {
442 let routes = build_two_routes();
443 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
444 let resp = &spec["paths"]["/users/{id}"]["get"]["responses"];
445 assert!(resp["200"].is_object());
446 assert!(resp["404"].is_object());
447 }
448
449 #[test]
450 fn primary_response_includes_schema() {
451 let routes = build_two_routes();
452 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
453 let primary = &spec["paths"]["/users/{id}"]["get"]["responses"]["200"];
454 assert!(primary["content"]["application/json"]["schema"]["properties"]["id"].is_object());
455 }
456
457 #[test]
458 fn registers_bearer_security_scheme_once() {
459 let routes = build_two_routes();
460 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
461 let schemes = &spec["components"]["securitySchemes"];
462 assert_eq!(schemes.as_object().unwrap().len(), 1);
464 let (_name, scheme) = schemes.as_object().unwrap().iter().next().unwrap();
465 assert_eq!(scheme["type"], "http");
466 assert_eq!(scheme["scheme"], "bearer");
467 assert_eq!(scheme["bearerFormat"], "jwt");
468 }
469
470 #[test]
471 fn operations_reference_security_scheme() {
472 let routes = build_two_routes();
473 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
474 let op = &spec["paths"]["/users/{id}"]["get"];
475 let sec = op["security"].as_array().unwrap();
476 assert_eq!(sec.len(), 1);
477 let scheme_name = sec[0].as_object().unwrap().keys().next().unwrap();
478 assert!(spec["components"]["securitySchemes"][scheme_name].is_object());
480 }
481
482 #[test]
483 fn no_204_response_carries_body() {
484 let mut q = create_user_query();
485 q.command = QueryCommand::Exec;
487 q.custom.retain(|a| a.name != "http_status");
489 let route = route_from_query(&q, &empty_catalog(), &BuildOptions::default())
490 .unwrap()
491 .unwrap();
492 let spec = openapi_from_routes(&[route], &OpenApiInfo::new("t", "1"));
493 let resp = &spec["paths"]["/users"]["post"]["responses"]["204"];
494 assert!(resp["content"].is_null());
495 }
496}