1use alloc::{
4 boxed::Box,
5 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6 format,
7 rc::Rc,
8 string::{String, ToString},
9 vec::Vec,
10};
11use core::error::Error;
12use trie_rs::{
13 inc_search::{IncSearch, Position},
14 map::{Trie, TrieBuilder},
15};
16
17#[derive(Clone, Debug)]
18pub struct Resolver {
19 trie: Trie<Sect, Vec<Rc<Module>>>,
20}
21
22impl Resolver {
23 pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, Box<dyn Error>> {
24 Ok(self.find(url)?.collect())
25 }
26
27 pub fn find(&self, url: &str) -> Result<impl Iterator<Item = Rc<Module>>, Box<dyn Error>> {
28 Ok(SearchIter {
29 trie: &self.trie,
30 input_idx: 0,
31 input: split_url(url)?,
32 items: &[],
33 save_stack: Vec::new(),
34 search: self.trie.inc_search(),
35 unique: BTreeSet::new(),
36 })
37 }
38}
39
40#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
41pub struct Module {
42 pub name: String,
43}
44
45#[derive(Clone, Debug, Default)]
46pub struct ResolverBuilder {
47 modules: BTreeMap<String, Rc<Module>>,
48 protocol_modules: BTreeMap<String, Vec<Rc<Module>>>,
49 pattern_modules: BTreeMap<String, Vec<Rc<Module>>>,
50 prefix_modules: BTreeMap<String, Vec<Rc<Module>>>,
51}
52
53impl ResolverBuilder {
54 pub fn new() -> Self {
55 Self::default()
56 }
57
58 pub fn build(self) -> Result<Resolver, Box<dyn Error>> {
59 let mut trie = TrieBuilder::new();
60 for (k, v) in self.prefix_modules {
61 let k = split_url(&k)?;
62 trie.push(k, v);
63 }
64 for (k, v) in self.protocol_modules {
65 let k = Sect::Protocol(k);
66 trie.push([k], v);
67 }
68 for (k, v) in self.pattern_modules {
69 let k = split_url(&k)?.into_iter().map(Sect::into_pattern);
70 trie.insert(k, v);
71 }
72 let trie = trie.build();
73
74 Ok(Resolver { trie })
75 }
76
77 pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Box<dyn Error>> {
78 let module = self.add_module(module);
79 let mods = self
80 .protocol_modules
81 .entry(protocol.to_string())
82 .or_default();
83 mods.push(module);
84 Ok(())
85 }
86 pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), Box<dyn Error>> {
87 let _ = split_url(prefix)?;
88 let module = self.add_module(module);
89 let mods = self.prefix_modules.entry(prefix.to_string()).or_default();
90 mods.push(module.clone());
91 Ok(())
92 }
93 pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), Box<dyn Error>> {
94 let _ = split_url(pattern)?;
95 let module = self.add_module(module);
96 let mods = self.pattern_modules.entry(pattern.to_string()).or_default();
97 mods.push(module.clone());
98 Ok(())
99 }
100
101 fn add_module(&mut self, name: &str) -> Rc<Module> {
102 let name = name.to_string();
103 self.modules
104 .entry(name.clone())
105 .or_insert_with(|| Rc::new(Module { name }))
106 .clone()
107 }
108}
109
110struct SearchIter<'r> {
111 trie: &'r Trie<Sect, Vec<Rc<Module>>>,
112 input_idx: usize,
113 input: Vec<Sect>,
114 items: &'r [Rc<Module>],
115 save_stack: Vec<(Position, usize)>,
116 search: IncSearch<'r, Sect, Vec<Rc<Module>>>,
117 unique: BTreeSet<Rc<Module>>,
118}
119
120impl<'r> Iterator for SearchIter<'r> {
121 type Item = Rc<Module>;
122
123 fn next(&mut self) -> Option<Self::Item> {
124 while let Some((first, rest)) = self.items.split_first() {
125 self.items = rest;
126 if self.unique.insert(first.clone()) {
127 return Some(first.clone());
128 }
129 }
130
131 loop {
132 let part = loop {
134 if let Some(part) = self.input.get(self.input_idx) {
135 break part;
136 }
137
138 if let Some(save_state) = self.save_stack.pop() {
140 self.search = IncSearch::resume(self.trie, save_state.0);
142 self.input_idx = save_state.1;
143
144 if let Some(cur) = self.search.value() {
146 self.items = cur;
147 while let Some((first, rest)) = self.items.split_first() {
148 self.items = rest;
149 if self.unique.insert(first.clone()) {
150 return Some(first.clone());
151 }
152 }
153 }
154
155 continue;
157 };
158
159 return None; };
161
162 let answer = match part {
164 Sect::Protocol(_) => self.search.query(part),
165 Sect::Domain(_) => {
166 let answer = self.search.query(part);
167
168 let mut search = self.search.clone();
172 if search.query(&Sect::WildcardDomain).is_some() {
173 let mut n = 1;
174 while self
175 .input
176 .get(self.input_idx + n)
177 .is_some_and(|i| matches!(i, Sect::Domain(_)))
178 {
179 n += 1;
180 }
181
182 let pos = Position::from(search);
184 self.save_stack.push((pos, self.input_idx + n));
185 }
186
187 answer
188 }
189 Sect::Path(_) => {
190 {
191 let mut search = self.search.clone();
192 if search.query(&Sect::WildcardPath).is_some() {
193 let pos = Position::from(search);
196 self.save_stack.push((pos, self.input_idx + 1));
197 }
198 }
199 self.search.query(part)
200 }
201 Sect::QueryParamName(_) => self.search.query(part),
202 Sect::QueryParamValue(_) => {
203 {
204 let mut search = self.search.clone();
205 if search.query(&Sect::WildcardQueryParamValue).is_some() {
206 let pos = Position::from(search);
207 self.save_stack.push((pos, self.input_idx + 1));
208 };
209 };
210 self.search.query(part)
211 }
212 _ => unreachable!(),
213 };
214
215 self.input_idx += 1;
216
217 if !answer.is_some_and(|a| a.is_prefix()) {
218 self.input_idx = self.input.len();
221 }
222
223 if let Some(cur) = self.search.value() {
225 self.items = cur;
226 while let Some((first, rest)) = self.items.split_first() {
227 self.items = rest;
228 if self.unique.insert(first.clone()) {
229 return Some(first.clone());
230 }
231 }
232 }
233 }
234 }
235}
236
237#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
238enum Sect {
239 Protocol(String),
240 Domain(String),
241 WildcardDomain,
242 Path(String),
243 WildcardPath,
244 QueryParamName(String),
245 QueryParamValue(String),
246 WildcardQueryParamValue,
247}
248
249impl Sect {
250 pub fn into_pattern(self) -> Self {
255 match self {
256 Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
257 Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
258 Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
259 _ => self,
260 }
261 }
262}
263
264fn split_url(url: &str) -> Result<Vec<Sect>, Box<dyn Error>> {
266 let mut res = Vec::new();
267
268 if !url.contains(':') {
269 res.push(Sect::Protocol(url.into()));
270 return Ok(res);
271 }
272
273 let url: url::Url = url
274 .parse()
275 .map_err(|e| format!("Unable to handle URL {url:?}: {e}"))?;
276
277 let proto = url.scheme();
278 res.push(Sect::Protocol(proto.into()));
279
280 if let Some(host) = url.host_str() {
281 let mut host_parts: Vec<&str> = host.split('.').rev().collect();
282
283 if (proto == "http" || proto == "https")
284 && host_parts.last().is_some_and(|last| *last == "www")
285 {
286 let _www = host_parts.pop();
288 }
289
290 for part in host_parts {
291 res.push(Sect::Domain(part.into()));
292 }
293 }
294
295 if url.cannot_be_a_base() {
296 res.push(Sect::Path(url.path().into()))
297 } else {
298 if let Some(path_parts) = url.path_segments() {
299 for part in path_parts {
300 if part.is_empty() {
301 continue;
302 }
303 res.push(Sect::Path(part.into()));
304 }
305 }
306 }
307
308 for (k, v) in url.query_pairs() {
309 res.push(Sect::QueryParamName(k.into()));
310 if !v.is_empty() {
311 res.push(Sect::QueryParamValue(v.into()));
312 }
313 }
314
315 Ok(res)
316}
317
318#[cfg(test)]
319mod test {
320 use super::*;
321
322 extern crate std;
323 use std::{eprintln, vec};
324
325 #[test]
326 fn matching() {
327 let mut builder = ResolverBuilder::default();
328
329 builder.insert_protocol("near", "near").unwrap();
330 builder
331 .insert_pattern("near-account", "near://account/:id")
332 .unwrap();
333 builder.insert_pattern("near-tx", "near://tx/:id").unwrap();
334 builder
335 .insert_prefix("google", "https://google.com/search?q=")
336 .unwrap();
337 builder.insert_prefix("x", "https://x.com/").unwrap();
338 builder
339 .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
340 .unwrap();
341 builder
342 .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
343 .unwrap();
344 builder
345 .insert_pattern("subdomains", "https://*.baz.com/")
346 .unwrap();
347 builder.insert_pattern("data", "data:text/plain").unwrap();
348 builder.insert_pattern("fs", "file://").unwrap();
349 builder.insert_pattern("fs2", "file:///2").unwrap();
350
351 let resolver = builder.build().expect("resolver should build");
352
353 eprintln!("{resolver:?}");
354
355 let tests = vec![
356 ("near", "near"),
357 ("near://tx/1234", "near-tx"),
358 ("near://account/1234", "near-account"),
359 ("near://other/1234", "near"),
360 ("https://google.com/search?q=foobar", "google"),
361 ("https://x.com/foobar", "x"),
362 ("https://www.linkedin.com/in/foobar/test", "linkedin"),
363 ("https://youtube.com/watch?v=foobar", "youtube"),
364 ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
365 ("data:text/plain?Hello+World", "data"),
366 ("file:///foo/bar/baz", "fs"),
367 ("file:///2/foo", "fs2"),
368 ];
369
370 for (input, want) in tests {
371 assert_eq!(
372 resolver
373 .find(input)
374 .expect("resolve succeeds")
375 .find(|out| out.name == want)
376 .unwrap_or_else(|| panic!(
377 "the wanted result should be returned, input={input} want={want}"
378 ))
379 .name,
380 want
381 );
382 }
383 }
384}