poem_openapi/
openapi.rs

1use std::{
2    collections::{BTreeMap, HashMap, HashSet},
3    marker::PhantomData,
4};
5
6use poem::{
7    Endpoint, EndpointExt, IntoEndpoint, Request, Response, Result, Route, RouteMethod,
8    endpoint::{BoxEndpoint, make_sync},
9};
10#[cfg(feature = "cookie")]
11use poem::{middleware::CookieJarManager, web::cookie::CookieKey};
12
13use crate::{
14    OpenApi, Webhook,
15    base::UrlQuery,
16    registry::{
17        Document, MetaContact, MetaExternalDocument, MetaHeader, MetaInfo, MetaLicense,
18        MetaOperationParam, MetaParamIn, MetaSchemaRef, MetaServer, MetaServerVariable, Registry,
19    },
20    types::Type,
21};
22
23/// An object representing a Server.
24#[derive(Debug, Clone)]
25pub struct ServerObject {
26    url: String,
27    description: Option<String>,
28    variables: BTreeMap<String, MetaServerVariable>,
29}
30
31impl<T: Into<String>> From<T> for ServerObject {
32    fn from(url: T) -> Self {
33        Self::new(url)
34    }
35}
36
37impl ServerObject {
38    /// Create a server object by url.
39    pub fn new(url: impl Into<String>) -> ServerObject {
40        Self {
41            url: url.into(),
42            description: None,
43            variables: BTreeMap::new(),
44        }
45    }
46
47    /// Sets an string describing the host designated by the URL.
48    #[must_use]
49    pub fn description(self, description: impl Into<String>) -> Self {
50        Self {
51            description: Some(description.into()),
52            ..self
53        }
54    }
55
56    /// Adds a server variable with a limited set of values.
57    ///
58    /// The variable name must be present in the server URL in curly braces.
59    #[must_use]
60    pub fn enum_variable(
61        mut self,
62        name: impl Into<String>,
63        description: impl Into<String>,
64        default: impl Into<String>,
65        enum_values: Vec<impl Into<String>>,
66    ) -> Self {
67        self.variables.insert(
68            name.into(),
69            MetaServerVariable {
70                description: description.into(),
71                default: default.into(),
72                enum_values: enum_values.into_iter().map(Into::into).collect(),
73            },
74        );
75        self
76    }
77
78    /// Adds a server variable that can take any value.
79    ///
80    /// The variable name must be present in the server URL in curly braces.
81    #[must_use]
82    pub fn variable(
83        mut self,
84        name: impl Into<String>,
85        description: impl Into<String>,
86        default: impl Into<String>,
87    ) -> Self {
88        self.variables.insert(
89            name.into(),
90            MetaServerVariable {
91                description: description.into(),
92                default: default.into(),
93                enum_values: Vec::new(),
94            },
95        );
96        self
97    }
98}
99
100/// A contact information for the exposed API.
101#[derive(Debug, Default)]
102pub struct ContactObject {
103    name: Option<String>,
104    url: Option<String>,
105    email: Option<String>,
106}
107
108impl ContactObject {
109    /// Create a new Contact object
110    #[inline]
111    pub fn new() -> Self {
112        Self::default()
113    }
114
115    /// Sets the identifying name of the contact person/organization.
116    #[must_use]
117    pub fn name(self, name: impl Into<String>) -> Self {
118        Self {
119            name: Some(name.into()),
120            ..self
121        }
122    }
123
124    /// Sets the URL pointing to the contact information.
125    #[must_use]
126    pub fn url(self, url: impl Into<String>) -> Self {
127        Self {
128            url: Some(url.into()),
129            ..self
130        }
131    }
132
133    /// Sets the email address of the contact person/organization.
134    #[must_use]
135    pub fn email(self, email: impl Into<String>) -> Self {
136        Self {
137            email: Some(email.into()),
138            ..self
139        }
140    }
141}
142
143/// A license information for the exposed API.
144#[derive(Debug)]
145pub struct LicenseObject {
146    name: String,
147    identifier: Option<String>,
148    url: Option<String>,
149}
150
151impl<T: Into<String>> From<T> for LicenseObject {
152    fn from(url: T) -> Self {
153        Self::new(url)
154    }
155}
156
157impl LicenseObject {
158    /// Create a license object by name.
159    pub fn new(name: impl Into<String>) -> LicenseObject {
160        Self {
161            name: name.into(),
162            identifier: None,
163            url: None,
164        }
165    }
166
167    /// Sets the [`SPDX`](https://spdx.org/spdx-specification-21-web-version#h.jxpfx0ykyb60) license expression for the API.
168    #[must_use]
169    pub fn identifier(self, identifier: impl Into<String>) -> Self {
170        Self {
171            identifier: Some(identifier.into()),
172            ..self
173        }
174    }
175
176    /// Sets the URL to the license used for the API.
177    #[must_use]
178    pub fn url(self, url: impl Into<String>) -> Self {
179        Self {
180            url: Some(url.into()),
181            ..self
182        }
183    }
184}
185
186/// An object representing a external document.
187#[derive(Debug, Clone)]
188pub struct ExternalDocumentObject {
189    url: String,
190    description: Option<String>,
191}
192
193impl<T: Into<String>> From<T> for ExternalDocumentObject {
194    fn from(url: T) -> Self {
195        Self::new(url)
196    }
197}
198
199impl ExternalDocumentObject {
200    /// Create a external document object by url.
201    pub fn new(url: impl Into<String>) -> ExternalDocumentObject {
202        Self {
203            url: url.into(),
204            description: None,
205        }
206    }
207
208    /// Sets a description of the target documentation.
209    #[must_use]
210    pub fn description(self, description: impl Into<String>) -> Self {
211        Self {
212            description: Some(description.into()),
213            ..self
214        }
215    }
216}
217
218/// An extra header
219#[derive(Debug, Clone)]
220pub struct ExtraHeader {
221    name: String,
222    description: Option<String>,
223    deprecated: bool,
224}
225
226impl<T: AsRef<str>> From<T> for ExtraHeader {
227    fn from(name: T) -> Self {
228        Self::new(name)
229    }
230}
231
232impl ExtraHeader {
233    /// Create a extra header object by name.
234    pub fn new(name: impl AsRef<str>) -> ExtraHeader {
235        Self {
236            name: name.as_ref().to_uppercase(),
237            description: None,
238            deprecated: false,
239        }
240    }
241
242    /// Sets a description of the extra header.
243    #[must_use]
244    pub fn description(self, description: impl Into<String>) -> Self {
245        Self {
246            description: Some(description.into()),
247            ..self
248        }
249    }
250
251    /// Specifies this header is deprecated.
252    pub fn deprecated(self) -> Self {
253        Self {
254            deprecated: true,
255            ..self
256        }
257    }
258}
259
260/// An OpenAPI service for Poem.
261#[derive(Clone)]
262pub struct OpenApiService<T, W> {
263    api: T,
264    _webhook: PhantomData<W>,
265    info: MetaInfo,
266    external_document: Option<MetaExternalDocument>,
267    servers: Vec<MetaServer>,
268    #[cfg(feature = "cookie")]
269    cookie_key: Option<CookieKey>,
270    extra_response_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
271    extra_request_headers: Vec<(ExtraHeader, MetaSchemaRef, bool)>,
272    url_prefix: Option<String>,
273}
274
275impl<T> OpenApiService<T, ()> {
276    /// Create an OpenAPI container.
277    #[must_use]
278    pub fn new(api: T, title: impl Into<String>, version: impl Into<String>) -> Self {
279        Self {
280            api,
281            _webhook: PhantomData,
282            info: MetaInfo {
283                title: title.into(),
284                summary: None,
285                description: None,
286                version: version.into(),
287                terms_of_service: None,
288                contact: None,
289                license: None,
290            },
291            external_document: None,
292            servers: Vec::new(),
293            #[cfg(feature = "cookie")]
294            cookie_key: None,
295            extra_response_headers: vec![],
296            extra_request_headers: vec![],
297            url_prefix: None,
298        }
299    }
300}
301
302impl<T, W> OpenApiService<T, W> {
303    /// Sets the webhooks.
304    pub fn webhooks<W2>(self) -> OpenApiService<T, W2> {
305        OpenApiService {
306            api: self.api,
307            _webhook: PhantomData,
308            info: self.info,
309            external_document: self.external_document,
310            servers: self.servers,
311            #[cfg(feature = "cookie")]
312            cookie_key: self.cookie_key,
313            extra_response_headers: self.extra_response_headers,
314            extra_request_headers: self.extra_request_headers,
315            url_prefix: None,
316        }
317    }
318
319    /// Sets the summary of the API container.
320    #[must_use]
321    pub fn summary(mut self, summary: impl Into<String>) -> Self {
322        self.info.summary = Some(summary.into());
323        self
324    }
325
326    /// Sets the description of the API container.
327    #[must_use]
328    pub fn description(mut self, description: impl Into<String>) -> Self {
329        self.info.description = Some(description.into());
330        self
331    }
332
333    /// Sets a URL to the Terms of Service for the API.
334    #[must_use]
335    pub fn terms_of_service(mut self, url: impl Into<String>) -> Self {
336        self.info.terms_of_service = Some(url.into());
337        self
338    }
339
340    /// Appends a server to the API container.
341    ///
342    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#server-object>
343    #[must_use]
344    pub fn server(mut self, server: impl Into<ServerObject>) -> Self {
345        let server = server.into();
346        self.servers.push(MetaServer {
347            url: server.url,
348            description: server.description,
349            variables: server.variables,
350        });
351        self
352    }
353
354    /// Sets the contact information for the exposed API.
355    #[must_use]
356    pub fn contact(mut self, contact: ContactObject) -> Self {
357        self.info.contact = Some(MetaContact {
358            name: contact.name,
359            url: contact.url,
360            email: contact.email,
361        });
362        self
363    }
364
365    /// Sets the license information for the exposed API.
366    ///
367    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#license-object>
368    #[must_use]
369    pub fn license(mut self, license: impl Into<LicenseObject>) -> Self {
370        let license = license.into();
371        self.info.license = Some(MetaLicense {
372            name: license.name,
373            identifier: license.identifier,
374            url: license.url,
375        });
376        self
377    }
378
379    /// Add a external document object.
380    ///
381    /// Reference: <https://github.com/OAI/OpenAPI-Specification/blob/main/versions/3.1.0.md#external-documentation-object>
382    #[must_use]
383    pub fn external_document(
384        mut self,
385        external_document: impl Into<ExternalDocumentObject>,
386    ) -> Self {
387        let external_document = external_document.into();
388        self.external_document = Some(MetaExternalDocument {
389            url: external_document.url,
390            description: external_document.description,
391        });
392        self
393    }
394
395    /// Add extra response header
396    #[must_use]
397    pub fn extra_response_header<HT, H>(mut self, header: H) -> Self
398    where
399        HT: Type,
400        H: Into<ExtraHeader>,
401    {
402        let extra_header = header.into();
403        self.extra_response_headers
404            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
405        self
406    }
407
408    /// Add extra request header
409    #[must_use]
410    pub fn extra_request_header<HT, H>(mut self, header: H) -> Self
411    where
412        HT: Type,
413        H: Into<ExtraHeader>,
414    {
415        let extra_header = header.into();
416        self.extra_request_headers
417            .push((extra_header, HT::schema_ref(), HT::IS_REQUIRED));
418        self
419    }
420
421    /// Sets the cookie key.
422    #[must_use]
423    #[cfg(feature = "cookie")]
424    pub fn cookie_key(self, key: CookieKey) -> Self {
425        Self {
426            cookie_key: Some(key),
427            ..self
428        }
429    }
430
431    /// Sets optional URl prefix to be added to path
432    pub fn url_prefix(self, url_prefix: impl Into<String>) -> Self {
433        Self {
434            url_prefix: Some(url_prefix.into()),
435            ..self
436        }
437    }
438
439    /// Create the OpenAPI Explorer endpoint.
440    #[must_use]
441    #[cfg(feature = "openapi-explorer")]
442    pub fn openapi_explorer(&self) -> impl Endpoint + 'static
443    where
444        T: OpenApi,
445        W: Webhook,
446    {
447        crate::ui::openapi_explorer::create_endpoint(self.spec())
448    }
449
450    /// Create the OpenAPI Explorer HTML
451    #[cfg(feature = "openapi-explorer")]
452    pub fn openapi_explorer_html(&self) -> String
453    where
454        T: OpenApi,
455        W: Webhook,
456    {
457        crate::ui::openapi_explorer::create_html(&self.spec())
458    }
459
460    /// Create the Swagger UI endpoint.
461    #[must_use]
462    #[cfg(feature = "swagger-ui")]
463    pub fn swagger_ui(&self) -> impl Endpoint + 'static
464    where
465        T: OpenApi,
466        W: Webhook,
467    {
468        crate::ui::swagger_ui::create_endpoint(self.spec())
469    }
470
471    /// Create the Swagger UI HTML
472    #[cfg(feature = "swagger-ui")]
473    pub fn swagger_ui_html(&self) -> String
474    where
475        T: OpenApi,
476        W: Webhook,
477    {
478        crate::ui::swagger_ui::create_html(&self.spec())
479    }
480
481    /// Create the Rapidoc endpoint.
482    #[must_use]
483    #[cfg(feature = "rapidoc")]
484    pub fn rapidoc(&self) -> impl Endpoint + 'static
485    where
486        T: OpenApi,
487        W: Webhook,
488    {
489        crate::ui::rapidoc::create_endpoint(self.spec())
490    }
491
492    /// Create the Rapidoc HTML
493    #[cfg(feature = "rapidoc")]
494    pub fn rapidoc_html(&self) -> String
495    where
496        T: OpenApi,
497        W: Webhook,
498    {
499        crate::ui::rapidoc::create_html(&self.spec())
500    }
501
502    /// Create the Redoc endpoint.
503    #[must_use]
504    #[cfg(feature = "redoc")]
505    pub fn redoc(&self) -> impl Endpoint + 'static
506    where
507        T: OpenApi,
508        W: Webhook,
509    {
510        crate::ui::redoc::create_endpoint(self.spec())
511    }
512
513    /// Create the Redoc HTML
514    #[must_use]
515    #[cfg(feature = "redoc")]
516    pub fn redoc_html(&self) -> String
517    where
518        T: OpenApi,
519        W: Webhook,
520    {
521        crate::ui::redoc::create_html(&self.spec())
522    }
523
524    /// Create the Scalar endpoint.
525    #[must_use]
526    #[cfg(feature = "scalar")]
527    pub fn scalar(&self) -> impl Endpoint + 'static
528    where
529        T: OpenApi,
530        W: Webhook,
531    {
532        crate::ui::scalar::create_endpoint(self.spec())
533    }
534
535    /// Create the Scalar HTML
536    #[must_use]
537    #[cfg(feature = "scalar")]
538    pub fn scalar_html(&self) -> String
539    where
540        T: OpenApi,
541        W: Webhook,
542    {
543        crate::ui::scalar::create_html(&self.spec())
544    }
545
546    /// Create the Stoplight Elements endpoint.
547    #[must_use]
548    #[cfg(feature = "stoplight-elements")]
549    pub fn stoplight_elements(&self) -> impl Endpoint + 'static
550    where
551        T: OpenApi,
552        W: Webhook,
553    {
554        crate::ui::stoplight_elements::create_endpoint(self.spec())
555    }
556
557    /// Create the Stoplight Elements HTML.
558    #[must_use]
559    #[cfg(feature = "stoplight-elements")]
560    pub fn stoplight_elements_html(&self) -> String
561    where
562        T: OpenApi,
563        W: Webhook,
564    {
565        crate::ui::stoplight_elements::create_html(&self.spec())
566    }
567
568    /// Create an endpoint to serve the open api specification as JSON.
569    pub fn spec_endpoint(&self) -> impl Endpoint + 'static
570    where
571        T: OpenApi,
572        W: Webhook,
573    {
574        let spec = self.spec();
575        make_sync(move |_| {
576            Response::builder()
577                .content_type("application/json")
578                .body(spec.clone())
579        })
580    }
581
582    /// Create an endpoint to serve the open api specification as YAML.
583    pub fn spec_endpoint_yaml(&self) -> impl Endpoint + 'static
584    where
585        T: OpenApi,
586        W: Webhook,
587    {
588        let spec = self.spec_yaml();
589        make_sync(move |_| {
590            Response::builder()
591                .content_type("application/x-yaml")
592                .header("Content-Disposition", "inline; filename=\"spec.yaml\"")
593                .body(spec.clone())
594        })
595    }
596
597    fn document(&self) -> Document<'_>
598    where
599        T: OpenApi,
600        W: Webhook,
601    {
602        let mut registry = Registry::new();
603        let mut apis = T::meta();
604
605        // update extra request headers
606        for operation in apis
607            .iter_mut()
608            .flat_map(|meta_api| meta_api.paths.iter_mut())
609            .flat_map(|path| path.operations.iter_mut())
610        {
611            for (idx, (header, schema_ref, is_required)) in
612                self.extra_request_headers.iter().enumerate()
613            {
614                operation.params.insert(
615                    idx,
616                    MetaOperationParam {
617                        name: header.name.clone(),
618                        schema: schema_ref.clone(),
619                        in_type: MetaParamIn::Header,
620                        description: header.description.clone(),
621                        required: *is_required,
622                        deprecated: header.deprecated,
623                        explode: true,
624                        style: None,
625                    },
626                );
627            }
628        }
629
630        // update extra response headers
631        for resp in apis
632            .iter_mut()
633            .flat_map(|meta_api| meta_api.paths.iter_mut())
634            .flat_map(|path| path.operations.iter_mut())
635            .flat_map(|operation| operation.responses.responses.iter_mut())
636        {
637            for (idx, (header, schema_ref, is_required)) in
638                self.extra_response_headers.iter().enumerate()
639            {
640                resp.headers.insert(
641                    idx,
642                    MetaHeader {
643                        name: header.name.clone(),
644                        description: header.description.clone(),
645                        required: *is_required,
646                        deprecated: header.deprecated,
647                        schema: schema_ref.clone(),
648                    },
649                );
650            }
651        }
652
653        T::register(&mut registry);
654        W::register(&mut registry);
655
656        let webhooks = W::meta();
657
658        let mut doc = Document {
659            info: &self.info,
660            servers: &self.servers,
661            apis,
662            webhooks,
663            registry,
664            external_document: self.external_document.as_ref(),
665            url_prefix: self.url_prefix.as_deref(),
666        };
667        doc.remove_unused_schemas();
668
669        doc
670    }
671
672    /// Returns the OAS specification file as JSON.
673    pub fn spec(&self) -> String
674    where
675        T: OpenApi,
676        W: Webhook,
677    {
678        let doc = self.document();
679        serde_json::to_string_pretty(&doc).unwrap()
680    }
681
682    /// Returns the OAS specification file as YAML.
683    pub fn spec_yaml(&self) -> String
684    where
685        T: OpenApi,
686        W: Webhook,
687    {
688        let doc = self.document();
689        serde_yaml::to_string(&doc).unwrap()
690    }
691}
692
693impl<T: OpenApi, W: Webhook> IntoEndpoint for OpenApiService<T, W> {
694    type Endpoint = BoxEndpoint<'static>;
695
696    fn into_endpoint(self) -> Self::Endpoint {
697        async fn extract_query(mut req: Request) -> Result<Request> {
698            let url_query: Vec<(String, String)> = req.params().unwrap_or_default();
699            req.extensions_mut().insert(UrlQuery(url_query));
700            Ok(req)
701        }
702
703        #[cfg(feature = "cookie")]
704        let cookie_jar_manager = match self.cookie_key {
705            Some(key) => CookieJarManager::with_key(key),
706            None => CookieJarManager::new(),
707        };
708
709        // check duplicate operation id
710        let mut operation_ids = HashSet::new();
711        for operation in T::meta()
712            .into_iter()
713            .flat_map(|api| api.paths.into_iter())
714            .flat_map(|path| path.operations.into_iter())
715        {
716            if let Some(operation_id) = operation.operation_id {
717                if !operation_ids.insert(operation_id) {
718                    panic!("duplicate operation id: {operation_id}");
719                }
720            }
721        }
722
723        let mut items = HashMap::new();
724        self.api.add_routes(&mut items);
725
726        let route = items
727            .into_iter()
728            .fold(Route::new(), |route, (path, paths)| {
729                route.at(
730                    path,
731                    paths
732                        .into_iter()
733                        .fold(RouteMethod::new(), |route_method, (method, ep)| {
734                            route_method.method(method, ep)
735                        }),
736                )
737            });
738
739        #[cfg(feature = "cookie")]
740        let route = route.with(cookie_jar_manager);
741
742        route.before(extract_query).map_to_response().boxed()
743    }
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use crate::OpenApi;
750
751    #[test]
752    fn extra_response_headers() {
753        struct Api;
754
755        #[OpenApi(internal)]
756        impl Api {
757            #[oai(path = "/", method = "get")]
758            async fn test(&self) {}
759        }
760
761        let api_service = OpenApiService::new(Api, "demo", "1.0")
762            .extra_response_header::<i32, _>("a1")
763            .extra_response_header::<String, _>(ExtraHeader::new("A2").description("abc"))
764            .extra_response_header::<f32, _>(ExtraHeader::new("A3").deprecated());
765        let doc = api_service.document();
766        let headers = &doc.apis[0].paths[0].operations[0].responses.responses[0].headers;
767
768        assert_eq!(headers[0].name, "A1");
769        assert_eq!(headers[0].description, None);
770        assert!(!headers[0].deprecated);
771        assert_eq!(headers[0].schema, i32::schema_ref());
772
773        assert_eq!(headers[1].name, "A2");
774        assert_eq!(headers[1].description.as_deref(), Some("abc"));
775        assert!(!headers[1].deprecated);
776        assert_eq!(headers[1].schema, String::schema_ref());
777
778        assert_eq!(headers[2].name, "A3");
779        assert_eq!(headers[2].description, None);
780        assert!(headers[2].deprecated);
781        assert_eq!(headers[2].schema, f32::schema_ref());
782    }
783
784    #[test]
785    fn extra_request_headers() {
786        struct Api;
787
788        #[OpenApi(internal)]
789        impl Api {
790            #[oai(path = "/", method = "get")]
791            async fn test(&self) {}
792        }
793
794        let api_service = OpenApiService::new(Api, "demo", "1.0")
795            .extra_request_header::<i32, _>("a1")
796            .extra_request_header::<String, _>(ExtraHeader::new("A2").description("abc"))
797            .extra_request_header::<f32, _>(ExtraHeader::new("A3").deprecated());
798        let doc = api_service.document();
799        let params = &doc.apis[0].paths[0].operations[0].params;
800
801        assert_eq!(params[0].name, "A1");
802        assert_eq!(params[0].in_type, MetaParamIn::Header);
803        assert_eq!(params[0].description, None);
804        assert!(!params[0].deprecated);
805        assert_eq!(params[0].schema, i32::schema_ref());
806
807        assert_eq!(params[1].name, "A2");
808        assert_eq!(params[1].in_type, MetaParamIn::Header);
809        assert_eq!(params[1].description.as_deref(), Some("abc"));
810        assert!(!params[1].deprecated);
811        assert_eq!(params[1].schema, String::schema_ref());
812
813        assert_eq!(params[2].name, "A3");
814        assert_eq!(params[2].in_type, MetaParamIn::Header);
815        assert_eq!(params[2].description, None);
816        assert!(params[2].deprecated);
817        assert_eq!(params[2].schema, f32::schema_ref());
818    }
819}