1use crate::ModuleManifest;
4use alloc::{
5 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
6 rc::Rc,
7 string::{String, ToString},
8 vec::Vec,
9};
10use core::{borrow::Borrow, convert::Infallible};
11use error::UrlParseError;
12
13pub mod error;
14
15#[derive(Clone, Debug, Default)]
16pub struct Resolver {
17 modules: BTreeMap<String, Rc<Module>>,
18 file_extensions: BTreeMap<String, Vec<Rc<Module>>>,
19 nodes: slab::Slab<Node>,
20 roots: BTreeMap<Sect, usize>,
21}
22
23impl Resolver {
24 pub fn new() -> Self {
25 Resolver::default()
26 }
27
28 pub fn resolve(&self, url: &str) -> Result<Vec<Rc<Module>>, UrlParseError> {
29 let input = split_url(url)?;
30
31 let mut results: BTreeSet<Rc<Module>> = BTreeSet::new();
32
33 if matches!(input.first(), Some(Sect::Protocol(proto)) if proto == "file") {
34 if let Some(Sect::Path(filename)) = input.last() {
35 if let Some((_, ext)) = filename.split_once(".") {
36 self.file_extensions
37 .get(ext)
38 .into_iter()
39 .flatten()
40 .for_each(|module| {
41 results.insert(module.clone());
42 });
43 }
44 }
45 }
46
47 let with_freemove = |node_idx: usize| {
48 core::iter::once(node_idx)
50 .chain(self.nodes[node_idx].paths.get(&Sect::FreeMove).copied())
52 };
53
54 let start_states: BTreeSet<usize> = self
56 .roots
57 .iter()
58 .filter_map(|(path, &node_idx)| path.matches_input(&input[0]).then_some(node_idx))
59 .collect();
60
61 let final_states = if input.len() == 1 {
62 start_states.into_iter().flat_map(with_freemove).collect()
64 } else {
65 input[1..].iter().fold(start_states, |states, sect| {
67 states
68 .into_iter()
69 .flat_map(|node_idx| &self.nodes[node_idx].paths)
70 .filter_map(|(path, &node_idx)| path.matches_input(sect).then_some(node_idx))
71 .flat_map(with_freemove)
72 .collect()
73 })
74 };
75
76 for &state_idx in &final_states {
78 for module in &self.nodes[state_idx].modules {
79 results.insert(module.clone());
80 }
81 }
82
83 Ok(results.into_iter().collect())
84 }
85
86 pub fn insert_file_extension(
87 &mut self,
88 module: &str,
89 file_extension: &str,
90 ) -> Result<(), Infallible> {
91 let module = self.add_module(module);
92
93 self.file_extensions
94 .entry(file_extension.to_string())
95 .or_default()
96 .push(module);
97
98 Ok(())
99 }
100 pub fn insert_manifest(&mut self, manifest: &ModuleManifest) -> Result<(), UrlParseError> {
101 for protocol in &manifest.handles.url_protocols {
102 self.insert_protocol(&manifest.name, protocol).ok();
103 }
104 for prefix in &manifest.handles.url_prefixes {
105 self.insert_prefix(&manifest.name, prefix)?;
106 }
107 for pattern in &manifest.handles.url_patterns {
108 self.insert_pattern(&manifest.name, pattern)?;
109 }
110 for file_extension in &manifest.handles.file_extensions {
111 self.insert_file_extension(&manifest.name, file_extension)
112 .ok();
113 }
114 Ok(())
115 }
116 pub fn insert_protocol(&mut self, module: &str, protocol: &str) -> Result<(), Infallible> {
117 let path = &[Sect::Protocol(protocol.to_string()), Sect::FreeMove];
118 let module = self.add_module(module);
119 let node_idx = self.get_or_create_node(path);
120
121 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
123 self.nodes[node_idx].modules.insert(module);
124
125 Ok(())
126 }
127 pub fn insert_prefix(&mut self, module: &str, prefix: &str) -> Result<(), UrlParseError> {
128 let mut path = split_url(prefix)?;
129 path.push(Sect::FreeMove);
132 let module = self.add_module(module);
133 let node_idx = self.get_or_create_node(&path);
134
135 self.nodes[node_idx].paths.insert(Sect::FreeMove, node_idx);
138 self.nodes[node_idx].modules.insert(module);
139
140 Ok(())
141 }
142 pub fn insert_pattern(&mut self, module: &str, pattern: &str) -> Result<(), UrlParseError> {
143 let path: Vec<Sect> = split_url(pattern)?
144 .into_iter()
145 .map(Sect::into_pattern)
146 .collect();
147 let module = self.add_module(module);
148 let node_idx = self.get_or_create_node(&path);
149
150 self.nodes[node_idx].modules.insert(module);
151
152 Ok(())
153 }
154
155 #[cfg(all(feature = "std", feature = "serde"))]
156 pub fn try_from_dir(path: impl AsRef<std::path::Path>) -> Result<Self, error::FromDirError> {
157 use error::FromDirError;
158
159 let path = path.as_ref();
160
161 let dir = std::fs::read_dir(path).map_err(|source| FromDirError::ManifestDirIo {
162 path: path.into(),
163 source,
164 })?;
165
166 let mut resolver = Resolver::new();
167
168 for entry in dir {
169 let entry = entry.map_err(|source| FromDirError::ManifestDirIo {
170 path: path.into(),
171 source,
172 })?;
173 if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
174 continue;
175 }
176 let filename = entry.file_name();
177 let filename = filename.to_string_lossy();
178 if !filename.ends_with(".yaml") && !filename.ends_with(".yml") {
179 continue;
180 }
181 let path = entry.path();
182 let file = std::fs::File::open(&path).map_err(|source| FromDirError::ManifestIo {
183 path: path.clone(),
184 source,
185 })?;
186
187 let manifest =
188 serde_yaml_ng::from_reader(file).map_err(|source| FromDirError::Parse {
189 path: path.clone(),
190 source,
191 })?;
192 resolver
193 .insert_manifest(&manifest)
194 .map_err(|source| FromDirError::Insert {
195 path: path.clone(),
196 source,
197 })?;
198 }
199
200 Ok(resolver)
201 }
202
203 pub fn try_from_iter<I, T>(mut iter: I) -> Result<Self, UrlParseError>
204 where
205 I: Iterator<Item = T>,
206 T: Borrow<ModuleManifest>,
207 {
208 iter.try_fold(Resolver::default(), |mut r, m| {
209 r.insert_manifest(m.borrow())?;
210 Ok(r)
211 })
212 }
213
214 fn get_or_create_node(&mut self, path: &[Sect]) -> usize {
215 let root_idx = *self
217 .roots
218 .entry(path[0].clone())
219 .or_insert_with(|| self.nodes.insert(Node::default()));
220
221 path[1..].iter().fold(root_idx, |cur_idx, sect| {
222 match (self.nodes[cur_idx].paths.get(sect), sect) {
223 (Some(&idx), _sect) => idx,
224 (None, Sect::WildcardDomain) => {
225 self.nodes[cur_idx].paths.insert(sect.clone(), cur_idx);
227 cur_idx
228 },
229 (None, sect) => {
230 let new_node_idx = self.nodes.insert(Node::default());
232
233 self.nodes[cur_idx].paths.insert(sect.clone(), new_node_idx);
235 new_node_idx
236 },
237 }
238 })
239 }
240
241 fn add_module(&mut self, name: &str) -> Rc<Module> {
242 let name = name.to_string();
243 self.modules
244 .entry(name.clone())
245 .or_insert_with(|| Rc::new(Module { name }))
246 .clone()
247 }
248}
249
250impl TryFrom<&[ModuleManifest]> for Resolver {
251 type Error = UrlParseError;
252
253 fn try_from(value: &[ModuleManifest]) -> Result<Self, Self::Error> {
254 Resolver::try_from_iter(value.iter())
255 }
256}
257
258#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
259pub struct Module {
260 pub name: String,
261}
262
263#[derive(Clone, Debug, Default)]
264struct Node {
265 paths: BTreeMap<Sect, usize>,
266 modules: BTreeSet<Rc<Module>>,
267}
268
269#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
270enum Sect {
271 Protocol(String),
273 Domain(String),
275 WildcardDomain,
277 Path(String),
279 WildcardPath,
281 QueryParamName(String),
283 QueryParamValue(String),
285 WildcardQueryParamValue,
287 FreeMove,
289}
290
291impl Sect {
292 pub fn into_pattern(self) -> Self {
297 match self {
298 Sect::Domain(p) if p == "*" => Sect::WildcardDomain,
299 Sect::Path(p) if p.starts_with(':') => Sect::WildcardPath,
300 Sect::QueryParamValue(p) if p.starts_with(':') => Sect::WildcardQueryParamValue,
301 _ => self,
302 }
303 }
304
305 fn matches_input(&self, input: &Self) -> bool {
306 use Sect::*;
307 match (self, input) {
308 (a, b) if a == b => true,
309 (WildcardDomain, Domain(_)) => true,
310 (WildcardPath, Path(_)) => true,
311 (WildcardQueryParamValue, QueryParamValue(_)) => true,
312 (FreeMove, _) => true,
314 _ => false,
315 }
316 }
317}
318
319fn split_url(url: &str) -> Result<Vec<Sect>, UrlParseError> {
321 if url.is_empty() {
322 return Err(UrlParseError::EmptyUrl);
323 }
324
325 let mut res = Vec::new();
326
327 if !url.contains(':') {
328 res.push(Sect::Protocol(url.into()));
329 return Ok(res);
330 }
331
332 let url: url::Url = url.parse().map_err(|e| UrlParseError::InvalidUrl {
333 url: url.to_string(),
334 source: e,
335 })?;
336
337 let proto = url.scheme();
338 res.push(Sect::Protocol(proto.into()));
339
340 if let Some(host) = url.host_str() {
341 let mut host_parts: Vec<&str> = host.split('.').rev().collect();
342
343 if (proto == "http" || proto == "https")
344 && host_parts.last().is_some_and(|last| *last == "www")
345 {
346 let _www = host_parts.pop();
348 }
349
350 for part in host_parts {
351 res.push(Sect::Domain(part.into()));
352 }
353 }
354
355 if url.cannot_be_a_base() {
356 res.push(Sect::Path(url.path().into()))
357 } else if let Some(path_parts) = url.path_segments() {
358 for part in path_parts {
359 if part.is_empty() {
360 continue;
361 }
362 res.push(Sect::Path(part.into()));
363 }
364 }
365
366 for (k, v) in url.query_pairs() {
367 res.push(Sect::QueryParamName(k.into()));
368 if !v.is_empty() {
369 res.push(Sect::QueryParamValue(v.into()));
370 }
371 }
372
373 Ok(res)
374}
375
376#[cfg(test)]
377mod test {
378 use super::*;
379
380 extern crate std;
381 use std::{eprintln, vec};
382
383 #[test]
384 fn matching() {
385 let mut resolver = Resolver::default();
386
387 resolver.insert_protocol("near", "near").unwrap();
388 resolver
389 .insert_pattern("near-account", "near://account/:id")
390 .unwrap();
391 resolver.insert_pattern("near-tx", "near://tx/:id").unwrap();
392 resolver
393 .insert_prefix("google", "https://google.com/search?q=")
394 .unwrap();
395 resolver.insert_prefix("x", "https://x.com/").unwrap();
396 resolver
397 .insert_pattern("linkedin", "https://*.linkedin.com/in/:account/test")
398 .unwrap();
399 resolver
400 .insert_pattern("youtube", "https://youtube.com/watch?v=:v")
401 .unwrap();
402 resolver
403 .insert_pattern("subdomains", "https://*.baz.com/")
404 .unwrap();
405 resolver.insert_prefix("data", "data:text/plain").unwrap();
406 resolver.insert_prefix("fs", "file://").unwrap();
407 resolver.insert_prefix("fs2", "file:///2").unwrap();
408 resolver.insert_file_extension("txt-ext", "txt").unwrap();
409 resolver.insert_file_extension("tar-ext", "tar.gz").unwrap();
410
411 eprintln!("{resolver:#?}");
412
413 let tests = vec![
414 ("near", "near"),
415 ("near://tx/1234", "near-tx"),
416 ("near://account/1234", "near-account"),
417 ("near://other/1234", "near"),
418 ("https://google.com/search?q=foobar", "google"),
419 ("https://x.com/foobar", "x"),
420 ("https://www.linkedin.com/in/foobar/test", "linkedin"),
421 ("https://youtube.com/watch?v=foobar", "youtube"),
422 ("https://multiple.subdomains.foo.bar.baz.com/", "subdomains"),
423 ("data:text/plain?Hello+World", "data"),
424 ("file:///foo/bar/baz", "fs"),
425 ("file:///2/foo", "fs2"),
426 ("file:///foobar.txt", "txt-ext"),
427 ("file:///foobar.tar.gz", "tar-ext"),
428 ];
429
430 for (input, want) in tests {
431 assert_eq!(
432 resolver
433 .resolve(input)
434 .expect("resolve succeeds")
435 .iter()
436 .find(|out| out.name == want)
437 .unwrap_or_else(|| panic!(
438 "the wanted result should be returned, input={input} want={want}"
439 ))
440 .name,
441 want
442 );
443 }
444 }
445
446 #[test]
447 fn prefix_doesnt_turn_pattern_to_prefix() {
448 let mut resolver = Resolver::new();
449
450 resolver
451 .insert_pattern("pattern", "https://foobar.com/")
452 .unwrap();
453 eprintln!("{resolver:#?}");
454
455 let results = resolver.resolve("https://foobar.com/").unwrap();
456 eprintln!("{results:?}");
457 assert!(
458 results
459 .first()
460 .is_some_and(|module| module.name == "pattern"),
461 "the pattern should match"
462 );
463
464 let results = resolver.resolve("https://foobar.com/more").unwrap();
465 eprintln!("{results:?}");
466 assert!(results.is_empty(), "the pattern shouldn't be a prefix");
467
468 resolver
469 .insert_prefix("prefix", "https://foobar.com/")
470 .unwrap();
471 eprintln!("{resolver:#?}");
472
473 let results = resolver.resolve("https://foobar.com/").unwrap();
474 eprintln!("{results:?}");
475 assert!(results.len() == 2, "both items should match");
476
477 let results = resolver.resolve("https://foobar.com/more").unwrap();
478 eprintln!("{results:?}");
479 assert!(results.len() == 1, "only the prefix should match");
480 assert!(
481 results
482 .first()
483 .is_some_and(|module| module.name == "prefix"),
484 "only the prefix should match"
485 );
486 }
487}