Skip to main content

modkit/api/
openapi_registry.rs

1//! `OpenAPI` registry for schema and operation management
2//!
3//! This module provides a standalone `OpenAPI` registry that collects operation specs
4//! and schemas, and builds a complete `OpenAPI` document from them.
5
6use anyhow::Result;
7use arc_swap::ArcSwap;
8use dashmap::DashMap;
9use std::collections::HashMap;
10use std::sync::Arc;
11use utoipa::openapi::{
12    OpenApi, OpenApiBuilder, Ref, RefOr, Required,
13    content::ContentBuilder,
14    info::InfoBuilder,
15    path::{
16        HttpMethod, OperationBuilder as UOperationBuilder, ParameterBuilder, ParameterIn,
17        PathItemBuilder, PathsBuilder,
18    },
19    request_body::RequestBodyBuilder,
20    response::{ResponseBuilder, ResponsesBuilder},
21    schema::{ComponentsBuilder, ObjectBuilder, Schema, SchemaFormat, SchemaType},
22    security::{HttpAuthScheme, HttpBuilder, SecurityScheme},
23};
24
25use crate::api::{operation_builder, problem};
26
27/// Type alias for schema collections used in API operations.
28type SchemaCollection = Vec<(String, RefOr<Schema>)>;
29
30/// `OpenAPI` document metadata (title, version, description)
31#[derive(Debug, Clone)]
32pub struct OpenApiInfo {
33    pub title: String,
34    pub version: String,
35    pub description: Option<String>,
36}
37
38impl Default for OpenApiInfo {
39    fn default() -> Self {
40        Self {
41            title: "API Documentation".to_owned(),
42            version: "0.1.0".to_owned(),
43            description: None,
44        }
45    }
46}
47
48/// `OpenAPI` registry trait for operation and schema registration
49pub trait OpenApiRegistry: Send + Sync {
50    /// Register an API operation specification
51    fn register_operation(&self, spec: &operation_builder::OperationSpec);
52
53    /// Ensure schema for a type (including transitive dependencies) is registered
54    /// under components and return the canonical component name for `$ref`.
55    /// This is a type-erased version for dyn compatibility.
56    fn ensure_schema_raw(&self, name: &str, schemas: SchemaCollection) -> String;
57
58    /// Downcast support for accessing the concrete implementation if needed.
59    fn as_any(&self) -> &dyn std::any::Any;
60}
61
62/// Helper function to call `ensure_schema` with proper type information
63pub fn ensure_schema<T: utoipa::ToSchema + utoipa::PartialSchema + 'static>(
64    registry: &dyn OpenApiRegistry,
65) -> String {
66    use utoipa::PartialSchema;
67
68    // 1) Canonical component name for T as seen by utoipa
69    let root_name = T::name().to_string();
70
71    // 2) Always insert T's own schema first (actual object, not a ref)
72    //    This avoids self-referential components.
73    let mut collected: SchemaCollection = vec![(root_name.clone(), <T as PartialSchema>::schema())];
74
75    // 3) Collect and append all referenced schemas (dependencies) of T
76    T::schemas(&mut collected);
77
78    // 4) Pass to registry for insertion
79    registry.ensure_schema_raw(&root_name, collected)
80}
81
82/// Implementation of `OpenAPI` registry with lock-free data structures
83pub struct OpenApiRegistryImpl {
84    /// Store operation specs keyed by "METHOD:path"
85    pub operation_specs: DashMap<String, operation_builder::OperationSpec>,
86    /// Store schema components using arc-swap for lock-free reads
87    pub components_registry: ArcSwap<HashMap<String, RefOr<Schema>>>,
88}
89
90impl OpenApiRegistryImpl {
91    /// Create a new empty registry
92    #[must_use]
93    pub fn new() -> Self {
94        Self {
95            operation_specs: DashMap::new(),
96            components_registry: ArcSwap::from_pointee(HashMap::new()),
97        }
98    }
99
100    /// Build `OpenAPI` specification from registered operations and components.
101    ///
102    /// # Arguments
103    /// * `info` - `OpenAPI` document metadata (title, version, description)
104    ///
105    /// # Errors
106    /// Returns an error if the `OpenAPI` specification cannot be built.
107    pub fn build_openapi(&self, info: &OpenApiInfo) -> Result<OpenApi> {
108        use http::Method;
109
110        // Log operation count for visibility
111        let op_count = self.operation_specs.len();
112        tracing::info!("Building OpenAPI: found {op_count} registered operations");
113
114        // 1) Paths
115        let mut paths = PathsBuilder::new();
116
117        for spec in self.operation_specs.iter().map(|e| e.value().clone()) {
118            let mut op = UOperationBuilder::new()
119                .operation_id(spec.operation_id.clone().or(Some(spec.handler_id.clone())))
120                .summary(spec.summary.clone())
121                .description(spec.description.clone());
122
123            for tag in &spec.tags {
124                op = op.tag(tag.clone());
125            }
126
127            // Vendor extensions
128            let mut ext = utoipa::openapi::extensions::Extensions::default();
129
130            // Rate limit
131            if let Some(rl) = spec.rate_limit.as_ref() {
132                ext.insert("x-rate-limit-rps".to_owned(), serde_json::json!(rl.rps));
133                ext.insert("x-rate-limit-burst".to_owned(), serde_json::json!(rl.burst));
134                ext.insert(
135                    "x-in-flight-limit".to_owned(),
136                    serde_json::json!(rl.in_flight),
137                );
138            }
139
140            // Pagination
141            if let Some(pagination) = spec.vendor_extensions.x_odata_filter.as_ref()
142                && let Ok(value) = serde_json::to_value(pagination)
143            {
144                ext.insert("x-odata-filter".to_owned(), value);
145            }
146            if let Some(pagination) = spec.vendor_extensions.x_odata_orderby.as_ref()
147                && let Ok(value) = serde_json::to_value(pagination)
148            {
149                ext.insert("x-odata-orderby".to_owned(), value);
150            }
151
152            if !ext.is_empty() {
153                op = op.extensions(Some(ext));
154            }
155
156            // Parameters
157            for p in &spec.params {
158                let in_ = match p.location {
159                    operation_builder::ParamLocation::Path => ParameterIn::Path,
160                    operation_builder::ParamLocation::Query => ParameterIn::Query,
161                    operation_builder::ParamLocation::Header => ParameterIn::Header,
162                    operation_builder::ParamLocation::Cookie => ParameterIn::Cookie,
163                };
164                let required =
165                    if matches!(p.location, operation_builder::ParamLocation::Path) || p.required {
166                        Required::True
167                    } else {
168                        Required::False
169                    };
170
171                let schema_type = match p.param_type.as_str() {
172                    "integer" => SchemaType::Type(utoipa::openapi::schema::Type::Integer),
173                    "number" => SchemaType::Type(utoipa::openapi::schema::Type::Number),
174                    "boolean" => SchemaType::Type(utoipa::openapi::schema::Type::Boolean),
175                    _ => SchemaType::Type(utoipa::openapi::schema::Type::String),
176                };
177                let schema = Schema::Object(ObjectBuilder::new().schema_type(schema_type).build());
178
179                let param = ParameterBuilder::new()
180                    .name(&p.name)
181                    .parameter_in(in_)
182                    .required(required)
183                    .description(p.description.clone())
184                    .schema(Some(schema))
185                    .build();
186
187                op = op.parameter(param);
188            }
189
190            // Request body
191            if let Some(rb) = &spec.request_body {
192                let content = match &rb.schema {
193                    operation_builder::RequestBodySchema::Ref { schema_name } => {
194                        ContentBuilder::new()
195                            .schema(Some(RefOr::Ref(Ref::from_schema_name(schema_name.clone()))))
196                            .build()
197                    }
198                    operation_builder::RequestBodySchema::MultipartFile { field_name } => {
199                        // Build multipart/form-data schema with a single binary file field
200                        // type: object
201                        // properties:
202                        //   {field_name}: { type: string, format: binary }
203                        // required: [ field_name ]
204                        let file_schema = Schema::Object(
205                            ObjectBuilder::new()
206                                .schema_type(SchemaType::Type(
207                                    utoipa::openapi::schema::Type::String,
208                                ))
209                                .format(Some(SchemaFormat::Custom("binary".into())))
210                                .build(),
211                        );
212                        let obj = ObjectBuilder::new()
213                            .property(field_name.clone(), file_schema)
214                            .required(field_name.clone());
215                        let schema = Schema::Object(obj.build());
216                        ContentBuilder::new().schema(Some(schema)).build()
217                    }
218                    operation_builder::RequestBodySchema::Binary => {
219                        // Represent raw binary body as type string, format binary.
220                        // This is used for application/octet-stream and similar raw binary content.
221                        let schema = Schema::Object(
222                            ObjectBuilder::new()
223                                .schema_type(SchemaType::Type(
224                                    utoipa::openapi::schema::Type::String,
225                                ))
226                                .format(Some(SchemaFormat::Custom("binary".into())))
227                                .build(),
228                        );
229
230                        ContentBuilder::new().schema(Some(schema)).build()
231                    }
232                    operation_builder::RequestBodySchema::InlineObject => {
233                        // Preserve previous behavior for inline object bodies
234                        ContentBuilder::new()
235                            .schema(Some(Schema::Object(ObjectBuilder::new().build())))
236                            .build()
237                    }
238                };
239                let mut rbld = RequestBodyBuilder::new()
240                    .description(rb.description.clone())
241                    .content(rb.content_type.to_owned(), content);
242                if rb.required {
243                    rbld = rbld.required(Some(Required::True));
244                }
245                op = op.request_body(Some(rbld.build()));
246            }
247
248            // Responses
249            let mut responses = ResponsesBuilder::new();
250            for r in &spec.responses {
251                let is_json_like = r.content_type == "application/json"
252                    || r.content_type == problem::APPLICATION_PROBLEM_JSON
253                    || r.content_type == "text/event-stream";
254                let resp = if is_json_like {
255                    if let Some(name) = &r.schema_name {
256                        // Manually build content to preserve the correct content type
257                        let content = ContentBuilder::new()
258                            .schema(Some(RefOr::Ref(Ref::new(format!(
259                                "#/components/schemas/{name}"
260                            )))))
261                            .build();
262                        ResponseBuilder::new()
263                            .description(&r.description)
264                            .content(r.content_type, content)
265                            .build()
266                    } else {
267                        let content = ContentBuilder::new()
268                            .schema(Some(Schema::Object(ObjectBuilder::new().build())))
269                            .build();
270                        ResponseBuilder::new()
271                            .description(&r.description)
272                            .content(r.content_type, content)
273                            .build()
274                    }
275                } else {
276                    let schema = Schema::Object(
277                        ObjectBuilder::new()
278                            .schema_type(SchemaType::Type(utoipa::openapi::schema::Type::String))
279                            .format(Some(SchemaFormat::Custom(r.content_type.into())))
280                            .build(),
281                    );
282                    let content = ContentBuilder::new().schema(Some(schema)).build();
283                    ResponseBuilder::new()
284                        .description(&r.description)
285                        .content(r.content_type, content)
286                        .build()
287                };
288                responses = responses.response(r.status.to_string(), resp);
289            }
290            op = op.responses(responses.build());
291
292            // Add security requirement if operation requires authentication
293            if spec.authenticated {
294                let sec_req = utoipa::openapi::security::SecurityRequirement::new(
295                    "bearerAuth",
296                    Vec::<String>::new(),
297                );
298                op = op.security(sec_req);
299            }
300
301            let method = match spec.method {
302                Method::POST => HttpMethod::Post,
303                Method::PUT => HttpMethod::Put,
304                Method::DELETE => HttpMethod::Delete,
305                Method::PATCH => HttpMethod::Patch,
306                // GET and any other method default to Get
307                _ => HttpMethod::Get,
308            };
309
310            let item = PathItemBuilder::new().operation(method, op.build()).build();
311            // Convert Axum-style path to OpenAPI-style path
312            let openapi_path = operation_builder::axum_to_openapi_path(&spec.path);
313            paths = paths.path(openapi_path, item);
314        }
315
316        // 2) Components (from our registry)
317        let mut components = ComponentsBuilder::new();
318        for (name, schema) in self.components_registry.load().iter() {
319            components = components.schema(name.clone(), schema.clone());
320        }
321
322        // Add bearer auth security scheme
323        components = components.security_scheme(
324            "bearerAuth",
325            SecurityScheme::Http(
326                HttpBuilder::new()
327                    .scheme(HttpAuthScheme::Bearer)
328                    .bearer_format("JWT")
329                    .build(),
330            ),
331        );
332
333        // 3) Info & final OpenAPI doc
334        let openapi_info = InfoBuilder::new()
335            .title(&info.title)
336            .version(&info.version)
337            .description(info.description.clone())
338            .build();
339
340        let openapi = OpenApiBuilder::new()
341            .info(openapi_info)
342            .paths(paths.build())
343            .components(Some(components.build()))
344            .build();
345
346        Ok(openapi)
347    }
348}
349
350impl Default for OpenApiRegistryImpl {
351    fn default() -> Self {
352        Self::new()
353    }
354}
355
356impl OpenApiRegistry for OpenApiRegistryImpl {
357    fn register_operation(&self, spec: &operation_builder::OperationSpec) {
358        let operation_key = format!("{}:{}", spec.method.as_str(), spec.path);
359        self.operation_specs
360            .insert(operation_key.clone(), spec.clone());
361
362        tracing::debug!(
363            handler_id = %spec.handler_id,
364            method = %spec.method.as_str(),
365            path = %spec.path,
366            summary = %spec.summary.as_deref().unwrap_or("No summary"),
367            operation_key = %operation_key,
368            "Registered API operation in registry"
369        );
370    }
371
372    fn ensure_schema_raw(&self, root_name: &str, schemas: SchemaCollection) -> String {
373        // Snapshot & copy-on-write
374        let current = self.components_registry.load();
375        let mut reg = (**current).clone();
376
377        for (name, schema) in schemas {
378            // Conflict policy: identical → no-op; different → warn & override
379            if let Some(existing) = reg.get(&name) {
380                let a = serde_json::to_value(existing).ok();
381                let b = serde_json::to_value(&schema).ok();
382                if a == b {
383                    continue; // Skip identical schemas
384                }
385                tracing::warn!(%name, "Schema content conflict; overriding with latest");
386            }
387            reg.insert(name, schema);
388        }
389
390        self.components_registry.store(Arc::new(reg));
391        root_name.to_owned()
392    }
393
394    fn as_any(&self) -> &dyn std::any::Any {
395        self
396    }
397}
398
399#[cfg(test)]
400#[cfg_attr(coverage_nightly, coverage(off))]
401mod tests {
402    use super::*;
403    use crate::api::operation_builder::{
404        OperationSpec, ParamLocation, ParamSpec, ResponseSpec, VendorExtensions,
405    };
406    use http::Method;
407
408    #[test]
409    fn test_registry_creation() {
410        let registry = OpenApiRegistryImpl::new();
411        assert_eq!(registry.operation_specs.len(), 0);
412        assert_eq!(registry.components_registry.load().len(), 0);
413    }
414
415    #[test]
416    fn test_register_operation() {
417        let registry = OpenApiRegistryImpl::new();
418        let spec = OperationSpec {
419            method: Method::GET,
420            path: "/test".to_owned(),
421            operation_id: Some("test_op".to_owned()),
422            summary: Some("Test operation".to_owned()),
423            description: None,
424            tags: vec![],
425            params: vec![],
426            request_body: None,
427            responses: vec![ResponseSpec {
428                status: 200,
429                content_type: "application/json",
430                description: "Success".to_owned(),
431                schema_name: None,
432            }],
433            handler_id: "get_test".to_owned(),
434            authenticated: false,
435            is_public: false,
436            rate_limit: None,
437            allowed_request_content_types: None,
438            vendor_extensions: VendorExtensions::default(),
439            license_requirement: None,
440        };
441
442        registry.register_operation(&spec);
443        assert_eq!(registry.operation_specs.len(), 1);
444    }
445
446    #[test]
447    fn test_build_empty_openapi() {
448        let registry = OpenApiRegistryImpl::new();
449        let info = OpenApiInfo {
450            title: "Test API".to_owned(),
451            version: "1.0.0".to_owned(),
452            description: Some("Test API Description".to_owned()),
453        };
454        let doc = registry.build_openapi(&info).unwrap();
455        let json = serde_json::to_value(&doc).unwrap();
456
457        // Verify it's valid OpenAPI document structure
458        assert!(json.get("openapi").is_some());
459        assert!(json.get("info").is_some());
460        assert!(json.get("paths").is_some());
461
462        // Verify info section
463        let openapi_info = json.get("info").unwrap();
464        assert_eq!(openapi_info.get("title").unwrap(), "Test API");
465        assert_eq!(openapi_info.get("version").unwrap(), "1.0.0");
466        assert_eq!(
467            openapi_info.get("description").unwrap(),
468            "Test API Description"
469        );
470    }
471
472    #[test]
473    fn test_build_openapi_with_operation() {
474        let registry = OpenApiRegistryImpl::new();
475        let spec = OperationSpec {
476            method: Method::GET,
477            path: "/users/{id}".to_owned(),
478            operation_id: Some("get_user".to_owned()),
479            summary: Some("Get user by ID".to_owned()),
480            description: Some("Retrieves a user by their ID".to_owned()),
481            tags: vec!["users".to_owned()],
482            params: vec![ParamSpec {
483                name: "id".to_owned(),
484                location: ParamLocation::Path,
485                required: true,
486                description: Some("User ID".to_owned()),
487                param_type: "string".to_owned(),
488            }],
489            request_body: None,
490            responses: vec![ResponseSpec {
491                status: 200,
492                content_type: "application/json",
493                description: "User found".to_owned(),
494                schema_name: None,
495            }],
496            handler_id: "get_users_id".to_owned(),
497            authenticated: false,
498            is_public: false,
499            rate_limit: None,
500            allowed_request_content_types: None,
501            vendor_extensions: VendorExtensions::default(),
502            license_requirement: None,
503        };
504
505        registry.register_operation(&spec);
506        let info = OpenApiInfo::default();
507        let doc = registry.build_openapi(&info).unwrap();
508        let json = serde_json::to_value(&doc).unwrap();
509
510        // Verify path exists
511        let paths = json.get("paths").unwrap();
512        assert!(paths.get("/users/{id}").is_some());
513
514        // Verify operation details
515        let get_op = paths.get("/users/{id}").unwrap().get("get").unwrap();
516        assert_eq!(get_op.get("operationId").unwrap(), "get_user");
517        assert_eq!(get_op.get("summary").unwrap(), "Get user by ID");
518    }
519
520    #[test]
521    fn test_ensure_schema_raw() {
522        let registry = OpenApiRegistryImpl::new();
523        let schema = Schema::Object(ObjectBuilder::new().build());
524        let schemas = vec![("TestSchema".to_owned(), RefOr::T(schema))];
525
526        let name = registry.ensure_schema_raw("TestSchema", schemas);
527        assert_eq!(name, "TestSchema");
528        assert_eq!(registry.components_registry.load().len(), 1);
529    }
530
531    #[test]
532    fn test_build_openapi_with_binary_request() {
533        use crate::api::operation_builder::RequestBodySchema;
534
535        let registry = OpenApiRegistryImpl::new();
536        let spec = OperationSpec {
537            method: Method::POST,
538            path: "/files/v1/upload".to_owned(),
539            operation_id: Some("upload_file".to_owned()),
540            summary: Some("Upload a file".to_owned()),
541            description: Some("Upload raw binary file".to_owned()),
542            tags: vec!["upload".to_owned()],
543            params: vec![],
544            request_body: Some(crate::api::operation_builder::RequestBodySpec {
545                content_type: "application/octet-stream",
546                description: Some("Raw file bytes".to_owned()),
547                schema: RequestBodySchema::Binary,
548                required: true,
549            }),
550            responses: vec![ResponseSpec {
551                status: 200,
552                content_type: "application/json",
553                description: "Upload successful".to_owned(),
554                schema_name: None,
555            }],
556            handler_id: "post_upload".to_owned(),
557            authenticated: false,
558            is_public: false,
559            rate_limit: None,
560            allowed_request_content_types: Some(vec!["application/octet-stream"]),
561            vendor_extensions: VendorExtensions::default(),
562            license_requirement: None,
563        };
564
565        registry.register_operation(&spec);
566        let info = OpenApiInfo::default();
567        let doc = registry.build_openapi(&info).unwrap();
568        let json = serde_json::to_value(&doc).unwrap();
569
570        // Verify path exists
571        let paths = json.get("paths").unwrap();
572        assert!(paths.get("/files/v1/upload").is_some());
573
574        // Verify request body has application/octet-stream with binary schema
575        let post_op = paths.get("/files/v1/upload").unwrap().get("post").unwrap();
576        let request_body = post_op.get("requestBody").unwrap();
577        let content = request_body.get("content").unwrap();
578        let octet_stream = content
579            .get("application/octet-stream")
580            .expect("application/octet-stream content type should exist");
581
582        // Verify schema is type: string, format: binary
583        let schema = octet_stream.get("schema").unwrap();
584        assert_eq!(schema.get("type").unwrap(), "string");
585        assert_eq!(schema.get("format").unwrap(), "binary");
586
587        // Verify required flag
588        assert_eq!(request_body.get("required").unwrap(), true);
589    }
590
591    #[test]
592    fn test_build_openapi_with_pagination() {
593        let registry = OpenApiRegistryImpl::new();
594
595        let mut filter: operation_builder::ODataPagination<
596            std::collections::BTreeMap<String, Vec<String>>,
597        > = operation_builder::ODataPagination::default();
598        filter.allowed_fields.insert(
599            "name".to_owned(),
600            vec!["eq", "ne", "contains", "startswith", "endswith", "in"]
601                .into_iter()
602                .map(String::from)
603                .collect(),
604        );
605        filter.allowed_fields.insert(
606            "age".to_owned(),
607            vec!["eq", "ne", "gt", "ge", "lt", "le", "in"]
608                .into_iter()
609                .map(String::from)
610                .collect(),
611        );
612
613        let mut order_by: operation_builder::ODataPagination<Vec<String>> =
614            operation_builder::ODataPagination::default();
615        order_by.allowed_fields.push("name asc".to_owned());
616        order_by.allowed_fields.push("name desc".to_owned());
617        order_by.allowed_fields.push("age asc".to_owned());
618        order_by.allowed_fields.push("age desc".to_owned());
619
620        let mut spec = OperationSpec {
621            method: Method::GET,
622            path: "/test".to_owned(),
623            operation_id: Some("test_op".to_owned()),
624            summary: Some("Test".to_owned()),
625            description: None,
626            tags: vec![],
627            params: vec![],
628            request_body: None,
629            responses: vec![ResponseSpec {
630                status: 200,
631                content_type: "application/json",
632                description: "OK".to_owned(),
633                schema_name: None,
634            }],
635            handler_id: "get_test".to_owned(),
636            authenticated: false,
637            is_public: false,
638            rate_limit: None,
639            allowed_request_content_types: None,
640            vendor_extensions: VendorExtensions::default(),
641            license_requirement: None,
642        };
643        spec.vendor_extensions.x_odata_filter = Some(filter);
644        spec.vendor_extensions.x_odata_orderby = Some(order_by);
645
646        registry.register_operation(&spec);
647        let info = OpenApiInfo::default();
648        let doc = registry.build_openapi(&info).unwrap();
649        let json = serde_json::to_value(&doc).unwrap();
650
651        let paths = json.get("paths").unwrap();
652        let op = paths.get("/test").unwrap().get("get").unwrap();
653
654        let filter_ext = op
655            .get("x-odata-filter")
656            .expect("x-odata-filter should be present");
657
658        let allowed_fields = filter_ext.get("allowedFields").unwrap();
659        assert!(allowed_fields.get("name").is_some());
660        assert!(allowed_fields.get("age").is_some());
661
662        let order_ext = op
663            .get("x-odata-orderby")
664            .expect("x-odata-orderby should be present");
665
666        let allowed_order = order_ext.get("allowedFields").unwrap().as_array().unwrap();
667        assert!(allowed_order.iter().any(|v| v.as_str() == Some("name asc")));
668        assert!(allowed_order.iter().any(|v| v.as_str() == Some("age desc")));
669    }
670}