asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::models::ModuleManifest;
4use alloc::{
5    boxed::Box,
6    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
7    format,
8    rc::Rc,
9    string::{String, ToString},
10    vec::Vec,
11};
12use core::{borrow::Borrow, error::Error};
13
14#[derive(Clone, Debug, Default)]
15pub struct Resolver {
16    modules: BTreeMap<String, Rc<Module>>,
17    file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
18    nodes: slab::Slab<Node>,
19    roots: BTreeMap<Sect, usize>,
20}
21
22impl Resolver {
23    pub fn new() -> Self {
24        Resolver::default()
25    }
26
27    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, Box<dyn Error>> {
28        let input = split_url(url)?;
29
30        let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
31
32        if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
33            if let Some(Sect::Path(filename)) = input.last() {
34                if let Some((_, ext)) = filename.split_once(".") {
35                    self.file_extensions
36                        .get(ext)
37                        .into_iter()
38                        .flatten()
39                        .for_each(|module| {
40                            results.insert(module.clone());
41                        });
42                }
43            }
44        }
45
46        let with_freemove = |node_idx: usize| {
47            // Return the node ID
48            core::iter::once(node_idx)
49                // And the destination ID after following a `FreeMove` path from the node
50                .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
51        };
52
53        // Initialize with start states that match the first input
54        let start_states: BTreeSet<usize> = self
55            .roots
56            .iter()
57            .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
58            .collect();
59
60        let final_states = if input.len() == 1 {
61            // There is no further input, just get free moves from the start_states
62            start_states.into_iter().flat_map(with_freemove).collect()
63        } else {
64            // Process remaining input
65            input[1..].iter().fold(start_states, |states, sect| {
66                states
67                    .into_iter()
68                    .flat_map(|node_idx| &self.nodes[node_idx].paths)
69                    .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
70                    .flat_map(with_freemove)
71                    .collect()
72            })
73        };
74
75        // Collect all modules from final states
76        for &state_idx in &final_states {
77            for module in &self.nodes[state_idx].modules {
78                results.insert(module.clone());
79            }
80        }
81
82        Ok(results.into_iter().collect())
83    }
84
85    pub fn insert_file_extension(
86        &mut self,
87        module: &str,
88        file_extension: &str,
89    ) -> Result<(), Box<dyn Error>> {
90        let module = self.add_module(module);
91
92        self.file_extensions
93            .entry(file_extension.to_string())
94            .or_default()
95            .push(module);
96
97        Ok(())
98    }
99    pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), Box<dyn Error>> {
100        for protocol in &manifest.handles.url_protocols {
101            self.insert_protocol(&manifest.name, protocol)?;
102        }
103        for prefix in &manifest.handles.url_prefixes {
104            self.insert_prefix(&manifest.name, prefix)?;
105        }
106        for pattern in &manifest.handles.url_patterns {
107            self.insert_pattern(&manifest.name, pattern)?;
108        }
109        for file_extension in &manifest.handles.file_extensions {
110            self.insert_file_extension(&manifest.name, file_extension)?;
111        }
112        Ok(())
113    }
114    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Box<dyn Error>> {
115        let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
116        let module = self.add_module(module);
117        let node_idx = self.get_or_create_node(path);
118
119        // Add a free move back to itself from the `FreeMove` node. (represents a protocol as a prefix):
120        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
121        self.nodes[node_idx].modules.insert(module);
122
123        Ok(())
124    }
125    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), Box<dyn Error>> {
126        let mut path = split_url(prefix)?;
127        // Add a `FreeMove` node at the end of the path to separate the prefix from
128        // patterns at the same node
129        path.push(Sect::FreeMove);
130        let module = self.add_module(module);
131        let node_idx = self.get_or_create_node(&path);
132
133        // Add a free move back to itself from the `FreeMove` node. Enables matching
134        // zero-or-more of anything:
135        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
136        self.nodes[node_idx].modules.insert(module);
137
138        Ok(())
139    }
140    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), Box<dyn Error>> {
141        let path: Vec<Sect> = split_url(pattern)?
142            .into_iter()
143            .map(Sect::into_pattern)
144            .collect();
145        let module = self.add_module(module);
146        let node_idx = self.get_or_create_node(&path);
147
148        self.nodes[node_idx].modules.insert(module);
149
150        Ok(())
151    }
152
153    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, Box<dyn Error>>
154    where
155        I: Iterator<Item = T>,
156        T: Borrow<ModuleManifest>,
157    {
158        iter.try_fold(Resolver::default(), |mut b, m| {
159            b.insert_manifest(m.borrow())?;
160            Ok(b)
161        })
162    }
163
164    fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
165        // Get or create the root node
166        let root_idx = *self
167            .roots
168            .entry(path[0].clone())
169            .or_insert_with(|| self.nodes.insert(Node::default()));
170
171        path[1..].iter().fold(root_idx, |cur_idx, sect| {
172            match (self.nodes[cur_idx].paths.get(sect), sect) {
173                (Some(&idx), _sect) => idx,
174                (None, Sect::WildcardDomain) => {
175                    // If the sect is a wildcard domain add a link to self, this will also match multiple subdomains.
176                    self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
177                    cur_idx
178                }
179                (None, sect) => {
180                    // Create a new node
181                    let new_node_idx = self.nodes.insert(Node::default());
182
183                    // Add the transition from current node to new node
184                    self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
185                    new_node_idx
186                }
187            }
188        })
189    }
190
191    fn add_module(&mut self, name: &str) -> Rc<Module> {
192        let name = name.to_string();
193        self.modules
194            .entry(name.clone())
195            .or_insert_with(|| Rc::new(Module { name }))
196            .clone()
197    }
198}
199
200impl TryFrom<&[ModuleManifest]> for Resolver {
201    type Error = Box<dyn Error>;
202
203    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
204        Resolver::try_from_iter(value.iter())
205    }
206}
207
208#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
209pub struct Module {
210    pub name: String,
211}
212
213#[derive(Clone, Debug, Default)]
214struct Node {
215    paths: BTreeMap<Sect, usize>,
216    modules: BTreeSet<Rc<Module>>,
217}
218
219#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
220enum Sect {
221    /// `https` from `https://example.org/`, matches the protocol (a.k.a. scheme) of an URL
222    Protocol(String),
223    /// `org` and `example` from `https://example.org/`, matches a single literal subdomain
224    Domain(String),
225    /// `*` from `https://*.example.org/`, matches zero-or-more subdomains
226    WildcardDomain,
227    /// `file` and `path` from `https://example.org/file/path`, match literal path segments
228    Path(String),
229    /// `:name` from `https://example.org/file/:name`, matches any single path segment
230    WildcardPath,
231    /// `q` from `https://example.org/?q=example`, matches a parameter name
232    QueryParamName(String),
233    /// `example` from `https://example.org/?q=example`, matches a literal parameter value
234    QueryParamValue(String),
235    /// `:query` from `https://example.org/?q=:query`, matches any query param value
236    WildcardQueryParamValue,
237    /// Matches a single section of any kind
238    FreeMove,
239}
240
241impl Sect {
242    /// Transform a sect that matches a pattern format to a wildcard.
243    /// - If a domain section is "*", make it a wildcard domain pattern
244    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
245    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
246    pub fn into_pattern(self) -> Self {
247        match self {
248            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
249            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
250            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
251            _ => self,
252        }
253    }
254
255    fn matches_input(&self, input: &Self) -> bool {
256        use Sect::*;
257        match (self, input) {
258            (a, b) if a == b => true,
259            (WildcardDomain, Domain(_)) => true,
260            (WildcardPath, Path(_)) => true,
261            (WildcardQueryParamValue, QueryParamValue(_)) => true,
262            // As a special case if the path section is a `FreeMove` then always accept it.
263            (FreeMove, _) => true,
264            _ => false,
265        }
266    }
267}
268
269/// Split and URL into sections that we care about. This is effectively a tokenizer.
270fn split_url(url: &str) -> Result<Vec<Sect>, Box<dyn Error>> {
271    if url.is_empty() {
272        return Err("URL cannot be empty".into());
273    }
274
275    let mut res = Vec::new();
276
277    if !url.contains(':') {
278        res.push(Sect::Protocol(url.into()));
279        return Ok(res);
280    }
281
282    let url: url::Url = url
283        .parse()
284        .map_err(|e| format!("Unable to handle URL {url:?}: {e}"))?;
285
286    let proto = url.scheme();
287    res.push(Sect::Protocol(proto.into()));
288
289    if let Some(host) = url.host_str() {
290        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
291
292        if (proto == "http" || proto == "https")
293            && host_parts.last().is_some_and(|last| *last == "www")
294        {
295            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
296            let _www = host_parts.pop();
297        }
298
299        for part in host_parts {
300            res.push(Sect::Domain(part.into()));
301        }
302    }
303
304    if url.cannot_be_a_base() {
305        res.push(Sect::Path(url.path().into()))
306    } else if let Some(path_parts) = url.path_segments() {
307        for part in path_parts {
308            if part.is_empty() {
309                continue;
310            }
311            res.push(Sect::Path(part.into()));
312        }
313    }
314
315    for (k, v) in url.query_pairs() {
316        res.push(Sect::QueryParamName(k.into()));
317        if !v.is_empty() {
318            res.push(Sect::QueryParamValue(v.into()));
319        }
320    }
321
322    Ok(res)
323}
324
325#[cfg(test)]
326mod test {
327    use super::*;
328
329    extern crate std;
330    use std::{eprintln, vec};
331
332    #[test]
333    fn matching() {
334        let mut resolver = Resolver::default();
335
336        resolver.insert_protocol("near", "near").unwrap();
337        resolver
338            .insert_pattern("near-account", "near://account/:id")
339            .unwrap();
340        resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
341        resolver
342            .insert_prefix("google", "https://google.com/search?q=")
343            .unwrap();
344        resolver.insert_prefix("x", "https://x.com/").unwrap();
345        resolver
346            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
347            .unwrap();
348        resolver
349            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
350            .unwrap();
351        resolver
352            .insert_pattern("subdomains", "https://*.baz.com/")
353            .unwrap();
354        resolver.insert_prefix("data", "data:text/plain").unwrap();
355        resolver.insert_prefix("fs", "file://").unwrap();
356        resolver.insert_prefix("fs2", "file:///2").unwrap();
357        resolver.insert_file_extension("txt-ext", "txt").unwrap();
358        resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
359
360        eprintln!("{resolver:#?}");
361
362        let tests = vec![
363            ("near", "near"),
364            ("near://tx/1234", "near-tx"),
365            ("near://account/1234", "near-account"),
366            ("near://other/1234", "near"),
367            ("https://google.com/search?q=foobar", "google"),
368            ("https://x.com/foobar", "x"),
369            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
370            ("https://youtube.com/watch?v=foobar", "youtube"),
371            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
372            ("data:text/plain?Hello+World", "data"),
373            ("file:///foo/bar/baz", "fs"),
374            ("file:///2/foo", "fs2"),
375            ("file:///foobar.txt", "txt-ext"),
376            ("file:///foobar.tar.gz", "tar-ext"),
377        ];
378
379        for (input, want) in tests {
380            assert_eq!(
381                resolver
382                    .resolve(input)
383                    .expect("resolve succeeds")
384                    .iter()
385                    .find(|out| out.name == want)
386                    .unwrap_or_else(|| panic!(
387                        "the wanted result should be returned, input={input} want={want}"
388                    ))
389                    .name,
390                want
391            );
392        }
393    }
394
395    #[test]
396    fn prefix_doesnt_turn_pattern_to_prefix() {
397        let mut resolver = Resolver::new();
398
399        resolver
400            .insert_pattern("pattern", "https://foobar.com/")
401            .unwrap();
402        eprintln!("{resolver:#?}");
403
404        let results = resolver.resolve("https://foobar.com/").unwrap();
405        eprintln!("{results:?}");
406        assert!(
407            results
408                .first()
409                .is_some_and(|module| module.name == "pattern"),
410            "the pattern should match"
411        );
412
413        let results = resolver.resolve("https://foobar.com/more").unwrap();
414        eprintln!("{results:?}");
415        assert!(results.is_empty(), "the pattern shouldn't be a prefix");
416
417        resolver
418            .insert_prefix("prefix", "https://foobar.com/")
419            .unwrap();
420        eprintln!("{resolver:#?}");
421
422        let results = resolver.resolve("https://foobar.com/").unwrap();
423        eprintln!("{results:?}");
424        assert!(results.len() == 2, "both items should match");
425
426        let results = resolver.resolve("https://foobar.com/more").unwrap();
427        eprintln!("{results:?}");
428        assert!(results.len() == 1, "only the prefix should match");
429        assert!(
430            results
431                .first()
432                .is_some_and(|module| module.name == "prefix"),
433            "only the prefix should match"
434        );
435    }
436}