Skip to main content

asimov_module/
resolve.rs

1// This is free and unencumbered software released into the public domain.
2
3use crate::{ModuleManifest, resolve::error::InsertManifestError};
4use alloc::{
5    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6    rc::Rc,
7    string::{String, ToString},
8    vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17    modules: BTreeMap<String, Rc<Module>>,
18    file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19    content_types: BTreeMap<mime::Mime, Vec<Rc<Module>>>,
20    nodes: slab::Slab<Node>,
21    roots: BTreeMap<Sect, usize>,
22}
23
24impl Resolver {
25    pub fn new() -> Self {
26        Resolver::default()
27    }
28
29    pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
30        let input = split_url(url)?;
31
32        // `results` is a set of `(path_length, Module)` items.
33        // `path_length` first so they get sorted by the length of matching path.
34        // We reverse iter this later so that longer paths are returned first.
35        let mut results: BTreeSet<(usize, Rc<Module>)> = BTreeSet::new();
36
37        if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
38            if let Some(Sect::Path(filename)) = input.last() {
39                if let Some((_, ext)) = filename.split_once(".") {
40                    self.file_extensions
41                        .get(ext)
42                        .into_iter()
43                        .flatten()
44                        .for_each(|module| {
45                            results.insert((0, module.clone()));
46                        });
47                }
48            }
49        }
50
51        let with_freemove = |node_idx: usize| {
52            // Return the node ID
53            core::iter::once(node_idx)
54                // And the destination ID after following a `FreeMove` path from the node
55                .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
56        };
57
58        // Initialize with start states that match the first input
59        let start_states: BTreeSet<usize> = self
60            .roots
61            .iter()
62            .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
63            .collect();
64
65        let final_states = if input.len() == 1 {
66            // There is no further input, just get free moves from the start_states
67            start_states.into_iter().flat_map(with_freemove).collect()
68        } else {
69            // Process remaining input
70            input[1..].iter().fold(start_states, |states, sect| {
71                states
72                    .into_iter()
73                    .flat_map(|node_idx| &self.nodes[node_idx].paths)
74                    .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
75                    .flat_map(with_freemove)
76                    .collect()
77            })
78        };
79
80        // Collect all modules from final states
81        for node in final_states.iter().map(|&idx| &self.nodes[idx]) {
82            for module in &node.modules {
83                results.insert((node.path_length, module.clone()));
84            }
85        }
86
87        Ok(results
88            .into_iter()
89             // The `results` set is sorted by path_length.
90             // Reverse to prefer longer matches.
91            .rev()
92            .map(|(_, module)| module)
93            .collect())
94    }
95
96    pub fn resolve_content_type(&self, content_type: &mime::Mime) -> Vec<Rc<Module>> {
97        let mut modules: BTreeSet<Rc<Module>> = BTreeSet::new();
98
99        let exact = self.content_types.get(content_type);
100        modules.extend(exact.into_iter().flatten().cloned());
101
102        let toptype = content_type.type_();
103        let star_sub = alloc::format!("{toptype}/*");
104        if let Ok(star_sub) = star_sub.parse() {
105            let matches = self.content_types.get(&star_sub);
106            modules.extend(matches.into_iter().flatten().cloned());
107        }
108
109        let starstar = self.content_types.get(&mime::STAR_STAR);
110        modules.extend(starstar.into_iter().flatten().cloned());
111
112        modules.into_iter().collect()
113    }
114
115    pub fn insert_manifest(
116        &mut self,
117        manifest: &ModuleManifest,
118    ) -> Result<(), InsertManifestError> {
119        for protocol in &manifest.handles.url_protocols {
120            self.insert_protocol(&manifest.name, protocol).ok();
121        }
122        for prefix in &manifest.handles.url_prefixes {
123            self.insert_prefix(&manifest.name, prefix)?;
124        }
125        for pattern in &manifest.handles.url_patterns {
126            self.insert_pattern(&manifest.name, pattern)?;
127        }
128        for file_extension in &manifest.handles.file_extensions {
129            self.insert_file_extension(&manifest.name, file_extension)
130                .ok();
131        }
132        for content_type in &manifest.handles.content_types {
133            let content_type = content_type.parse()?;
134            self.insert_content_type(&manifest.name, content_type).ok();
135        }
136        Ok(())
137    }
138
139    pub fn insert_content_type(
140        &mut self,
141        module: &str,
142        content_type: mime::Mime,
143    ) -> Result<(), Infallible> {
144        let module = self.add_module(module);
145
146        self.content_types
147            .entry(content_type)
148            .or_default()
149            .push(module);
150
151        Ok(())
152    }
153    pub fn insert_file_extension(
154        &mut self,
155        module: &str,
156        file_extension: &str,
157    ) -> Result<(), Infallible> {
158        let module = self.add_module(module);
159
160        let ext = file_extension
161            .strip_prefix(".")
162            .unwrap_or(file_extension)
163            .to_string();
164
165        self.file_extensions.entry(ext).or_default().push(module);
166
167        Ok(())
168    }
169    pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
170        let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
171        let module = self.add_module(module);
172        let node_idx = self.get_or_create_node(path);
173
174        // Add a free move back to itself from the `FreeMove` node. (represents a protocol as a prefix):
175        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
176        self.nodes[node_idx].modules.insert(module);
177
178        Ok(())
179    }
180    pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
181        let mut path = split_url(prefix)?;
182        // Add a `FreeMove` node at the end of the path to separate the prefix from
183        // patterns at the same node
184        path.push(Sect::FreeMove);
185        let module = self.add_module(module);
186        let node_idx = self.get_or_create_node(&path);
187
188        // Add a free move back to itself from the `FreeMove` node. Enables matching
189        // zero-or-more of anything:
190        self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
191        self.nodes[node_idx].modules.insert(module);
192
193        Ok(())
194    }
195    pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
196        let path: Vec<Sect> = split_url(pattern)?
197            .into_iter()
198            .map(Sect::into_pattern)
199            .collect();
200        let module = self.add_module(module);
201        let node_idx = self.get_or_create_node(&path);
202
203        self.nodes[node_idx].modules.insert(module);
204
205        Ok(())
206    }
207
208    #[cfg(all(feature = "std", feature = "serde"))]
209    pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
210        use error::FromDirError;
211
212        let path = path.as_ref();
213
214        let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
215            path: path.into(),
216            source,
217        })?;
218
219        let mut resolver = Resolver::new();
220
221        for entry in dir {
222            let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
223                path: path.into(),
224                source,
225            })?;
226            if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
227                continue;
228            }
229            let filename = entry.file_name();
230            let filename = filename.to_string_lossy();
231            if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
232                continue;
233            }
234            let path = entry.path();
235            let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
236                path: path.clone(),
237                source,
238            })?;
239
240            let manifest =
241                serde_yaml_ng::from_reader(file).map_err(|source| FromDirError::Parse {
242                    path: path.clone(),
243                    source,
244                })?;
245            resolver
246                .insert_manifest(&manifest)
247                .map_err(|source| FromDirError::Insert {
248                    path: path.clone(),
249                    source,
250                })?;
251        }
252
253        Ok(resolver)
254    }
255
256    pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, InsertManifestError>
257    where
258        I: Iterator<Item = T>,
259        T: Borrow<ModuleManifest>,
260    {
261        iter.try_fold(Resolver::default(), |mut r, m| {
262            r.insert_manifest(m.borrow())?;
263            Ok(r)
264        })
265    }
266
267    fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
268        // Get or create the root node
269        let root_idx = *self
270            .roots
271            .entry(path[0].clone())
272            .or_insert_with(|| self.nodes.insert(Node::default()));
273
274        path.iter()
275            .enumerate()
276            .skip(1) // first one was already inserted to `self.roots`
277            .fold(root_idx, |cur_idx, (path_length, sect)| {
278                match (self.nodes[cur_idx].paths.get(sect), sect) {
279                    (Some(&idx), _sect) => idx,
280                    (None, Sect::WildcardDomain) => {
281                        // If the sect is a wildcard domain add a link to self, this will also match multiple subdomains.
282                        self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
283                        cur_idx
284                    },
285                    (None, sect) => {
286                        // Create a new node
287                        let new_node_idx = self.nodes.insert(Node {
288                            path_length,
289                            ..Default::default()
290                        });
291
292                        // Add the transition from current node to new node
293                        self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
294                        new_node_idx
295                    },
296                }
297            })
298    }
299
300    fn add_module(&mut self, name: &str) -> Rc<Module> {
301        let name = name.to_string();
302        self.modules
303            .entry(name.clone())
304            .or_insert_with(|| Rc::new(Module { name }))
305            .clone()
306    }
307}
308
309impl TryFrom<&[ModuleManifest]> for Resolver {
310    type Error = InsertManifestError;
311
312    fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
313        Resolver::try_from_iter(value.iter())
314    }
315}
316
317#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
318pub struct Module {
319    pub name: String,
320}
321
322#[derive(Clone, Debug, Default)]
323struct Node {
324    paths: BTreeMap<Sect, usize>,
325    modules: BTreeSet<Rc<Module>>,
326    path_length: usize,
327}
328
329#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
330enum Sect {
331    /// `https` from `https://example.org/`, matches the protocol (a.k.a. scheme) of an URL
332    Protocol(String),
333    /// `org` and `example` from `https://example.org/`, matches a single literal subdomain
334    Domain(String),
335    /// `*` from `https://*.example.org/`, matches zero-or-more subdomains
336    WildcardDomain,
337    /// `file` and `path` from `https://example.org/file/path`, match literal path segments
338    Path(String),
339    /// `:name` from `https://example.org/file/:name`, matches any single path segment
340    WildcardPath,
341    /// `q` from `https://example.org/?q=example`, matches a parameter name
342    QueryParamName(String),
343    /// `example` from `https://example.org/?q=example`, matches a literal parameter value
344    QueryParamValue(String),
345    /// `:query` from `https://example.org/?q=:query`, matches any query param value
346    WildcardQueryParamValue,
347    /// Matches a single section of any kind
348    FreeMove,
349}
350
351impl Sect {
352    /// Transform a sect that matches a pattern format to a wildcard.
353    /// - If a domain section is "*", make it a wildcard domain pattern
354    /// - If a path section begins with ":" ("/:foo/:bar"), make it a wildcard path pattern
355    /// - If the value of a query parameter begins with ":" ("q=:query"), make it a wildcard query param pattern
356    pub fn into_pattern(self) -> Self {
357        match self {
358            Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
359            Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
360            Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
361            _ => self,
362        }
363    }
364
365    fn matches_input(&self, input: &Self) -> bool {
366        use Sect::*;
367        match (self, input) {
368            (a, b) if a == b => true,
369            (WildcardDomain, Domain(_)) => true,
370            (WildcardPath, Path(_)) => true,
371            (WildcardQueryParamValue, QueryParamValue(_)) => true,
372            // As a special case if the path section is a `FreeMove` then always accept it.
373            (FreeMove, _) => true,
374            _ => false,
375        }
376    }
377}
378
379/// Split and URL into sections that we care about. This is effectively a tokenizer.
380fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
381    if url.is_empty() {
382        return Err(UrlParseError::EmptyUrl);
383    }
384
385    let mut res = Vec::new();
386
387    if !url.contains(':') {
388        res.push(Sect::Protocol(url.into()));
389        return Ok(res);
390    }
391
392    let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
393        url: url.to_string(),
394        source: e,
395    })?;
396
397    let proto = url.scheme();
398    res.push(Sect::Protocol(proto.into()));
399
400    if let Some(host) = url.host_str() {
401        let mut host_parts: Vec<&str> = host.split('.').rev().collect();
402
403        if (proto == "http" || proto == "https")
404            && host_parts.last().is_some_and(|last| *last == "www")
405        {
406            // ignore a "www." at the beginning of the domain. The domain has been reversed so we're popping the last element
407            let _www = host_parts.pop();
408        }
409
410        for part in host_parts {
411            res.push(Sect::Domain(part.into()));
412        }
413    }
414
415    if url.cannot_be_a_base() {
416        res.push(Sect::Path(url.path().into()))
417    } else if let Some(path_parts) = url.path_segments() {
418        for part in path_parts {
419            if part.is_empty() {
420                continue;
421            }
422            res.push(Sect::Path(part.into()));
423        }
424    }
425
426    for (k, v) in url.query_pairs() {
427        res.push(Sect::QueryParamName(k.into()));
428        if !v.is_empty() {
429            res.push(Sect::QueryParamValue(v.into()));
430        }
431    }
432
433    Ok(res)
434}
435
436#[cfg(test)]
437mod test {
438    use super::*;
439
440    extern crate std;
441    use std::{eprintln, vec};
442
443    #[test]
444    fn matching() {
445        let mut resolver = Resolver::default();
446
447        resolver.insert_protocol("near", "near").unwrap();
448        resolver
449            .insert_pattern("near-account", "near://account/:id")
450            .unwrap();
451        resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
452        resolver
453            .insert_prefix("google", "https://google.com/search?q=")
454            .unwrap();
455        resolver.insert_prefix("x", "https://x.com/").unwrap();
456        resolver
457            .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
458            .unwrap();
459        resolver
460            .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
461            .unwrap();
462        resolver
463            .insert_pattern("subdomains", "https://*.baz.com/")
464            .unwrap();
465        resolver.insert_prefix("data", "data:text/plain").unwrap();
466        resolver.insert_prefix("fs", "file://").unwrap();
467        resolver.insert_prefix("fs2", "file:///2").unwrap();
468        resolver.insert_file_extension("txt-ext", "txt").unwrap();
469        resolver.insert_file_extension("zip-ext", ".zip").unwrap(); // leading dot should work just as well
470        resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
471
472        eprintln!("{resolver:#?}");
473
474        let tests = vec![
475            ("near", "near"),
476            ("near://tx/1234", "near-tx"),
477            ("near://account/1234", "near-account"),
478            ("near://other/1234", "near"),
479            ("https://google.com/search?q=foobar", "google"),
480            ("https://x.com/foobar", "x"),
481            ("https://www.linkedin.com/in/foobar/test", "linkedin"),
482            ("https://youtube.com/watch?v=foobar", "youtube"),
483            ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
484            ("data:text/plain?Hello+World", "data"),
485            ("file:///foo/bar/baz", "fs"),
486            ("file:/archive.zip", "zip-ext"),
487            ("file:///2/foo", "fs2"),
488            ("file:///foobar.txt", "txt-ext"),
489            ("file:///foobar.tar.gz", "tar-ext"),
490        ];
491
492        for (input, want) in tests {
493            assert_eq!(
494                resolver
495                    .resolve(input)
496                    .expect("resolve succeeds")
497                    .iter()
498                    .find(|out| out.name == want)
499                    .unwrap_or_else(|| panic!(
500                        "the wanted result should be returned, input={input} want={want}"
501                    ))
502                    .name,
503                want
504            );
505        }
506    }
507
508    #[test]
509    fn content_type() {
510        let mut resolver = Resolver::default();
511        resolver
512            .insert_content_type("starstar", mime::STAR_STAR)
513            .unwrap();
514        resolver
515            .insert_content_type("textstar", mime::TEXT_STAR)
516            .unwrap();
517        resolver
518            .insert_content_type("textplain", mime::TEXT_PLAIN)
519            .unwrap();
520
521        let results = resolver.resolve_content_type(&mime::TEXT_PLAIN);
522
523        assert!(
524            results.iter().any(|out| out.name == "starstar"),
525            "*/* should match"
526        );
527        assert!(
528            results.iter().any(|out| out.name == "textstar"),
529            "text/* should match"
530        );
531        assert!(
532            results.iter().any(|out| out.name == "textplain"),
533            "text/plain should match"
534        );
535
536        let results = resolver.resolve_content_type(&mime::APPLICATION_JSON);
537        assert!(
538            results.iter().any(|out| out.name == "starstar"),
539            "*/* should match"
540        );
541        assert!(
542            !results.iter().any(|out| out.name == "textstar"),
543            "text/* shouldn't match"
544        );
545        assert!(
546            !results.iter().any(|out| out.name == "textplain"),
547            "text/plain shouldn't match"
548        );
549    }
550
551    #[test]
552    fn prefix_doesnt_turn_pattern_to_prefix() {
553        let mut resolver = Resolver::new();
554
555        resolver
556            .insert_pattern("pattern", "https://foobar.com/")
557            .unwrap();
558        eprintln!("{resolver:#?}");
559
560        let results = resolver.resolve("https://foobar.com/").unwrap();
561        eprintln!("{results:?}");
562        assert!(
563            results
564                .first()
565                .is_some_and(|module| module.name == "pattern"),
566            "the pattern should match"
567        );
568
569        let results = resolver.resolve("https://foobar.com/more").unwrap();
570        eprintln!("{results:?}");
571        assert!(results.is_empty(), "the pattern shouldn't be a prefix");
572
573        resolver
574            .insert_prefix("prefix", "https://foobar.com/")
575            .unwrap();
576        eprintln!("{resolver:#?}");
577
578        let results = resolver.resolve("https://foobar.com/").unwrap();
579        eprintln!("{results:?}");
580        assert!(results.len() == 2, "both items should match");
581
582        let results = resolver.resolve("https://foobar.com/more").unwrap();
583        eprintln!("{results:?}");
584        assert!(results.len() == 1, "only the prefix should match");
585        assert!(
586            results
587                .first()
588                .is_some_and(|module| module.name == "prefix"),
589            "only the prefix should match"
590        );
591    }
592
593    #[test]
594    fn longer_matches_are_returned_first() {
595        let mut resolver = Resolver::new();
596        resolver.insert_prefix("fs", "file://").unwrap();
597        resolver.insert_prefix("fs2", "file:///path").unwrap();
598        resolver.insert_prefix("fs3", "file:///path/to").unwrap();
599        resolver.insert_file_extension("txt-ext", "txt").unwrap();
600
601        let mut it = resolver
602            .resolve("file:///path/to/file.txt")
603            .unwrap()
604            .into_iter();
605        assert_eq!("fs3", it.next().unwrap().name);
606        assert_eq!("fs2", it.next().unwrap().name);
607        assert_eq!("fs", it.next().unwrap().name);
608        assert_eq!("txt-ext", it.next().unwrap().name);
609        assert_eq!(None, it.next());
610
611        let mut it = resolver.resolve("file:///file.txt").unwrap().into_iter();
612        assert_eq!("fs", it.next().unwrap().name);
613        assert_eq!("txt-ext", it.next().unwrap().name);
614        assert_eq!(None, it.next());
615
616        let mut it = resolver.resolve("https://example.org").unwrap().into_iter();
617        assert_eq!(None, it.next());
618    }
619}