asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::ModuleManifest;
4use alloc::{
5    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6    rc::Rc,
7    string::{String, ToString},
8    vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17    modules: BTreeMap<String, Rc<Module>>,
18    file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19    nodes: slab::Slab<Node>,
20    roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24    pub fn new() -> Self {
25        Resolver::default()
26    }
27
28    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29        let input = split_url(url)?;
30
31        let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33        if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34            if let Some(Sect::Path(filename)) = input.last() {
35                if let Some((_, ext)) = filename.split_once(".") {
36                    self.file_extensions
37                        .get(ext)
38                        .into_iter()
39                        .flatten()
40                        .for_each(|module| {
41                            results.insert(module.clone());
42                        });
43                }
44            }
45        }
46
47        let with_freemove = |node_idx: usize| {
48            // Return the node ID
49            core::iter::once(node_idx)
50                // And the destination ID after following a `FreeMove` path from the node
51                .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52        };
53
54        // Initialize with start states that match the first input
55        let start_states: BTreeSet<usize> = self
56            .roots
57            .iter()
58            .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59            .collect();
60
61        let final_states = if input.len() == 1 {
62            // There is no further input, just get free moves from the start_states
63            start_states.into_iter().flat_map(with_freemove).collect()
64        } else {
65            // Process remaining input
66            input[1..].iter().fold(start_states, |states, sect| {
67                states
68                    .into_iter()
69                    .flat_map(|node_idx| &self.nodes[node_idx].paths)
70                    .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71                    .flat_map(with_freemove)
72                    .collect()
73            })
74        };
75
76        // Collect all modules from final states
77        for &state_idx in &final_states {
78            for module in &self.nodes[state_idx].modules {
79                results.insert(module.clone());
80            }
81        }
82
83        Ok(results.into_iter().collect())
84    }
85
86    pub fn insert_file_extension(
87        &mut self,
88        module: &str,
89        file_extension: &str,
90    ) -> Result<(), Infallible> {
91        let module = self.add_module(module);
92
93        let ext = file_extension
94            .strip_prefix(".")
95            .unwrap_or(file_extension)
96            .to_string();
97
98        self.file_extensions.entry(ext).or_default().push(module);
99
100        Ok(())
101    }
102    pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
103        for protocol in &manifest.handles.url_protocols {
104            self.insert_protocol(&manifest.name, protocol).ok();
105        }
106        for prefix in &manifest.handles.url_prefixes {
107            self.insert_prefix(&manifest.name, prefix)?;
108        }
109        for pattern in &manifest.handles.url_patterns {
110            self.insert_pattern(&manifest.name, pattern)?;
111        }
112        for file_extension in &manifest.handles.file_extensions {
113            self.insert_file_extension(&manifest.name, file_extension)
114                .ok();
115        }
116        Ok(())
117    }
118    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
119        let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
120        let module = self.add_module(module);
121        let node_idx = self.get_or_create_node(path);
122
123        // Add a free move back to itself from the `FreeMove` node. (represents a protocol as a prefix):
124        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
125        self.nodes[node_idx].modules.insert(module);
126
127        Ok(())
128    }
129    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
130        let mut path = split_url(prefix)?;
131        // Add a `FreeMove` node at the end of the path to separate the prefix from
132        // patterns at the same node
133        path.push(Sect::FreeMove);
134        let module = self.add_module(module);
135        let node_idx = self.get_or_create_node(&path);
136
137        // Add a free move back to itself from the `FreeMove` node. Enables matching
138        // zero-or-more of anything:
139        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
140        self.nodes[node_idx].modules.insert(module);
141
142        Ok(())
143    }
144    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
145        let path: Vec<Sect> = split_url(pattern)?
146            .into_iter()
147            .map(Sect::into_pattern)
148            .collect();
149        let module = self.add_module(module);
150        let node_idx = self.get_or_create_node(&path);
151
152        self.nodes[node_idx].modules.insert(module);
153
154        Ok(())
155    }
156
157    #[cfg(all(feature = "std", feature = "serde"))]
158    pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
159        use error::FromDirError;
160
161        let path = path.as_ref();
162
163        let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
164            path: path.into(),
165            source,
166        })?;
167
168        let mut resolver = Resolver::new();
169
170        for entry in dir {
171            let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
172                path: path.into(),
173                source,
174            })?;
175            if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
176                continue;
177            }
178            let filename = entry.file_name();
179            let filename = filename.to_string_lossy();
180            if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
181                continue;
182            }
183            let path = entry.path();
184            let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
185                path: path.clone(),
186                source,
187            })?;
188
189            let manifest =
190                serde_yaml_ng::from_reader(file).map_err(|source| FromDirError::Parse {
191                    path: path.clone(),
192                    source,
193                })?;
194            resolver
195                .insert_manifest(&manifest)
196                .map_err(|source| FromDirError::Insert {
197                    path: path.clone(),
198                    source,
199                })?;
200        }
201
202        Ok(resolver)
203    }
204
205    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
206    where
207        I: Iterator<Item = T>,
208        T: Borrow<ModuleManifest>,
209    {
210        iter.try_fold(Resolver::default(), |mut r, m| {
211            r.insert_manifest(m.borrow())?;
212            Ok(r)
213        })
214    }
215
216    fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
217        // Get or create the root node
218        let root_idx = *self
219            .roots
220            .entry(path[0].clone())
221            .or_insert_with(|| self.nodes.insert(Node::default()));
222
223        path[1..].iter().fold(root_idx, |cur_idx, sect| {
224            match (self.nodes[cur_idx].paths.get(sect), sect) {
225                (Some(&idx), _sect) => idx,
226                (None, Sect::WildcardDomain) => {
227                    // If the sect is a wildcard domain add a link to self, this will also match multiple subdomains.
228                    self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
229                    cur_idx
230                },
231                (None, sect) => {
232                    // Create a new node
233                    let new_node_idx = self.nodes.insert(Node::default());
234
235                    // Add the transition from current node to new node
236                    self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
237                    new_node_idx
238                },
239            }
240        })
241    }
242
243    fn add_module(&mut self, name: &str) -> Rc<Module> {
244        let name = name.to_string();
245        self.modules
246            .entry(name.clone())
247            .or_insert_with(|| Rc::new(Module { name }))
248            .clone()
249    }
250}
251
252impl TryFrom<&[ModuleManifest]> for Resolver {
253    type Error = UrlParseError;
254
255    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
256        Resolver::try_from_iter(value.iter())
257    }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
261pub struct Module {
262    pub name: String,
263}
264
265#[derive(Clone, Debug, Default)]
266struct Node {
267    paths: BTreeMap<Sect, usize>,
268    modules: BTreeSet<Rc<Module>>,
269}
270
271#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
272enum Sect {
273    /// `https` from `https://example.org/`, matches the protocol (a.k.a. scheme) of an URL
274    Protocol(String),
275    /// `org` and `example` from `https://example.org/`, matches a single literal subdomain
276    Domain(String),
277    /// `*` from `https://*.example.org/`, matches zero-or-more subdomains
278    WildcardDomain,
279    /// `file` and `path` from `https://example.org/file/path`, match literal path segments
280    Path(String),
281    /// `:name` from `https://example.org/file/:name`, matches any single path segment
282    WildcardPath,
283    /// `q` from `https://example.org/?q=example`, matches a parameter name
284    QueryParamName(String),
285    /// `example` from `https://example.org/?q=example`, matches a literal parameter value
286    QueryParamValue(String),
287    /// `:query` from `https://example.org/?q=:query`, matches any query param value
288    WildcardQueryParamValue,
289    /// Matches a single section of any kind
290    FreeMove,
291}
292
293impl Sect {
294    /// Transform a sect that matches a pattern format to a wildcard.
295    /// - If a domain section is "*", make it a wildcard domain pattern
296    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
297    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
298    pub fn into_pattern(self) -> Self {
299        match self {
300            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
301            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
302            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
303            _ => self,
304        }
305    }
306
307    fn matches_input(&self, input: &Self) -> bool {
308        use Sect::*;
309        match (self, input) {
310            (a, b) if a == b => true,
311            (WildcardDomain, Domain(_)) => true,
312            (WildcardPath, Path(_)) => true,
313            (WildcardQueryParamValue, QueryParamValue(_)) => true,
314            // As a special case if the path section is a `FreeMove` then always accept it.
315            (FreeMove, _) => true,
316            _ => false,
317        }
318    }
319}
320
321/// Split and URL into sections that we care about. This is effectively a tokenizer.
322fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
323    if url.is_empty() {
324        return Err(UrlParseError::EmptyUrl);
325    }
326
327    let mut res = Vec::new();
328
329    if !url.contains(':') {
330        res.push(Sect::Protocol(url.into()));
331        return Ok(res);
332    }
333
334    let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
335        url: url.to_string(),
336        source: e,
337    })?;
338
339    let proto = url.scheme();
340    res.push(Sect::Protocol(proto.into()));
341
342    if let Some(host) = url.host_str() {
343        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
344
345        if (proto == "http" || proto == "https")
346            && host_parts.last().is_some_and(|last| *last == "www")
347        {
348            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
349            let _www = host_parts.pop();
350        }
351
352        for part in host_parts {
353            res.push(Sect::Domain(part.into()));
354        }
355    }
356
357    if url.cannot_be_a_base() {
358        res.push(Sect::Path(url.path().into()))
359    } else if let Some(path_parts) = url.path_segments() {
360        for part in path_parts {
361            if part.is_empty() {
362                continue;
363            }
364            res.push(Sect::Path(part.into()));
365        }
366    }
367
368    for (k, v) in url.query_pairs() {
369        res.push(Sect::QueryParamName(k.into()));
370        if !v.is_empty() {
371            res.push(Sect::QueryParamValue(v.into()));
372        }
373    }
374
375    Ok(res)
376}
377
378#[cfg(test)]
379mod test {
380    use super::*;
381
382    extern crate std;
383    use std::{eprintln, vec};
384
385    #[test]
386    fn matching() {
387        let mut resolver = Resolver::default();
388
389        resolver.insert_protocol("near", "near").unwrap();
390        resolver
391            .insert_pattern("near-account", "near://account/:id")
392            .unwrap();
393        resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
394        resolver
395            .insert_prefix("google", "https://google.com/search?q=")
396            .unwrap();
397        resolver.insert_prefix("x", "https://x.com/").unwrap();
398        resolver
399            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
400            .unwrap();
401        resolver
402            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
403            .unwrap();
404        resolver
405            .insert_pattern("subdomains", "https://*.baz.com/")
406            .unwrap();
407        resolver.insert_prefix("data", "data:text/plain").unwrap();
408        resolver.insert_prefix("fs", "file://").unwrap();
409        resolver.insert_prefix("fs2", "file:///2").unwrap();
410        resolver.insert_file_extension("txt-ext", "txt").unwrap();
411        resolver.insert_file_extension("zip-ext", ".zip").unwrap(); // leading dot should work just as well
412        resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
413
414        eprintln!("{resolver:#?}");
415
416        let tests = vec![
417            ("near", "near"),
418            ("near://tx/1234", "near-tx"),
419            ("near://account/1234", "near-account"),
420            ("near://other/1234", "near"),
421            ("https://google.com/search?q=foobar", "google"),
422            ("https://x.com/foobar", "x"),
423            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
424            ("https://youtube.com/watch?v=foobar", "youtube"),
425            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
426            ("data:text/plain?Hello+World", "data"),
427            ("file:///foo/bar/baz", "fs"),
428            ("file:/archive.zip", "zip-ext"),
429            ("file:///2/foo", "fs2"),
430            ("file:///foobar.txt", "txt-ext"),
431            ("file:///foobar.tar.gz", "tar-ext"),
432        ];
433
434        for (input, want) in tests {
435            assert_eq!(
436                resolver
437                    .resolve(input)
438                    .expect("resolve succeeds")
439                    .iter()
440                    .find(|out| out.name == want)
441                    .unwrap_or_else(|| panic!(
442                        "the wanted result should be returned, input={input} want={want}"
443                    ))
444                    .name,
445                want
446            );
447        }
448    }
449
450    #[test]
451    fn prefix_doesnt_turn_pattern_to_prefix() {
452        let mut resolver = Resolver::new();
453
454        resolver
455            .insert_pattern("pattern", "https://foobar.com/")
456            .unwrap();
457        eprintln!("{resolver:#?}");
458
459        let results = resolver.resolve("https://foobar.com/").unwrap();
460        eprintln!("{results:?}");
461        assert!(
462            results
463                .first()
464                .is_some_and(|module| module.name == "pattern"),
465            "the pattern should match"
466        );
467
468        let results = resolver.resolve("https://foobar.com/more").unwrap();
469        eprintln!("{results:?}");
470        assert!(results.is_empty(), "the pattern shouldn't be a prefix");
471
472        resolver
473            .insert_prefix("prefix", "https://foobar.com/")
474            .unwrap();
475        eprintln!("{resolver:#?}");
476
477        let results = resolver.resolve("https://foobar.com/").unwrap();
478        eprintln!("{results:?}");
479        assert!(results.len() == 2, "both items should match");
480
481        let results = resolver.resolve("https://foobar.com/more").unwrap();
482        eprintln!("{results:?}");
483        assert!(results.len() == 1, "only the prefix should match");
484        assert!(
485            results
486                .first()
487                .is_some_and(|module| module.name == "prefix"),
488            "only the prefix should match"
489        );
490    }
491}