asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::ModuleManifest;
4use alloc::{
5    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6    rc::Rc,
7    string::{String, ToString},
8    vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17    modules: BTreeMap<String, Rc<Module>>,
18    file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19    nodes: slab::Slab<Node>,
20    roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24    pub fn new() -> Self {
25        Resolver::default()
26    }
27
28    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29        let input = split_url(url)?;
30
31        let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33        if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34            if let Some(Sect::Path(filename)) = input.last() {
35                if let Some((_, ext)) = filename.split_once(".") {
36                    self.file_extensions
37                        .get(ext)
38                        .into_iter()
39                        .flatten()
40                        .for_each(|module| {
41                            results.insert(module.clone());
42                        });
43                }
44            }
45        }
46
47        let with_freemove = |node_idx: usize| {
48            // Return the node ID
49            core::iter::once(node_idx)
50                // And the destination ID after following a `FreeMove` path from the node
51                .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52        };
53
54        // Initialize with start states that match the first input
55        let start_states: BTreeSet<usize> = self
56            .roots
57            .iter()
58            .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59            .collect();
60
61        let final_states = if input.len() == 1 {
62            // There is no further input, just get free moves from the start_states
63            start_states.into_iter().flat_map(with_freemove).collect()
64        } else {
65            // Process remaining input
66            input[1..].iter().fold(start_states, |states, sect| {
67                states
68                    .into_iter()
69                    .flat_map(|node_idx| &self.nodes[node_idx].paths)
70                    .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71                    .flat_map(with_freemove)
72                    .collect()
73            })
74        };
75
76        // Collect all modules from final states
77        for &state_idx in &final_states {
78            for module in &self.nodes[state_idx].modules {
79                results.insert(module.clone());
80            }
81        }
82
83        Ok(results.into_iter().collect())
84    }
85
86    pub fn insert_file_extension(
87        &mut self,
88        module: &str,
89        file_extension: &str,
90    ) -> Result<(), Infallible> {
91        let module = self.add_module(module);
92
93        self.file_extensions
94            .entry(file_extension.to_string())
95            .or_default()
96            .push(module);
97
98        Ok(())
99    }
100    pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
101        for protocol in &manifest.handles.url_protocols {
102            self.insert_protocol(&manifest.name, protocol).ok();
103        }
104        for prefix in &manifest.handles.url_prefixes {
105            self.insert_prefix(&manifest.name, prefix)?;
106        }
107        for pattern in &manifest.handles.url_patterns {
108            self.insert_pattern(&manifest.name, pattern)?;
109        }
110        for file_extension in &manifest.handles.file_extensions {
111            self.insert_file_extension(&manifest.name, file_extension)
112                .ok();
113        }
114        Ok(())
115    }
116    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
117        let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
118        let module = self.add_module(module);
119        let node_idx = self.get_or_create_node(path);
120
121        // Add a free move back to itself from the `FreeMove` node. (represents a protocol as a prefix):
122        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
123        self.nodes[node_idx].modules.insert(module);
124
125        Ok(())
126    }
127    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
128        let mut path = split_url(prefix)?;
129        // Add a `FreeMove` node at the end of the path to separate the prefix from
130        // patterns at the same node
131        path.push(Sect::FreeMove);
132        let module = self.add_module(module);
133        let node_idx = self.get_or_create_node(&path);
134
135        // Add a free move back to itself from the `FreeMove` node. Enables matching
136        // zero-or-more of anything:
137        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
138        self.nodes[node_idx].modules.insert(module);
139
140        Ok(())
141    }
142    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
143        let path: Vec<Sect> = split_url(pattern)?
144            .into_iter()
145            .map(Sect::into_pattern)
146            .collect();
147        let module = self.add_module(module);
148        let node_idx = self.get_or_create_node(&path);
149
150        self.nodes[node_idx].modules.insert(module);
151
152        Ok(())
153    }
154
155    #[cfg(all(feature = "std", feature = "serde"))]
156    pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
157        use error::FromDirError;
158
159        let path = path.as_ref();
160
161        let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
162            path: path.into(),
163            source,
164        })?;
165
166        let mut resolver = Resolver::new();
167
168        for entry in dir {
169            let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
170                path: path.into(),
171                source,
172            })?;
173            if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
174                continue;
175            }
176            let filename = entry.file_name();
177            let filename = filename.to_string_lossy();
178            if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
179                continue;
180            }
181            let path = entry.path();
182            let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
183                path: path.clone(),
184                source,
185            })?;
186
187            let manifest =
188                serde_yaml_ng::from_reader(file).map_err(|source| FromDirError::Parse {
189                    path: path.clone(),
190                    source,
191                })?;
192            resolver
193                .insert_manifest(&manifest)
194                .map_err(|source| FromDirError::Insert {
195                    path: path.clone(),
196                    source,
197                })?;
198        }
199
200        Ok(resolver)
201    }
202
203    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
204    where
205        I: Iterator<Item = T>,
206        T: Borrow<ModuleManifest>,
207    {
208        iter.try_fold(Resolver::default(), |mut r, m| {
209            r.insert_manifest(m.borrow())?;
210            Ok(r)
211        })
212    }
213
214    fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
215        // Get or create the root node
216        let root_idx = *self
217            .roots
218            .entry(path[0].clone())
219            .or_insert_with(|| self.nodes.insert(Node::default()));
220
221        path[1..].iter().fold(root_idx, |cur_idx, sect| {
222            match (self.nodes[cur_idx].paths.get(sect), sect) {
223                (Some(&idx), _sect) => idx,
224                (None, Sect::WildcardDomain) => {
225                    // If the sect is a wildcard domain add a link to self, this will also match multiple subdomains.
226                    self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
227                    cur_idx
228                },
229                (None, sect) => {
230                    // Create a new node
231                    let new_node_idx = self.nodes.insert(Node::default());
232
233                    // Add the transition from current node to new node
234                    self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
235                    new_node_idx
236                },
237            }
238        })
239    }
240
241    fn add_module(&mut self, name: &str) -> Rc<Module> {
242        let name = name.to_string();
243        self.modules
244            .entry(name.clone())
245            .or_insert_with(|| Rc::new(Module { name }))
246            .clone()
247    }
248}
249
250impl TryFrom<&[ModuleManifest]> for Resolver {
251    type Error = UrlParseError;
252
253    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
254        Resolver::try_from_iter(value.iter())
255    }
256}
257
258#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
259pub struct Module {
260    pub name: String,
261}
262
263#[derive(Clone, Debug, Default)]
264struct Node {
265    paths: BTreeMap<Sect, usize>,
266    modules: BTreeSet<Rc<Module>>,
267}
268
269#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
270enum Sect {
271    /// `https` from `https://example.org/`, matches the protocol (a.k.a. scheme) of an URL
272    Protocol(String),
273    /// `org` and `example` from `https://example.org/`, matches a single literal subdomain
274    Domain(String),
275    /// `*` from `https://*.example.org/`, matches zero-or-more subdomains
276    WildcardDomain,
277    /// `file` and `path` from `https://example.org/file/path`, match literal path segments
278    Path(String),
279    /// `:name` from `https://example.org/file/:name`, matches any single path segment
280    WildcardPath,
281    /// `q` from `https://example.org/?q=example`, matches a parameter name
282    QueryParamName(String),
283    /// `example` from `https://example.org/?q=example`, matches a literal parameter value
284    QueryParamValue(String),
285    /// `:query` from `https://example.org/?q=:query`, matches any query param value
286    WildcardQueryParamValue,
287    /// Matches a single section of any kind
288    FreeMove,
289}
290
291impl Sect {
292    /// Transform a sect that matches a pattern format to a wildcard.
293    /// - If a domain section is "*", make it a wildcard domain pattern
294    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
295    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
296    pub fn into_pattern(self) -> Self {
297        match self {
298            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
299            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
300            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
301            _ => self,
302        }
303    }
304
305    fn matches_input(&self, input: &Self) -> bool {
306        use Sect::*;
307        match (self, input) {
308            (a, b) if a == b => true,
309            (WildcardDomain, Domain(_)) => true,
310            (WildcardPath, Path(_)) => true,
311            (WildcardQueryParamValue, QueryParamValue(_)) => true,
312            // As a special case if the path section is a `FreeMove` then always accept it.
313            (FreeMove, _) => true,
314            _ => false,
315        }
316    }
317}
318
319/// Split and URL into sections that we care about. This is effectively a tokenizer.
320fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
321    if url.is_empty() {
322        return Err(UrlParseError::EmptyUrl);
323    }
324
325    let mut res = Vec::new();
326
327    if !url.contains(':') {
328        res.push(Sect::Protocol(url.into()));
329        return Ok(res);
330    }
331
332    let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
333        url: url.to_string(),
334        source: e,
335    })?;
336
337    let proto = url.scheme();
338    res.push(Sect::Protocol(proto.into()));
339
340    if let Some(host) = url.host_str() {
341        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
342
343        if (proto == "http" || proto == "https")
344            && host_parts.last().is_some_and(|last| *last == "www")
345        {
346            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
347            let _www = host_parts.pop();
348        }
349
350        for part in host_parts {
351            res.push(Sect::Domain(part.into()));
352        }
353    }
354
355    if url.cannot_be_a_base() {
356        res.push(Sect::Path(url.path().into()))
357    } else if let Some(path_parts) = url.path_segments() {
358        for part in path_parts {
359            if part.is_empty() {
360                continue;
361            }
362            res.push(Sect::Path(part.into()));
363        }
364    }
365
366    for (k, v) in url.query_pairs() {
367        res.push(Sect::QueryParamName(k.into()));
368        if !v.is_empty() {
369            res.push(Sect::QueryParamValue(v.into()));
370        }
371    }
372
373    Ok(res)
374}
375
376#[cfg(test)]
377mod test {
378    use super::*;
379
380    extern crate std;
381    use std::{eprintln, vec};
382
383    #[test]
384    fn matching() {
385        let mut resolver = Resolver::default();
386
387        resolver.insert_protocol("near", "near").unwrap();
388        resolver
389            .insert_pattern("near-account", "near://account/:id")
390            .unwrap();
391        resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
392        resolver
393            .insert_prefix("google", "https://google.com/search?q=")
394            .unwrap();
395        resolver.insert_prefix("x", "https://x.com/").unwrap();
396        resolver
397            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
398            .unwrap();
399        resolver
400            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
401            .unwrap();
402        resolver
403            .insert_pattern("subdomains", "https://*.baz.com/")
404            .unwrap();
405        resolver.insert_prefix("data", "data:text/plain").unwrap();
406        resolver.insert_prefix("fs", "file://").unwrap();
407        resolver.insert_prefix("fs2", "file:///2").unwrap();
408        resolver.insert_file_extension("txt-ext", "txt").unwrap();
409        resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
410
411        eprintln!("{resolver:#?}");
412
413        let tests = vec![
414            ("near", "near"),
415            ("near://tx/1234", "near-tx"),
416            ("near://account/1234", "near-account"),
417            ("near://other/1234", "near"),
418            ("https://google.com/search?q=foobar", "google"),
419            ("https://x.com/foobar", "x"),
420            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
421            ("https://youtube.com/watch?v=foobar", "youtube"),
422            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
423            ("data:text/plain?Hello+World", "data"),
424            ("file:///foo/bar/baz", "fs"),
425            ("file:///2/foo", "fs2"),
426            ("file:///foobar.txt", "txt-ext"),
427            ("file:///foobar.tar.gz", "tar-ext"),
428        ];
429
430        for (input, want) in tests {
431            assert_eq!(
432                resolver
433                    .resolve(input)
434                    .expect("resolve succeeds")
435                    .iter()
436                    .find(|out| out.name == want)
437                    .unwrap_or_else(|| panic!(
438                        "the wanted result should be returned, input={input} want={want}"
439                    ))
440                    .name,
441                want
442            );
443        }
444    }
445
446    #[test]
447    fn prefix_doesnt_turn_pattern_to_prefix() {
448        let mut resolver = Resolver::new();
449
450        resolver
451            .insert_pattern("pattern", "https://foobar.com/")
452            .unwrap();
453        eprintln!("{resolver:#?}");
454
455        let results = resolver.resolve("https://foobar.com/").unwrap();
456        eprintln!("{results:?}");
457        assert!(
458            results
459                .first()
460                .is_some_and(|module| module.name == "pattern"),
461            "the pattern should match"
462        );
463
464        let results = resolver.resolve("https://foobar.com/more").unwrap();
465        eprintln!("{results:?}");
466        assert!(results.is_empty(), "the pattern shouldn't be a prefix");
467
468        resolver
469            .insert_prefix("prefix", "https://foobar.com/")
470            .unwrap();
471        eprintln!("{resolver:#?}");
472
473        let results = resolver.resolve("https://foobar.com/").unwrap();
474        eprintln!("{results:?}");
475        assert!(results.len() == 2, "both items should match");
476
477        let results = resolver.resolve("https://foobar.com/more").unwrap();
478        eprintln!("{results:?}");
479        assert!(results.len() == 1, "only the prefix should match");
480        assert!(
481            results
482                .first()
483                .is_some_and(|module| module.name == "prefix"),
484            "only the prefix should match"
485        );
486    }
487}