1use crate::models::ModuleManifest;
4use alloc::{
5 boxed::Box,
6 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
7 format,
8 rc::Rc,
9 string::{String, ToString},
10 vec::Vec,
11};
12use core::{borrow::Borrow, error::Error};
13use trie_rs::{
14 inc_search::{IncSearch, Position},
15 map::{Trie, TrieBuilder},
16};
17
18#[derive(Clone, Debug)]
19pub struct Resolver {
20 trie: Trie<Sect, Vec<Rc<Module>>>,
21}
22
23impl Resolver {
24 pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, Box<dyn Error>> {
25 Ok(self.find(url)?.collect())
26 }
27
28 pub fn find(&self, url: &str) -> Result<impl Iterator<Item = Rc<Module>>, Box<dyn Error>> {
29 Ok(SearchIter {
30 trie: &self.trie,
31 input_idx: 0,
32 input: split_url(url)?,
33 items: &[],
34 save_stack: Vec::new(),
35 search: self.trie.inc_search(),
36 unique: BTreeSet::new(),
37 })
38 }
39
40 pub fn try_from_iter<I, T>(iter: I) -> Result<Self, Box<dyn Error>>
41 where
42 I: Iterator<Item = T>,
43 T: Borrow<ModuleManifest>,
44 {
45 ResolverBuilder::try_from_iter(iter)?.build()
46 }
47}
48
49impl TryFrom<&[ModuleManifest]> for Resolver {
50 type Error = Box<dyn Error>;
51
52 fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
53 ResolverBuilder::try_from(value)?.build()
54 }
55}
56
57#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
58pub struct Module {
59 pub name: String,
60}
61
62#[derive(Clone, Debug, Default)]
63pub struct ResolverBuilder {
64 modules: BTreeMap<String, Rc<Module>>,
65 protocol_modules: BTreeMap<String, Vec<Rc<Module>>>,
66 pattern_modules: BTreeMap<String, Vec<Rc<Module>>>,
67 prefix_modules: BTreeMap<String, Vec<Rc<Module>>>,
68}
69
70impl ResolverBuilder {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn build(self) -> Result<Resolver, Box<dyn Error>> {
76 let mut trie = TrieBuilder::new();
77 for (k, v) in self.prefix_modules {
78 let k = split_url(&k)?;
79 trie.push(k, v);
80 }
81 for (k, v) in self.protocol_modules {
82 let k = Sect::Protocol(k);
83 trie.push([k], v);
84 }
85 for (k, v) in self.pattern_modules {
86 let k = split_url(&k)?.into_iter().map(Sect::into_pattern);
87 trie.insert(k, v);
88 }
89 let trie = trie.build();
90
91 Ok(Resolver { trie })
92 }
93
94 pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Box<dyn Error>> {
95 let module = self.add_module(module);
96 let mods = self
97 .protocol_modules
98 .entry(protocol.to_string())
99 .or_default();
100 mods.push(module);
101 Ok(())
102 }
103 pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), Box<dyn Error>> {
104 let _ = split_url(prefix)?;
105 let module = self.add_module(module);
106 let mods = self.prefix_modules.entry(prefix.to_string()).or_default();
107 mods.push(module.clone());
108 Ok(())
109 }
110 pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), Box<dyn Error>> {
111 let _ = split_url(pattern)?;
112 let module = self.add_module(module);
113 let mods = self.pattern_modules.entry(pattern.to_string()).or_default();
114 mods.push(module.clone());
115 Ok(())
116 }
117
118 pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), Box<dyn Error>> {
119 for protocol in &manifest.handles.url_protocols {
120 self.insert_protocol(&manifest.name, protocol)?;
121 }
122 for prefix in &manifest.handles.url_prefixes {
123 self.insert_prefix(&manifest.name, prefix)?;
124 }
125 for pattern in &manifest.handles.url_patterns {
126 self.insert_pattern(&manifest.name, pattern)?;
127 }
128 Ok(())
129 }
130
131 pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, Box<dyn Error>>
132 where
133 I: Iterator<Item = T>,
134 T: Borrow<ModuleManifest>,
135 {
136 iter.try_fold(ResolverBuilder::new(), |mut b, m| {
137 b.insert_manifest(m.borrow())?;
138 Ok(b)
139 })
140 }
141
142 fn add_module(&mut self, name: &str) -> Rc<Module> {
143 let name = name.to_string();
144 self.modules
145 .entry(name.clone())
146 .or_insert_with(|| Rc::new(Module { name }))
147 .clone()
148 }
149}
150
151impl TryFrom<&[ModuleManifest]> for ResolverBuilder {
152 type Error = Box<dyn Error>;
153
154 fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
155 Self::try_from_iter(value.iter())
156 }
157}
158
159struct SearchIter<'r> {
160 trie: &'r Trie<Sect, Vec<Rc<Module>>>,
161 input_idx: usize,
162 input: Vec<Sect>,
163 items: &'r [Rc<Module>],
164 save_stack: Vec<(Position, usize)>,
165 search: IncSearch<'r, Sect, Vec<Rc<Module>>>,
166 unique: BTreeSet<Rc<Module>>,
167}
168
169impl<'r> Iterator for SearchIter<'r> {
170 type Item = Rc<Module>;
171
172 fn next(&mut self) -> Option<Self::Item> {
173 while let Some((first, rest)) = self.items.split_first() {
174 self.items = rest;
175 if self.unique.insert(first.clone()) {
176 return Some(first.clone());
177 }
178 }
179
180 loop {
181 let part = loop {
183 if let Some(part) = self.input.get(self.input_idx) {
184 break part;
185 }
186
187 if let Some(save_state) = self.save_stack.pop() {
189 self.search = IncSearch::resume(self.trie, save_state.0);
191 self.input_idx = save_state.1;
192
193 if let Some(cur) = self.search.value() {
195 self.items = cur;
196 while let Some((first, rest)) = self.items.split_first() {
197 self.items = rest;
198 if self.unique.insert(first.clone()) {
199 return Some(first.clone());
200 }
201 }
202 }
203
204 continue;
206 };
207
208 return None; };
210
211 let answer = match part {
213 Sect::Protocol(_) => self.search.query(part),
214 Sect::Domain(_) => {
215 let answer = self.search.query(part);
216
217 let mut search = self.search.clone();
221 if search.query(&Sect::WildcardDomain).is_some() {
222 let mut n = 1;
223 while self
224 .input
225 .get(self.input_idx + n)
226 .is_some_and(|i| matches!(i, Sect::Domain(_)))
227 {
228 n += 1;
229 }
230
231 let pos = Position::from(search);
233 self.save_stack.push((pos, self.input_idx + n));
234 }
235
236 answer
237 }
238 Sect::Path(_) => {
239 {
240 let mut search = self.search.clone();
241 if search.query(&Sect::WildcardPath).is_some() {
242 let pos = Position::from(search);
245 self.save_stack.push((pos, self.input_idx + 1));
246 }
247 }
248 self.search.query(part)
249 }
250 Sect::QueryParamName(_) => self.search.query(part),
251 Sect::QueryParamValue(_) => {
252 {
253 let mut search = self.search.clone();
254 if search.query(&Sect::WildcardQueryParamValue).is_some() {
255 let pos = Position::from(search);
256 self.save_stack.push((pos, self.input_idx + 1));
257 };
258 };
259 self.search.query(part)
260 }
261 _ => unreachable!(),
262 };
263
264 self.input_idx += 1;
265
266 if !answer.is_some_and(|a| a.is_prefix()) {
267 self.input_idx = self.input.len();
270 }
271
272 if let Some(cur) = self.search.value() {
274 self.items = cur;
275 while let Some((first, rest)) = self.items.split_first() {
276 self.items = rest;
277 if self.unique.insert(first.clone()) {
278 return Some(first.clone());
279 }
280 }
281 }
282 }
283 }
284}
285
286#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
287enum Sect {
288 Protocol(String),
289 Domain(String),
290 WildcardDomain,
291 Path(String),
292 WildcardPath,
293 QueryParamName(String),
294 QueryParamValue(String),
295 WildcardQueryParamValue,
296}
297
298impl Sect {
299 pub fn into_pattern(self) -> Self {
304 match self {
305 Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
306 Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
307 Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
308 _ => self,
309 }
310 }
311}
312
313fn split_url(url: &str) -> Result<Vec<Sect>, Box<dyn Error>> {
315 let mut res = Vec::new();
316
317 if !url.contains(':') {
318 res.push(Sect::Protocol(url.into()));
319 return Ok(res);
320 }
321
322 let url: url::Url = url
323 .parse()
324 .map_err(|e| format!("Unable to handle URL {url:?}: {e}"))?;
325
326 let proto = url.scheme();
327 res.push(Sect::Protocol(proto.into()));
328
329 if let Some(host) = url.host_str() {
330 let mut host_parts: Vec<&str> = host.split('.').rev().collect();
331
332 if (proto == "http" || proto == "https")
333 && host_parts.last().is_some_and(|last| *last == "www")
334 {
335 let _www = host_parts.pop();
337 }
338
339 for part in host_parts {
340 res.push(Sect::Domain(part.into()));
341 }
342 }
343
344 if url.cannot_be_a_base() {
345 res.push(Sect::Path(url.path().into()))
346 } else {
347 if let Some(path_parts) = url.path_segments() {
348 for part in path_parts {
349 if part.is_empty() {
350 continue;
351 }
352 res.push(Sect::Path(part.into()));
353 }
354 }
355 }
356
357 for (k, v) in url.query_pairs() {
358 res.push(Sect::QueryParamName(k.into()));
359 if !v.is_empty() {
360 res.push(Sect::QueryParamValue(v.into()));
361 }
362 }
363
364 Ok(res)
365}
366
367#[cfg(test)]
368mod test {
369 use super::*;
370
371 extern crate std;
372 use std::{eprintln, vec};
373
374 #[test]
375 fn matching() {
376 let mut builder = ResolverBuilder::default();
377
378 builder.insert_protocol("near", "near").unwrap();
379 builder
380 .insert_pattern("near-account", "near://account/:id")
381 .unwrap();
382 builder.insert_pattern("near-tx", "near://tx/:id").unwrap();
383 builder
384 .insert_prefix("google", "https://google.com/search?q=")
385 .unwrap();
386 builder.insert_prefix("x", "https://x.com/").unwrap();
387 builder
388 .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
389 .unwrap();
390 builder
391 .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
392 .unwrap();
393 builder
394 .insert_pattern("subdomains", "https://*.baz.com/")
395 .unwrap();
396 builder.insert_pattern("data", "data:text/plain").unwrap();
397 builder.insert_pattern("fs", "file://").unwrap();
398 builder.insert_pattern("fs2", "file:///2").unwrap();
399
400 let resolver = builder.build().expect("resolver should build");
401
402 eprintln!("{resolver:?}");
403
404 let tests = vec![
405 ("near", "near"),
406 ("near://tx/1234", "near-tx"),
407 ("near://account/1234", "near-account"),
408 ("near://other/1234", "near"),
409 ("https://google.com/search?q=foobar", "google"),
410 ("https://x.com/foobar", "x"),
411 ("https://www.linkedin.com/in/foobar/test", "linkedin"),
412 ("https://youtube.com/watch?v=foobar", "youtube"),
413 ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
414 ("data:text/plain?Hello+World", "data"),
415 ("file:///foo/bar/baz", "fs"),
416 ("file:///2/foo", "fs2"),
417 ];
418
419 for (input, want) in tests {
420 assert_eq!(
421 resolver
422 .find(input)
423 .expect("resolve succeeds")
424 .find(|out| out.name == want)
425 .unwrap_or_else(|| panic!(
426 "the wanted result should be returned, input={input} want={want}"
427 ))
428 .name,
429 want
430 );
431 }
432 }
433}