1use crate::{App, AppContext, GpuiBorrow, VisualContext, Window, seal::Sealed};
2use anyhow::{Context as _, Result};
3use collections::FxHashSet;
4use derive_more::{Deref, DerefMut};
5use parking_lot::{RwLock, RwLockUpgradableReadGuard};
6use slotmap::{KeyData, SecondaryMap, SlotMap};
7use std::{
8 any::{Any, TypeId, type_name},
9 cell::RefCell,
10 cmp::Ordering,
11 fmt::{self, Display},
12 hash::{Hash, Hasher},
13 marker::PhantomData,
14 num::NonZeroU64,
15 sync::{
16 Arc, Weak,
17 atomic::{AtomicU64, AtomicUsize, Ordering::SeqCst},
18 },
19 thread::panicking,
20};
21
22use super::Context;
23use crate::util::atomic_incr_if_not_zero;
24#[cfg(any(test, feature = "leak-detection"))]
25use collections::HashMap;
26
27slotmap::new_key_type! {
28 pub struct EntityId;
30}
31
32impl From<u64> for EntityId {
33 fn from(value: u64) -> Self {
34 Self(KeyData::from_ffi(value))
35 }
36}
37
38impl EntityId {
39 pub fn as_non_zero_u64(self) -> NonZeroU64 {
41 NonZeroU64::new(self.0.as_ffi()).unwrap()
42 }
43
44 pub fn as_u64(self) -> u64 {
46 self.0.as_ffi()
47 }
48}
49
50impl Display for EntityId {
51 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
52 write!(f, "{}", self.as_u64())
53 }
54}
55
56pub(crate) struct EntityMap {
57 entities: SecondaryMap<EntityId, Box<dyn Any>>,
58 pub accessed_entities: RefCell<FxHashSet<EntityId>>,
59 ref_counts: Arc<RwLock<EntityRefCounts>>,
60}
61
62#[doc(hidden)]
63pub(crate) struct EntityRefCounts {
64 counts: SlotMap<EntityId, AtomicUsize>,
65 dropped_entity_ids: Vec<EntityId>,
66 #[cfg(any(test, feature = "leak-detection"))]
67 leak_detector: LeakDetector,
68}
69
70impl EntityMap {
71 pub fn new() -> Self {
72 Self {
73 entities: SecondaryMap::new(),
74 accessed_entities: RefCell::new(FxHashSet::default()),
75 ref_counts: Arc::new(RwLock::new(EntityRefCounts {
76 counts: SlotMap::with_key(),
77 dropped_entity_ids: Vec::new(),
78 #[cfg(any(test, feature = "leak-detection"))]
79 leak_detector: LeakDetector {
80 next_handle_id: 0,
81 entity_handles: HashMap::default(),
82 },
83 })),
84 }
85 }
86
87 #[doc(hidden)]
88 pub fn ref_counts_drop_handle(&self) -> Arc<RwLock<EntityRefCounts>> {
89 self.ref_counts.clone()
90 }
91
92 #[cfg(any(test, feature = "leak-detection"))]
98 pub fn leak_detector_snapshot(&self) -> LeakDetectorSnapshot {
99 self.ref_counts.read().leak_detector.snapshot()
100 }
101
102 #[cfg(any(test, feature = "leak-detection"))]
106 pub fn assert_no_new_leaks(&self, snapshot: &LeakDetectorSnapshot) {
107 self.ref_counts
108 .read()
109 .leak_detector
110 .assert_no_new_leaks(snapshot)
111 }
112
113 pub fn reserve<T: 'static>(&self) -> Slot<T> {
115 let id = self.ref_counts.write().counts.insert(1.into());
116 Slot(Entity::new(id, Arc::downgrade(&self.ref_counts)))
117 }
118
119 pub fn insert<T>(&mut self, slot: Slot<T>, entity: T) -> Entity<T>
121 where
122 T: 'static,
123 {
124 let mut accessed_entities = self.accessed_entities.get_mut();
125 accessed_entities.insert(slot.entity_id);
126
127 let handle = slot.0;
128 self.entities.insert(handle.entity_id, Box::new(entity));
129 handle
130 }
131
132 #[track_caller]
134 pub fn lease<T>(&mut self, pointer: &Entity<T>) -> Lease<T> {
135 self.assert_valid_context(pointer);
136 let mut accessed_entities = self.accessed_entities.get_mut();
137 accessed_entities.insert(pointer.entity_id);
138
139 let entity = Some(
140 self.entities
141 .remove(pointer.entity_id)
142 .unwrap_or_else(|| double_lease_panic::<T>("update")),
143 );
144 Lease {
145 entity,
146 id: pointer.entity_id,
147 entity_type: PhantomData,
148 }
149 }
150
151 pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
153 self.entities.insert(lease.id, lease.entity.take().unwrap());
154 }
155
156 pub fn read<T: 'static>(&self, entity: &Entity<T>) -> &T {
157 self.assert_valid_context(entity);
158 let mut accessed_entities = self.accessed_entities.borrow_mut();
159 accessed_entities.insert(entity.entity_id);
160
161 self.entities
162 .get(entity.entity_id)
163 .and_then(|entity| entity.downcast_ref())
164 .unwrap_or_else(|| double_lease_panic::<T>("read"))
165 }
166
167 fn assert_valid_context(&self, entity: &AnyEntity) {
168 debug_assert!(
169 Weak::ptr_eq(&entity.entity_map, &Arc::downgrade(&self.ref_counts)),
170 "used a entity with the wrong context"
171 );
172 }
173
174 pub fn extend_accessed(&mut self, entities: &FxHashSet<EntityId>) {
175 self.accessed_entities
176 .get_mut()
177 .extend(entities.iter().copied());
178 }
179
180 pub fn clear_accessed(&mut self) {
181 self.accessed_entities.get_mut().clear();
182 }
183
184 pub fn take_dropped(&mut self) -> Vec<(EntityId, Box<dyn Any>)> {
185 let mut ref_counts = &mut *self.ref_counts.write();
186 let dropped_entity_ids = ref_counts.dropped_entity_ids.drain(..);
187 let mut accessed_entities = self.accessed_entities.get_mut();
188
189 dropped_entity_ids
190 .filter_map(|entity_id| {
191 let count = ref_counts.counts.remove(entity_id).unwrap();
192 debug_assert_eq!(
193 count.load(SeqCst),
194 0,
195 "dropped an entity that was referenced"
196 );
197 accessed_entities.remove(&entity_id);
198 Some((entity_id, self.entities.remove(entity_id)?))
201 })
202 .collect()
203 }
204}
205
206#[track_caller]
207fn double_lease_panic<T>(operation: &str) -> ! {
208 panic!(
209 "cannot {operation} {} while it is already being updated",
210 std::any::type_name::<T>()
211 )
212}
213
214pub(crate) struct Lease<T> {
215 entity: Option<Box<dyn Any>>,
216 pub id: EntityId,
217 entity_type: PhantomData<T>,
218}
219
220impl<T: 'static> core::ops::Deref for Lease<T> {
221 type Target = T;
222
223 fn deref(&self) -> &Self::Target {
224 self.entity.as_ref().unwrap().downcast_ref().unwrap()
225 }
226}
227
228impl<T: 'static> core::ops::DerefMut for Lease<T> {
229 fn deref_mut(&mut self) -> &mut Self::Target {
230 self.entity.as_mut().unwrap().downcast_mut().unwrap()
231 }
232}
233
234impl<T> Drop for Lease<T> {
235 fn drop(&mut self) {
236 if self.entity.is_some() && !panicking() {
237 panic!("Leases must be ended with EntityMap::end_lease")
238 }
239 }
240}
241
242#[derive(Deref, DerefMut)]
243pub(crate) struct Slot<T>(Entity<T>);
244
245pub struct AnyEntity {
247 pub(crate) entity_id: EntityId,
248 pub(crate) entity_type: TypeId,
249 entity_map: Weak<RwLock<EntityRefCounts>>,
250 #[cfg(any(test, feature = "leak-detection"))]
251 handle_id: HandleId,
252}
253
254impl AnyEntity {
255 fn new(
256 id: EntityId,
257 entity_type: TypeId,
258 entity_map: Weak<RwLock<EntityRefCounts>>,
259 #[cfg(any(test, feature = "leak-detection"))] type_name: &'static str,
260 ) -> Self {
261 Self {
262 entity_id: id,
263 entity_type,
264 #[cfg(any(test, feature = "leak-detection"))]
265 handle_id: entity_map
266 .clone()
267 .upgrade()
268 .unwrap()
269 .write()
270 .leak_detector
271 .handle_created(id, Some(type_name)),
272 entity_map,
273 }
274 }
275
276 #[inline]
278 pub fn entity_id(&self) -> EntityId {
279 self.entity_id
280 }
281
282 #[inline]
284 pub fn entity_type(&self) -> TypeId {
285 self.entity_type
286 }
287
288 pub fn downgrade(&self) -> AnyWeakEntity {
290 AnyWeakEntity {
291 entity_id: self.entity_id,
292 entity_type: self.entity_type,
293 entity_ref_counts: self.entity_map.clone(),
294 }
295 }
296
297 pub fn downcast<T: 'static>(self) -> Result<Entity<T>, AnyEntity> {
300 if TypeId::of::<T>() == self.entity_type {
301 Ok(Entity {
302 any_entity: self,
303 entity_type: PhantomData,
304 })
305 } else {
306 Err(self)
307 }
308 }
309}
310
311impl Clone for AnyEntity {
312 fn clone(&self) -> Self {
313 if let Some(entity_map) = self.entity_map.upgrade() {
314 let entity_map = entity_map.read();
315 let count = entity_map
316 .counts
317 .get(self.entity_id)
318 .expect("detected over-release of a entity");
319 let prev_count = count.fetch_add(1, SeqCst);
320 assert_ne!(prev_count, 0, "Detected over-release of a entity.");
321 }
322
323 Self {
324 entity_id: self.entity_id,
325 entity_type: self.entity_type,
326 entity_map: self.entity_map.clone(),
327 #[cfg(any(test, feature = "leak-detection"))]
328 handle_id: self
329 .entity_map
330 .upgrade()
331 .unwrap()
332 .write()
333 .leak_detector
334 .handle_created(self.entity_id, None),
335 }
336 }
337}
338
339impl Drop for AnyEntity {
340 fn drop(&mut self) {
341 if let Some(entity_map) = self.entity_map.upgrade() {
342 let entity_map = entity_map.upgradable_read();
343 let count = entity_map
344 .counts
345 .get(self.entity_id)
346 .expect("detected over-release of a handle.");
347 let prev_count = count.fetch_sub(1, SeqCst);
348 assert_ne!(prev_count, 0, "Detected over-release of a entity.");
349 if prev_count == 1 {
350 let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
352 entity_map.dropped_entity_ids.push(self.entity_id);
353 }
354 }
355
356 #[cfg(any(test, feature = "leak-detection"))]
357 if let Some(entity_map) = self.entity_map.upgrade() {
358 entity_map
359 .write()
360 .leak_detector
361 .handle_released(self.entity_id, self.handle_id)
362 }
363 }
364}
365
366impl<T> From<Entity<T>> for AnyEntity {
367 #[inline]
368 fn from(entity: Entity<T>) -> Self {
369 entity.any_entity
370 }
371}
372
373impl Hash for AnyEntity {
374 #[inline]
375 fn hash<H: Hasher>(&self, state: &mut H) {
376 self.entity_id.hash(state);
377 }
378}
379
380impl PartialEq for AnyEntity {
381 #[inline]
382 fn eq(&self, other: &Self) -> bool {
383 self.entity_id == other.entity_id
384 }
385}
386
387impl Eq for AnyEntity {}
388
389impl Ord for AnyEntity {
390 #[inline]
391 fn cmp(&self, other: &Self) -> Ordering {
392 self.entity_id.cmp(&other.entity_id)
393 }
394}
395
396impl PartialOrd for AnyEntity {
397 #[inline]
398 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
399 Some(self.cmp(other))
400 }
401}
402
403impl std::fmt::Debug for AnyEntity {
404 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
405 f.debug_struct("AnyEntity")
406 .field("entity_id", &self.entity_id.as_u64())
407 .finish()
408 }
409}
410
411#[derive(Deref, DerefMut)]
414pub struct Entity<T> {
415 #[deref]
416 #[deref_mut]
417 pub(crate) any_entity: AnyEntity,
418 pub(crate) entity_type: PhantomData<fn(T) -> T>,
419}
420
421impl<T> Sealed for Entity<T> {}
422
423impl<T: 'static> Entity<T> {
424 #[inline]
425 fn new(id: EntityId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self
426 where
427 T: 'static,
428 {
429 Self {
430 any_entity: AnyEntity::new(
431 id,
432 TypeId::of::<T>(),
433 entity_map,
434 #[cfg(any(test, feature = "leak-detection"))]
435 std::any::type_name::<T>(),
436 ),
437 entity_type: PhantomData,
438 }
439 }
440
441 #[inline]
443 pub fn entity_id(&self) -> EntityId {
444 self.any_entity.entity_id
445 }
446
447 #[inline]
449 pub fn downgrade(&self) -> WeakEntity<T> {
450 WeakEntity {
451 any_entity: self.any_entity.downgrade(),
452 entity_type: self.entity_type,
453 }
454 }
455
456 #[inline]
458 pub fn into_any(self) -> AnyEntity {
459 self.any_entity
460 }
461
462 #[inline]
464 pub fn read<'a>(&self, cx: &'a App) -> &'a T {
465 cx.entities.read(self)
466 }
467
468 #[inline]
470 pub fn read_with<R, C: AppContext>(&self, cx: &C, f: impl FnOnce(&T, &App) -> R) -> R {
471 cx.read_entity(self, f)
472 }
473
474 #[inline]
476 pub fn update<R, C: AppContext>(
477 &self,
478 cx: &mut C,
479 update: impl FnOnce(&mut T, &mut Context<T>) -> R,
480 ) -> R {
481 cx.update_entity(self, update)
482 }
483
484 #[inline]
486 pub fn as_mut<'a, C: AppContext>(&self, cx: &'a mut C) -> GpuiBorrow<'a, T> {
487 cx.as_mut(self)
488 }
489
490 pub fn write<C: AppContext>(&self, cx: &mut C, value: T) {
492 self.update(cx, |entity, cx| {
493 *entity = value;
494 cx.notify();
495 })
496 }
497
498 #[inline]
502 pub fn update_in<R, C: VisualContext>(
503 &self,
504 cx: &mut C,
505 update: impl FnOnce(&mut T, &mut Window, &mut Context<T>) -> R,
506 ) -> C::Result<R> {
507 cx.update_window_entity(self, update)
508 }
509}
510
511impl<T> Clone for Entity<T> {
512 #[inline]
513 fn clone(&self) -> Self {
514 Self {
515 any_entity: self.any_entity.clone(),
516 entity_type: self.entity_type,
517 }
518 }
519}
520
521impl<T> std::fmt::Debug for Entity<T> {
522 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
523 f.debug_struct("Entity")
524 .field("entity_id", &self.any_entity.entity_id)
525 .field("entity_type", &type_name::<T>())
526 .finish()
527 }
528}
529
530impl<T> Hash for Entity<T> {
531 #[inline]
532 fn hash<H: Hasher>(&self, state: &mut H) {
533 self.any_entity.hash(state);
534 }
535}
536
537impl<T> PartialEq for Entity<T> {
538 #[inline]
539 fn eq(&self, other: &Self) -> bool {
540 self.any_entity == other.any_entity
541 }
542}
543
544impl<T> Eq for Entity<T> {}
545
546impl<T> PartialEq<WeakEntity<T>> for Entity<T> {
547 #[inline]
548 fn eq(&self, other: &WeakEntity<T>) -> bool {
549 self.any_entity.entity_id() == other.entity_id()
550 }
551}
552
553impl<T: 'static> Ord for Entity<T> {
554 #[inline]
555 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
556 self.entity_id().cmp(&other.entity_id())
557 }
558}
559
560impl<T: 'static> PartialOrd for Entity<T> {
561 #[inline]
562 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
563 Some(self.cmp(other))
564 }
565}
566
567#[derive(Clone)]
569pub struct AnyWeakEntity {
570 pub(crate) entity_id: EntityId,
571 entity_type: TypeId,
572 entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
573}
574
575impl AnyWeakEntity {
576 #[inline]
578 pub fn entity_id(&self) -> EntityId {
579 self.entity_id
580 }
581
582 pub fn is_upgradable(&self) -> bool {
584 let ref_count = self
585 .entity_ref_counts
586 .upgrade()
587 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
588 .unwrap_or(0);
589 ref_count > 0
590 }
591
592 pub fn upgrade(&self) -> Option<AnyEntity> {
594 let ref_counts = &self.entity_ref_counts.upgrade()?;
595 let ref_counts = ref_counts.read();
596 let ref_count = ref_counts.counts.get(self.entity_id)?;
597
598 if atomic_incr_if_not_zero(ref_count) == 0 {
599 return None;
601 }
602 drop(ref_counts);
603
604 Some(AnyEntity {
605 entity_id: self.entity_id,
606 entity_type: self.entity_type,
607 entity_map: self.entity_ref_counts.clone(),
608 #[cfg(any(test, feature = "leak-detection"))]
609 handle_id: self
610 .entity_ref_counts
611 .upgrade()
612 .unwrap()
613 .write()
614 .leak_detector
615 .handle_created(self.entity_id, None),
616 })
617 }
618
619 #[cfg(any(test, feature = "leak-detection"))]
647 pub fn assert_released(&self) {
648 self.entity_ref_counts
649 .upgrade()
650 .unwrap()
651 .write()
652 .leak_detector
653 .assert_released(self.entity_id);
654
655 if self
656 .entity_ref_counts
657 .upgrade()
658 .and_then(|ref_counts| Some(ref_counts.read().counts.get(self.entity_id)?.load(SeqCst)))
659 .is_some()
660 {
661 panic!(
662 "entity was recently dropped but resources are retained until the end of the effect cycle."
663 )
664 }
665 }
666
667 pub fn new_invalid() -> Self {
669 static UNIQUE_NON_CONFLICTING_ID_GENERATOR: AtomicU64 = AtomicU64::new(u64::MAX);
673 let entity_id = UNIQUE_NON_CONFLICTING_ID_GENERATOR.fetch_sub(1, SeqCst);
674
675 Self {
676 entity_id: entity_id.into(),
686 entity_type: TypeId::of::<()>(),
687 entity_ref_counts: Weak::new(),
688 }
689 }
690}
691
692impl std::fmt::Debug for AnyWeakEntity {
693 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
694 f.debug_struct(type_name::<Self>())
695 .field("entity_id", &self.entity_id)
696 .field("entity_type", &self.entity_type)
697 .finish()
698 }
699}
700
701impl<T> From<WeakEntity<T>> for AnyWeakEntity {
702 #[inline]
703 fn from(entity: WeakEntity<T>) -> Self {
704 entity.any_entity
705 }
706}
707
708impl Hash for AnyWeakEntity {
709 #[inline]
710 fn hash<H: Hasher>(&self, state: &mut H) {
711 self.entity_id.hash(state);
712 }
713}
714
715impl PartialEq for AnyWeakEntity {
716 #[inline]
717 fn eq(&self, other: &Self) -> bool {
718 self.entity_id == other.entity_id
719 }
720}
721
722impl Eq for AnyWeakEntity {}
723
724impl Ord for AnyWeakEntity {
725 #[inline]
726 fn cmp(&self, other: &Self) -> Ordering {
727 self.entity_id.cmp(&other.entity_id)
728 }
729}
730
731impl PartialOrd for AnyWeakEntity {
732 #[inline]
733 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
734 Some(self.cmp(other))
735 }
736}
737
738#[derive(Deref, DerefMut)]
740pub struct WeakEntity<T> {
741 #[deref]
742 #[deref_mut]
743 any_entity: AnyWeakEntity,
744 entity_type: PhantomData<fn(T) -> T>,
745}
746
747impl<T> std::fmt::Debug for WeakEntity<T> {
748 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
749 f.debug_struct(type_name::<Self>())
750 .field("entity_id", &self.any_entity.entity_id)
751 .field("entity_type", &type_name::<T>())
752 .finish()
753 }
754}
755
756impl<T> Clone for WeakEntity<T> {
757 fn clone(&self) -> Self {
758 Self {
759 any_entity: self.any_entity.clone(),
760 entity_type: self.entity_type,
761 }
762 }
763}
764
765impl<T: 'static> WeakEntity<T> {
766 pub fn upgrade(&self) -> Option<Entity<T>> {
768 Some(Entity {
769 any_entity: self.any_entity.upgrade()?,
770 entity_type: self.entity_type,
771 })
772 }
773
774 pub fn update<C, R>(
778 &self,
779 cx: &mut C,
780 update: impl FnOnce(&mut T, &mut Context<T>) -> R,
781 ) -> Result<R>
782 where
783 C: AppContext,
784 {
785 let entity = self.upgrade().context("entity released")?;
786 Ok(cx.update_entity(&entity, update))
787 }
788
789 pub fn update_in<C, R>(
793 &self,
794 cx: &mut C,
795 update: impl FnOnce(&mut T, &mut Window, &mut Context<T>) -> R,
796 ) -> Result<R>
797 where
798 C: AppContext,
799 {
800 let entity = self.upgrade().context("entity released")?;
801 cx.with_window(entity.entity_id(), |window, app| {
802 entity.update(app, |entity, cx| update(entity, window, cx))
803 })
804 .context("entity has no current window")
805 }
806
807 pub fn read_with<C, R>(&self, cx: &C, read: impl FnOnce(&T, &App) -> R) -> Result<R>
811 where
812 C: AppContext,
813 {
814 let entity = self.upgrade().context("entity released")?;
815 Ok(cx.read_entity(&entity, read))
816 }
817
818 #[inline]
820 pub fn new_invalid() -> Self {
821 Self {
822 any_entity: AnyWeakEntity::new_invalid(),
823 entity_type: PhantomData,
824 }
825 }
826}
827
828impl<T> Hash for WeakEntity<T> {
829 #[inline]
830 fn hash<H: Hasher>(&self, state: &mut H) {
831 self.any_entity.hash(state);
832 }
833}
834
835impl<T> PartialEq for WeakEntity<T> {
836 #[inline]
837 fn eq(&self, other: &Self) -> bool {
838 self.any_entity == other.any_entity
839 }
840}
841
842impl<T> Eq for WeakEntity<T> {}
843
844impl<T> PartialEq<Entity<T>> for WeakEntity<T> {
845 #[inline]
846 fn eq(&self, other: &Entity<T>) -> bool {
847 self.entity_id() == other.any_entity.entity_id()
848 }
849}
850
851impl<T: 'static> Ord for WeakEntity<T> {
852 #[inline]
853 fn cmp(&self, other: &Self) -> Ordering {
854 self.entity_id().cmp(&other.entity_id())
855 }
856}
857
858impl<T: 'static> PartialOrd for WeakEntity<T> {
859 #[inline]
860 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
861 Some(self.cmp(other))
862 }
863}
864
865#[cfg(any(test, feature = "leak-detection"))]
870static LEAK_BACKTRACE: std::sync::LazyLock<bool> =
871 std::sync::LazyLock::new(|| std::env::var("LEAK_BACKTRACE").is_ok_and(|b| !b.is_empty()));
872
873#[cfg(any(test, feature = "leak-detection"))]
878#[derive(Clone, Copy, Debug, Default, Hash, PartialEq, Eq)]
879pub(crate) struct HandleId {
880 id: u64,
881}
882
883#[cfg(any(test, feature = "leak-detection"))]
933pub(crate) struct LeakDetector {
934 next_handle_id: u64,
935 entity_handles: HashMap<EntityId, EntityLeakData>,
936}
937
938#[cfg(any(test, feature = "leak-detection"))]
944pub struct LeakDetectorSnapshot {
945 entity_ids: collections::HashSet<EntityId>,
946}
947
948#[cfg(any(test, feature = "leak-detection"))]
949struct EntityLeakData {
950 handles: HashMap<HandleId, Option<backtrace::Backtrace>>,
951 type_name: &'static str,
952}
953
954#[cfg(any(test, feature = "leak-detection"))]
955impl LeakDetector {
956 #[track_caller]
962 pub fn handle_created(
963 &mut self,
964 entity_id: EntityId,
965 type_name: Option<&'static str>,
966 ) -> HandleId {
967 let id = gpui_util::post_inc(&mut self.next_handle_id);
968 let handle_id = HandleId { id };
969 let handles = self
970 .entity_handles
971 .entry(entity_id)
972 .or_insert_with(|| EntityLeakData {
973 handles: HashMap::default(),
974 type_name: type_name.unwrap_or("<unknown>"),
975 });
976 handles.handles.insert(
977 handle_id,
978 LEAK_BACKTRACE.then(backtrace::Backtrace::new_unresolved),
979 );
980 handle_id
981 }
982
983 pub fn handle_released(&mut self, entity_id: EntityId, handle_id: HandleId) {
988 if let std::collections::hash_map::Entry::Occupied(mut data) =
989 self.entity_handles.entry(entity_id)
990 {
991 data.get_mut().handles.remove(&handle_id);
992 if data.get().handles.is_empty() {
993 data.remove();
994 }
995 }
996 }
997
998 pub fn assert_released(&mut self, entity_id: EntityId) {
1006 use std::fmt::Write as _;
1007
1008 if let Some(data) = self.entity_handles.remove(&entity_id) {
1009 let mut out = String::new();
1010 for (_, backtrace) in data.handles {
1011 if let Some(mut backtrace) = backtrace {
1012 backtrace.resolve();
1013 let backtrace = BacktraceFormatter(backtrace);
1014 writeln!(out, "Leaked handle:\n{:?}", backtrace).unwrap();
1015 } else {
1016 writeln!(
1017 out,
1018 "Leaked handle: (export LEAK_BACKTRACE to find allocation site)"
1019 )
1020 .unwrap();
1021 }
1022 }
1023 panic!("Handles for {} leaked:\n{out}", data.type_name);
1024 }
1025 }
1026
1027 pub fn snapshot(&self) -> LeakDetectorSnapshot {
1033 LeakDetectorSnapshot {
1034 entity_ids: self.entity_handles.keys().copied().collect(),
1035 }
1036 }
1037
1038 pub fn assert_no_new_leaks(&self, snapshot: &LeakDetectorSnapshot) {
1050 use std::fmt::Write as _;
1051
1052 let mut out = String::new();
1053 for (entity_id, data) in &self.entity_handles {
1054 if snapshot.entity_ids.contains(entity_id) {
1055 continue;
1056 }
1057 for (_, backtrace) in &data.handles {
1058 if let Some(backtrace) = backtrace {
1059 let mut backtrace = backtrace.clone();
1060 backtrace.resolve();
1061 let backtrace = BacktraceFormatter(backtrace);
1062 writeln!(
1063 out,
1064 "Leaked handle for entity {} ({entity_id:?}):\n{:?}",
1065 data.type_name, backtrace
1066 )
1067 .unwrap();
1068 } else {
1069 writeln!(
1070 out,
1071 "Leaked handle for entity {} ({entity_id:?}): (export LEAK_BACKTRACE to find allocation site)",
1072 data.type_name
1073 )
1074 .unwrap();
1075 }
1076 }
1077 }
1078
1079 if !out.is_empty() {
1080 panic!("New entity leaks detected since snapshot:\n{out}");
1081 }
1082 }
1083}
1084
1085#[cfg(any(test, feature = "leak-detection"))]
1086impl Drop for LeakDetector {
1087 fn drop(&mut self) {
1088 use std::fmt::Write;
1089
1090 if self.entity_handles.is_empty() || std::thread::panicking() {
1091 return;
1092 }
1093
1094 let mut out = String::new();
1095 for (entity_id, data) in self.entity_handles.drain() {
1096 for (_handle, backtrace) in data.handles {
1097 if let Some(mut backtrace) = backtrace {
1098 backtrace.resolve();
1099 let backtrace = BacktraceFormatter(backtrace);
1100 writeln!(
1101 out,
1102 "Leaked handle for entity {} ({entity_id:?}):\n{:?}",
1103 data.type_name, backtrace
1104 )
1105 .unwrap();
1106 } else {
1107 writeln!(
1108 out,
1109 "Leaked handle for entity {} ({entity_id:?}): (export LEAK_BACKTRACE to find allocation site)",
1110 data.type_name
1111 )
1112 .unwrap();
1113 }
1114 }
1115 }
1116 panic!("Exited with leaked handles:\n{out}");
1117 }
1118}
1119
1120#[cfg(any(test, feature = "leak-detection"))]
1121struct BacktraceFormatter(backtrace::Backtrace);
1122
1123#[cfg(any(test, feature = "leak-detection"))]
1124impl fmt::Debug for BacktraceFormatter {
1125 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
1126 use backtrace::{BacktraceFmt, BytesOrWideString, PrintFmt};
1127
1128 let style = if fmt.alternate() {
1129 PrintFmt::Full
1130 } else {
1131 PrintFmt::Short
1132 };
1133
1134 let cwd = std::env::current_dir();
1139 let mut print_path = move |fmt: &mut fmt::Formatter<'_>, path: BytesOrWideString<'_>| {
1140 let path = path.into_path_buf();
1141 if style != PrintFmt::Full {
1142 if let Ok(cwd) = &cwd {
1143 if let Ok(suffix) = path.strip_prefix(cwd) {
1144 return fmt::Display::fmt(&suffix.display(), fmt);
1145 }
1146 }
1147 }
1148 fmt::Display::fmt(&path.display(), fmt)
1149 };
1150
1151 let mut f = BacktraceFmt::new(fmt, style, &mut print_path);
1152 f.add_context()?;
1153 let mut strip = true;
1154 for frame in self.0.frames() {
1155 if let [symbol, ..] = frame.symbols()
1156 && let Some(name) = symbol.name()
1157 && let Some(filename) = name.as_str()
1158 {
1159 match filename {
1160 "test::run_test_in_process"
1161 | "scheduler::executor::spawn_local_with_source_location::impl$1::poll<core::pin::Pin<alloc::boxed::Box<dyn$<core::future::future::Future<assoc$<Output,enum2$<core::result::Result<workspace::OpenResult,anyhow::Error> > > > >,alloc::alloc::Global> > >" => {
1162 strip = true
1163 }
1164 "gpui::app::entity_map::LeakDetector::handle_created" => {
1165 strip = false;
1166 continue;
1167 }
1168 "zed::main" => {
1169 strip = true;
1170 f.frame().backtrace_frame(frame)?;
1171 }
1172 _ => {}
1173 }
1174 }
1175 if strip {
1176 continue;
1177 }
1178 f.frame().backtrace_frame(frame)?;
1179 }
1180 f.finish()?;
1181 Ok(())
1182 }
1183}
1184
1185#[cfg(test)]
1186mod test {
1187 use crate::EntityMap;
1188
1189 struct TestEntity {
1190 pub i: i32,
1191 }
1192
1193 #[test]
1194 fn test_entity_map_slot_assignment_before_cleanup() {
1195 let mut entity_map = EntityMap::new();
1197
1198 let slot = entity_map.reserve::<TestEntity>();
1199 entity_map.insert(slot, TestEntity { i: 1 });
1200
1201 let slot = entity_map.reserve::<TestEntity>();
1202 entity_map.insert(slot, TestEntity { i: 2 });
1203
1204 let dropped = entity_map.take_dropped();
1205 assert_eq!(dropped.len(), 2);
1206
1207 assert_eq!(
1208 dropped
1209 .into_iter()
1210 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
1211 .collect::<Vec<i32>>(),
1212 vec![1, 2],
1213 );
1214 }
1215
1216 #[test]
1217 fn test_entity_map_weak_upgrade_before_cleanup() {
1218 let mut entity_map = EntityMap::new();
1220
1221 let slot = entity_map.reserve::<TestEntity>();
1222 let handle = entity_map.insert(slot, TestEntity { i: 1 });
1223 let weak = handle.downgrade();
1224 drop(handle);
1225
1226 let strong = weak.upgrade();
1227 assert_eq!(strong, None);
1228
1229 let dropped = entity_map.take_dropped();
1230 assert_eq!(dropped.len(), 1);
1231
1232 assert_eq!(
1233 dropped
1234 .into_iter()
1235 .map(|(_, entity)| entity.downcast::<TestEntity>().unwrap().i)
1236 .collect::<Vec<i32>>(),
1237 vec![1],
1238 );
1239 }
1240
1241 #[test]
1242 fn test_leak_detector_snapshot_no_leaks() {
1243 let mut entity_map = EntityMap::new();
1244
1245 let slot = entity_map.reserve::<TestEntity>();
1246 let pre_existing = entity_map.insert(slot, TestEntity { i: 1 });
1247
1248 let snapshot = entity_map.leak_detector_snapshot();
1249
1250 let slot = entity_map.reserve::<TestEntity>();
1251 let temporary = entity_map.insert(slot, TestEntity { i: 2 });
1252 drop(temporary);
1253
1254 entity_map.assert_no_new_leaks(&snapshot);
1255
1256 drop(pre_existing);
1257 }
1258
1259 #[test]
1260 #[should_panic(expected = "New entity leaks detected since snapshot")]
1261 fn test_leak_detector_snapshot_detects_new_leak() {
1262 let mut entity_map = EntityMap::new();
1263
1264 let slot = entity_map.reserve::<TestEntity>();
1265 let pre_existing = entity_map.insert(slot, TestEntity { i: 1 });
1266
1267 let snapshot = entity_map.leak_detector_snapshot();
1268
1269 let slot = entity_map.reserve::<TestEntity>();
1270 let leaked = entity_map.insert(slot, TestEntity { i: 2 });
1271
1272 entity_map.assert_no_new_leaks(&snapshot);
1274
1275 drop(pre_existing);
1276 drop(leaked);
1277 }
1278}