1use crate::{
2	extension::{
3		opaque_fn::OpaqueFunctionCall,
4		repository::{self, Repository},
5		Call, CallContext, CallNotFoundError, Dependency, Extension, ExtensionContext,
6	},
7	util::State,
8};
9use anyhow::Result;
10use async_recursion::async_recursion;
11use semver::{Version, VersionReq};
12use std::{
13	collections::{HashMap, HashSet},
14	sync::Arc,
15};
16use thiserror::Error;
17use tokio::sync::Mutex;
18
19use super::{extension_impl, extension_protocol::ExtensionProtocol, repository_context::RepositoryContext};
20
21pub struct ExtensionRepository {
22	locked: Arc<Mutex<bool>>,
23
24	all_extensions: Arc<Mutex<HashMap<&'static str, &'static Extension>>>,
27
28	activated_extensions: Arc<Mutex<Vec<&'static str>>>,
30
31	extension_id_to_version: Arc<Mutex<HashMap<&'static str, &'static str>>>,
33
34	extension_dependencies_resolved: Arc<Mutex<HashMap<&'static str, HashSet<&'static str>>>>,
41
42	extension_dependencies_expected: Arc<Mutex<HashMap<&'static str, HashMap<&'static str, (&'static str, bool)>>>>,
46
47	extensions_dependents_expected: Arc<Mutex<HashMap<&'static str, Vec<&'static str>>>>,
49
50	version_mismatches: Arc<Mutex<Vec<(&'static str, &'static str, &'static str, &'static str, &'static str)>>>,
52
53	extension_states: Arc<Mutex<HashMap<&'static str, Arc<Mutex<State>>>>>,
54
55	extension_calls: Arc<Mutex<HashMap<&'static str, Arc<OpaqueFunctionCall>>>>,
57}
58
59impl<'a> ExtensionRepository {
60	fn construct() -> ExtensionRepository {
61		ExtensionRepository {
62			locked: Arc::new(Mutex::new(false)),
63			all_extensions: Arc::new(Mutex::new(HashMap::new())),
64			activated_extensions: Arc::new(Mutex::new(Vec::new())),
65			extension_id_to_version: Arc::new(Mutex::new(HashMap::new())),
66			extension_dependencies_resolved: Arc::new(Mutex::new(HashMap::new())),
67			extension_dependencies_expected: Arc::new(Mutex::new(HashMap::new())),
68			extensions_dependents_expected: Arc::new(Mutex::new(HashMap::new())),
69			version_mismatches: Arc::new(Mutex::new(Vec::new())),
70			extension_states: Arc::new(Mutex::new(HashMap::new())),
71			extension_calls: Arc::new(Mutex::new(HashMap::new())),
72		}
73	}
74
75	fn init_state(&self) -> State {
76		let mut state = State::default();
77		state.put(RepositoryContext::new(Arc::clone(&self.extension_calls)));
78		state
79	}
80
81	async fn init(&self) {
82		self.extension_states
83			.lock()
84			.await
85			.insert("", Arc::new(Mutex::new(self.init_state())));
86		self.extension_calls.lock().await.insert(
87			repository::ADD_CALL.id,
88			Arc::new(OpaqueFunctionCall::from(&extension_impl::add_call)),
89		);
90	}
91
92	pub async fn new() -> ExtensionRepository {
93		let repository = Self::construct();
94		repository.init().await;
95		repository
96	}
97
98	async fn get_extension_version_for(&self, extension_name: &'static str) -> Option<&'static str> {
101		self.extension_id_to_version
102			.lock()
103			.await
104			.get(extension_name)
105			.map(|extension_version| *extension_version)
106	}
107
108	pub async fn add(&self, extension: &'static Extension) -> Result<()> {
116		self.try_insert_extension(extension).await
117	}
118
119	pub async fn print_problems(&self) {
120		for (extension, dependencies) in self.extension_dependencies_expected.lock().await.iter() {
121			if dependencies.len() == 0 {
122				continue;
123			}
124			let mut missing: Vec<&'static str> = Vec::new();
125			for (dependency, (_, is_required)) in dependencies {
126				if *is_required {
127					missing.push(dependency);
128				}
129			}
130			crate::critical!(
131				"Extension '{}@{}' was not activated, missing '{}'",
132				extension,
133				self.get_extension_version_for(extension).await.unwrap(),
134				missing.join("', '"),
135			)
136		}
137	}
138
139	pub async fn inject(&self, _extension: &'static Extension) -> Result<()> {
145		if *self.locked.lock().await {
146			return Err(ExtensionInstallationError::Locked.into());
147		}
148
149		todo!("James Bradlee: Implement inject")
150	}
151
152	async fn unsafely_insert_extension(&self, extension: &'static Extension) {
155		self.all_extensions.lock().await.insert(extension.name, extension);
156		self.extension_id_to_version
157			.lock()
158			.await
159			.insert(extension.name, extension.version);
160	}
161
162	async fn try_insert_extension(&self, extension: &'static Extension) -> Result<()> {
164		if *self.locked.lock().await {
165			return Err(ExtensionInstallationError::Locked.into());
166		}
167
168		if let Some(version) = self.get_extension_version_for(extension.name).await {
169			if version == extension.version {
170				return Ok(());
172			}
173			return Err(
175				ExtensionInstallationError::ExtensionAlreadyAdded(extension.name, version, extension.version).into(),
176			);
177		}
178
179		self.unsafely_insert_extension(extension).await;
180
181		self.resolve(extension).await?;
182
183		Ok(())
184	}
185
186	async fn resolve(&self, extension: &'static Extension) -> Result<()> {
187		let mut all_names: HashSet<&'static str> = HashSet::new();
188
189		let mut has_problems = false;
190		let mut pending_dependencies: HashMap<&'static str, (&'static str, bool)> = HashMap::new();
191		let mut solved_dependencies: HashSet<&'static str> = HashSet::new();
192		let mut pending_dependency_names: Vec<&'static str> = Vec::new();
193
194		for dependency in extension.dependencies {
195			let (is_required, name, version_matcher) = match dependency {
196				Dependency::Optional(name, version_matcher) => (false, *name, *version_matcher),
197				Dependency::Required(name, version_matcher) => (true, *name, *version_matcher),
198			};
199			if all_names.contains(name) {
200				return Err(
201					ExtensionInstallationError::DuplicateDependency(extension.name, extension.version, name).into(),
202				);
203			}
204			all_names.insert(name);
205
206			if !self.activated_extensions.lock().await.contains(&name) {
207				pending_dependencies.insert(name, (version_matcher, is_required));
208				pending_dependency_names.push(name);
209			} else {
210				if let Some(received_version) = self.match_dependency(name, version_matcher).await? {
211					if is_required {
212						has_problems = true;
213						self.version_mismatches.lock().await.push((
214							extension.name,
215							extension.version,
216							name,
217							received_version,
218							version_matcher,
219						));
220						crate::warn!(
221							"Extension '{}@{}' expected version '{}' from required dependency '{}' (but got '{}') - extension will not be initialized",
222							extension.name,
223							extension.version,
224							version_matcher,
225							name,
226							received_version
227						);
228					} else {
229						crate::warn!(
230							"Extension '{}@{}' expected version '{}' from optional dependency '{}' (but got '{}')",
231							extension.name,
232							extension.version,
233							version_matcher,
234							name,
235							received_version
236						);
237					}
238				} else {
239					solved_dependencies.insert(name);
240				}
241			}
242		}
243
244		{
245			let mut reverse = self.extensions_dependents_expected.lock().await;
246			for name in pending_dependency_names {
247				if let Some(lookup) = reverse.get_mut(name) {
248					lookup.push(extension.name);
249				} else {
250					reverse.insert(name, vec![extension.name]);
251				}
252			}
253		}
254
255		let has_pending = pending_dependencies.len() > 0;
256
257		self.extension_dependencies_expected
258			.lock()
259			.await
260			.insert(extension.name, pending_dependencies);
261
262		self.extension_dependencies_resolved
263			.lock()
264			.await
265			.insert(extension.name, solved_dependencies);
266
267		if !has_pending && !has_problems {
268			self.complete(extension).await?;
269		}
270
271		Ok(())
272	}
273
274	#[async_recursion(?Send)]
275	async fn complete(&self, extension: &'static Extension) -> Result<()> {
276		crate::debug!("[repository] Completing {}@{}", extension.name, extension.version);
277		self.activate_extension(extension).await?;
278		crate::debug!(
279			"[repository] Initialized {}@{} - now resolving dependents",
280			extension.name,
281			extension.version
282		);
283
284		let mut extensions_to_complete: Vec<&'static Extension> = Vec::new();
285
286		if let Some(dependents) = self.extensions_dependents_expected.lock().await.remove(extension.name) {
287			for dependent in dependents {
288				crate::debug!(
289					"[repository] from {}@{} resolving {}",
290					extension.name,
291					extension.version,
292					dependent
293				);
294				let mut has_problems = false;
295				let mut should_complete = false;
296
297				{
298					let dependent_version = self.get_extension_version_for(dependent).await.unwrap();
299					let mut expected = self.extension_dependencies_expected.lock().await;
300					let deps_dependencies = expected.get_mut(dependent).unwrap();
301					let size = deps_dependencies.len();
302					let (version_matcher, is_required) = deps_dependencies.remove(extension.name).unwrap();
303
304					if self.match_dependency(extension.name, version_matcher).await?.is_some() {
305						if is_required {
306							has_problems = true;
307							self.version_mismatches.lock().await.push((
308								dependent,
309								dependent_version,
310								extension.name,
311								extension.version,
312								version_matcher,
313							));
314						}
315					} else {
316						self.extension_dependencies_resolved
317							.lock()
318							.await
319							.get_mut(dependent)
320							.unwrap()
321							.insert(extension.name);
322						if size == 1 {
323							should_complete = true;
324						}
325					}
326				}
327
328				if !has_problems && should_complete {
329					extensions_to_complete.push(self.all_extensions.lock().await.get(dependent).unwrap());
330				}
331			}
332		}
333
334		for ext in extensions_to_complete {
335			self.complete(ext).await?;
336		}
337
338		crate::debug!("[repository] Completed {}@{}", extension.name, extension.version);
339
340		Ok(())
341	}
342
343	async fn match_dependency(
344		&self,
345		dependency_name: &'static str,
346		version_match: &'static str,
347	) -> Result<Option<&'static str>> {
348		if let Some(received_version) = self.extension_id_to_version.lock().await.get(dependency_name) {
349			let expected_version_match = VersionReq::parse(version_match)?;
350			let received_version_semver = Version::parse(received_version)?;
351			if expected_version_match.matches(&received_version_semver) {
352				Ok(None)
353			} else {
354				Ok(Some(received_version))
355			}
356		} else {
357			Err(ExtensionInstallationError::ExtensionNotFound(dependency_name).into())
358		}
359	}
360
361	async fn activate_extension(&'a self, extension: &'static Extension) -> Result<()> {
362		let state = Arc::new(Mutex::new(State::default()));
363
364		self.extension_states
365			.lock()
366			.await
367			.insert(extension.name, Arc::clone(&state));
368
369		(extension.init)(ExtensionContext::new(
370			Arc::clone(&state),
371			extension.name,
372			Repository(Box::pin(ExtensionProtocol::new(
373				extension.name,
374				Arc::clone(&self.extension_states),
375				Arc::clone(&self.extension_dependencies_resolved),
376				Arc::clone(&self.extension_calls),
377			))),
378		))
379		.await?;
380
381		self.activated_extensions.lock().await.push(extension.name);
382		Ok(())
383	}
384
385	pub async fn lock(&self) -> Result<()> {
386		{
387			let mut locked = self.locked.lock().await;
388			if *locked {
389				return Ok(());
390			}
391
392			*locked = true;
393		}
394		Ok(())
397	}
398
399	pub async fn call<Argument, Return>(
400		&self,
401		call: &'static Call<Argument, Return>,
402		argument: Argument,
403	) -> Result<Return> {
404		let state = Arc::clone(
405			self.extension_states
406				.lock()
407				.await
408				.get(call.owner)
409				.expect("should never happen"),
410		);
411		let handler = if let Some(fun) = self.extension_calls.lock().await.get(call.id) {
412			Arc::clone(fun)
413		} else {
414			return Err(CallNotFoundError("host", call.id).into());
415		};
416
417		unsafe {
418			handler.invoke(CallContext {
419				state,
420				caller: "",
421				argument,
422			})
423		}
424		.await
425	}
426}
427
428#[derive(Error, Debug)]
429pub enum ExtensionInstallationError {
430	#[error("Repository is locked")]
431	Locked,
432
433	#[error("Extension '{0}' has already been added with version '{1}', trying to add '{2}'!")]
439	ExtensionAlreadyAdded(&'static str, &'static str, &'static str),
440
441	#[error("Extension '{0}@{1}' is missing a required dependency '{2}@{3}'")]
449	MissingDependency(&'static str, &'static str, &'static str, &'static str),
450
451	#[error("Extension '{0}@{1}' expected dependency '{2}' with a version in range(s) '{4}', but got '{3}'")]
461	VersionMismatch(&'static str, &'static str, &'static str, &'static str, &'static str),
462
463	#[error("Extension '{0}@{1}' contains a duplicate dependency of '{2}'")]
468	DuplicateDependency(&'static str, &'static str, &'static str),
469
470	#[error("Extension '{0}' not found!")]
475	ExtensionNotFound(&'static str),
476}