Skip to main content

salvo_oapi/
routing.rs

1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use salvo_core::Router;
6
7use crate::SecurityRequirement;
8use crate::path::PathItemType;
9
10fn normalize_oapi_path(path: &str) -> String {
11    let mut normalized = String::with_capacity(path.len());
12    let mut chars = path.char_indices().peekable();
13
14    while let Some((start, ch)) = chars.next() {
15        if ch != '{' {
16            normalized.push(ch);
17            continue;
18        }
19        // Keep escaped literal braces (`{{`) as-is.
20        if chars.peek().map(|(_, next)| *next) == Some('{') {
21            normalized.push('{');
22            normalized.push('{');
23            chars.next();
24            continue;
25        }
26
27        let content_start = start + ch.len_utf8();
28        let mut braces_depth = 0usize;
29        let mut escaping = false;
30        let mut param_end = None;
31
32        for (idx, current) in chars.by_ref() {
33            if escaping {
34                escaping = false;
35                continue;
36            }
37            match current {
38                '\\' => escaping = true,
39                '{' => braces_depth += 1,
40                '}' => {
41                    if braces_depth == 0 {
42                        param_end = Some(idx);
43                        break;
44                    }
45                    braces_depth -= 1;
46                }
47                _ => {}
48            }
49        }
50
51        if let Some(param_end) = param_end {
52            let Some(content) = path.get(content_start..param_end) else {
53                break;
54            };
55            if let Some(name_end) = content.find([':', '|']) {
56                normalized.push('{');
57                let Some(name) = content.get(..name_end) else {
58                    break;
59                };
60                normalized.push_str(name);
61                normalized.push('}');
62            } else {
63                normalized.push('{');
64                normalized.push_str(content);
65                normalized.push('}');
66            }
67        } else {
68            if let Some(rest) = path.get(start..) {
69                normalized.push_str(rest);
70            }
71            break;
72        }
73    }
74    normalized
75}
76
77#[derive(Debug, Default)]
78pub(crate) struct NormNode {
79    // pub(crate) router_id: usize,
80    pub(crate) handler_type_id: Option<TypeId>,
81    pub(crate) handler_type_name: Option<&'static str>,
82    pub(crate) method: Option<PathItemType>,
83    pub(crate) path: Option<String>,
84    pub(crate) children: Vec<Self>,
85    pub(crate) metadata: Metadata,
86}
87
88impl NormNode {
89    pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
90        let mut node = Self {
91            // router_id: router.id,
92            metadata: inherited_metadata,
93            ..Self::default()
94        };
95        let registry = METADATA_REGISTRY
96            .read()
97            .expect("failed to lock METADATA_REGISTRY for read");
98        if let Some(metadata) = registry.get(&router.id) {
99            node.metadata.tags.extend(metadata.tags.iter().cloned());
100            node.metadata
101                .securities
102                .extend(metadata.securities.iter().cloned());
103        }
104
105        for filter in router.filters() {
106            let info = format!("{filter:?}");
107            if info.starts_with("path:") {
108                let path = info
109                    .split_once(':')
110                    .expect("split once by ':' should not be get `None`")
111                    .1;
112                node.path = Some(normalize_oapi_path(path));
113            } else if info.starts_with("method:") {
114                match info
115                    .split_once(':')
116                    .expect("split once by ':' should not be get `None`.")
117                    .1
118                {
119                    "GET" => node.method = Some(PathItemType::Get),
120                    "POST" => node.method = Some(PathItemType::Post),
121                    "PUT" => node.method = Some(PathItemType::Put),
122                    "DELETE" => node.method = Some(PathItemType::Delete),
123                    "HEAD" => node.method = Some(PathItemType::Head),
124                    "OPTIONS" => node.method = Some(PathItemType::Options),
125                    "CONNECT" => node.method = Some(PathItemType::Connect),
126                    "TRACE" => node.method = Some(PathItemType::Trace),
127                    "PATCH" => node.method = Some(PathItemType::Patch),
128                    _ => {}
129                }
130            }
131        }
132        node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
133        node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
134        let routers = router.routers();
135        if !routers.is_empty() {
136            for router in routers {
137                node.children.push(Self::new(router, node.metadata.clone()));
138            }
139        }
140        node
141    }
142}
143
144/// A component for save router metadata.
145type MetadataMap = RwLock<HashMap<usize, Metadata>>;
146static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
147
148/// Router extension trait for openapi metadata.
149pub trait RouterExt {
150    /// Add security requirement to the router.
151    ///
152    /// All endpoints in the router and it's descents will inherit this security requirement.
153    #[must_use]
154    fn oapi_security(self, security: SecurityRequirement) -> Self;
155
156    /// Add security requirements to the router.
157    ///
158    /// All endpoints in the router and it's descents will inherit these security requirements.
159    #[must_use]
160    fn oapi_securities<I>(self, security: I) -> Self
161    where
162        I: IntoIterator<Item = SecurityRequirement>;
163
164    /// Add tag to the router.
165    ///
166    /// All endpoints in the router and it's descents will inherit this tag.
167    #[must_use]
168    fn oapi_tag(self, tag: impl Into<String>) -> Self;
169
170    /// Add tags to the router.
171    ///
172    /// All endpoints in the router and it's descents will inherit these tags.
173    #[must_use]
174    fn oapi_tags<I, V>(self, tags: I) -> Self
175    where
176        I: IntoIterator<Item = V>,
177        V: Into<String>;
178}
179
180impl RouterExt for Router {
181    fn oapi_security(self, security: SecurityRequirement) -> Self {
182        let mut guard = METADATA_REGISTRY
183            .write()
184            .expect("failed to lock METADATA_REGISTRY for write");
185        let metadata = guard.entry(self.id).or_default();
186        metadata.securities.push(security);
187        self
188    }
189    fn oapi_securities<I>(self, iter: I) -> Self
190    where
191        I: IntoIterator<Item = SecurityRequirement>,
192    {
193        let mut guard = METADATA_REGISTRY
194            .write()
195            .expect("failed to lock METADATA_REGISTRY for write");
196        let metadata = guard.entry(self.id).or_default();
197        metadata.securities.extend(iter);
198        self
199    }
200    fn oapi_tag(self, tag: impl Into<String>) -> Self {
201        let mut guard = METADATA_REGISTRY
202            .write()
203            .expect("failed to lock METADATA_REGISTRY for write");
204        let metadata = guard.entry(self.id).or_default();
205        metadata.tags.insert(tag.into());
206        self
207    }
208    fn oapi_tags<I, V>(self, iter: I) -> Self
209    where
210        I: IntoIterator<Item = V>,
211        V: Into<String>,
212    {
213        let mut guard = METADATA_REGISTRY
214            .write()
215            .expect("failed to lock METADATA_REGISTRY for write");
216        let metadata = guard.entry(self.id).or_default();
217        metadata.tags.extend(iter.into_iter().map(Into::into));
218        self
219    }
220}
221
222#[non_exhaustive]
223#[derive(Default, Clone, Debug)]
224pub(crate) struct Metadata {
225    pub(crate) tags: BTreeSet<String>,
226    pub(crate) securities: Vec<SecurityRequirement>,
227}
228
229#[cfg(test)]
230mod tests {
231    use super::normalize_oapi_path;
232
233    #[test]
234    fn normalize_braced_path_constraints() {
235        assert_eq!(normalize_oapi_path("/posts/{id}"), "/posts/{id}");
236        assert_eq!(normalize_oapi_path("/posts/{id:num}"), "/posts/{id}");
237        assert_eq!(
238            normalize_oapi_path("/posts/{id:num(3..=10)}"),
239            "/posts/{id}"
240        );
241        assert_eq!(normalize_oapi_path(r"/posts/{id|\d+}"), "/posts/{id}");
242        assert_eq!(normalize_oapi_path("/posts/{id|[a-z]{2}}"), "/posts/{id}");
243        assert_eq!(
244            normalize_oapi_path("/posts/article_{id:num}"),
245            "/posts/article_{id}"
246        );
247    }
248}