tc_state/
closure.rs

1//! An [`OpDef`] which closes over zero or more [`State`]s
2
3use std::collections::HashMap;
4use std::convert::TryInto;
5use std::fmt;
6use std::marker::PhantomData;
7
8use async_trait::async_trait;
9use destream::de;
10use futures::future::TryFutureExt;
11use futures::stream::{FuturesUnordered, TryStreamExt};
12use log::debug;
13use safecast::{CastInto, TryCastFrom, TryCastInto};
14
15use tc_error::*;
16use tc_scalar::{Executor, OpDef, OpDefType, OpRef, Scalar, SELF};
17use tc_transact::hash::{AsyncHash, Digest, Hash, Output, Sha256};
18use tc_transact::public::{DeleteHandler, GetHandler, Handler, PostHandler, PutHandler};
19use tc_transact::{Gateway, IntoView, Transaction, TxnId};
20use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf};
21
22use super::view::StateView;
23use super::{CacheBlock, State};
24
25/// An [`OpDef`] which closes over zero or more [`State`]s
26pub struct Closure<Txn> {
27    context: Map<State<Txn>>,
28    op: OpDef,
29}
30
31impl<Txn> Clone for Closure<Txn> {
32    fn clone(&self) -> Self {
33        Self {
34            context: self.context.clone(),
35            op: self.op.clone(),
36        }
37    }
38}
39
40impl<Txn> Closure<Txn> {
41    /// Return the context and [`OpDef`] which define this `Closure`.
42    pub fn into_inner(self) -> (Map<State<Txn>>, OpDef) {
43        (self.context, self.op)
44    }
45}
46
47impl<Txn> Closure<Txn>
48where
49    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
50{
51    /// Replace references to `$self` with the given `path`.
52    pub fn dereference_self(self, path: &TCPathBuf) -> Self {
53        let mut context = self.context;
54        context.remove::<Id>(&SELF.into());
55
56        let op = self.op.dereference_self::<State<Txn>>(path);
57
58        Self { context, op }
59    }
60
61    /// Return `true` if this `Closure` may write to service other than where it's defined
62    pub fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
63        self.op.is_inter_service_write::<State<Txn>>(cluster_path)
64    }
65
66    /// Replace references to the given `path` with `$self`
67    pub fn reference_self(self, path: &TCPathBuf) -> Self {
68        let before = self.op.clone();
69        let op = self.op.reference_self::<State<Txn>>(path);
70
71        let context = if op == before {
72            self.context
73        } else {
74            let op_ref = OpRef::Get((path.clone().into(), Scalar::default()));
75            let mut context = self.context;
76            context.insert(SELF.into(), op_ref.into());
77            context
78        };
79
80        Self { context, op }
81    }
82
83    /// Execute this `Closure` with the given `args`
84    pub async fn call(self, txn: &Txn, args: State<Txn>) -> TCResult<State<Txn>> {
85        let capture = if let Some(capture) = self.op.last().cloned() {
86            capture
87        } else {
88            return Ok(State::default());
89        };
90
91        let mut context = self.context;
92        let subject = context.remove::<Id>(&SELF.into());
93
94        debug!("call Closure with state {:?} and args {:?}", context, args);
95
96        match self.op {
97            OpDef::Get((key_name, op_def)) => {
98                let key = args.try_cast_into(|s| TCError::unexpected(s, "a Value"))?;
99
100                context.insert(key_name, key);
101
102                Executor::with_context(txn, subject.as_ref(), context, op_def)
103                    .capture(capture)
104                    .await
105            }
106            OpDef::Put((key_name, value_name, op_def)) => {
107                let (key, value) =
108                    args.try_cast_into(|s| TCError::unexpected(s, "arguments for PUT Op"))?;
109
110                context.insert(key_name, key);
111                context.insert(value_name, value);
112
113                Executor::with_context(txn, subject.as_ref(), context, op_def)
114                    .capture(capture)
115                    .await
116            }
117            OpDef::Post(op_def) => {
118                let params: Map<State<Txn>> = args.try_into()?;
119                context.extend(params);
120
121                Executor::with_context(txn, subject.as_ref(), context, op_def)
122                    .capture(capture)
123                    .await
124            }
125            OpDef::Delete((key_name, op_def)) => {
126                let key = args.try_cast_into(|s| TCError::unexpected(s, "a Value"))?;
127                context.insert(key_name, key);
128
129                Executor::with_context(txn, subject.as_ref(), context, op_def)
130                    .capture(capture)
131                    .await
132            }
133        }
134    }
135
136    /// Execute this `Closure` with an owned `txn` and the given `args`.
137    pub async fn call_owned(self, txn: Txn, args: State<Txn>) -> TCResult<State<Txn>> {
138        self.call(&txn, args).await
139    }
140}
141
142#[async_trait]
143impl<Txn> tc_transact::public::ClosureInstance<State<Txn>> for Closure<Txn>
144where
145    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
146{
147    async fn call(self: Box<Self>, txn: Txn, args: State<Txn>) -> TCResult<State<Txn>> {
148        self.call_owned(txn, args).await
149    }
150}
151
152impl<Txn> From<(Map<State<Txn>>, OpDef)> for Closure<Txn> {
153    fn from(tuple: (Map<State<Txn>>, OpDef)) -> Self {
154        let (context, op) = tuple;
155
156        Self { context, op }
157    }
158}
159
160impl<'a, Txn> Handler<'a, State<Txn>> for Closure<Txn>
161where
162    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
163{
164    fn get<'b>(self: Box<Self>) -> Option<GetHandler<'a, 'b, Txn, State<Txn>>>
165    where
166        'b: 'a,
167    {
168        if self.op.class() == OpDefType::Get {
169            Some(Box::new(|txn, key| Box::pin(self.call(txn, key.into()))))
170        } else {
171            None
172        }
173    }
174
175    fn put<'b>(self: Box<Self>) -> Option<PutHandler<'a, 'b, Txn, State<Txn>>>
176    where
177        'b: 'a,
178    {
179        if self.op.class() == OpDefType::Put {
180            Some(Box::new(|txn, key, value| {
181                Box::pin(self.call(txn, (key, value).cast_into()).map_ok(|_| ()))
182            }))
183        } else {
184            None
185        }
186    }
187
188    fn post<'b>(self: Box<Self>) -> Option<PostHandler<'a, 'b, Txn, State<Txn>>>
189    where
190        'b: 'a,
191    {
192        if self.op.class() == OpDefType::Post {
193            Some(Box::new(|txn, params| {
194                Box::pin(self.call(txn, params.into()))
195            }))
196        } else {
197            None
198        }
199    }
200
201    fn delete<'b>(self: Box<Self>) -> Option<DeleteHandler<'a, 'b, Txn>>
202    where
203        'b: 'a,
204    {
205        if self.op.class() == OpDefType::Delete {
206            Some(Box::new(|txn, key| {
207                Box::pin(self.call(txn, key.into()).map_ok(|_| ()))
208            }))
209        } else {
210            None
211        }
212    }
213}
214
215#[async_trait]
216impl<Txn> AsyncHash for Closure<Txn>
217where
218    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
219{
220    async fn hash(&self, txn_id: TxnId) -> TCResult<Output<Sha256>> {
221        let mut hasher = Sha256::default();
222        hasher.update(self.context.hash(txn_id).await?);
223        hasher.update(&Hash::<Sha256>::hash(&self.op));
224        Ok(hasher.finalize())
225    }
226}
227
228#[async_trait]
229impl<'en, Txn> IntoView<'en, CacheBlock> for Closure<Txn>
230where
231    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
232{
233    type Txn = Txn;
234    type View = (HashMap<Id, StateView<'en>>, OpDef);
235
236    async fn into_view(self, txn: Self::Txn) -> TCResult<Self::View> {
237        let mut context = HashMap::with_capacity(self.context.len());
238        let mut resolvers: FuturesUnordered<_> = self
239            .context
240            .into_iter()
241            .map(|(id, state)| state.into_view(txn.clone()).map_ok(|view| (id, view)))
242            .collect();
243
244        while let Some((id, state)) = resolvers.try_next().await? {
245            context.insert(id, state);
246        }
247
248        Ok((context, self.op))
249    }
250}
251
252#[async_trait]
253impl<Txn> de::FromStream for Closure<Txn>
254where
255    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
256{
257    type Context = Txn;
258
259    async fn from_stream<D: de::Decoder>(txn: Txn, decoder: &mut D) -> Result<Self, D::Error> {
260        decoder
261            .decode_seq(ClosureVisitor {
262                txn,
263                phantom: PhantomData,
264            })
265            .await
266    }
267}
268
269impl<Txn> From<OpDef> for Closure<Txn> {
270    fn from(op: OpDef) -> Self {
271        Self {
272            context: Map::default(),
273            op,
274        }
275    }
276}
277
278impl<Txn> TryCastFrom<Scalar> for Closure<Txn> {
279    fn can_cast_from(scalar: &Scalar) -> bool {
280        match scalar {
281            Scalar::Op(_) => true,
282            _ => false,
283        }
284    }
285
286    fn opt_cast_from(scalar: Scalar) -> Option<Self> {
287        match scalar {
288            Scalar::Op(op) => Some(Self {
289                context: Map::default(),
290                op,
291            }),
292            _ => None,
293        }
294    }
295}
296
297impl<Txn> fmt::Debug for Closure<Txn> {
298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299        write!(f, "closure over {:?}: {:?}", self.context, self.op)
300    }
301}
302
303struct ClosureVisitor<Txn> {
304    txn: Txn,
305    phantom: PhantomData<CacheBlock>,
306}
307
308#[async_trait]
309impl<Txn> de::Visitor for ClosureVisitor<Txn>
310where
311    Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
312{
313    type Value = Closure<Txn>;
314
315    fn expecting() -> &'static str {
316        "a Closure"
317    }
318
319    async fn visit_seq<A: de::SeqAccess>(self, mut seq: A) -> Result<Self::Value, A::Error> {
320        let context = match seq.next_element(self.txn).await? {
321            Some(State::Map(context)) => Ok(context),
322            Some(other) => Err(de::Error::invalid_type(
323                format!("{other:?}"),
324                "a Closure context",
325            )),
326            None => Err(de::Error::invalid_length(0, "a Closure context and Op")),
327        }?;
328
329        let op = seq.next_element(()).await?;
330        let op = op.ok_or_else(|| de::Error::invalid_length(1, "a Closure Op"))?;
331        Ok(Closure { context, op })
332    }
333}