1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use crate::settings::OpenApiSettings;
use crate::OperationInfo;

extern crate okapi;

use okapi::openapi3::{Components, OpenApi, Operation, PathItem, RefOr, SecurityScheme};
use rocket::http::Method;
use schemars::gen::SchemaGenerator;
use schemars::schema::SchemaObject;
use schemars::JsonSchema;
use schemars::{Map, MapEntry};
use std::collections::HashMap;

/// A struct that visits all `rocket::Route`s, and aggregates information about them.
#[derive(Debug, Clone)]
pub struct OpenApiGenerator {
    settings: OpenApiSettings,
    schema_generator: SchemaGenerator,
    security_schemes: Map<String, SecurityScheme>,
    operations: Map<String, HashMap<Method, Operation>>,
}

impl OpenApiGenerator {
    /// Create a new `OpenApiGenerator` from the settings provided.
    #[must_use]
    pub fn new(settings: OpenApiSettings) -> Self {
        OpenApiGenerator {
            schema_generator: settings.schema_settings.clone().into_generator(),
            settings,
            security_schemes: Default::default(),
            operations: Default::default(),
        }
    }

    /// Adds a security scheme to the generated output
    pub fn add_security_scheme(&mut self, name: String, scheme: SecurityScheme) {
        self.security_schemes.insert(name, scheme);
    }

    /// Add a new `HTTP Method` to the collection of endpoints in the `OpenApiGenerator`.
    pub fn add_operation(&mut self, mut op: OperationInfo) {
        if let Some(op_id) = op.operation.operation_id {
            // TODO do this outside add_operation
            op.operation.operation_id = Some(op_id.trim_start_matches(':').replace("::", "_"));
        }
        match self.operations.entry(op.path) {
            MapEntry::Occupied(mut e) => {
                let map = e.get_mut();
                if map.insert(op.method, op.operation).is_some() {
                    // This will trow a warning if 2 routes have the same path and method
                    // This is allowed by Rocket when a ranking is given for example: `#[get("/user", rank = 2)]`
                    // See: https://rocket.rs/v0.4/guide/requests/#forwarding
                    println!("Warning: Operation replaced for {}:{}", op.method, e.key());
                }
            }
            MapEntry::Vacant(e) => {
                let mut map = HashMap::new();
                map.insert(op.method, op.operation);
                e.insert(map);
            }
        };
    }

    /// Returns a JSON Schema object for the type `T`.
    pub fn json_schema<T: ?Sized + JsonSchema>(&mut self) -> SchemaObject {
        self.schema_generator.subschema_for::<T>().into()
    }

    /// Obtain the internal `SchemaGenerator` object.
    #[must_use]
    pub fn schema_generator(&self) -> &SchemaGenerator {
        &self.schema_generator
    }

    /// Return the component definition/schema of an object without any references.
    pub fn json_schema_no_ref<T: ?Sized + JsonSchema>(&mut self) -> SchemaObject {
        <T>::json_schema(&mut self.schema_generator).into()
    }

    /// Generate an `OpenApi` specification for all added operations.
    #[must_use]
    pub fn into_openapi(self) -> OpenApi {
        let mut schema_generator = self.schema_generator;
        let mut schemas = schema_generator.take_definitions();

        // Add the security schemes
        let mut schemes: Map<String, RefOr<SecurityScheme>> = Default::default();
        for (name, schema) in self.security_schemes {
            schemes.insert(name, schema.into());
        }

        for visitor in schema_generator.visitors_mut() {
            for schema in schemas.values_mut() {
                visitor.visit_schema(schema)
            }
        }

        OpenApi {
            openapi: "3.0.0".to_owned(),
            paths: {
                let mut paths = Map::new();
                for (path, map) in self.operations {
                    for (method, op) in map {
                        let path_item = paths.entry(path.clone()).or_default();
                        set_operation(path_item, method, op);
                    }
                }
                paths
            },
            components: Some(Components {
                schemas: schemas.into_iter().map(|(k, v)| (k, v.into())).collect(),
                security_schemes: schemes,
                ..Default::default()
            }),
            ..OpenApi::default()
        }
    }
}

fn set_operation(path_item: &mut PathItem, method: Method, op: Operation) {
    use Method::{Connect, Delete, Get, Head, Options, Patch, Post, Put, Trace};
    let option = match method {
        Get => &mut path_item.get,
        Put => &mut path_item.put,
        Post => &mut path_item.post,
        Delete => &mut path_item.delete,
        Options => &mut path_item.options,
        Head => &mut path_item.head,
        Patch => &mut path_item.patch,
        Trace => &mut path_item.trace,
        // Connect not available in OpenAPI3. Maybe should set in extensions?
        Connect => return,
    };
    assert!(option.is_none());
    option.replace(op);
}