1use super::node::ErasedNode;
2use super::node_update::{NodeUpdateDelayed, OnUpdateHandler};
3use super::stabilisation_num::StabilisationNum;
4use super::state::{IncrStatus, State};
5use crate::node::Node;
6use crate::node_update::HandleUpdate;
7use std::cell::RefCell;
8use std::collections::HashMap;
9use std::fmt::{Debug, Display};
10use std::hash::Hash;
11use std::rc::Rc;
12use std::{cell::Cell, rc::Weak};
13
14use super::{CellIncrement, Incr};
15use super::{NodeRef, Value};
16use crate::incrsan::NotObserver;
17
18use self::ObserverState::*;
19
20#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
21pub struct ObserverId(usize);
22impl ObserverId {
23 fn next() -> Self {
24 thread_local! {
25 static OBSERVER_ID: Cell<usize> = Cell::new(0);
26 }
27
28 OBSERVER_ID.with(|x| {
29 let next = x.get() + 1;
30 x.set(next);
31 ObserverId(next)
32 })
33 }
34}
35
36pub(crate) struct InternalObserver<T> {
37 id: ObserverId,
38 pub(crate) state: Cell<ObserverState>,
39 observing: Incr<T>,
40 weak_self: Weak<Self>,
41 on_update_handlers: RefCell<HashMap<SubscriptionToken, OnUpdateHandler<T>>>,
42 next_subscriber: Cell<SubscriptionToken>,
43}
44
45pub(crate) type WeakObserver = Weak<dyn ErasedObserver>;
46pub(crate) type StrongObserver = Rc<dyn ErasedObserver>;
47
48pub(crate) trait ErasedObserver: Debug + NotObserver {
49 fn id(&self) -> ObserverId;
50 fn state(&self) -> &Cell<ObserverState>;
51 fn observing_packed(&self) -> NodeRef;
52 fn observing_erased(&self) -> &Node;
53 fn disallow_future_use(&self, state: &State);
54 fn num_handlers(&self) -> i32;
55 fn add_to_observed_node(&self);
56 fn remove_from_observed_node(&self);
57 fn unsubscribe(&self, token: SubscriptionToken) -> Result<(), ObserverError>;
58 fn run_all(&self, input: &Node, node_update: NodeUpdateDelayed, now: StabilisationNum);
59}
60
61impl<T: Value> ErasedObserver for InternalObserver<T> {
62 fn id(&self) -> ObserverId {
63 self.id
64 }
65 fn state(&self) -> &Cell<ObserverState> {
66 &self.state
67 }
68 fn observing_packed(&self) -> NodeRef {
69 self.observing.node.clone().packed()
70 }
71 fn observing_erased(&self) -> &Node {
72 self.observing.node.erased()
73 }
74 fn disallow_future_use(&self, state: &State) {
75 match self.state.get() {
76 Disallowed | Unlinked => {}
77 Created => {
78 state
79 .num_active_observers
80 .set(state.num_active_observers.get() - 1);
81 self.state.set(Unlinked);
82 let mut ouh = self.on_update_handlers.borrow_mut();
83 ouh.clear();
84 }
85 InUse => {
86 state
87 .num_active_observers
88 .set(state.num_active_observers.get() - 1);
89 self.state.set(Disallowed);
90 let mut dobs = state.disallowed_observers.borrow_mut();
91 dobs.push(self.weak_self.clone());
92 }
93 }
94 }
95 fn num_handlers(&self) -> i32 {
96 self.on_update_handlers.borrow().len() as i32
97 }
98 fn add_to_observed_node(&self) {
99 let node = &self.observing.node;
100 node.add_observer(self.id(), self.weak_self.clone());
101 let num = node.num_on_update_handlers();
102 num.set(num.get() + self.num_handlers());
103 }
104 fn remove_from_observed_node(&self) {
105 let node = &self.observing.node;
106 node.remove_observer(self.id());
107 let num = node.num_on_update_handlers();
108 num.set(num.get() - self.num_handlers());
109 }
110
111 fn unsubscribe(&self, token: SubscriptionToken) -> Result<(), ObserverError> {
113 if token.0 != self.id {
114 return Err(ObserverError::Mismatch);
115 }
116 match self.state.get() {
117 Disallowed | Unlinked => Ok(()),
122 Created | InUse => {
123 self.on_update_handlers.borrow_mut().remove(&token);
125
126 match self.state.get() {
127 Created => {
128 Ok(())
130 }
131 InUse => {
132 let observing = self.observing_erased();
133 let num = observing.num_on_update_handlers();
134 num.increment();
135 Ok(())
136 }
137 _ => unreachable!(),
138 }
139 }
140 }
141 }
142 fn run_all(&self, input: &Node, node_update: NodeUpdateDelayed, now: StabilisationNum) {
143 let mut handlers = self.on_update_handlers.borrow_mut();
144 for (id, handler) in handlers.iter_mut() {
145 tracing::trace!("running update handler with id {id:?}");
146 match self.state.get() {
150 Created | Unlinked => panic!(),
151 Disallowed => (),
152 InUse => handler.run(input, node_update, now),
153 }
154 }
155 }
156}
157
158impl<T: Value> Debug for InternalObserver<T> {
159 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160 f.debug_struct("InternalObserver")
161 .field("state", &self.state.get())
162 .field("value", &self.try_get_value())
163 .finish()
164 }
165}
166
167impl<T: Value> InternalObserver<T> {
168 pub(crate) fn incr_state(&self) -> Option<Rc<State>> {
169 self.observing.node.state_opt()
170 }
171 pub(crate) fn new(observing: Incr<T>) -> Rc<Self> {
172 let id = ObserverId::next();
173 Rc::new_cyclic(|weak_self| Self {
174 id,
175 state: Cell::new(Created),
176 observing,
177 on_update_handlers: Default::default(),
178 weak_self: weak_self.clone(),
179 next_subscriber: SubscriptionToken(id, 1).into(),
180 })
181 }
182 pub(crate) fn try_get_value(&self) -> Result<T, ObserverError> {
183 let t = self.incr_state();
184 match t {
185 Some(t) => match t.status.get() {
186 IncrStatus::NotStabilising | IncrStatus::RunningOnUpdateHandlers => {
187 self.value_inner()
188 }
189 IncrStatus::Stabilising => Err(ObserverError::CurrentlyStabilising),
190 },
191 None => Err(ObserverError::ObservingInvalid),
193 }
194 }
195 pub(crate) fn value_inner(&self) -> Result<T, ObserverError> {
196 match self.state.get() {
197 Created => Err(ObserverError::NeverStabilised),
198 InUse => self
199 .observing
200 .node
201 .value_opt()
202 .ok_or(ObserverError::ObservingInvalid),
203 Disallowed | Unlinked => Err(ObserverError::Disallowed),
204 }
205 }
206 pub(crate) fn subscribe(
207 &self,
208 handler: OnUpdateHandler<T>,
209 ) -> Result<SubscriptionToken, ObserverError> {
210 match self.state.get() {
211 Disallowed | Unlinked => Err(ObserverError::Disallowed),
212 Created | InUse => {
213 let token = self.next_subscriber.get();
214 self.next_subscriber.set(token.succ());
215 self.on_update_handlers.borrow_mut().insert(token, handler);
216 match self.state.get() {
217 Created => {
218 }
221 InUse => {
222 let observing = self.observing_erased();
223 let num = observing.num_on_update_handlers();
224 num.set(num.get() + 1);
225 }
226 _ => unreachable!(),
227 }
228 Ok(token)
229 }
230 }
231 }
232}
233
234#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
235pub struct SubscriptionToken(ObserverId, i32);
236
237impl SubscriptionToken {
238 fn succ(&self) -> Self {
239 Self(self.0, self.1 + 1)
240 }
241 pub(crate) fn observer_id(&self) -> ObserverId {
242 self.0
243 }
244}
245
246#[derive(Copy, Clone, Debug, PartialEq)]
255pub(crate) enum ObserverState {
256 Created,
257 InUse,
258 Disallowed,
259 Unlinked,
260}
261
262#[derive(Debug, PartialEq, Eq, Clone)]
263#[non_exhaustive]
264pub enum ObserverError {
265 CurrentlyStabilising,
266 NeverStabilised,
267 Disallowed,
268 ObservingInvalid,
269 Mismatch,
270}
271
272impl Display for ObserverError {
273 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274 match self {
275 Self::CurrentlyStabilising => write!(f, "Incremental is currently stabilising. You cannot call Observer::value inside e.g. a map or bind function."),
276 Self::NeverStabilised => write!(f, "Incremental has never stabilised. Observer does not yet have a value."),
277 Self::Disallowed => write!(f, "Observer has been disallowed"),
278 Self::ObservingInvalid => write!(f, "observing an invalid Incr"),
279 Self::Mismatch => write!(f, "called unsubscribe with the wrong observer"),
280 }
281 }
282}
283impl std::error::Error for ObserverError {}
284
285#[cfg(debug_assertions)]
286impl<T> Drop for InternalObserver<T> {
287 fn drop(&mut self) {
288 let count = Rc::strong_count(&self.observing.node);
289 tracing::info!(
290 "dropping InternalObserver with id {:?}, observing node with strong_count {count}",
291 self.id
292 );
293 debug_assert!(matches!(self.state.get(), Disallowed | Unlinked));
294 }
295}