1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use salvo_core::Router;
6
7use crate::SecurityRequirement;
8use crate::path::PathItemType;
9
10fn normalize_oapi_path(path: &str) -> String {
11 let mut normalized = String::with_capacity(path.len());
12 let mut chars = path.char_indices().peekable();
13
14 while let Some((start, ch)) = chars.next() {
15 if ch != '{' {
16 normalized.push(ch);
17 continue;
18 }
19 if chars.peek().map(|(_, next)| *next) == Some('{') {
21 normalized.push('{');
22 normalized.push('{');
23 chars.next();
24 continue;
25 }
26
27 let content_start = start + ch.len_utf8();
28 let mut braces_depth = 0usize;
29 let mut escaping = false;
30 let mut param_end = None;
31
32 for (idx, current) in chars.by_ref() {
33 if escaping {
34 escaping = false;
35 continue;
36 }
37 match current {
38 '\\' => escaping = true,
39 '{' => braces_depth += 1,
40 '}' => {
41 if braces_depth == 0 {
42 param_end = Some(idx);
43 break;
44 }
45 braces_depth -= 1;
46 }
47 _ => {}
48 }
49 }
50
51 if let Some(param_end) = param_end {
52 let Some(content) = path.get(content_start..param_end) else {
53 break;
54 };
55 if let Some(name_end) = content.find([':', '|']) {
56 normalized.push('{');
57 let Some(name) = content.get(..name_end) else {
58 break;
59 };
60 normalized.push_str(name);
61 normalized.push('}');
62 } else {
63 normalized.push('{');
64 normalized.push_str(content);
65 normalized.push('}');
66 }
67 } else {
68 if let Some(rest) = path.get(start..) {
69 normalized.push_str(rest);
70 }
71 break;
72 }
73 }
74 normalized
75}
76
77#[derive(Debug, Default)]
78pub(crate) struct NormNode {
79 pub(crate) handler_type_id: Option<TypeId>,
81 pub(crate) handler_type_name: Option<&'static str>,
82 pub(crate) method: Option<PathItemType>,
83 pub(crate) path: Option<String>,
84 pub(crate) children: Vec<Self>,
85 pub(crate) metadata: Metadata,
86}
87
88impl NormNode {
89 pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
90 let mut node = Self {
91 metadata: inherited_metadata,
93 ..Self::default()
94 };
95 let registry = METADATA_REGISTRY
96 .read()
97 .expect("failed to lock METADATA_REGISTRY for read");
98 if let Some(metadata) = registry.get(&router.id) {
99 node.metadata.tags.extend(metadata.tags.iter().cloned());
100 node.metadata
101 .securities
102 .extend(metadata.securities.iter().cloned());
103 }
104
105 for filter in router.filters() {
106 let info = format!("{filter:?}");
107 if info.starts_with("path:") {
108 let path = info
109 .split_once(':')
110 .expect("split once by ':' should not be get `None`")
111 .1;
112 node.path = Some(normalize_oapi_path(path));
113 } else if info.starts_with("method:") {
114 match info
115 .split_once(':')
116 .expect("split once by ':' should not be get `None`.")
117 .1
118 {
119 "GET" => node.method = Some(PathItemType::Get),
120 "POST" => node.method = Some(PathItemType::Post),
121 "PUT" => node.method = Some(PathItemType::Put),
122 "DELETE" => node.method = Some(PathItemType::Delete),
123 "HEAD" => node.method = Some(PathItemType::Head),
124 "OPTIONS" => node.method = Some(PathItemType::Options),
125 "CONNECT" => node.method = Some(PathItemType::Connect),
126 "TRACE" => node.method = Some(PathItemType::Trace),
127 "PATCH" => node.method = Some(PathItemType::Patch),
128 _ => {}
129 }
130 }
131 }
132 node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
133 node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
134 let routers = router.routers();
135 if !routers.is_empty() {
136 for router in routers {
137 node.children.push(Self::new(router, node.metadata.clone()));
138 }
139 }
140 node
141 }
142}
143
144type MetadataMap = RwLock<HashMap<usize, Metadata>>;
146static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
147
148pub trait RouterExt {
150 #[must_use]
154 fn oapi_security(self, security: SecurityRequirement) -> Self;
155
156 #[must_use]
160 fn oapi_securities<I>(self, security: I) -> Self
161 where
162 I: IntoIterator<Item = SecurityRequirement>;
163
164 #[must_use]
168 fn oapi_tag(self, tag: impl Into<String>) -> Self;
169
170 #[must_use]
174 fn oapi_tags<I, V>(self, tags: I) -> Self
175 where
176 I: IntoIterator<Item = V>,
177 V: Into<String>;
178}
179
180impl RouterExt for Router {
181 fn oapi_security(self, security: SecurityRequirement) -> Self {
182 let mut guard = METADATA_REGISTRY
183 .write()
184 .expect("failed to lock METADATA_REGISTRY for write");
185 let metadata = guard.entry(self.id).or_default();
186 metadata.securities.push(security);
187 self
188 }
189 fn oapi_securities<I>(self, iter: I) -> Self
190 where
191 I: IntoIterator<Item = SecurityRequirement>,
192 {
193 let mut guard = METADATA_REGISTRY
194 .write()
195 .expect("failed to lock METADATA_REGISTRY for write");
196 let metadata = guard.entry(self.id).or_default();
197 metadata.securities.extend(iter);
198 self
199 }
200 fn oapi_tag(self, tag: impl Into<String>) -> Self {
201 let mut guard = METADATA_REGISTRY
202 .write()
203 .expect("failed to lock METADATA_REGISTRY for write");
204 let metadata = guard.entry(self.id).or_default();
205 metadata.tags.insert(tag.into());
206 self
207 }
208 fn oapi_tags<I, V>(self, iter: I) -> Self
209 where
210 I: IntoIterator<Item = V>,
211 V: Into<String>,
212 {
213 let mut guard = METADATA_REGISTRY
214 .write()
215 .expect("failed to lock METADATA_REGISTRY for write");
216 let metadata = guard.entry(self.id).or_default();
217 metadata.tags.extend(iter.into_iter().map(Into::into));
218 self
219 }
220}
221
222#[non_exhaustive]
223#[derive(Default, Clone, Debug)]
224pub(crate) struct Metadata {
225 pub(crate) tags: BTreeSet<String>,
226 pub(crate) securities: Vec<SecurityRequirement>,
227}
228
229#[cfg(test)]
230mod tests {
231 use super::normalize_oapi_path;
232
233 #[test]
234 fn normalize_braced_path_constraints() {
235 assert_eq!(normalize_oapi_path("/posts/{id}"), "/posts/{id}");
236 assert_eq!(normalize_oapi_path("/posts/{id:num}"), "/posts/{id}");
237 assert_eq!(
238 normalize_oapi_path("/posts/{id:num(3..=10)}"),
239 "/posts/{id}"
240 );
241 assert_eq!(normalize_oapi_path(r"/posts/{id|\d+}"), "/posts/{id}");
242 assert_eq!(normalize_oapi_path("/posts/{id|[a-z]{2}}"), "/posts/{id}");
243 assert_eq!(
244 normalize_oapi_path("/posts/article_{id:num}"),
245 "/posts/article_{id}"
246 );
247 }
248}