1use std::any::TypeId;
4
5use crate::dependency_builder::{self, DepBuilder};
6use crate::types::{
7 HashMap, NonAsyncRwLock, Registerable, RegisterableSingleton, Visitor,
8};
9
10#[derive(Debug, Clone, PartialEq, Hash)]
12#[non_exhaustive]
13pub enum ValidationError {
14 Cycle,
16 Missing,
18}
19
20impl std::fmt::Display for ValidationError {
21 #[allow(clippy::use_debug)]
22 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Cycle => write!(fmt, "cycle detected!"),
25 Self::Missing => write!(fmt, "dependencies missing!"),
26 }
27 }
28}
29
30impl std::error::Error for ValidationError {
31 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
32 None
33 }
34}
35
36#[derive(Debug, Clone, PartialEq, Hash)]
38#[non_exhaustive]
39pub enum FullValidationError {
40 Cycle(Option<String>),
42 Missing(Vec<MissingDependencies>),
44}
45
46impl std::fmt::Display for FullValidationError {
47 #[allow(clippy::use_debug)]
48 fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 match self {
50 Self::Cycle(node) => match node {
51 Some(node) => write!(fmt, "cycle detected at {node}"),
52 None => write!(fmt, "cycle detected!"),
53 },
54 Self::Missing(all_missing) => {
55 writeln!(fmt, "dependencies missing:")?;
56
57 for missing in all_missing {
58 writeln!(
59 fmt,
60 "dependencies missing for {} ({:?}):",
61 missing.ty.1, missing.ty.0
62 )?;
63 for (type_id, type_name) in &missing.deps {
64 writeln!(fmt, " - {type_name} ({type_id:?})")?;
65 }
66 writeln!(fmt, "\n")?;
67 }
68
69 Ok(())
70 }
71 }
72 }
73}
74
75impl std::error::Error for FullValidationError {
76 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
77 None
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, Hash)]
83pub struct MissingDependencies {
84 pub(crate) ty: (TypeId, &'static str),
86 pub(crate) deps: Vec<(TypeId, &'static str)>,
88}
89
90impl MissingDependencies {
91 pub fn ty(&self) -> &(TypeId, &'static str) {
97 &self.ty
98 }
99
100 pub fn missing_dependencies(&self) -> &[(TypeId, &'static str)] {
102 &self.deps
103 }
104}
105
106pub(crate) struct DependencyValidator {
108 visitor: NonAsyncRwLock<HashMap<TypeId, Visitor>>,
111 context: NonAsyncRwLock<VisitorContext>,
113}
114
115impl DependencyValidator {
116 pub(crate) fn new() -> Self {
118 Self {
119 visitor: NonAsyncRwLock::new(HashMap::new()),
120 context: NonAsyncRwLock::new(VisitorContext::new()),
121 }
122 }
123
124 pub(crate) fn add_transient_no_deps<T>(&self)
126 where
127 T: Registerable,
128 {
129 let visitor = Visitor(|_this, _visitors, context| {
130 if let Some(index) = context.visited.get(&TypeId::of::<T>()) {
131 return *index;
132 }
133
134 let index = context.graph.add_node(std::any::type_name::<T>());
135
136 context.visited.insert(TypeId::of::<T>(), index);
137
138 index
139 });
140
141 {
142 let mut visitors = self.visitor.write();
143 visitors.insert(TypeId::of::<T>(), visitor);
144 {
145 let mut context = self.context.write();
146 context.reset();
147 }
148 }
149 }
150
151 pub(crate) fn add_singleton_no_deps<T>(&self)
153 where
154 T: RegisterableSingleton,
155 {
156 self.add_transient_no_deps::<T>();
157 }
158
159 pub(crate) fn add_transient_deps<
161 T: Registerable,
162 #[cfg(not(feature = "tokio"))] Deps: DepBuilder<T> + 'static,
163 #[cfg(feature = "tokio")] Deps: DepBuilder<T> + Sync + 'static,
164 >(
165 &self,
166 ) {
167 let visitor = Visitor(|this, visitors, context| {
168 if let Some(index) = context.visited.get(&TypeId::of::<T>()) {
170 return *index;
171 }
172
173 let current = context.graph.add_node(std::any::type_name::<T>());
174
175 {
177 context.visited.insert(TypeId::of::<T>(), current);
178 }
179
180 let type_ids =
181 Deps::as_typeids(dependency_builder::private::SealToken);
182
183 for (type_id, type_name) in &type_ids {
184 if let Some(index) = context.visited.get(type_id) {
186 context.graph.add_edge(current, *index, ());
187 continue;
188 }
189
190 if let Some(visitor) = visitors.get(type_id) {
192 let index = (visitor.0)(this, visitors, context);
193 context.graph.add_edge(current, index, ());
194 continue;
195 }
196
197 {
198 if let Some(ty) =
199 context.missing.get_mut(&TypeId::of::<T>())
200 {
201 ty.deps.push((*type_id, type_name));
202 } else {
203 context.missing.insert(
204 TypeId::of::<T>(),
205 MissingDependencies {
206 ty: (
207 TypeId::of::<T>(),
208 std::any::type_name::<T>(),
209 ),
210 deps: vec![(*type_id, type_name)],
211 },
212 );
213 }
214 }
215
216 #[cfg(feature = "tracing")]
217 tracing::warn!(
218 "couldn't add dependency of {}: {type_name}",
219 std::any::type_name::<T>()
220 );
221 }
222
223 current
224 });
225
226 {
227 let mut visitors = self.visitor.write();
228 visitors.insert(TypeId::of::<T>(), visitor);
229 {
230 let mut context = self.context.write();
231 context.reset();
232 }
233 }
234 }
235
236 pub(crate) fn add_singleton_deps<
238 T: RegisterableSingleton,
239 #[cfg(not(feature = "tokio"))] Deps: DepBuilder<T> + 'static,
240 #[cfg(feature = "tokio")] Deps: DepBuilder<T> + Sync + 'static,
241 >(
242 &self,
243 ) {
244 self.add_transient_deps::<T, Deps>();
245 }
246
247 pub(crate) fn validate_all(&self) -> Result<(), ValidationError> {
250 let read_context = self.context.read();
251 if Self::validate_context(&read_context)? {
252 return Ok(());
254 }
255
256 drop(read_context);
259 let visitors = self.visitor.read();
260 let mut write_context = self.context.write();
261 if Self::validate_context(&write_context)? {
262 return Ok(());
265 }
266
267 self.calculate_validation(&visitors, &mut write_context);
269
270 Self::validate_context(&write_context)?;
272
273 Ok(())
274 }
275
276 pub(crate) fn validate_all_full(&self) -> Result<(), FullValidationError> {
279 let mut context = VisitorContext::new();
280 {
281 let visitors = self.visitor.read();
282 self.calculate_validation(&visitors, &mut context);
283 }
284
285 if !context.missing.is_empty() {
291 let mut vec = Vec::with_capacity(context.missing.len());
292 context.missing.iter().for_each(|(_, ty)| {
293 vec.push(ty.clone());
294 });
295 return Err(FullValidationError::Missing(vec));
296 }
297
298 if let Some(cached) = &context.validation_cache {
299 return match cached {
300 Ok(_) => Ok(()),
301 Err(err) => {
302 let index = err.node_id();
303 let node_name = context.graph.node_weight(index);
304 return Err(FullValidationError::Cycle(
305 node_name.map(|el| (*el).to_owned()),
306 ));
307 }
308 };
309 }
310
311 unreachable!("this is a bug")
312 }
313
314 fn validate_context(
320 context: &VisitorContext,
321 ) -> Result<bool, ValidationError> {
322 if !context.missing.is_empty() {
323 return Err(ValidationError::Missing);
324 }
325
326 if let Some(cached) = &context.validation_cache {
327 return match cached {
328 Ok(_) => Ok(true),
329 Err(_) => Err(ValidationError::Cycle),
330 };
331 }
332
333 Ok(false)
334 }
335
336 fn calculate_validation(
338 &self,
339 visitors: &HashMap<TypeId, Visitor>,
340 context: &mut VisitorContext,
341 ) {
342 {
343 for cb in visitors.values() {
344 (cb.0)(self, visitors, context);
347 }
348 }
349
350 let mut space = petgraph::algo::DfsSpace::new(&context.graph);
352 context.validation_cache =
353 Some(petgraph::algo::toposort(&context.graph, Some(&mut space)));
354 }
355
356 pub(crate) fn validate<T>(&self) -> Result<(), ValidationError>
358 where
359 T: Registerable,
360 {
361 let _ = std::marker::PhantomData::<T>;
362 self.validate_all()
363 }
364
365 pub(crate) fn dotgraph(&self) -> Result<String, ValidationError> {
367 self.validate_all()?;
368
369 let context = self.context.read();
370 let dot = petgraph::dot::Dot::with_config(
371 &context.graph,
372 &[petgraph::dot::Config::EdgeNoLabel],
373 );
374
375 Ok(format!("{dot:?}"))
376 }
377}
378
379pub(crate) struct VisitorContext {
381 graph: petgraph::Graph<&'static str, (), petgraph::Directed>,
383 missing: HashMap<TypeId, MissingDependencies>,
385 visited: HashMap<TypeId, petgraph::graph::NodeIndex>,
387 validation_cache: Option<
389 Result<
390 Vec<petgraph::graph::NodeIndex>,
391 petgraph::algo::Cycle<petgraph::graph::NodeIndex>,
392 >,
393 >,
394}
395
396impl VisitorContext {
397 pub fn new() -> Self {
399 Self {
400 graph: petgraph::Graph::new(),
401 missing: HashMap::new(),
402 visited: HashMap::new(),
403 validation_cache: None,
404 }
405 }
406
407 pub fn reset(&mut self) {
409 self.graph.clear();
410 self.missing.clear();
411 self.visited.clear();
412 self.validation_cache = None;
413 }
414}