1use crate::ModuleManifest;
4use alloc::{
5 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6 rc::Rc,
7 string::{String, ToString},
8 vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17 modules: BTreeMap<String, Rc<Module>>,
18 file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19 nodes: slab::Slab<Node>,
20 roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24 pub fn new() -> Self {
25 Resolver::default()
26 }
27
28 pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29 let input = split_url(url)?;
30
31 let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33 if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34 if let Some(Sect::Path(filename)) = input.last() {
35 if let Some((_, ext)) = filename.split_once(".") {
36 self.file_extensions
37 .get(ext)
38 .into_iter()
39 .flatten()
40 .for_each(|module| {
41 results.insert(module.clone());
42 });
43 }
44 }
45 }
46
47 let with_freemove = |node_idx: usize| {
48 core::iter::once(node_idx)
50 .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52 };
53
54 let start_states: BTreeSet<usize> = self
56 .roots
57 .iter()
58 .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59 .collect();
60
61 let final_states = if input.len() == 1 {
62 start_states.into_iter().flat_map(with_freemove).collect()
64 } else {
65 input[1..].iter().fold(start_states, |states, sect| {
67 states
68 .into_iter()
69 .flat_map(|node_idx| &self.nodes[node_idx].paths)
70 .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71 .flat_map(with_freemove)
72 .collect()
73 })
74 };
75
76 for &state_idx in &final_states {
78 for module in &self.nodes[state_idx].modules {
79 results.insert(module.clone());
80 }
81 }
82
83 Ok(results.into_iter().collect())
84 }
85
86 pub fn insert_file_extension(
87 &mut self,
88 module: &str,
89 file_extension: &str,
90 ) -> Result<(), Infallible> {
91 let module = self.add_module(module);
92
93 self.file_extensions
94 .entry(file_extension.to_string())
95 .or_default()
96 .push(module);
97
98 Ok(())
99 }
100 pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
101 for protocol in &manifest.handles.url_protocols {
102 self.insert_protocol(&manifest.name, protocol).ok();
103 }
104 for prefix in &manifest.handles.url_prefixes {
105 self.insert_prefix(&manifest.name, prefix)?;
106 }
107 for pattern in &manifest.handles.url_patterns {
108 self.insert_pattern(&manifest.name, pattern)?;
109 }
110 for file_extension in &manifest.handles.file_extensions {
111 self.insert_file_extension(&manifest.name, file_extension)
112 .ok();
113 }
114 Ok(())
115 }
116 pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
117 let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
118 let module = self.add_module(module);
119 let node_idx = self.get_or_create_node(path);
120
121 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
123 self.nodes[node_idx].modules.insert(module);
124
125 Ok(())
126 }
127 pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
128 let mut path = split_url(prefix)?;
129 path.push(Sect::FreeMove);
132 let module = self.add_module(module);
133 let node_idx = self.get_or_create_node(&path);
134
135 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
138 self.nodes[node_idx].modules.insert(module);
139
140 Ok(())
141 }
142 pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
143 let path: Vec<Sect> = split_url(pattern)?
144 .into_iter()
145 .map(Sect::into_pattern)
146 .collect();
147 let module = self.add_module(module);
148 let node_idx = self.get_or_create_node(&path);
149
150 self.nodes[node_idx].modules.insert(module);
151
152 Ok(())
153 }
154
155 #[cfg(all(feature = "std", feature = "serde"))]
156 pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
157 use error::FromDirError;
158
159 let path = path.as_ref();
160
161 let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
162 path: path.into(),
163 source,
164 })?;
165
166 let mut resolver = Resolver::new();
167
168 for entry in dir {
169 let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
170 path: path.into(),
171 source,
172 })?;
173 if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
174 continue;
175 }
176 let filename = entry.file_name();
177 let filename = filename.to_string_lossy();
178 if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
179 continue;
180 }
181 let path = entry.path();
182 let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
183 path: path.clone(),
184 source,
185 })?;
186
187 let manifest = serde_yml::from_reader(file).map_err(|source| FromDirError::Parse {
188 path: path.clone(),
189 source,
190 })?;
191 resolver
192 .insert_manifest(&manifest)
193 .map_err(|source| FromDirError::Insert {
194 path: path.clone(),
195 source,
196 })?;
197 }
198
199 Ok(resolver)
200 }
201
202 pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
203 where
204 I: Iterator<Item = T>,
205 T: Borrow<ModuleManifest>,
206 {
207 iter.try_fold(Resolver::default(), |mut r, m| {
208 r.insert_manifest(m.borrow())?;
209 Ok(r)
210 })
211 }
212
213 fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
214 let root_idx = *self
216 .roots
217 .entry(path[0].clone())
218 .or_insert_with(|| self.nodes.insert(Node::default()));
219
220 path[1..].iter().fold(root_idx, |cur_idx, sect| {
221 match (self.nodes[cur_idx].paths.get(sect), sect) {
222 (Some(&idx), _sect) => idx,
223 (None, Sect::WildcardDomain) => {
224 self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
226 cur_idx
227 },
228 (None, sect) => {
229 let new_node_idx = self.nodes.insert(Node::default());
231
232 self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
234 new_node_idx
235 },
236 }
237 })
238 }
239
240 fn add_module(&mut self, name: &str) -> Rc<Module> {
241 let name = name.to_string();
242 self.modules
243 .entry(name.clone())
244 .or_insert_with(|| Rc::new(Module { name }))
245 .clone()
246 }
247}
248
249impl TryFrom<&[ModuleManifest]> for Resolver {
250 type Error = UrlParseError;
251
252 fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
253 Resolver::try_from_iter(value.iter())
254 }
255}
256
257#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
258pub struct Module {
259 pub name: String,
260}
261
262#[derive(Clone, Debug, Default)]
263struct Node {
264 paths: BTreeMap<Sect, usize>,
265 modules: BTreeSet<Rc<Module>>,
266}
267
268#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
269enum Sect {
270 Protocol(String),
272 Domain(String),
274 WildcardDomain,
276 Path(String),
278 WildcardPath,
280 QueryParamName(String),
282 QueryParamValue(String),
284 WildcardQueryParamValue,
286 FreeMove,
288}
289
290impl Sect {
291 pub fn into_pattern(self) -> Self {
296 match self {
297 Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
298 Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
299 Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
300 _ => self,
301 }
302 }
303
304 fn matches_input(&self, input: &Self) -> bool {
305 use Sect::*;
306 match (self, input) {
307 (a, b) if a == b => true,
308 (WildcardDomain, Domain(_)) => true,
309 (WildcardPath, Path(_)) => true,
310 (WildcardQueryParamValue, QueryParamValue(_)) => true,
311 (FreeMove, _) => true,
313 _ => false,
314 }
315 }
316}
317
318fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
320 if url.is_empty() {
321 return Err(UrlParseError::EmptyUrl);
322 }
323
324 let mut res = Vec::new();
325
326 if !url.contains(':') {
327 res.push(Sect::Protocol(url.into()));
328 return Ok(res);
329 }
330
331 let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
332 url: url.to_string(),
333 source: e,
334 })?;
335
336 let proto = url.scheme();
337 res.push(Sect::Protocol(proto.into()));
338
339 if let Some(host) = url.host_str() {
340 let mut host_parts: Vec<&str> = host.split('.').rev().collect();
341
342 if (proto == "http" || proto == "https")
343 && host_parts.last().is_some_and(|last| *last == "www")
344 {
345 let _www = host_parts.pop();
347 }
348
349 for part in host_parts {
350 res.push(Sect::Domain(part.into()));
351 }
352 }
353
354 if url.cannot_be_a_base() {
355 res.push(Sect::Path(url.path().into()))
356 } else if let Some(path_parts) = url.path_segments() {
357 for part in path_parts {
358 if part.is_empty() {
359 continue;
360 }
361 res.push(Sect::Path(part.into()));
362 }
363 }
364
365 for (k, v) in url.query_pairs() {
366 res.push(Sect::QueryParamName(k.into()));
367 if !v.is_empty() {
368 res.push(Sect::QueryParamValue(v.into()));
369 }
370 }
371
372 Ok(res)
373}
374
375#[cfg(test)]
376mod test {
377 use super::*;
378
379 extern crate std;
380 use std::{eprintln, vec};
381
382 #[test]
383 fn matching() {
384 let mut resolver = Resolver::default();
385
386 resolver.insert_protocol("near", "near").unwrap();
387 resolver
388 .insert_pattern("near-account", "near://account/:id")
389 .unwrap();
390 resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
391 resolver
392 .insert_prefix("google", "https://google.com/search?q=")
393 .unwrap();
394 resolver.insert_prefix("x", "https://x.com/").unwrap();
395 resolver
396 .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
397 .unwrap();
398 resolver
399 .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
400 .unwrap();
401 resolver
402 .insert_pattern("subdomains", "https://*.baz.com/")
403 .unwrap();
404 resolver.insert_prefix("data", "data:text/plain").unwrap();
405 resolver.insert_prefix("fs", "file://").unwrap();
406 resolver.insert_prefix("fs2", "file:///2").unwrap();
407 resolver.insert_file_extension("txt-ext", "txt").unwrap();
408 resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
409
410 eprintln!("{resolver:#?}");
411
412 let tests = vec![
413 ("near", "near"),
414 ("near://tx/1234", "near-tx"),
415 ("near://account/1234", "near-account"),
416 ("near://other/1234", "near"),
417 ("https://google.com/search?q=foobar", "google"),
418 ("https://x.com/foobar", "x"),
419 ("https://www.linkedin.com/in/foobar/test", "linkedin"),
420 ("https://youtube.com/watch?v=foobar", "youtube"),
421 ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
422 ("data:text/plain?Hello+World", "data"),
423 ("file:///foo/bar/baz", "fs"),
424 ("file:///2/foo", "fs2"),
425 ("file:///foobar.txt", "txt-ext"),
426 ("file:///foobar.tar.gz", "tar-ext"),
427 ];
428
429 for (input, want) in tests {
430 assert_eq!(
431 resolver
432 .resolve(input)
433 .expect("resolve succeeds")
434 .iter()
435 .find(|out| out.name == want)
436 .unwrap_or_else(|| panic!(
437 "the wanted result should be returned, input={input} want={want}"
438 ))
439 .name,
440 want
441 );
442 }
443 }
444
445 #[test]
446 fn prefix_doesnt_turn_pattern_to_prefix() {
447 let mut resolver = Resolver::new();
448
449 resolver
450 .insert_pattern("pattern", "https://foobar.com/")
451 .unwrap();
452 eprintln!("{resolver:#?}");
453
454 let results = resolver.resolve("https://foobar.com/").unwrap();
455 eprintln!("{results:?}");
456 assert!(
457 results
458 .first()
459 .is_some_and(|module| module.name == "pattern"),
460 "the pattern should match"
461 );
462
463 let results = resolver.resolve("https://foobar.com/more").unwrap();
464 eprintln!("{results:?}");
465 assert!(results.is_empty(), "the pattern shouldn't be a prefix");
466
467 resolver
468 .insert_prefix("prefix", "https://foobar.com/")
469 .unwrap();
470 eprintln!("{resolver:#?}");
471
472 let results = resolver.resolve("https://foobar.com/").unwrap();
473 eprintln!("{results:?}");
474 assert!(results.len() == 2, "both items should match");
475
476 let results = resolver.resolve("https://foobar.com/more").unwrap();
477 eprintln!("{results:?}");
478 assert!(results.len() == 1, "only the prefix should match");
479 assert!(
480 results
481 .first()
482 .is_some_and(|module| module.name == "prefix"),
483 "only the prefix should match"
484 );
485 }
486}