opendp/interactive/
mod.rs

1use crate::core::{Domain, Measure, Measurement, Metric, MetricSpace};
2use std::any::{Any, type_name};
3use std::cell::RefCell;
4use std::ops::Deref;
5use std::rc::Rc;
6
7use crate::error::*;
8
9/// A queryable is like a state machine:
10/// 1. it takes an input of type `Query<Q>`,
11/// 2. updates its internal state,
12/// 3. and emits an answer of type `Answer<A>`
13pub struct Queryable<Q: ?Sized, A>(Rc<RefCell<dyn FnMut(&Self, Query<Q>) -> Fallible<Answer<A>>>>);
14
15impl<Q: ?Sized, A> Queryable<Q, A> {
16    pub fn eval(&mut self, query: &Q) -> Fallible<A> {
17        match self.eval_query(Query::External(query))? {
18            Answer::External(ext) => Ok(ext),
19            Answer::Internal(_) => fallible!(
20                FailedFunction,
21                "cannot return internal answer from an external query"
22            ),
23        }
24    }
25
26    pub fn eval_wrap(&mut self, query: &Q, wrapper: Option<Wrapper>) -> Fallible<A> {
27        if let Some(w) = wrapper {
28            wrap(w, || self.eval(query))
29        } else {
30            self.eval(query)
31        }
32    }
33
34    #[allow(dead_code)]
35    pub(crate) fn eval_internal<'a, AI: 'static>(&mut self, query: &'a dyn Any) -> Fallible<AI> {
36        match self.eval_query(Query::Internal(query))? {
37            Answer::Internal(value) => value.downcast::<AI>().map(|v| *v).map_err(|_| {
38                err!(
39                    FailedCast,
40                    "could not downcast answer to {}",
41                    type_name::<AI>()
42                )
43            }),
44            Answer::External(_) => fallible!(
45                FailedFunction,
46                "cannot return external answer from an internal query"
47            ),
48        }
49    }
50
51    #[inline]
52    pub(crate) fn eval_query(&mut self, query: Query<Q>) -> Fallible<Answer<A>> {
53        let mut transition = self.0.as_ref().try_borrow_mut().map_err(|_| {
54            err!(
55                FailedFunction,
56                "a queryable may only execute one query at a time"
57            )
58        })?;
59
60        (transition)(self, query)
61    }
62}
63
64// in the Queryable struct definition, this 'a lifetime is supplied by an HRTB after `dyn`, and then elided
65#[derive(Debug)]
66pub(crate) enum Query<'a, Q: ?Sized> {
67    External(&'a Q),
68    #[allow(dead_code)]
69    Internal(&'a dyn Any),
70}
71
72pub(crate) enum Answer<A> {
73    External(A),
74    Internal(Box<dyn Any>),
75}
76
77impl<A> Answer<A> {
78    #[allow(dead_code)]
79    pub fn internal<T: 'static>(value: T) -> Self {
80        Self::Internal(Box::new(value))
81    }
82}
83
84thread_local! {
85    pub(crate) static WRAPPER: RefCell<Option<Rc<dyn Fn(PolyQueryable) -> Fallible<PolyQueryable>>>> = RefCell::new(None);
86}
87
88pub(crate) fn wrap<T>(wrapper: Wrapper, f: impl FnOnce() -> T) -> T {
89    let prev_wrapper = WRAPPER.with(|w| w.borrow_mut().take());
90
91    let new_wrapper = Some(if let Some(prev) = prev_wrapper.clone() {
92        Rc::new(move |qbl| (prev)((wrapper)(qbl)?)) as Rc<_>
93    } else {
94        wrapper.0
95    });
96
97    WRAPPER.with(|w| *w.borrow_mut() = new_wrapper);
98    let res = f();
99    WRAPPER.with(|w| *w.borrow_mut() = prev_wrapper);
100    res
101}
102
103impl<DI: Domain, MI: Metric, MO: Measure, TO> Measurement<DI, MI, MO, TO>
104where
105    (DI, MI): MetricSpace,
106{
107    pub fn invoke_wrap(&self, arg: &DI::Carrier, wrapper: Option<Wrapper>) -> Fallible<TO> {
108        if let Some(w) = wrapper {
109            wrap(w, || self.invoke(arg))
110        } else {
111            self.invoke(arg)
112        }
113    }
114}
115
116#[derive(Clone)]
117pub struct Wrapper(pub Rc<dyn Fn(PolyQueryable) -> Fallible<PolyQueryable>>);
118
119impl Wrapper {
120    pub fn new(wrapper: impl Fn(PolyQueryable) -> Fallible<PolyQueryable> + 'static) -> Self {
121        Wrapper(Rc::new(wrapper))
122    }
123
124    /// Creates a recursive wrapper that recursively applies itself to child queryables.
125    /// `hook` is called any time the wrapped queryable or any of its children are queried.
126    pub fn new_recursive_pre_hook(hook: impl FnMut() -> Fallible<()> + Clone + 'static) -> Wrapper {
127        RecursiveWrapper(Rc::new(move |recursive_wrapper, mut inner_qbl| {
128            let mut hook = hook.clone();
129            Ok(Queryable::new_raw(move |_, query: Query<dyn Any>| {
130                // call the hook
131                hook()?;
132
133                // evaluate the query and wrap the answer
134                let out = wrap(recursive_wrapper.to_wrapper(), || {
135                    inner_qbl.eval_query(query)
136                });
137                out
138            }))
139        }))
140        .to_wrapper()
141    }
142}
143
144// make Wrapper callable as a function
145impl Deref for Wrapper {
146    type Target = dyn Fn(PolyQueryable) -> Fallible<PolyQueryable>;
147
148    fn deref(&self) -> &Self::Target {
149        &*self.0
150    }
151}
152
153/// RecursiveWrapper is a utility for constructing a closure that wraps a Queryable,
154/// in a way that recursively wraps any children of the Queryable.
155// The use of a struct avoids an infinite recursion in the type system,
156// as the first argument to the closure is the same type as the closure itself.
157#[derive(Clone)]
158struct RecursiveWrapper(pub Rc<dyn Fn(RecursiveWrapper, PolyQueryable) -> Fallible<PolyQueryable>>);
159
160impl RecursiveWrapper {
161    fn to_wrapper(&self) -> Wrapper {
162        let self_ = self.clone();
163        Wrapper::new(move |qbl: PolyQueryable| self_.0(self_.clone(), qbl))
164    }
165}
166
167impl<Q: ?Sized, A> Queryable<Q, A>
168where
169    Self: IntoPolyQueryable + FromPolyQueryable,
170{
171    pub(crate) fn new(
172        transition: impl FnMut(&Self, Query<Q>) -> Fallible<Answer<A>> + 'static,
173    ) -> Fallible<Self> {
174        let queryable = Queryable::new_raw(transition);
175        let wrapper = WRAPPER.with(|w| w.borrow().clone());
176        Ok(match wrapper {
177            None => queryable,
178            Some(w) => Queryable::from_poly(w(queryable.into_poly())?),
179        })
180    }
181
182    pub(crate) fn new_raw(
183        transition: impl FnMut(&Self, Query<Q>) -> Fallible<Answer<A>> + 'static,
184    ) -> Self {
185        Queryable(Rc::new(RefCell::new(transition)))
186    }
187
188    #[allow(dead_code)]
189    pub(crate) fn new_external(
190        mut transition: impl FnMut(&Q) -> Fallible<A> + 'static,
191    ) -> Fallible<Self> {
192        Queryable::new(
193            move |_self: &Self, query: Query<Q>| -> Fallible<Answer<A>> {
194                match query {
195                    Query::External(q) => transition(q).map(Answer::External),
196                    Query::Internal(_) => fallible!(FailedFunction, "unrecognized internal query"),
197                }
198            },
199        )
200    }
201
202    #[allow(dead_code)]
203    pub(crate) fn new_raw_external(
204        mut transition: impl FnMut(&Q) -> Fallible<A> + 'static,
205    ) -> Self {
206        Queryable::new_raw(
207            move |_self: &Self, query: Query<Q>| -> Fallible<Answer<A>> {
208                match query {
209                    Query::External(q) => transition(q).map(Answer::External),
210                    Query::Internal(_) => fallible!(FailedFunction, "unrecognized internal query"),
211                }
212            },
213        )
214    }
215}
216
217// manually implemented instead of derived so that Q and A don't have to be Clone
218impl<Q: ?Sized, A> Clone for Queryable<Q, A> {
219    fn clone(&self) -> Self {
220        Self(self.0.clone())
221    }
222}
223
224pub type PolyQueryable = Queryable<dyn Any, Box<dyn Any>>;
225
226pub trait IntoPolyQueryable {
227    fn into_poly(self) -> PolyQueryable;
228}
229
230impl<Q: 'static, A: 'static> IntoPolyQueryable for Queryable<Q, A> {
231    fn into_poly(mut self) -> PolyQueryable {
232        Queryable::new_raw(move |_self: &PolyQueryable, query: Query<dyn Any>| {
233            Ok(match query {
234                Query::External(q) => {
235                    let answer = self.eval(q.downcast_ref::<Q>().ok_or_else(|| {
236                        err!(FailedCast, "query must be of type {}", type_name::<Q>())
237                    })?)?;
238                    Answer::External(Box::new(answer))
239                }
240                Query::Internal(q) => {
241                    let Answer::Internal(a) = self.eval_query(Query::Internal(q))? else {
242                        return fallible!(
243                            FailedFunction,
244                            "internal query returned external answer"
245                        );
246                    };
247                    Answer::Internal(a)
248                }
249            })
250        })
251    }
252}
253
254// The previous impl over all Q has an implicit `Sized` trait bound, whereas this parameterizes Q as dyn Any, which is not Sized.
255// Therefore, the compiler recognizes these impls as disjoint.
256impl IntoPolyQueryable for PolyQueryable {
257    fn into_poly(self) -> PolyQueryable {
258        // if already a PolyQueryable, no need to do anything.
259        self
260    }
261}
262
263pub trait FromPolyQueryable {
264    fn from_poly(v: PolyQueryable) -> Self;
265}
266
267impl<Q: 'static, A: 'static> FromPolyQueryable for Queryable<Q, A> {
268    fn from_poly(mut self_: PolyQueryable) -> Self {
269        Queryable::new_raw(move |_self: &Queryable<Q, A>, query: Query<Q>| {
270            Ok(match query {
271                Query::External(query) => {
272                    let answer = self_.eval(query)?;
273
274                    let answer = *answer.downcast::<A>().map_err(|_| {
275                        err!(FailedCast, "failed to downcast to {:?}", type_name::<A>())
276                    })?;
277                    Answer::External(answer)
278                }
279                Query::Internal(q) => {
280                    let Answer::Internal(a) = self_.eval_query(Query::Internal(q))? else {
281                        return fallible!(
282                            FailedFunction,
283                            "internal query returned external answer"
284                        );
285                    };
286                    Answer::Internal(a)
287                }
288            })
289        })
290    }
291}
292
293// The previous impl over all Q has an implicit `Sized` trait bound, whereas this parameterizes Q as dyn Any, which is not Sized.
294// Therefore, the compiler recognizes these impls as disjoint.
295impl FromPolyQueryable for PolyQueryable {
296    fn from_poly(self_: Self) -> Self {
297        // if already a PolyQueryable, no need to do anything.
298        self_
299    }
300}
301
302impl<Q: ?Sized> Queryable<Q, Box<dyn Any>> {
303    /// Evaluates a polymorphic query and downcasts to the given type.
304    pub fn eval_poly<A: 'static>(&mut self, query: &Q) -> Fallible<A> {
305        self.eval(query)?
306            .downcast()
307            .map_err(|_| {
308                err!(
309                    FailedCast,
310                    "eval_poly failed to downcast to {}",
311                    std::any::type_name::<A>()
312                )
313            })
314            .map(|b| *b)
315    }
316}