asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use alloc::{
4    boxed::Box,
5    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6    format,
7    rc::Rc,
8    string::{String, ToString},
9    vec::Vec,
10};
11use core::error::Error;
12use trie_rs::{
13    inc_search::{IncSearch, Position},
14    map::{Trie, TrieBuilder},
15};
16
17#[derive(Clone, Debug)]
18pub struct Resolver {
19    trie: Trie<Sect, Vec<Rc<Module>>>,
20}
21
22impl Resolver {
23    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, Box<dyn Error>> {
24        Ok(self.find(url)?.collect())
25    }
26
27    pub fn find(&self, url: &str) -> Result<impl Iterator<Item = Rc<Module>>, Box<dyn Error>> {
28        Ok(SearchIter {
29            trie: &self.trie,
30            input_idx: 0,
31            input: split_url(url)?,
32            items: &[],
33            save_stack: Vec::new(),
34            search: self.trie.inc_search(),
35            unique: BTreeSet::new(),
36        })
37    }
38}
39
40#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
41pub struct Module {
42    pub name: String,
43}
44
45#[derive(Clone, Debug, Default)]
46pub struct ResolverBuilder {
47    modules: BTreeMap<String, Rc<Module>>,
48    protocol_modules: BTreeMap<String, Vec<Rc<Module>>>,
49    pattern_modules: BTreeMap<String, Vec<Rc<Module>>>,
50    prefix_modules: BTreeMap<String, Vec<Rc<Module>>>,
51}
52
53impl ResolverBuilder {
54    pub fn new() -> Self {
55        Self::default()
56    }
57
58    pub fn build(self) -> Result<Resolver, Box<dyn Error>> {
59        let mut trie = TrieBuilder::new();
60        for (k, v) in self.prefix_modules {
61            let k = split_url(&k)?;
62            trie.push(k, v);
63        }
64        for (k, v) in self.protocol_modules {
65            let k = Sect::Protocol(k);
66            trie.push([k], v);
67        }
68        for (k, v) in self.pattern_modules {
69            let k = split_url(&k)?.into_iter().map(Sect::into_pattern);
70            trie.insert(k, v);
71        }
72        let trie = trie.build();
73
74        Ok(Resolver { trie })
75    }
76
77    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Box<dyn Error>> {
78        let module = self.add_module(module);
79        let mods = self
80            .protocol_modules
81            .entry(protocol.to_string())
82            .or_default();
83        mods.push(module);
84        Ok(())
85    }
86    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), Box<dyn Error>> {
87        let _ = split_url(prefix)?;
88        let module = self.add_module(module);
89        let mods = self.prefix_modules.entry(prefix.to_string()).or_default();
90        mods.push(module.clone());
91        Ok(())
92    }
93    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), Box<dyn Error>> {
94        let _ = split_url(pattern)?;
95        let module = self.add_module(module);
96        let mods = self.pattern_modules.entry(pattern.to_string()).or_default();
97        mods.push(module.clone());
98        Ok(())
99    }
100
101    fn add_module(&mut self, name: &str) -> Rc<Module> {
102        let name = name.to_string();
103        self.modules
104            .entry(name.clone())
105            .or_insert_with(|| Rc::new(Module { name }))
106            .clone()
107    }
108}
109
110struct SearchIter<'r> {
111    trie: &'r Trie<Sect, Vec<Rc<Module>>>,
112    input_idx: usize,
113    input: Vec<Sect>,
114    items: &'r [Rc<Module>],
115    save_stack: Vec<(Position, usize)>,
116    search: IncSearch<'r, Sect, Vec<Rc<Module>>>,
117    unique: BTreeSet<Rc<Module>>,
118}
119
120impl<'r> Iterator for SearchIter<'r> {
121    type Item = Rc<Module>;
122
123    fn next(&mut self) -> Option<Self::Item> {
124        while let Some((first, rest)) = self.items.split_first() {
125            self.items = rest;
126            if self.unique.insert(first.clone()) {
127                return Some(first.clone());
128            }
129        }
130
131        loop {
132            // Try to get current part or backtrack
133            let part = loop {
134                if let Some(part) = self.input.get(self.input_idx) {
135                    break part;
136                }
137
138                // No more input, try to backtrack
139                if let Some(save_state) = self.save_stack.pop() {
140                    // Restore saved state
141                    self.search = IncSearch::resume(self.trie, save_state.0);
142                    self.input_idx = save_state.1;
143
144                    // Check if the resumed state has values to return
145                    if let Some(cur) = self.search.value() {
146                        self.items = cur;
147                        while let Some((first, rest)) = self.items.split_first() {
148                            self.items = rest;
149                            if self.unique.insert(first.clone()) {
150                                return Some(first.clone());
151                            }
152                        }
153                    }
154
155                    // otherwise continue consuming input from the resumed state
156                    continue;
157                };
158
159                return None; // No more save states, we're done
160            };
161
162            // Try different matching strategies based on the part type
163            let answer = match part {
164                Sect::Protocol(_) => self.search.query(part),
165                Sect::Domain(_) => {
166                    let answer = self.search.query(part);
167
168                    // *after* matching the current domain section try to match
169                    // a wildcard domain. If it matches, consume inputs as
170                    // long as there are domain sections.
171                    let mut search = self.search.clone();
172                    if search.query(&Sect::WildcardDomain).is_some() {
173                        let mut n = 1;
174                        while self
175                            .input
176                            .get(self.input_idx + n)
177                            .is_some_and(|i| matches!(i, Sect::Domain(_)))
178                        {
179                            n += 1;
180                        }
181
182                        // save a state with (matched wildcard, all subdomains consumed)
183                        let pos = Position::from(search);
184                        self.save_stack.push((pos, self.input_idx + n));
185                    }
186
187                    answer
188                }
189                Sect::Path(_) => {
190                    {
191                        let mut search = self.search.clone();
192                        if search.query(&Sect::WildcardPath).is_some() {
193                            // We matched a wildcard path element.
194                            // Save the position that represents a consumed input.
195                            let pos = Position::from(search);
196                            self.save_stack.push((pos, self.input_idx + 1));
197                        }
198                    }
199                    self.search.query(part)
200                }
201                Sect::QueryParamName(_) => self.search.query(part),
202                Sect::QueryParamValue(_) => {
203                    {
204                        let mut search = self.search.clone();
205                        if search.query(&Sect::WildcardQueryParamValue).is_some() {
206                            let pos = Position::from(search);
207                            self.save_stack.push((pos, self.input_idx + 1));
208                        };
209                    };
210                    self.search.query(part)
211                }
212                _ => unreachable!(),
213            };
214
215            self.input_idx += 1;
216
217            if !answer.is_some_and(|a| a.is_prefix()) {
218                // Current node is not a prefix, i.e. complete match found.
219                // Consume remaining input.
220                self.input_idx = self.input.len();
221            }
222
223            // Check if current node has values (could use `answer.is_match()`).
224            if let Some(cur) = self.search.value() {
225                self.items = cur;
226                while let Some((first, rest)) = self.items.split_first() {
227                    self.items = rest;
228                    if self.unique.insert(first.clone()) {
229                        return Some(first.clone());
230                    }
231                }
232            }
233        }
234    }
235}
236
237#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
238enum Sect {
239    Protocol(String),
240    Domain(String),
241    WildcardDomain,
242    Path(String),
243    WildcardPath,
244    QueryParamName(String),
245    QueryParamValue(String),
246    WildcardQueryParamValue,
247}
248
249impl Sect {
250    /// Transform a sect that matches a pattern format to a wildcard.
251    /// - If a domain section is "*", make it a wildcard domain pattern
252    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
253    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
254    pub fn into_pattern(self) -> Self {
255        match self {
256            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
257            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
258            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
259            _ => self,
260        }
261    }
262}
263
264/// Split and URL into sections that we care about. This is effectively a tokenizer.
265fn split_url(url: &str) -> Result<Vec<Sect>, Box<dyn Error>> {
266    let mut res = Vec::new();
267
268    if !url.contains(':') {
269        res.push(Sect::Protocol(url.into()));
270        return Ok(res);
271    }
272
273    let url: url::Url = url
274        .parse()
275        .map_err(|e| format!("Unable to handle URL {url:?}: {e}"))?;
276
277    let proto = url.scheme();
278    res.push(Sect::Protocol(proto.into()));
279
280    if let Some(host) = url.host_str() {
281        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
282
283        if (proto == "http" || proto == "https")
284            && host_parts.last().is_some_and(|last| *last == "www")
285        {
286            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
287            let _www = host_parts.pop();
288        }
289
290        for part in host_parts {
291            res.push(Sect::Domain(part.into()));
292        }
293    }
294
295    if url.cannot_be_a_base() {
296        res.push(Sect::Path(url.path().into()))
297    } else {
298        if let Some(path_parts) = url.path_segments() {
299            for part in path_parts {
300                if part.is_empty() {
301                    continue;
302                }
303                res.push(Sect::Path(part.into()));
304            }
305        }
306    }
307
308    for (k, v) in url.query_pairs() {
309        res.push(Sect::QueryParamName(k.into()));
310        if !v.is_empty() {
311            res.push(Sect::QueryParamValue(v.into()));
312        }
313    }
314
315    Ok(res)
316}
317
318#[cfg(test)]
319mod test {
320    use super::*;
321
322    extern crate std;
323    use std::{eprintln, vec};
324
325    #[test]
326    fn matching() {
327        let mut builder = ResolverBuilder::default();
328
329        builder.insert_protocol("near", "near").unwrap();
330        builder
331            .insert_pattern("near-account", "near://account/:id")
332            .unwrap();
333        builder.insert_pattern("near-tx", "near://tx/:id").unwrap();
334        builder
335            .insert_prefix("google", "https://google.com/search?q=")
336            .unwrap();
337        builder.insert_prefix("x", "https://x.com/").unwrap();
338        builder
339            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
340            .unwrap();
341        builder
342            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
343            .unwrap();
344        builder
345            .insert_pattern("subdomains", "https://*.baz.com/")
346            .unwrap();
347        builder.insert_pattern("data", "data:text/plain").unwrap();
348        builder.insert_pattern("fs", "file://").unwrap();
349        builder.insert_pattern("fs2", "file:///2").unwrap();
350
351        let resolver = builder.build().expect("resolver should build");
352
353        eprintln!("{resolver:?}");
354
355        let tests = vec![
356            ("near", "near"),
357            ("near://tx/1234", "near-tx"),
358            ("near://account/1234", "near-account"),
359            ("near://other/1234", "near"),
360            ("https://google.com/search?q=foobar", "google"),
361            ("https://x.com/foobar", "x"),
362            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
363            ("https://youtube.com/watch?v=foobar", "youtube"),
364            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
365            ("data:text/plain?Hello+World", "data"),
366            ("file:///foo/bar/baz", "fs"),
367            ("file:///2/foo", "fs2"),
368        ];
369
370        for (input, want) in tests {
371            assert_eq!(
372                resolver
373                    .find(input)
374                    .expect("resolve succeeds")
375                    .find(|out| out.name == want)
376                    .unwrap_or_else(|| panic!(
377                        "the wanted result should be returned, input={input} want={want}"
378                    ))
379                    .name,
380                want
381            );
382        }
383    }
384}