1use anyhow::Result;
7use arc_swap::ArcSwap;
8use dashmap::DashMap;
9use std::collections::HashMap;
10use std::sync::Arc;
11use utoipa::openapi::{
12 OpenApi, OpenApiBuilder, Ref, RefOr, Required,
13 content::ContentBuilder,
14 info::InfoBuilder,
15 path::{
16 HttpMethod, OperationBuilder as UOperationBuilder, ParameterBuilder, ParameterIn,
17 PathItemBuilder, PathsBuilder,
18 },
19 request_body::RequestBodyBuilder,
20 response::{ResponseBuilder, ResponsesBuilder},
21 schema::{ComponentsBuilder, ObjectBuilder, Schema, SchemaFormat, SchemaType},
22 security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
23};
24
25use crate::api::{operation_builder, problem};
26
27type SchemaCollection = Vec<(String, RefOr<Schema>)>;
29
30#[derive(Debug, Clone)]
32pub struct OpenApiInfo {
33 pub title: String,
34 pub version: String,
35 pub description: Option<String>,
36}
37
38impl Default for OpenApiInfo {
39 fn default() -> Self {
40 Self {
41 title: "API Documentation".to_owned(),
42 version: "0.1.0".to_owned(),
43 description: None,
44 }
45 }
46}
47
48pub trait OpenApiRegistry: Send + Sync {
50 fn register_operation(&self, spec: &operation_builder::OperationSpec);
52
53 fn ensure_schema_raw(&self, name: &str, schemas: SchemaCollection) -> String;
57
58 fn as_any(&self) -> &dyn std::any::Any;
60}
61
62pub fn ensure_schema<T: utoipa::ToSchema + utoipa::PartialSchema + 'static>(
64 registry: &dyn OpenApiRegistry,
65) -> String {
66 use utoipa::PartialSchema;
67
68 let root_name = T::name().to_string();
70
71 let mut collected: SchemaCollection = vec![(root_name.clone(), <T as PartialSchema>::schema())];
74
75 T::schemas(&mut collected);
77
78 registry.ensure_schema_raw(&root_name, collected)
80}
81
82pub struct OpenApiRegistryImpl {
84 pub operation_specs: DashMap<String, operation_builder::OperationSpec>,
86 pub components_registry: ArcSwap<HashMap<String, RefOr<Schema>>>,
88}
89
90impl OpenApiRegistryImpl {
91 #[must_use]
93 pub fn new() -> Self {
94 Self {
95 operation_specs: DashMap::new(),
96 components_registry: ArcSwap::from_pointee(HashMap::new()),
97 }
98 }
99
100 pub fn build_openapi(&self, info: &OpenApiInfo) -> Result<OpenApi> {
108 use http::Method;
109
110 let op_count = self.operation_specs.len();
112 tracing::info!("Building OpenAPI: found {op_count} registered operations");
113
114 let mut paths = PathsBuilder::new();
116
117 for spec in self.operation_specs.iter().map(|e| e.value().clone()) {
118 let mut op = UOperationBuilder::new()
119 .operation_id(spec.operation_id.clone().or(Some(spec.handler_id.clone())))
120 .summary(spec.summary.clone())
121 .description(spec.description.clone());
122
123 for tag in &spec.tags {
124 op = op.tag(tag.clone());
125 }
126
127 let mut ext = utoipa::openapi::extensions::Extensions::default();
129
130 if let Some(rl) = spec.rate_limit.as_ref() {
132 ext.insert("x-rate-limit-rps".to_owned(), serde_json::json!(rl.rps));
133 ext.insert("x-rate-limit-burst".to_owned(), serde_json::json!(rl.burst));
134 ext.insert(
135 "x-in-flight-limit".to_owned(),
136 serde_json::json!(rl.in_flight),
137 );
138 }
139
140 if let Some(pagination) = spec.vendor_extensions.x_odata_filter.as_ref()
142 && let Ok(value) = serde_json::to_value(pagination)
143 {
144 ext.insert("x-odata-filter".to_owned(), value);
145 }
146 if let Some(pagination) = spec.vendor_extensions.x_odata_orderby.as_ref()
147 && let Ok(value) = serde_json::to_value(pagination)
148 {
149 ext.insert("x-odata-orderby".to_owned(), value);
150 }
151
152 if !ext.is_empty() {
153 op = op.extensions(Some(ext));
154 }
155
156 for p in &spec.params {
158 let in_ = match p.location {
159 operation_builder::ParamLocation::Path => ParameterIn::Path,
160 operation_builder::ParamLocation::Query => ParameterIn::Query,
161 operation_builder::ParamLocation::Header => ParameterIn::Header,
162 operation_builder::ParamLocation::Cookie => ParameterIn::Cookie,
163 };
164 let required =
165 if matches!(p.location, operation_builder::ParamLocation::Path) || p.required {
166 Required::True
167 } else {
168 Required::False
169 };
170
171 let schema_type = match p.param_type.as_str() {
172 "integer" => SchemaType::Type(utoipa::openapi::schema::Type::Integer),
173 "number" => SchemaType::Type(utoipa::openapi::schema::Type::Number),
174 "boolean" => SchemaType::Type(utoipa::openapi::schema::Type::Boolean),
175 _ => SchemaType::Type(utoipa::openapi::schema::Type::String),
176 };
177 let schema = Schema::Object(ObjectBuilder::new().schema_type(schema_type).build());
178
179 let param = ParameterBuilder::new()
180 .name(&p.name)
181 .parameter_in(in_)
182 .required(required)
183 .description(p.description.clone())
184 .schema(Some(schema))
185 .build();
186
187 op = op.parameter(param);
188 }
189
190 if let Some(rb) = &spec.request_body {
192 let content = match &rb.schema {
193 operation_builder::RequestBodySchema::Ref { schema_name } => {
194 ContentBuilder::new()
195 .schema(Some(RefOr::Ref(Ref::from_schema_name(schema_name.clone()))))
196 .build()
197 }
198 operation_builder::RequestBodySchema::MultipartFile { field_name } => {
199 let file_schema = Schema::Object(
205 ObjectBuilder::new()
206 .schema_type(SchemaType::Type(
207 utoipa::openapi::schema::Type::String,
208 ))
209 .format(Some(SchemaFormat::Custom("binary".into())))
210 .build(),
211 );
212 let obj = ObjectBuilder::new()
213 .property(field_name.clone(), file_schema)
214 .required(field_name.clone());
215 let schema = Schema::Object(obj.build());
216 ContentBuilder::new().schema(Some(schema)).build()
217 }
218 operation_builder::RequestBodySchema::Binary => {
219 let schema = Schema::Object(
222 ObjectBuilder::new()
223 .schema_type(SchemaType::Type(
224 utoipa::openapi::schema::Type::String,
225 ))
226 .format(Some(SchemaFormat::Custom("binary".into())))
227 .build(),
228 );
229
230 ContentBuilder::new().schema(Some(schema)).build()
231 }
232 operation_builder::RequestBodySchema::InlineObject => {
233 ContentBuilder::new()
235 .schema(Some(Schema::Object(ObjectBuilder::new().build())))
236 .build()
237 }
238 };
239 let mut rbld = RequestBodyBuilder::new()
240 .description(rb.description.clone())
241 .content(rb.content_type.to_owned(), content);
242 if rb.required {
243 rbld = rbld.required(Some(Required::True));
244 }
245 op = op.request_body(Some(rbld.build()));
246 }
247
248 let mut responses = ResponsesBuilder::new();
250 for r in &spec.responses {
251 let is_json_like = r.content_type == "application/json"
252 || r.content_type == problem::APPLICATION_PROBLEM_JSON
253 || r.content_type == "text/event-stream";
254 let resp = if is_json_like {
255 if let Some(name) = &r.schema_name {
256 let content = ContentBuilder::new()
258 .schema(Some(RefOr::Ref(Ref::new(format!(
259 "#/components/schemas/{name}"
260 )))))
261 .build();
262 ResponseBuilder::new()
263 .description(&r.description)
264 .content(r.content_type, content)
265 .build()
266 } else {
267 let content = ContentBuilder::new()
268 .schema(Some(Schema::Object(ObjectBuilder::new().build())))
269 .build();
270 ResponseBuilder::new()
271 .description(&r.description)
272 .content(r.content_type, content)
273 .build()
274 }
275 } else {
276 let schema = Schema::Object(
277 ObjectBuilder::new()
278 .schema_type(SchemaType::Type(utoipa::openapi::schema::Type::String))
279 .format(Some(SchemaFormat::Custom(r.content_type.into())))
280 .build(),
281 );
282 let content = ContentBuilder::new().schema(Some(schema)).build();
283 ResponseBuilder::new()
284 .description(&r.description)
285 .content(r.content_type, content)
286 .build()
287 };
288 responses = responses.response(r.status.to_string(), resp);
289 }
290 op = op.responses(responses.build());
291
292 if spec.authenticated {
294 let sec_req = utoipa::openapi::security::SecurityRequirement::new(
295 "bearerAuth",
296 Vec::<String>::new(),
297 );
298 op = op.security(sec_req);
299 }
300
301 let method = match spec.method {
302 Method::POST => HttpMethod::Post,
303 Method::PUT => HttpMethod::Put,
304 Method::DELETE => HttpMethod::Delete,
305 Method::PATCH => HttpMethod::Patch,
306 _ => HttpMethod::Get,
308 };
309
310 let item = PathItemBuilder::new().operation(method, op.build()).build();
311 let openapi_path = operation_builder::axum_to_openapi_path(&spec.path);
313 paths = paths.path(openapi_path, item);
314 }
315
316 let mut components = ComponentsBuilder::new();
318 for (name, schema) in self.components_registry.load().iter() {
319 components = components.schema(name.clone(), schema.clone());
320 }
321
322 components = components.security_scheme(
324 "bearerAuth",
325 SecurityScheme::Http(
326 HttpBuilder::new()
327 .scheme(HttpAuthScheme::Bearer)
328 .bearer_format("JWT")
329 .build(),
330 ),
331 );
332
333 let openapi_info = InfoBuilder::new()
335 .title(&info.title)
336 .version(&info.version)
337 .description(info.description.clone())
338 .build();
339
340 let openapi = OpenApiBuilder::new()
341 .info(openapi_info)
342 .paths(paths.build())
343 .components(Some(components.build()))
344 .build();
345
346 Ok(openapi)
347 }
348}
349
350impl Default for OpenApiRegistryImpl {
351 fn default() -> Self {
352 Self::new()
353 }
354}
355
356impl OpenApiRegistry for OpenApiRegistryImpl {
357 fn register_operation(&self, spec: &operation_builder::OperationSpec) {
358 let operation_key = format!("{}:{}", spec.method.as_str(), spec.path);
359 self.operation_specs
360 .insert(operation_key.clone(), spec.clone());
361
362 tracing::debug!(
363 handler_id = %spec.handler_id,
364 method = %spec.method.as_str(),
365 path = %spec.path,
366 summary = %spec.summary.as_deref().unwrap_or("No summary"),
367 operation_key = %operation_key,
368 "Registered API operation in registry"
369 );
370 }
371
372 fn ensure_schema_raw(&self, root_name: &str, schemas: SchemaCollection) -> String {
373 let current = self.components_registry.load();
375 let mut reg = (**current).clone();
376
377 for (name, schema) in schemas {
378 if let Some(existing) = reg.get(&name) {
380 let a = serde_json::to_value(existing).ok();
381 let b = serde_json::to_value(&schema).ok();
382 if a == b {
383 continue; }
385 tracing::warn!(%name, "Schema content conflict; overriding with latest");
386 }
387 reg.insert(name, schema);
388 }
389
390 self.components_registry.store(Arc::new(reg));
391 root_name.to_owned()
392 }
393
394 fn as_any(&self) -> &dyn std::any::Any {
395 self
396 }
397}
398
399#[cfg(test)]
400#[cfg_attr(coverage_nightly, coverage(off))]
401mod tests {
402 use super::*;
403 use crate::api::operation_builder::{
404 OperationSpec, ParamLocation, ParamSpec, ResponseSpec, VendorExtensions,
405 };
406 use http::Method;
407
408 #[test]
409 fn test_registry_creation() {
410 let registry = OpenApiRegistryImpl::new();
411 assert_eq!(registry.operation_specs.len(), 0);
412 assert_eq!(registry.components_registry.load().len(), 0);
413 }
414
415 #[test]
416 fn test_register_operation() {
417 let registry = OpenApiRegistryImpl::new();
418 let spec = OperationSpec {
419 method: Method::GET,
420 path: "/test".to_owned(),
421 operation_id: Some("test_op".to_owned()),
422 summary: Some("Test operation".to_owned()),
423 description: None,
424 tags: vec![],
425 params: vec![],
426 request_body: None,
427 responses: vec![ResponseSpec {
428 status: 200,
429 content_type: "application/json",
430 description: "Success".to_owned(),
431 schema_name: None,
432 }],
433 handler_id: "get_test".to_owned(),
434 authenticated: false,
435 is_public: false,
436 rate_limit: None,
437 allowed_request_content_types: None,
438 vendor_extensions: VendorExtensions::default(),
439 license_requirement: None,
440 };
441
442 registry.register_operation(&spec);
443 assert_eq!(registry.operation_specs.len(), 1);
444 }
445
446 #[test]
447 fn test_build_empty_openapi() {
448 let registry = OpenApiRegistryImpl::new();
449 let info = OpenApiInfo {
450 title: "Test API".to_owned(),
451 version: "1.0.0".to_owned(),
452 description: Some("Test API Description".to_owned()),
453 };
454 let doc = registry.build_openapi(&info).unwrap();
455 let json = serde_json::to_value(&doc).unwrap();
456
457 assert!(json.get("openapi").is_some());
459 assert!(json.get("info").is_some());
460 assert!(json.get("paths").is_some());
461
462 let openapi_info = json.get("info").unwrap();
464 assert_eq!(openapi_info.get("title").unwrap(), "Test API");
465 assert_eq!(openapi_info.get("version").unwrap(), "1.0.0");
466 assert_eq!(
467 openapi_info.get("description").unwrap(),
468 "Test API Description"
469 );
470 }
471
472 #[test]
473 fn test_build_openapi_with_operation() {
474 let registry = OpenApiRegistryImpl::new();
475 let spec = OperationSpec {
476 method: Method::GET,
477 path: "/users/{id}".to_owned(),
478 operation_id: Some("get_user".to_owned()),
479 summary: Some("Get user by ID".to_owned()),
480 description: Some("Retrieves a user by their ID".to_owned()),
481 tags: vec!["users".to_owned()],
482 params: vec![ParamSpec {
483 name: "id".to_owned(),
484 location: ParamLocation::Path,
485 required: true,
486 description: Some("User ID".to_owned()),
487 param_type: "string".to_owned(),
488 }],
489 request_body: None,
490 responses: vec![ResponseSpec {
491 status: 200,
492 content_type: "application/json",
493 description: "User found".to_owned(),
494 schema_name: None,
495 }],
496 handler_id: "get_users_id".to_owned(),
497 authenticated: false,
498 is_public: false,
499 rate_limit: None,
500 allowed_request_content_types: None,
501 vendor_extensions: VendorExtensions::default(),
502 license_requirement: None,
503 };
504
505 registry.register_operation(&spec);
506 let info = OpenApiInfo::default();
507 let doc = registry.build_openapi(&info).unwrap();
508 let json = serde_json::to_value(&doc).unwrap();
509
510 let paths = json.get("paths").unwrap();
512 assert!(paths.get("/users/{id}").is_some());
513
514 let get_op = paths.get("/users/{id}").unwrap().get("get").unwrap();
516 assert_eq!(get_op.get("operationId").unwrap(), "get_user");
517 assert_eq!(get_op.get("summary").unwrap(), "Get user by ID");
518 }
519
520 #[test]
521 fn test_ensure_schema_raw() {
522 let registry = OpenApiRegistryImpl::new();
523 let schema = Schema::Object(ObjectBuilder::new().build());
524 let schemas = vec![("TestSchema".to_owned(), RefOr::T(schema))];
525
526 let name = registry.ensure_schema_raw("TestSchema", schemas);
527 assert_eq!(name, "TestSchema");
528 assert_eq!(registry.components_registry.load().len(), 1);
529 }
530
531 #[test]
532 fn test_build_openapi_with_binary_request() {
533 use crate::api::operation_builder::RequestBodySchema;
534
535 let registry = OpenApiRegistryImpl::new();
536 let spec = OperationSpec {
537 method: Method::POST,
538 path: "/files/v1/upload".to_owned(),
539 operation_id: Some("upload_file".to_owned()),
540 summary: Some("Upload a file".to_owned()),
541 description: Some("Upload raw binary file".to_owned()),
542 tags: vec!["upload".to_owned()],
543 params: vec![],
544 request_body: Some(crate::api::operation_builder::RequestBodySpec {
545 content_type: "application/octet-stream",
546 description: Some("Raw file bytes".to_owned()),
547 schema: RequestBodySchema::Binary,
548 required: true,
549 }),
550 responses: vec![ResponseSpec {
551 status: 200,
552 content_type: "application/json",
553 description: "Upload successful".to_owned(),
554 schema_name: None,
555 }],
556 handler_id: "post_upload".to_owned(),
557 authenticated: false,
558 is_public: false,
559 rate_limit: None,
560 allowed_request_content_types: Some(vec!["application/octet-stream"]),
561 vendor_extensions: VendorExtensions::default(),
562 license_requirement: None,
563 };
564
565 registry.register_operation(&spec);
566 let info = OpenApiInfo::default();
567 let doc = registry.build_openapi(&info).unwrap();
568 let json = serde_json::to_value(&doc).unwrap();
569
570 let paths = json.get("paths").unwrap();
572 assert!(paths.get("/files/v1/upload").is_some());
573
574 let post_op = paths.get("/files/v1/upload").unwrap().get("post").unwrap();
576 let request_body = post_op.get("requestBody").unwrap();
577 let content = request_body.get("content").unwrap();
578 let octet_stream = content
579 .get("application/octet-stream")
580 .expect("application/octet-stream content type should exist");
581
582 let schema = octet_stream.get("schema").unwrap();
584 assert_eq!(schema.get("type").unwrap(), "string");
585 assert_eq!(schema.get("format").unwrap(), "binary");
586
587 assert_eq!(request_body.get("required").unwrap(), true);
589 }
590
591 #[test]
592 fn test_build_openapi_with_pagination() {
593 let registry = OpenApiRegistryImpl::new();
594
595 let mut filter: operation_builder::ODataPagination<
596 std::collections::BTreeMap<String, Vec<String>>,
597 > = operation_builder::ODataPagination::default();
598 filter.allowed_fields.insert(
599 "name".to_owned(),
600 vec!["eq", "ne", "contains", "startswith", "endswith", "in"]
601 .into_iter()
602 .map(String::from)
603 .collect(),
604 );
605 filter.allowed_fields.insert(
606 "age".to_owned(),
607 vec!["eq", "ne", "gt", "ge", "lt", "le", "in"]
608 .into_iter()
609 .map(String::from)
610 .collect(),
611 );
612
613 let mut order_by: operation_builder::ODataPagination<Vec<String>> =
614 operation_builder::ODataPagination::default();
615 order_by.allowed_fields.push("name asc".to_owned());
616 order_by.allowed_fields.push("name desc".to_owned());
617 order_by.allowed_fields.push("age asc".to_owned());
618 order_by.allowed_fields.push("age desc".to_owned());
619
620 let mut spec = OperationSpec {
621 method: Method::GET,
622 path: "/test".to_owned(),
623 operation_id: Some("test_op".to_owned()),
624 summary: Some("Test".to_owned()),
625 description: None,
626 tags: vec![],
627 params: vec![],
628 request_body: None,
629 responses: vec![ResponseSpec {
630 status: 200,
631 content_type: "application/json",
632 description: "OK".to_owned(),
633 schema_name: None,
634 }],
635 handler_id: "get_test".to_owned(),
636 authenticated: false,
637 is_public: false,
638 rate_limit: None,
639 allowed_request_content_types: None,
640 vendor_extensions: VendorExtensions::default(),
641 license_requirement: None,
642 };
643 spec.vendor_extensions.x_odata_filter = Some(filter);
644 spec.vendor_extensions.x_odata_orderby = Some(order_by);
645
646 registry.register_operation(&spec);
647 let info = OpenApiInfo::default();
648 let doc = registry.build_openapi(&info).unwrap();
649 let json = serde_json::to_value(&doc).unwrap();
650
651 let paths = json.get("paths").unwrap();
652 let op = paths.get("/test").unwrap().get("get").unwrap();
653
654 let filter_ext = op
655 .get("x-odata-filter")
656 .expect("x-odata-filter should be present");
657
658 let allowed_fields = filter_ext.get("allowedFields").unwrap();
659 assert!(allowed_fields.get("name").is_some());
660 assert!(allowed_fields.get("age").is_some());
661
662 let order_ext = op
663 .get("x-odata-orderby")
664 .expect("x-odata-orderby should be present");
665
666 let allowed_order = order_ext.get("allowedFields").unwrap().as_array().unwrap();
667 assert!(allowed_order.iter().any(|v| v.as_str() == Some("name asc")));
668 assert!(allowed_order.iter().any(|v| v.as_str() == Some("age desc")));
669 }
670}