1use indexmap::IndexMap;
10use serde::{Deserialize, Serialize};
11use serde_json::{Map, Value, json};
12
13use super::annotations::{ApiKeyLocation, AuthRequirement, HttpMethod, HttpParamBinding};
14use super::route::SqlRoute;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct OpenApiInfo {
18 pub title: String,
19 pub version: String,
20 #[serde(skip_serializing_if = "Option::is_none")]
21 pub description: Option<String>,
22}
23
24impl OpenApiInfo {
25 pub fn new(title: impl Into<String>, version: impl Into<String>) -> Self {
26 Self {
27 title: title.into(),
28 version: version.into(),
29 description: None,
30 }
31 }
32}
33
34pub fn openapi_from_routes(routes: &[SqlRoute], info: &OpenApiInfo) -> Value {
37 let (security_schemes, scheme_names) = collect_security_schemes(routes);
40
41 let mut paths: IndexMap<String, Map<String, Value>> = IndexMap::new();
45 for route in routes {
46 let entry = paths.entry(route.http.path.clone()).or_default();
47 let operation = build_operation(route, &scheme_names);
48 entry.insert(method_key(route.http.method).to_string(), operation);
49 }
50
51 let mut paths_obj = Map::new();
52 for (path, methods) in paths {
53 paths_obj.insert(path, Value::Object(methods));
54 }
55
56 let mut spec = Map::new();
57 spec.insert("openapi".into(), json!("3.1.0"));
58 spec.insert("info".into(), serde_json::to_value(info).expect("info serializes"));
59 spec.insert("paths".into(), Value::Object(paths_obj));
60
61 let mut components = Map::new();
62 if !security_schemes.is_empty() {
63 components.insert("securitySchemes".into(), Value::Object(security_schemes));
64 }
65 if !components.is_empty() {
66 spec.insert("components".into(), Value::Object(components));
67 }
68
69 Value::Object(spec)
70}
71
72fn build_operation(route: &SqlRoute, scheme_names: &std::collections::BTreeMap<AuthRequirement, String>) -> Value {
73 let mut op = Map::new();
74 op.insert("operationId".into(), json!(&route.operation_id));
75
76 if let Some(s) = &route.http.summary {
77 op.insert("summary".into(), json!(s));
78 }
79 if let Some(d) = &route.http.description {
80 op.insert("description".into(), json!(d));
81 }
82 if !route.http.tags.is_empty() {
83 op.insert("tags".into(), json!(&route.http.tags));
84 }
85
86 let parameters = build_parameters(route);
87 if !parameters.is_empty() {
88 op.insert("parameters".into(), Value::Array(parameters));
89 }
90
91 if let Some(request_body) = build_request_body(route) {
92 op.insert("requestBody".into(), request_body);
93 }
94
95 op.insert("responses".into(), build_responses(route));
96
97 if let Some(auth) = &route.http.auth
98 && !matches!(auth, AuthRequirement::None)
99 && let Some(name) = scheme_names.get(auth)
100 {
101 op.insert("security".into(), json!([{ name.as_str(): [] }]));
102 }
103
104 Value::Object(op)
105}
106
107fn build_parameters(route: &SqlRoute) -> Vec<Value> {
108 let mut out = Vec::new();
109 let parameter_schema = &route.metadata["parameter_schema"];
110 let properties = parameter_schema.get("properties").and_then(Value::as_object);
111 let Some(properties) = properties else {
112 return out;
113 };
114 let required: std::collections::HashSet<&str> = parameter_schema
115 .get("required")
116 .and_then(Value::as_array)
117 .map(|arr| arr.iter().filter_map(Value::as_str).collect())
118 .unwrap_or_default();
119
120 for (name, schema) in properties {
121 let location = match route.param_locations.get(name) {
122 Some(HttpParamBinding::Path) => "path",
123 Some(HttpParamBinding::Query) => "query",
124 Some(HttpParamBinding::Header) => "header",
125 _ => continue,
126 };
127 let is_required = location == "path" || required.contains(name.as_str());
128 let mut p = Map::new();
129 p.insert("name".into(), json!(name));
130 p.insert("in".into(), json!(location));
131 p.insert("required".into(), json!(is_required));
132 p.insert("schema".into(), schema.clone());
133 out.push(Value::Object(p));
134 }
135 out
136}
137
138fn build_request_body(route: &SqlRoute) -> Option<Value> {
139 let request_schema = route.metadata.get("request_schema")?;
140 if request_schema.is_null() {
141 return None;
142 }
143 Some(json!({
144 "required": true,
145 "content": {
146 "application/json": { "schema": request_schema }
147 }
148 }))
149}
150
151fn build_responses(route: &SqlRoute) -> Value {
152 let mut responses = Map::new();
153 let response_schema = route.metadata.get("response_schema").cloned().unwrap_or(Value::Null);
154 let codes: Vec<u16> = if route.http.status_codes.is_empty() {
155 vec![route.default_status]
156 } else {
157 route.http.status_codes.clone()
158 };
159 for (idx, code) in codes.iter().enumerate() {
160 let is_primary = idx == 0;
161 let mut body = Map::new();
162 body.insert("description".into(), json!(describe_status(*code)));
163 if is_primary && !response_schema.is_null() && *code != 204 {
164 body.insert(
165 "content".into(),
166 json!({ "application/json": { "schema": response_schema.clone() } }),
167 );
168 }
169 responses.insert(code.to_string(), Value::Object(body));
170 }
171 Value::Object(responses)
172}
173
174const fn describe_status(code: u16) -> &'static str {
175 match code {
176 200 => "OK",
177 201 => "Created",
178 202 => "Accepted",
179 204 => "No Content",
180 400 => "Bad Request",
181 401 => "Unauthorized",
182 403 => "Forbidden",
183 404 => "Not Found",
184 409 => "Conflict",
185 422 => "Unprocessable Entity",
186 500 => "Internal Server Error",
187 _ => "Response",
188 }
189}
190
191fn collect_security_schemes(
192 routes: &[SqlRoute],
193) -> (Map<String, Value>, std::collections::BTreeMap<AuthRequirement, String>) {
194 let mut schemes = Map::new();
195 let mut name_for = std::collections::BTreeMap::new();
196 for route in routes {
197 let Some(auth) = &route.http.auth else { continue };
198 if matches!(auth, AuthRequirement::None) {
199 continue;
200 }
201 if name_for.contains_key(auth) {
202 continue;
203 }
204 let name = match auth {
205 AuthRequirement::None => unreachable!(),
206 AuthRequirement::Bearer { format: None } => "bearerAuth".to_string(),
207 AuthRequirement::Bearer { format: Some(f) } => format!("bearer{}", f.to_uppercase()),
208 AuthRequirement::ApiKey { location, name } => {
209 format!("apiKey_{}_{}", location_short(*location), name.replace('-', "_"))
210 }
211 };
212 let scheme_value = match auth {
213 AuthRequirement::None => unreachable!(),
214 AuthRequirement::Bearer { format } => {
215 let mut s = Map::new();
216 s.insert("type".into(), json!("http"));
217 s.insert("scheme".into(), json!("bearer"));
218 if let Some(f) = format {
219 s.insert("bearerFormat".into(), json!(f));
220 }
221 Value::Object(s)
222 }
223 AuthRequirement::ApiKey { location, name } => json!({
224 "type": "apiKey",
225 "in": location_str(*location),
226 "name": name,
227 }),
228 };
229 schemes.insert(name.clone(), scheme_value);
230 name_for.insert(auth.clone(), name);
231 }
232 (schemes, name_for)
233}
234
235const fn location_short(loc: ApiKeyLocation) -> &'static str {
236 match loc {
237 ApiKeyLocation::Header => "h",
238 ApiKeyLocation::Query => "q",
239 ApiKeyLocation::Cookie => "c",
240 }
241}
242
243const fn location_str(loc: ApiKeyLocation) -> &'static str {
244 match loc {
245 ApiKeyLocation::Header => "header",
246 ApiKeyLocation::Query => "query",
247 ApiKeyLocation::Cookie => "cookie",
248 }
249}
250
251const fn method_key(m: HttpMethod) -> &'static str {
252 match m {
253 HttpMethod::Get => "get",
254 HttpMethod::Post => "post",
255 HttpMethod::Put => "put",
256 HttpMethod::Patch => "patch",
257 HttpMethod::Delete => "delete",
258 HttpMethod::Head => "head",
259 HttpMethod::Options => "options",
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use crate::sql::neutral_to_json_schema::BuildOptions;
267 use crate::sql::route::route_from_query;
268 use scythe_core::analyzer::{AnalyzedColumn, AnalyzedParam, AnalyzedQuery};
269 use scythe_core::catalog::Catalog;
270 use scythe_core::parser::{CustomAnnotation, QueryCommand};
271
272 fn empty_catalog() -> Catalog {
273 Catalog::from_ddl(&[]).unwrap()
274 }
275
276 fn get_user_query() -> AnalyzedQuery {
277 AnalyzedQuery {
278 name: "GetUser".to_string(),
279 command: QueryCommand::One,
280 sql: "SELECT id, email FROM users WHERE id = $1".to_string(),
281 columns: vec![
282 AnalyzedColumn {
283 name: "id".into(),
284 neutral_type: "int64".into(),
285 nullable: false,
286 },
287 AnalyzedColumn {
288 name: "email".into(),
289 neutral_type: "string".into(),
290 nullable: false,
291 },
292 ],
293 params: vec![AnalyzedParam {
294 name: "id".into(),
295 neutral_type: "int64".into(),
296 nullable: false,
297 position: 1,
298 }],
299 deprecated: None,
300 source_table: Some("users".into()),
301 composites: vec![],
302 enums: vec![],
303 optional_params: vec![],
304 group_by: None,
305 custom: vec![
306 CustomAnnotation {
307 name: "http".into(),
308 value: "GET /users/{id}".into(),
309 line: 1,
310 },
311 CustomAnnotation {
312 name: "http_auth".into(),
313 value: "bearer:jwt".into(),
314 line: 2,
315 },
316 CustomAnnotation {
317 name: "http_status".into(),
318 value: "200,404".into(),
319 line: 3,
320 },
321 CustomAnnotation {
322 name: "http_tags".into(),
323 value: "users".into(),
324 line: 4,
325 },
326 CustomAnnotation {
327 name: "http_summary".into(),
328 value: "Fetch a user".into(),
329 line: 5,
330 },
331 ],
332 }
333 }
334
335 fn create_user_query() -> AnalyzedQuery {
336 AnalyzedQuery {
337 name: "CreateUser".to_string(),
338 command: QueryCommand::ExecRows,
339 sql: "INSERT INTO users (email) VALUES ($1)".to_string(),
340 columns: vec![],
341 params: vec![AnalyzedParam {
342 name: "email".into(),
343 neutral_type: "string".into(),
344 nullable: false,
345 position: 1,
346 }],
347 deprecated: None,
348 source_table: None,
349 composites: vec![],
350 enums: vec![],
351 optional_params: vec![],
352 group_by: None,
353 custom: vec![
354 CustomAnnotation {
355 name: "http".into(),
356 value: "POST /users".into(),
357 line: 1,
358 },
359 CustomAnnotation {
360 name: "http_auth".into(),
361 value: "bearer:jwt".into(),
362 line: 2,
363 },
364 CustomAnnotation {
365 name: "http_status".into(),
366 value: "201".into(),
367 line: 3,
368 },
369 ],
370 }
371 }
372
373 fn build_two_routes() -> Vec<SqlRoute> {
374 let opts = BuildOptions::default();
375 let r1 = route_from_query(&get_user_query(), &empty_catalog(), &opts)
376 .unwrap()
377 .unwrap();
378 let r2 = route_from_query(&create_user_query(), &empty_catalog(), &opts)
379 .unwrap()
380 .unwrap();
381 vec![r1, r2]
382 }
383
384 #[test]
385 fn emits_openapi_3_1_header() {
386 let routes = build_two_routes();
387 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("test", "0.1.0"));
388 assert_eq!(spec["openapi"], "3.1.0");
389 assert_eq!(spec["info"]["title"], "test");
390 assert_eq!(spec["info"]["version"], "0.1.0");
391 }
392
393 #[test]
394 fn groups_methods_under_shared_path() {
395 let routes = build_two_routes();
396 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
397 assert!(spec["paths"]["/users"]["post"].is_object());
399 assert!(spec["paths"]["/users/{id}"]["get"].is_object());
400 }
401
402 #[test]
403 fn operation_carries_operation_id_summary_tags() {
404 let routes = build_two_routes();
405 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
406 let op = &spec["paths"]["/users/{id}"]["get"];
407 assert_eq!(op["operationId"], "GetUser");
408 assert_eq!(op["summary"], "Fetch a user");
409 assert_eq!(op["tags"], json!(["users"]));
410 }
411
412 #[test]
413 fn path_parameter_emitted() {
414 let routes = build_two_routes();
415 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
416 let params = spec["paths"]["/users/{id}"]["get"]["parameters"].as_array().unwrap();
417 assert_eq!(params.len(), 1);
418 assert_eq!(params[0]["name"], "id");
419 assert_eq!(params[0]["in"], "path");
420 assert_eq!(params[0]["required"], true);
421 }
422
423 #[test]
424 fn post_carries_request_body() {
425 let routes = build_two_routes();
426 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
427 let body = &spec["paths"]["/users"]["post"]["requestBody"];
428 assert_eq!(body["required"], true);
429 assert!(body["content"]["application/json"]["schema"]["properties"]["email"].is_object());
430 }
431
432 #[test]
433 fn responses_keyed_by_status_codes() {
434 let routes = build_two_routes();
435 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
436 let resp = &spec["paths"]["/users/{id}"]["get"]["responses"];
437 assert!(resp["200"].is_object());
438 assert!(resp["404"].is_object());
439 }
440
441 #[test]
442 fn primary_response_includes_schema() {
443 let routes = build_two_routes();
444 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
445 let primary = &spec["paths"]["/users/{id}"]["get"]["responses"]["200"];
446 assert!(primary["content"]["application/json"]["schema"]["properties"]["id"].is_object());
447 }
448
449 #[test]
450 fn registers_bearer_security_scheme_once() {
451 let routes = build_two_routes();
452 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
453 let schemes = &spec["components"]["securitySchemes"];
454 assert_eq!(schemes.as_object().unwrap().len(), 1);
456 let (_name, scheme) = schemes.as_object().unwrap().iter().next().unwrap();
457 assert_eq!(scheme["type"], "http");
458 assert_eq!(scheme["scheme"], "bearer");
459 assert_eq!(scheme["bearerFormat"], "jwt");
460 }
461
462 #[test]
463 fn operations_reference_security_scheme() {
464 let routes = build_two_routes();
465 let spec = openapi_from_routes(&routes, &OpenApiInfo::new("t", "1"));
466 let op = &spec["paths"]["/users/{id}"]["get"];
467 let sec = op["security"].as_array().unwrap();
468 assert_eq!(sec.len(), 1);
469 let scheme_name = sec[0].as_object().unwrap().keys().next().unwrap();
470 assert!(spec["components"]["securitySchemes"][scheme_name].is_object());
472 }
473
474 #[test]
475 fn no_204_response_carries_body() {
476 let mut q = create_user_query();
477 q.command = QueryCommand::Exec;
479 q.custom.retain(|a| a.name != "http_status");
481 let route = route_from_query(&q, &empty_catalog(), &BuildOptions::default())
482 .unwrap()
483 .unwrap();
484 let spec = openapi_from_routes(&[route], &OpenApiInfo::new("t", "1"));
485 let resp = &spec["paths"]["/users"]["post"]["responses"]["204"];
486 assert!(resp["content"].is_null());
487 }
488}