asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::models::ModuleManifest;
4use alloc::{
5    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6    rc::Rc,
7    string::{String, ToString},
8    vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17    modules: BTreeMap<String, Rc<Module>>,
18    file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19    nodes: slab::Slab<Node>,
20    roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24    pub fn new() -> Self {
25        Resolver::default()
26    }
27
28    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29        let input = split_url(url)?;
30
31        let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33        if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34            if let Some(Sect::Path(filename)) = input.last() {
35                if let Some((_, ext)) = filename.split_once(".") {
36                    self.file_extensions
37                        .get(ext)
38                        .into_iter()
39                        .flatten()
40                        .for_each(|module| {
41                            results.insert(module.clone());
42                        });
43                }
44            }
45        }
46
47        let with_freemove = |node_idx: usize| {
48            // Return the node ID
49            core::iter::once(node_idx)
50                // And the destination ID after following a `FreeMove` path from the node
51                .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52        };
53
54        // Initialize with start states that match the first input
55        let start_states: BTreeSet<usize> = self
56            .roots
57            .iter()
58            .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59            .collect();
60
61        let final_states = if input.len() == 1 {
62            // There is no further input, just get free moves from the start_states
63            start_states.into_iter().flat_map(with_freemove).collect()
64        } else {
65            // Process remaining input
66            input[1..].iter().fold(start_states, |states, sect| {
67                states
68                    .into_iter()
69                    .flat_map(|node_idx| &self.nodes[node_idx].paths)
70                    .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71                    .flat_map(with_freemove)
72                    .collect()
73            })
74        };
75
76        // Collect all modules from final states
77        for &state_idx in &final_states {
78            for module in &self.nodes[state_idx].modules {
79                results.insert(module.clone());
80            }
81        }
82
83        Ok(results.into_iter().collect())
84    }
85
86    pub fn insert_file_extension(
87        &mut self,
88        module: &str,
89        file_extension: &str,
90    ) -> Result<(), Infallible> {
91        let module = self.add_module(module);
92
93        self.file_extensions
94            .entry(file_extension.to_string())
95            .or_default()
96            .push(module);
97
98        Ok(())
99    }
100    pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
101        for protocol in &manifest.handles.url_protocols {
102            self.insert_protocol(&manifest.name, protocol).ok();
103        }
104        for prefix in &manifest.handles.url_prefixes {
105            self.insert_prefix(&manifest.name, prefix)?;
106        }
107        for pattern in &manifest.handles.url_patterns {
108            self.insert_pattern(&manifest.name, pattern)?;
109        }
110        for file_extension in &manifest.handles.file_extensions {
111            self.insert_file_extension(&manifest.name, file_extension)
112                .ok();
113        }
114        Ok(())
115    }
116    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
117        let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
118        let module = self.add_module(module);
119        let node_idx = self.get_or_create_node(path);
120
121        // Add a free move back to itself from the `FreeMove` node. (represents a protocol as a prefix):
122        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
123        self.nodes[node_idx].modules.insert(module);
124
125        Ok(())
126    }
127    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
128        let mut path = split_url(prefix)?;
129        // Add a `FreeMove` node at the end of the path to separate the prefix from
130        // patterns at the same node
131        path.push(Sect::FreeMove);
132        let module = self.add_module(module);
133        let node_idx = self.get_or_create_node(&path);
134
135        // Add a free move back to itself from the `FreeMove` node. Enables matching
136        // zero-or-more of anything:
137        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
138        self.nodes[node_idx].modules.insert(module);
139
140        Ok(())
141    }
142    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
143        let path: Vec<Sect> = split_url(pattern)?
144            .into_iter()
145            .map(Sect::into_pattern)
146            .collect();
147        let module = self.add_module(module);
148        let node_idx = self.get_or_create_node(&path);
149
150        self.nodes[node_idx].modules.insert(module);
151
152        Ok(())
153    }
154
155    #[cfg(all(feature = "std", feature = "serde"))]
156    pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
157        use error::FromDirError;
158
159        let path = path.as_ref();
160
161        let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
162            path: path.into(),
163            source,
164        })?;
165
166        let mut resolver = Resolver::new();
167
168        for entry in dir {
169            let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
170                path: path.into(),
171                source,
172            })?;
173            if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
174                continue;
175            }
176            let filename = entry.file_name();
177            let filename = filename.to_string_lossy();
178            if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
179                continue;
180            }
181            let path = entry.path();
182            let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
183                path: path.clone(),
184                source,
185            })?;
186
187            let manifest = serde_yml::from_reader(file).map_err(|source| FromDirError::Parse {
188                path: path.clone(),
189                source,
190            })?;
191            resolver
192                .insert_manifest(&manifest)
193                .map_err(|source| FromDirError::Insert {
194                    path: path.clone(),
195                    source,
196                })?;
197        }
198
199        Ok(resolver)
200    }
201
202    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
203    where
204        I: Iterator<Item = T>,
205        T: Borrow<ModuleManifest>,
206    {
207        iter.try_fold(Resolver::default(), |mut r, m| {
208            r.insert_manifest(m.borrow())?;
209            Ok(r)
210        })
211    }
212
213    fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
214        // Get or create the root node
215        let root_idx = *self
216            .roots
217            .entry(path[0].clone())
218            .or_insert_with(|| self.nodes.insert(Node::default()));
219
220        path[1..].iter().fold(root_idx, |cur_idx, sect| {
221            match (self.nodes[cur_idx].paths.get(sect), sect) {
222                (Some(&idx), _sect) => idx,
223                (None, Sect::WildcardDomain) => {
224                    // If the sect is a wildcard domain add a link to self, this will also match multiple subdomains.
225                    self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
226                    cur_idx
227                },
228                (None, sect) => {
229                    // Create a new node
230                    let new_node_idx = self.nodes.insert(Node::default());
231
232                    // Add the transition from current node to new node
233                    self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
234                    new_node_idx
235                },
236            }
237        })
238    }
239
240    fn add_module(&mut self, name: &str) -> Rc<Module> {
241        let name = name.to_string();
242        self.modules
243            .entry(name.clone())
244            .or_insert_with(|| Rc::new(Module { name }))
245            .clone()
246    }
247}
248
249impl TryFrom<&[ModuleManifest]> for Resolver {
250    type Error = UrlParseError;
251
252    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
253        Resolver::try_from_iter(value.iter())
254    }
255}
256
257#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
258pub struct Module {
259    pub name: String,
260}
261
262#[derive(Clone, Debug, Default)]
263struct Node {
264    paths: BTreeMap<Sect, usize>,
265    modules: BTreeSet<Rc<Module>>,
266}
267
268#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
269enum Sect {
270    /// `https` from `https://example.org/`, matches the protocol (a.k.a. scheme) of an URL
271    Protocol(String),
272    /// `org` and `example` from `https://example.org/`, matches a single literal subdomain
273    Domain(String),
274    /// `*` from `https://*.example.org/`, matches zero-or-more subdomains
275    WildcardDomain,
276    /// `file` and `path` from `https://example.org/file/path`, match literal path segments
277    Path(String),
278    /// `:name` from `https://example.org/file/:name`, matches any single path segment
279    WildcardPath,
280    /// `q` from `https://example.org/?q=example`, matches a parameter name
281    QueryParamName(String),
282    /// `example` from `https://example.org/?q=example`, matches a literal parameter value
283    QueryParamValue(String),
284    /// `:query` from `https://example.org/?q=:query`, matches any query param value
285    WildcardQueryParamValue,
286    /// Matches a single section of any kind
287    FreeMove,
288}
289
290impl Sect {
291    /// Transform a sect that matches a pattern format to a wildcard.
292    /// - If a domain section is "*", make it a wildcard domain pattern
293    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
294    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
295    pub fn into_pattern(self) -> Self {
296        match self {
297            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
298            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
299            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
300            _ => self,
301        }
302    }
303
304    fn matches_input(&self, input: &Self) -> bool {
305        use Sect::*;
306        match (self, input) {
307            (a, b) if a == b => true,
308            (WildcardDomain, Domain(_)) => true,
309            (WildcardPath, Path(_)) => true,
310            (WildcardQueryParamValue, QueryParamValue(_)) => true,
311            // As a special case if the path section is a `FreeMove` then always accept it.
312            (FreeMove, _) => true,
313            _ => false,
314        }
315    }
316}
317
318/// Split and URL into sections that we care about. This is effectively a tokenizer.
319fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
320    if url.is_empty() {
321        return Err(UrlParseError::EmptyUrl);
322    }
323
324    let mut res = Vec::new();
325
326    if !url.contains(':') {
327        res.push(Sect::Protocol(url.into()));
328        return Ok(res);
329    }
330
331    let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
332        url: url.to_string(),
333        source: e,
334    })?;
335
336    let proto = url.scheme();
337    res.push(Sect::Protocol(proto.into()));
338
339    if let Some(host) = url.host_str() {
340        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
341
342        if (proto == "http" || proto == "https")
343            && host_parts.last().is_some_and(|last| *last == "www")
344        {
345            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
346            let _www = host_parts.pop();
347        }
348
349        for part in host_parts {
350            res.push(Sect::Domain(part.into()));
351        }
352    }
353
354    if url.cannot_be_a_base() {
355        res.push(Sect::Path(url.path().into()))
356    } else if let Some(path_parts) = url.path_segments() {
357        for part in path_parts {
358            if part.is_empty() {
359                continue;
360            }
361            res.push(Sect::Path(part.into()));
362        }
363    }
364
365    for (k, v) in url.query_pairs() {
366        res.push(Sect::QueryParamName(k.into()));
367        if !v.is_empty() {
368            res.push(Sect::QueryParamValue(v.into()));
369        }
370    }
371
372    Ok(res)
373}
374
375#[cfg(test)]
376mod test {
377    use super::*;
378
379    extern crate std;
380    use std::{eprintln, vec};
381
382    #[test]
383    fn matching() {
384        let mut resolver = Resolver::default();
385
386        resolver.insert_protocol("near", "near").unwrap();
387        resolver
388            .insert_pattern("near-account", "near://account/:id")
389            .unwrap();
390        resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
391        resolver
392            .insert_prefix("google", "https://google.com/search?q=")
393            .unwrap();
394        resolver.insert_prefix("x", "https://x.com/").unwrap();
395        resolver
396            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
397            .unwrap();
398        resolver
399            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
400            .unwrap();
401        resolver
402            .insert_pattern("subdomains", "https://*.baz.com/")
403            .unwrap();
404        resolver.insert_prefix("data", "data:text/plain").unwrap();
405        resolver.insert_prefix("fs", "file://").unwrap();
406        resolver.insert_prefix("fs2", "file:///2").unwrap();
407        resolver.insert_file_extension("txt-ext", "txt").unwrap();
408        resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
409
410        eprintln!("{resolver:#?}");
411
412        let tests = vec![
413            ("near", "near"),
414            ("near://tx/1234", "near-tx"),
415            ("near://account/1234", "near-account"),
416            ("near://other/1234", "near"),
417            ("https://google.com/search?q=foobar", "google"),
418            ("https://x.com/foobar", "x"),
419            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
420            ("https://youtube.com/watch?v=foobar", "youtube"),
421            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
422            ("data:text/plain?Hello+World", "data"),
423            ("file:///foo/bar/baz", "fs"),
424            ("file:///2/foo", "fs2"),
425            ("file:///foobar.txt", "txt-ext"),
426            ("file:///foobar.tar.gz", "tar-ext"),
427        ];
428
429        for (input, want) in tests {
430            assert_eq!(
431                resolver
432                    .resolve(input)
433                    .expect("resolve succeeds")
434                    .iter()
435                    .find(|out| out.name == want)
436                    .unwrap_or_else(|| panic!(
437                        "the wanted result should be returned, input={input} want={want}"
438                    ))
439                    .name,
440                want
441            );
442        }
443    }
444
445    #[test]
446    fn prefix_doesnt_turn_pattern_to_prefix() {
447        let mut resolver = Resolver::new();
448
449        resolver
450            .insert_pattern("pattern", "https://foobar.com/")
451            .unwrap();
452        eprintln!("{resolver:#?}");
453
454        let results = resolver.resolve("https://foobar.com/").unwrap();
455        eprintln!("{results:?}");
456        assert!(
457            results
458                .first()
459                .is_some_and(|module| module.name == "pattern"),
460            "the pattern should match"
461        );
462
463        let results = resolver.resolve("https://foobar.com/more").unwrap();
464        eprintln!("{results:?}");
465        assert!(results.is_empty(), "the pattern shouldn't be a prefix");
466
467        resolver
468            .insert_prefix("prefix", "https://foobar.com/")
469            .unwrap();
470        eprintln!("{resolver:#?}");
471
472        let results = resolver.resolve("https://foobar.com/").unwrap();
473        eprintln!("{results:?}");
474        assert!(results.len() == 2, "both items should match");
475
476        let results = resolver.resolve("https://foobar.com/more").unwrap();
477        eprintln!("{results:?}");
478        assert!(results.len() == 1, "only the prefix should match");
479        assert!(
480            results
481                .first()
482                .is_some_and(|module| module.name == "prefix"),
483            "only the prefix should match"
484        );
485    }
486}