asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::models::ModuleManifest;
4use alloc::{
5    boxed::Box,
6    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
7    format,
8    rc::Rc,
9    string::{String, ToString},
10    vec::Vec,
11};
12use core::{borrow::Borrow, error::Error};
13use trie_rs::{
14    inc_search::{IncSearch, Position},
15    map::{Trie, TrieBuilder},
16};
17
18#[derive(Clone, Debug)]
19pub struct Resolver {
20    trie: Trie<Sect, Vec<Rc<Module>>>,
21}
22
23impl Resolver {
24    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, Box<dyn Error>> {
25        Ok(self.find(url)?.collect())
26    }
27
28    pub fn find(&self, url: &str) -> Result<impl Iterator<Item = Rc<Module>>, Box<dyn Error>> {
29        Ok(SearchIter {
30            trie: &self.trie,
31            input_idx: 0,
32            input: split_url(url)?,
33            items: &[],
34            save_stack: Vec::new(),
35            search: self.trie.inc_search(),
36            unique: BTreeSet::new(),
37        })
38    }
39
40    pub fn try_from_iter<I, T>(iter: I) -> Result<Self, Box<dyn Error>>
41    where
42        I: Iterator<Item = T>,
43        T: Borrow<ModuleManifest>,
44    {
45        ResolverBuilder::try_from_iter(iter)?.build()
46    }
47}
48
49impl TryFrom<&[ModuleManifest]> for Resolver {
50    type Error = Box<dyn Error>;
51
52    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
53        ResolverBuilder::try_from(value)?.build()
54    }
55}
56
57#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
58pub struct Module {
59    pub name: String,
60}
61
62#[derive(Clone, Debug, Default)]
63pub struct ResolverBuilder {
64    modules: BTreeMap<String, Rc<Module>>,
65    protocol_modules: BTreeMap<String, Vec<Rc<Module>>>,
66    pattern_modules: BTreeMap<String, Vec<Rc<Module>>>,
67    prefix_modules: BTreeMap<String, Vec<Rc<Module>>>,
68}
69
70impl ResolverBuilder {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    pub fn build(self) -> Result<Resolver, Box<dyn Error>> {
76        let mut trie = TrieBuilder::new();
77        for (k, v) in self.prefix_modules {
78            let k = split_url(&k)?;
79            trie.push(k, v);
80        }
81        for (k, v) in self.protocol_modules {
82            let k = Sect::Protocol(k);
83            trie.push([k], v);
84        }
85        for (k, v) in self.pattern_modules {
86            let k = split_url(&k)?.into_iter().map(Sect::into_pattern);
87            trie.insert(k, v);
88        }
89        let trie = trie.build();
90
91        Ok(Resolver { trie })
92    }
93
94    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Box<dyn Error>> {
95        let module = self.add_module(module);
96        let mods = self
97            .protocol_modules
98            .entry(protocol.to_string())
99            .or_default();
100        mods.push(module);
101        Ok(())
102    }
103    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), Box<dyn Error>> {
104        let _ = split_url(prefix)?;
105        let module = self.add_module(module);
106        let mods = self.prefix_modules.entry(prefix.to_string()).or_default();
107        mods.push(module.clone());
108        Ok(())
109    }
110    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), Box<dyn Error>> {
111        let _ = split_url(pattern)?;
112        let module = self.add_module(module);
113        let mods = self.pattern_modules.entry(pattern.to_string()).or_default();
114        mods.push(module.clone());
115        Ok(())
116    }
117
118    pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), Box<dyn Error>> {
119        for protocol in &manifest.handles.url_protocols {
120            self.insert_protocol(&manifest.name, protocol)?;
121        }
122        for prefix in &manifest.handles.url_prefixes {
123            self.insert_prefix(&manifest.name, prefix)?;
124        }
125        for pattern in &manifest.handles.url_patterns {
126            self.insert_pattern(&manifest.name, pattern)?;
127        }
128        Ok(())
129    }
130
131    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, Box<dyn Error>>
132    where
133        I: Iterator<Item = T>,
134        T: Borrow<ModuleManifest>,
135    {
136        iter.try_fold(ResolverBuilder::new(), |mut b, m| {
137            b.insert_manifest(m.borrow())?;
138            Ok(b)
139        })
140    }
141
142    fn add_module(&mut self, name: &str) -> Rc<Module> {
143        let name = name.to_string();
144        self.modules
145            .entry(name.clone())
146            .or_insert_with(|| Rc::new(Module { name }))
147            .clone()
148    }
149}
150
151impl TryFrom<&[ModuleManifest]> for ResolverBuilder {
152    type Error = Box<dyn Error>;
153
154    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
155        Self::try_from_iter(value.iter())
156    }
157}
158
159struct SearchIter<'r> {
160    trie: &'r Trie<Sect, Vec<Rc<Module>>>,
161    input_idx: usize,
162    input: Vec<Sect>,
163    items: &'r [Rc<Module>],
164    save_stack: Vec<(Position, usize)>,
165    search: IncSearch<'r, Sect, Vec<Rc<Module>>>,
166    unique: BTreeSet<Rc<Module>>,
167}
168
169impl<'r> Iterator for SearchIter<'r> {
170    type Item = Rc<Module>;
171
172    fn next(&mut self) -> Option<Self::Item> {
173        while let Some((first, rest)) = self.items.split_first() {
174            self.items = rest;
175            if self.unique.insert(first.clone()) {
176                return Some(first.clone());
177            }
178        }
179
180        loop {
181            // Try to get current part or backtrack
182            let part = loop {
183                if let Some(part) = self.input.get(self.input_idx) {
184                    break part;
185                }
186
187                // No more input, try to backtrack
188                if let Some(save_state) = self.save_stack.pop() {
189                    // Restore saved state
190                    self.search = IncSearch::resume(self.trie, save_state.0);
191                    self.input_idx = save_state.1;
192
193                    // Check if the resumed state has values to return
194                    if let Some(cur) = self.search.value() {
195                        self.items = cur;
196                        while let Some((first, rest)) = self.items.split_first() {
197                            self.items = rest;
198                            if self.unique.insert(first.clone()) {
199                                return Some(first.clone());
200                            }
201                        }
202                    }
203
204                    // otherwise continue consuming input from the resumed state
205                    continue;
206                };
207
208                return None; // No more save states, we're done
209            };
210
211            // Try different matching strategies based on the part type
212            let answer = match part {
213                Sect::Protocol(_) => self.search.query(part),
214                Sect::Domain(_) => {
215                    let answer = self.search.query(part);
216
217                    // *after* matching the current domain section try to match
218                    // a wildcard domain. If it matches, consume inputs as
219                    // long as there are domain sections.
220                    let mut search = self.search.clone();
221                    if search.query(&Sect::WildcardDomain).is_some() {
222                        let mut n = 1;
223                        while self
224                            .input
225                            .get(self.input_idx + n)
226                            .is_some_and(|i| matches!(i, Sect::Domain(_)))
227                        {
228                            n += 1;
229                        }
230
231                        // save a state with (matched wildcard, all subdomains consumed)
232                        let pos = Position::from(search);
233                        self.save_stack.push((pos, self.input_idx + n));
234                    }
235
236                    answer
237                }
238                Sect::Path(_) => {
239                    {
240                        let mut search = self.search.clone();
241                        if search.query(&Sect::WildcardPath).is_some() {
242                            // We matched a wildcard path element.
243                            // Save the position that represents a consumed input.
244                            let pos = Position::from(search);
245                            self.save_stack.push((pos, self.input_idx + 1));
246                        }
247                    }
248                    self.search.query(part)
249                }
250                Sect::QueryParamName(_) => self.search.query(part),
251                Sect::QueryParamValue(_) => {
252                    {
253                        let mut search = self.search.clone();
254                        if search.query(&Sect::WildcardQueryParamValue).is_some() {
255                            let pos = Position::from(search);
256                            self.save_stack.push((pos, self.input_idx + 1));
257                        };
258                    };
259                    self.search.query(part)
260                }
261                _ => unreachable!(),
262            };
263
264            self.input_idx += 1;
265
266            if !answer.is_some_and(|a| a.is_prefix()) {
267                // Current node is not a prefix, i.e. complete match found.
268                // Consume remaining input.
269                self.input_idx = self.input.len();
270            }
271
272            // Check if current node has values (could use `answer.is_match()`).
273            if let Some(cur) = self.search.value() {
274                self.items = cur;
275                while let Some((first, rest)) = self.items.split_first() {
276                    self.items = rest;
277                    if self.unique.insert(first.clone()) {
278                        return Some(first.clone());
279                    }
280                }
281            }
282        }
283    }
284}
285
286#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
287enum Sect {
288    Protocol(String),
289    Domain(String),
290    WildcardDomain,
291    Path(String),
292    WildcardPath,
293    QueryParamName(String),
294    QueryParamValue(String),
295    WildcardQueryParamValue,
296}
297
298impl Sect {
299    /// Transform a sect that matches a pattern format to a wildcard.
300    /// - If a domain section is "*", make it a wildcard domain pattern
301    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
302    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
303    pub fn into_pattern(self) -> Self {
304        match self {
305            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
306            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
307            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
308            _ => self,
309        }
310    }
311}
312
313/// Split and URL into sections that we care about. This is effectively a tokenizer.
314fn split_url(url: &str) -> Result<Vec<Sect>, Box<dyn Error>> {
315    let mut res = Vec::new();
316
317    if !url.contains(':') {
318        res.push(Sect::Protocol(url.into()));
319        return Ok(res);
320    }
321
322    let url: url::Url = url
323        .parse()
324        .map_err(|e| format!("Unable to handle URL {url:?}: {e}"))?;
325
326    let proto = url.scheme();
327    res.push(Sect::Protocol(proto.into()));
328
329    if let Some(host) = url.host_str() {
330        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
331
332        if (proto == "http" || proto == "https")
333            && host_parts.last().is_some_and(|last| *last == "www")
334        {
335            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
336            let _www = host_parts.pop();
337        }
338
339        for part in host_parts {
340            res.push(Sect::Domain(part.into()));
341        }
342    }
343
344    if url.cannot_be_a_base() {
345        res.push(Sect::Path(url.path().into()))
346    } else {
347        if let Some(path_parts) = url.path_segments() {
348            for part in path_parts {
349                if part.is_empty() {
350                    continue;
351                }
352                res.push(Sect::Path(part.into()));
353            }
354        }
355    }
356
357    for (k, v) in url.query_pairs() {
358        res.push(Sect::QueryParamName(k.into()));
359        if !v.is_empty() {
360            res.push(Sect::QueryParamValue(v.into()));
361        }
362    }
363
364    Ok(res)
365}
366
367#[cfg(test)]
368mod test {
369    use super::*;
370
371    extern crate std;
372    use std::{eprintln, vec};
373
374    #[test]
375    fn matching() {
376        let mut builder = ResolverBuilder::default();
377
378        builder.insert_protocol("near", "near").unwrap();
379        builder
380            .insert_pattern("near-account", "near://account/:id")
381            .unwrap();
382        builder.insert_pattern("near-tx", "near://tx/:id").unwrap();
383        builder
384            .insert_prefix("google", "https://google.com/search?q=")
385            .unwrap();
386        builder.insert_prefix("x", "https://x.com/").unwrap();
387        builder
388            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
389            .unwrap();
390        builder
391            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
392            .unwrap();
393        builder
394            .insert_pattern("subdomains", "https://*.baz.com/")
395            .unwrap();
396        builder.insert_pattern("data", "data:text/plain").unwrap();
397        builder.insert_pattern("fs", "file://").unwrap();
398        builder.insert_pattern("fs2", "file:///2").unwrap();
399
400        let resolver = builder.build().expect("resolver should build");
401
402        eprintln!("{resolver:?}");
403
404        let tests = vec![
405            ("near", "near"),
406            ("near://tx/1234", "near-tx"),
407            ("near://account/1234", "near-account"),
408            ("near://other/1234", "near"),
409            ("https://google.com/search?q=foobar", "google"),
410            ("https://x.com/foobar", "x"),
411            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
412            ("https://youtube.com/watch?v=foobar", "youtube"),
413            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
414            ("data:text/plain?Hello+World", "data"),
415            ("file:///foo/bar/baz", "fs"),
416            ("file:///2/foo", "fs2"),
417        ];
418
419        for (input, want) in tests {
420            assert_eq!(
421                resolver
422                    .find(input)
423                    .expect("resolve succeeds")
424                    .find(|out| out.name == want)
425                    .unwrap_or_else(|| panic!(
426                        "the wanted result should be returned, input={input} want={want}"
427                    ))
428                    .name,
429                want
430            );
431        }
432    }
433}