1use crate::ModuleManifest;
4use alloc::{
5 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6 rc::Rc,
7 string::{String, ToString},
8 vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17 modules: BTreeMap<String, Rc<Module>>,
18 file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19 nodes: slab::Slab<Node>,
20 roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24 pub fn new() -> Self {
25 Resolver::default()
26 }
27
28 pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29 let input = split_url(url)?;
30
31 let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33 if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34 if let Some(Sect::Path(filename)) = input.last() {
35 if let Some((_, ext)) = filename.split_once(".") {
36 self.file_extensions
37 .get(ext)
38 .into_iter()
39 .flatten()
40 .for_each(|module| {
41 results.insert(module.clone());
42 });
43 }
44 }
45 }
46
47 let with_freemove = |node_idx: usize| {
48 core::iter::once(node_idx)
50 .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52 };
53
54 let start_states: BTreeSet<usize> = self
56 .roots
57 .iter()
58 .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59 .collect();
60
61 let final_states = if input.len() == 1 {
62 start_states.into_iter().flat_map(with_freemove).collect()
64 } else {
65 input[1..].iter().fold(start_states, |states, sect| {
67 states
68 .into_iter()
69 .flat_map(|node_idx| &self.nodes[node_idx].paths)
70 .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71 .flat_map(with_freemove)
72 .collect()
73 })
74 };
75
76 for &state_idx in &final_states {
78 for module in &self.nodes[state_idx].modules {
79 results.insert(module.clone());
80 }
81 }
82
83 Ok(results.into_iter().collect())
84 }
85
86 pub fn insert_file_extension(
87 &mut self,
88 module: &str,
89 file_extension: &str,
90 ) -> Result<(), Infallible> {
91 let module = self.add_module(module);
92
93 let ext = file_extension
94 .strip_prefix(".")
95 .unwrap_or(file_extension)
96 .to_string();
97
98 self.file_extensions.entry(ext).or_default().push(module);
99
100 Ok(())
101 }
102 pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
103 for protocol in &manifest.handles.url_protocols {
104 self.insert_protocol(&manifest.name, protocol).ok();
105 }
106 for prefix in &manifest.handles.url_prefixes {
107 self.insert_prefix(&manifest.name, prefix)?;
108 }
109 for pattern in &manifest.handles.url_patterns {
110 self.insert_pattern(&manifest.name, pattern)?;
111 }
112 for file_extension in &manifest.handles.file_extensions {
113 self.insert_file_extension(&manifest.name, file_extension)
114 .ok();
115 }
116 Ok(())
117 }
118 pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
119 let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
120 let module = self.add_module(module);
121 let node_idx = self.get_or_create_node(path);
122
123 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
125 self.nodes[node_idx].modules.insert(module);
126
127 Ok(())
128 }
129 pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
130 let mut path = split_url(prefix)?;
131 path.push(Sect::FreeMove);
134 let module = self.add_module(module);
135 let node_idx = self.get_or_create_node(&path);
136
137 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
140 self.nodes[node_idx].modules.insert(module);
141
142 Ok(())
143 }
144 pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
145 let path: Vec<Sect> = split_url(pattern)?
146 .into_iter()
147 .map(Sect::into_pattern)
148 .collect();
149 let module = self.add_module(module);
150 let node_idx = self.get_or_create_node(&path);
151
152 self.nodes[node_idx].modules.insert(module);
153
154 Ok(())
155 }
156
157 #[cfg(all(feature = "std", feature = "serde"))]
158 pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
159 use error::FromDirError;
160
161 let path = path.as_ref();
162
163 let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
164 path: path.into(),
165 source,
166 })?;
167
168 let mut resolver = Resolver::new();
169
170 for entry in dir {
171 let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
172 path: path.into(),
173 source,
174 })?;
175 if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
176 continue;
177 }
178 let filename = entry.file_name();
179 let filename = filename.to_string_lossy();
180 if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
181 continue;
182 }
183 let path = entry.path();
184 let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
185 path: path.clone(),
186 source,
187 })?;
188
189 let manifest =
190 serde_yaml_ng::from_reader(file).map_err(|source| FromDirError::Parse {
191 path: path.clone(),
192 source,
193 })?;
194 resolver
195 .insert_manifest(&manifest)
196 .map_err(|source| FromDirError::Insert {
197 path: path.clone(),
198 source,
199 })?;
200 }
201
202 Ok(resolver)
203 }
204
205 pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
206 where
207 I: Iterator<Item = T>,
208 T: Borrow<ModuleManifest>,
209 {
210 iter.try_fold(Resolver::default(), |mut r, m| {
211 r.insert_manifest(m.borrow())?;
212 Ok(r)
213 })
214 }
215
216 fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
217 let root_idx = *self
219 .roots
220 .entry(path[0].clone())
221 .or_insert_with(|| self.nodes.insert(Node::default()));
222
223 path[1..].iter().fold(root_idx, |cur_idx, sect| {
224 match (self.nodes[cur_idx].paths.get(sect), sect) {
225 (Some(&idx), _sect) => idx,
226 (None, Sect::WildcardDomain) => {
227 self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
229 cur_idx
230 },
231 (None, sect) => {
232 let new_node_idx = self.nodes.insert(Node::default());
234
235 self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
237 new_node_idx
238 },
239 }
240 })
241 }
242
243 fn add_module(&mut self, name: &str) -> Rc<Module> {
244 let name = name.to_string();
245 self.modules
246 .entry(name.clone())
247 .or_insert_with(|| Rc::new(Module { name }))
248 .clone()
249 }
250}
251
252impl TryFrom<&[ModuleManifest]> for Resolver {
253 type Error = UrlParseError;
254
255 fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
256 Resolver::try_from_iter(value.iter())
257 }
258}
259
260#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
261pub struct Module {
262 pub name: String,
263}
264
265#[derive(Clone, Debug, Default)]
266struct Node {
267 paths: BTreeMap<Sect, usize>,
268 modules: BTreeSet<Rc<Module>>,
269}
270
271#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
272enum Sect {
273 Protocol(String),
275 Domain(String),
277 WildcardDomain,
279 Path(String),
281 WildcardPath,
283 QueryParamName(String),
285 QueryParamValue(String),
287 WildcardQueryParamValue,
289 FreeMove,
291}
292
293impl Sect {
294 pub fn into_pattern(self) -> Self {
299 match self {
300 Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
301 Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
302 Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
303 _ => self,
304 }
305 }
306
307 fn matches_input(&self, input: &Self) -> bool {
308 use Sect::*;
309 match (self, input) {
310 (a, b) if a == b => true,
311 (WildcardDomain, Domain(_)) => true,
312 (WildcardPath, Path(_)) => true,
313 (WildcardQueryParamValue, QueryParamValue(_)) => true,
314 (FreeMove, _) => true,
316 _ => false,
317 }
318 }
319}
320
321fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
323 if url.is_empty() {
324 return Err(UrlParseError::EmptyUrl);
325 }
326
327 let mut res = Vec::new();
328
329 if !url.contains(':') {
330 res.push(Sect::Protocol(url.into()));
331 return Ok(res);
332 }
333
334 let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
335 url: url.to_string(),
336 source: e,
337 })?;
338
339 let proto = url.scheme();
340 res.push(Sect::Protocol(proto.into()));
341
342 if let Some(host) = url.host_str() {
343 let mut host_parts: Vec<&str> = host.split('.').rev().collect();
344
345 if (proto == "http" || proto == "https")
346 && host_parts.last().is_some_and(|last| *last == "www")
347 {
348 let _www = host_parts.pop();
350 }
351
352 for part in host_parts {
353 res.push(Sect::Domain(part.into()));
354 }
355 }
356
357 if url.cannot_be_a_base() {
358 res.push(Sect::Path(url.path().into()))
359 } else if let Some(path_parts) = url.path_segments() {
360 for part in path_parts {
361 if part.is_empty() {
362 continue;
363 }
364 res.push(Sect::Path(part.into()));
365 }
366 }
367
368 for (k, v) in url.query_pairs() {
369 res.push(Sect::QueryParamName(k.into()));
370 if !v.is_empty() {
371 res.push(Sect::QueryParamValue(v.into()));
372 }
373 }
374
375 Ok(res)
376}
377
378#[cfg(test)]
379mod test {
380 use super::*;
381
382 extern crate std;
383 use std::{eprintln, vec};
384
385 #[test]
386 fn matching() {
387 let mut resolver = Resolver::default();
388
389 resolver.insert_protocol("near", "near").unwrap();
390 resolver
391 .insert_pattern("near-account", "near://account/:id")
392 .unwrap();
393 resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
394 resolver
395 .insert_prefix("google", "https://google.com/search?q=")
396 .unwrap();
397 resolver.insert_prefix("x", "https://x.com/").unwrap();
398 resolver
399 .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
400 .unwrap();
401 resolver
402 .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
403 .unwrap();
404 resolver
405 .insert_pattern("subdomains", "https://*.baz.com/")
406 .unwrap();
407 resolver.insert_prefix("data", "data:text/plain").unwrap();
408 resolver.insert_prefix("fs", "file://").unwrap();
409 resolver.insert_prefix("fs2", "file:///2").unwrap();
410 resolver.insert_file_extension("txt-ext", "txt").unwrap();
411 resolver.insert_file_extension("zip-ext", ".zip").unwrap(); resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
413
414 eprintln!("{resolver:#?}");
415
416 let tests = vec![
417 ("near", "near"),
418 ("near://tx/1234", "near-tx"),
419 ("near://account/1234", "near-account"),
420 ("near://other/1234", "near"),
421 ("https://google.com/search?q=foobar", "google"),
422 ("https://x.com/foobar", "x"),
423 ("https://www.linkedin.com/in/foobar/test", "linkedin"),
424 ("https://youtube.com/watch?v=foobar", "youtube"),
425 ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
426 ("data:text/plain?Hello+World", "data"),
427 ("file:///foo/bar/baz", "fs"),
428 ("file:/archive.zip", "zip-ext"),
429 ("file:///2/foo", "fs2"),
430 ("file:///foobar.txt", "txt-ext"),
431 ("file:///foobar.tar.gz", "tar-ext"),
432 ];
433
434 for (input, want) in tests {
435 assert_eq!(
436 resolver
437 .resolve(input)
438 .expect("resolve succeeds")
439 .iter()
440 .find(|out| out.name == want)
441 .unwrap_or_else(|| panic!(
442 "the wanted result should be returned, input={input} want={want}"
443 ))
444 .name,
445 want
446 );
447 }
448 }
449
450 #[test]
451 fn prefix_doesnt_turn_pattern_to_prefix() {
452 let mut resolver = Resolver::new();
453
454 resolver
455 .insert_pattern("pattern", "https://foobar.com/")
456 .unwrap();
457 eprintln!("{resolver:#?}");
458
459 let results = resolver.resolve("https://foobar.com/").unwrap();
460 eprintln!("{results:?}");
461 assert!(
462 results
463 .first()
464 .is_some_and(|module| module.name == "pattern"),
465 "the pattern should match"
466 );
467
468 let results = resolver.resolve("https://foobar.com/more").unwrap();
469 eprintln!("{results:?}");
470 assert!(results.is_empty(), "the pattern shouldn't be a prefix");
471
472 resolver
473 .insert_prefix("prefix", "https://foobar.com/")
474 .unwrap();
475 eprintln!("{resolver:#?}");
476
477 let results = resolver.resolve("https://foobar.com/").unwrap();
478 eprintln!("{results:?}");
479 assert!(results.len() == 2, "both items should match");
480
481 let results = resolver.resolve("https://foobar.com/more").unwrap();
482 eprintln!("{results:?}");
483 assert!(results.len() == 1, "only the prefix should match");
484 assert!(
485 results
486 .first()
487 .is_some_and(|module| module.name == "prefix"),
488 "only the prefix should match"
489 );
490 }
491}