1use std::any::TypeId;
2use std::collections::{BTreeSet, HashMap};
3use std::sync::{LazyLock, RwLock};
4
5use salvo_core::Router;
6
7use crate::SecurityRequirement;
8use crate::path::PathItemType;
9
10fn normalize_oapi_path(path: &str) -> String {
11 let mut normalized = String::with_capacity(path.len());
12 let mut chars = path.char_indices().peekable();
13
14 while let Some((start, ch)) = chars.next() {
15 if ch != '{' {
16 normalized.push(ch);
17 continue;
18 }
19 if chars.peek().map(|(_, next)| *next) == Some('{') {
21 normalized.push('{');
22 normalized.push('{');
23 chars.next();
24 continue;
25 }
26
27 let content_start = start + ch.len_utf8();
28 let mut braces_depth = 0usize;
29 let mut escaping = false;
30 let mut param_end = None;
31
32 while let Some((idx, current)) = chars.next() {
33 if escaping {
34 escaping = false;
35 continue;
36 }
37 match current {
38 '\\' => escaping = true,
39 '{' => braces_depth += 1,
40 '}' => {
41 if braces_depth == 0 {
42 param_end = Some(idx);
43 break;
44 }
45 braces_depth -= 1;
46 }
47 _ => {}
48 }
49 }
50
51 if let Some(param_end) = param_end {
52 let content = &path[content_start..param_end];
53 if let Some(name_end) = content.find([':', '|']) {
54 normalized.push('{');
55 normalized.push_str(&content[..name_end]);
56 normalized.push('}');
57 } else {
58 normalized.push('{');
59 normalized.push_str(content);
60 normalized.push('}');
61 }
62 } else {
63 normalized.push_str(&path[start..]);
64 break;
65 }
66 }
67 normalized
68}
69
70#[derive(Debug, Default)]
71pub(crate) struct NormNode {
72 pub(crate) handler_type_id: Option<TypeId>,
74 pub(crate) handler_type_name: Option<&'static str>,
75 pub(crate) method: Option<PathItemType>,
76 pub(crate) path: Option<String>,
77 pub(crate) children: Vec<Self>,
78 pub(crate) metadata: Metadata,
79}
80
81impl NormNode {
82 pub(crate) fn new(router: &Router, inherited_metadata: Metadata) -> Self {
83 let mut node = Self {
84 metadata: inherited_metadata,
86 ..Self::default()
87 };
88 let registry = METADATA_REGISTRY
89 .read()
90 .expect("failed to lock METADATA_REGISTRY for read");
91 if let Some(metadata) = registry.get(&router.id) {
92 node.metadata.tags.extend(metadata.tags.iter().cloned());
93 node.metadata
94 .securities
95 .extend(metadata.securities.iter().cloned());
96 }
97
98 for filter in router.filters() {
99 let info = format!("{filter:?}");
100 if info.starts_with("path:") {
101 let path = info
102 .split_once(':')
103 .expect("split once by ':' should not be get `None`")
104 .1;
105 node.path = Some(normalize_oapi_path(path));
106 } else if info.starts_with("method:") {
107 match info
108 .split_once(':')
109 .expect("split once by ':' should not be get `None`.")
110 .1
111 {
112 "GET" => node.method = Some(PathItemType::Get),
113 "POST" => node.method = Some(PathItemType::Post),
114 "PUT" => node.method = Some(PathItemType::Put),
115 "DELETE" => node.method = Some(PathItemType::Delete),
116 "HEAD" => node.method = Some(PathItemType::Head),
117 "OPTIONS" => node.method = Some(PathItemType::Options),
118 "CONNECT" => node.method = Some(PathItemType::Connect),
119 "TRACE" => node.method = Some(PathItemType::Trace),
120 "PATCH" => node.method = Some(PathItemType::Patch),
121 _ => {}
122 }
123 }
124 }
125 node.handler_type_id = router.goal.as_ref().map(|h| h.type_id());
126 node.handler_type_name = router.goal.as_ref().map(|h| h.type_name());
127 let routers = router.routers();
128 if !routers.is_empty() {
129 for router in routers {
130 node.children.push(Self::new(router, node.metadata.clone()));
131 }
132 }
133 node
134 }
135}
136
137type MetadataMap = RwLock<HashMap<usize, Metadata>>;
139static METADATA_REGISTRY: LazyLock<MetadataMap> = LazyLock::new(MetadataMap::default);
140
141pub trait RouterExt {
143 #[must_use]
147 fn oapi_security(self, security: SecurityRequirement) -> Self;
148
149 #[must_use]
153 fn oapi_securities<I>(self, security: I) -> Self
154 where
155 I: IntoIterator<Item = SecurityRequirement>;
156
157 #[must_use]
161 fn oapi_tag(self, tag: impl Into<String>) -> Self;
162
163 #[must_use]
167 fn oapi_tags<I, V>(self, tags: I) -> Self
168 where
169 I: IntoIterator<Item = V>,
170 V: Into<String>;
171}
172
173impl RouterExt for Router {
174 fn oapi_security(self, security: SecurityRequirement) -> Self {
175 let mut guard = METADATA_REGISTRY
176 .write()
177 .expect("failed to lock METADATA_REGISTRY for write");
178 let metadata = guard.entry(self.id).or_default();
179 metadata.securities.push(security);
180 self
181 }
182 fn oapi_securities<I>(self, iter: I) -> Self
183 where
184 I: IntoIterator<Item = SecurityRequirement>,
185 {
186 let mut guard = METADATA_REGISTRY
187 .write()
188 .expect("failed to lock METADATA_REGISTRY for write");
189 let metadata = guard.entry(self.id).or_default();
190 metadata.securities.extend(iter);
191 self
192 }
193 fn oapi_tag(self, tag: impl Into<String>) -> Self {
194 let mut guard = METADATA_REGISTRY
195 .write()
196 .expect("failed to lock METADATA_REGISTRY for write");
197 let metadata = guard.entry(self.id).or_default();
198 metadata.tags.insert(tag.into());
199 self
200 }
201 fn oapi_tags<I, V>(self, iter: I) -> Self
202 where
203 I: IntoIterator<Item = V>,
204 V: Into<String>,
205 {
206 let mut guard = METADATA_REGISTRY
207 .write()
208 .expect("failed to lock METADATA_REGISTRY for write");
209 let metadata = guard.entry(self.id).or_default();
210 metadata.tags.extend(iter.into_iter().map(Into::into));
211 self
212 }
213}
214
215#[non_exhaustive]
216#[derive(Default, Clone, Debug)]
217pub(crate) struct Metadata {
218 pub(crate) tags: BTreeSet<String>,
219 pub(crate) securities: Vec<SecurityRequirement>,
220}
221
222#[cfg(test)]
223mod tests {
224 use super::normalize_oapi_path;
225
226 #[test]
227 fn normalize_braced_path_constraints() {
228 assert_eq!(normalize_oapi_path("/posts/{id}"), "/posts/{id}");
229 assert_eq!(normalize_oapi_path("/posts/{id:num}"), "/posts/{id}");
230 assert_eq!(
231 normalize_oapi_path("/posts/{id:num(3..=10)}"),
232 "/posts/{id}"
233 );
234 assert_eq!(normalize_oapi_path(r"/posts/{id|\d+}"), "/posts/{id}");
235 assert_eq!(normalize_oapi_path("/posts/{id|[a-z]{2}}"), "/posts/{id}");
236 assert_eq!(
237 normalize_oapi_path("/posts/article_{id:num}"),
238 "/posts/article_{id}"
239 );
240 }
241}