1use std::collections::HashMap;
4use std::convert::TryInto;
5use std::fmt;
6use std::marker::PhantomData;
7
8use async_trait::async_trait;
9use destream::de;
10use futures::future::TryFutureExt;
11use futures::stream::{FuturesUnordered, TryStreamExt};
12use log::debug;
13use safecast::{CastInto, TryCastFrom, TryCastInto};
14
15use tc_error::*;
16use tc_scalar::{Executor, OpDef, OpDefType, OpRef, Scalar, SELF};
17use tc_transact::hash::{AsyncHash, Digest, Hash, Output, Sha256};
18use tc_transact::public::{DeleteHandler, GetHandler, Handler, PostHandler, PutHandler};
19use tc_transact::{Gateway, IntoView, Transaction, TxnId};
20use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf};
21
22use super::view::StateView;
23use super::{CacheBlock, State};
24
25pub struct Closure<Txn> {
27 context: Map<State<Txn>>,
28 op: OpDef,
29}
30
31impl<Txn> Clone for Closure<Txn> {
32 fn clone(&self) -> Self {
33 Self {
34 context: self.context.clone(),
35 op: self.op.clone(),
36 }
37 }
38}
39
40impl<Txn> Closure<Txn> {
41 pub fn into_inner(self) -> (Map<State<Txn>>, OpDef) {
43 (self.context, self.op)
44 }
45}
46
47impl<Txn> Closure<Txn>
48where
49 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
50{
51 pub fn dereference_self(self, path: &TCPathBuf) -> Self {
53 let mut context = self.context;
54 context.remove::<Id>(&SELF.into());
55
56 let op = self.op.dereference_self::<State<Txn>>(path);
57
58 Self { context, op }
59 }
60
61 pub fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
63 self.op.is_inter_service_write::<State<Txn>>(cluster_path)
64 }
65
66 pub fn reference_self(self, path: &TCPathBuf) -> Self {
68 let before = self.op.clone();
69 let op = self.op.reference_self::<State<Txn>>(path);
70
71 let context = if op == before {
72 self.context
73 } else {
74 let op_ref = OpRef::Get((path.clone().into(), Scalar::default()));
75 let mut context = self.context;
76 context.insert(SELF.into(), op_ref.into());
77 context
78 };
79
80 Self { context, op }
81 }
82
83 pub async fn call(self, txn: &Txn, args: State<Txn>) -> TCResult<State<Txn>> {
85 let capture = if let Some(capture) = self.op.last().cloned() {
86 capture
87 } else {
88 return Ok(State::default());
89 };
90
91 let mut context = self.context;
92 let subject = context.remove::<Id>(&SELF.into());
93
94 debug!("call Closure with state {:?} and args {:?}", context, args);
95
96 match self.op {
97 OpDef::Get((key_name, op_def)) => {
98 let key = args.try_cast_into(|s| TCError::unexpected(s, "a Value"))?;
99
100 context.insert(key_name, key);
101
102 Executor::with_context(txn, subject.as_ref(), context, op_def)
103 .capture(capture)
104 .await
105 }
106 OpDef::Put((key_name, value_name, op_def)) => {
107 let (key, value) =
108 args.try_cast_into(|s| TCError::unexpected(s, "arguments for PUT Op"))?;
109
110 context.insert(key_name, key);
111 context.insert(value_name, value);
112
113 Executor::with_context(txn, subject.as_ref(), context, op_def)
114 .capture(capture)
115 .await
116 }
117 OpDef::Post(op_def) => {
118 let params: Map<State<Txn>> = args.try_into()?;
119 context.extend(params);
120
121 Executor::with_context(txn, subject.as_ref(), context, op_def)
122 .capture(capture)
123 .await
124 }
125 OpDef::Delete((key_name, op_def)) => {
126 let key = args.try_cast_into(|s| TCError::unexpected(s, "a Value"))?;
127 context.insert(key_name, key);
128
129 Executor::with_context(txn, subject.as_ref(), context, op_def)
130 .capture(capture)
131 .await
132 }
133 }
134 }
135
136 pub async fn call_owned(self, txn: Txn, args: State<Txn>) -> TCResult<State<Txn>> {
138 self.call(&txn, args).await
139 }
140}
141
142#[async_trait]
143impl<Txn> tc_transact::public::ClosureInstance<State<Txn>> for Closure<Txn>
144where
145 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
146{
147 async fn call(self: Box<Self>, txn: Txn, args: State<Txn>) -> TCResult<State<Txn>> {
148 self.call_owned(txn, args).await
149 }
150}
151
152impl<Txn> From<(Map<State<Txn>>, OpDef)> for Closure<Txn> {
153 fn from(tuple: (Map<State<Txn>>, OpDef)) -> Self {
154 let (context, op) = tuple;
155
156 Self { context, op }
157 }
158}
159
160impl<'a, Txn> Handler<'a, State<Txn>> for Closure<Txn>
161where
162 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
163{
164 fn get<'b>(self: Box<Self>) -> Option<GetHandler<'a, 'b, Txn, State<Txn>>>
165 where
166 'b: 'a,
167 {
168 if self.op.class() == OpDefType::Get {
169 Some(Box::new(|txn, key| Box::pin(self.call(txn, key.into()))))
170 } else {
171 None
172 }
173 }
174
175 fn put<'b>(self: Box<Self>) -> Option<PutHandler<'a, 'b, Txn, State<Txn>>>
176 where
177 'b: 'a,
178 {
179 if self.op.class() == OpDefType::Put {
180 Some(Box::new(|txn, key, value| {
181 Box::pin(self.call(txn, (key, value).cast_into()).map_ok(|_| ()))
182 }))
183 } else {
184 None
185 }
186 }
187
188 fn post<'b>(self: Box<Self>) -> Option<PostHandler<'a, 'b, Txn, State<Txn>>>
189 where
190 'b: 'a,
191 {
192 if self.op.class() == OpDefType::Post {
193 Some(Box::new(|txn, params| {
194 Box::pin(self.call(txn, params.into()))
195 }))
196 } else {
197 None
198 }
199 }
200
201 fn delete<'b>(self: Box<Self>) -> Option<DeleteHandler<'a, 'b, Txn>>
202 where
203 'b: 'a,
204 {
205 if self.op.class() == OpDefType::Delete {
206 Some(Box::new(|txn, key| {
207 Box::pin(self.call(txn, key.into()).map_ok(|_| ()))
208 }))
209 } else {
210 None
211 }
212 }
213}
214
215#[async_trait]
216impl<Txn> AsyncHash for Closure<Txn>
217where
218 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
219{
220 async fn hash(&self, txn_id: TxnId) -> TCResult<Output<Sha256>> {
221 let mut hasher = Sha256::default();
222 hasher.update(self.context.hash(txn_id).await?);
223 hasher.update(&Hash::<Sha256>::hash(&self.op));
224 Ok(hasher.finalize())
225 }
226}
227
228#[async_trait]
229impl<'en, Txn> IntoView<'en, CacheBlock> for Closure<Txn>
230where
231 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
232{
233 type Txn = Txn;
234 type View = (HashMap<Id, StateView<'en>>, OpDef);
235
236 async fn into_view(self, txn: Self::Txn) -> TCResult<Self::View> {
237 let mut context = HashMap::with_capacity(self.context.len());
238 let mut resolvers: FuturesUnordered<_> = self
239 .context
240 .into_iter()
241 .map(|(id, state)| state.into_view(txn.clone()).map_ok(|view| (id, view)))
242 .collect();
243
244 while let Some((id, state)) = resolvers.try_next().await? {
245 context.insert(id, state);
246 }
247
248 Ok((context, self.op))
249 }
250}
251
252#[async_trait]
253impl<Txn> de::FromStream for Closure<Txn>
254where
255 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
256{
257 type Context = Txn;
258
259 async fn from_stream<D: de::Decoder>(txn: Txn, decoder: &mut D) -> Result<Self, D::Error> {
260 decoder
261 .decode_seq(ClosureVisitor {
262 txn,
263 phantom: PhantomData,
264 })
265 .await
266 }
267}
268
269impl<Txn> From<OpDef> for Closure<Txn> {
270 fn from(op: OpDef) -> Self {
271 Self {
272 context: Map::default(),
273 op,
274 }
275 }
276}
277
278impl<Txn> TryCastFrom<Scalar> for Closure<Txn> {
279 fn can_cast_from(scalar: &Scalar) -> bool {
280 match scalar {
281 Scalar::Op(_) => true,
282 _ => false,
283 }
284 }
285
286 fn opt_cast_from(scalar: Scalar) -> Option<Self> {
287 match scalar {
288 Scalar::Op(op) => Some(Self {
289 context: Map::default(),
290 op,
291 }),
292 _ => None,
293 }
294 }
295}
296
297impl<Txn> fmt::Debug for Closure<Txn> {
298 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299 write!(f, "closure over {:?}: {:?}", self.context, self.op)
300 }
301}
302
303struct ClosureVisitor<Txn> {
304 txn: Txn,
305 phantom: PhantomData<CacheBlock>,
306}
307
308#[async_trait]
309impl<Txn> de::Visitor for ClosureVisitor<Txn>
310where
311 Txn: Transaction<CacheBlock> + Gateway<State<Txn>>,
312{
313 type Value = Closure<Txn>;
314
315 fn expecting() -> &'static str {
316 "a Closure"
317 }
318
319 async fn visit_seq<A: de::SeqAccess>(self, mut seq: A) -> Result<Self::Value, A::Error> {
320 let context = match seq.next_element(self.txn).await? {
321 Some(State::Map(context)) => Ok(context),
322 Some(other) => Err(de::Error::invalid_type(
323 format!("{other:?}"),
324 "a Closure context",
325 )),
326 None => Err(de::Error::invalid_length(0, "a Closure context and Op")),
327 }?;
328
329 let op = seq.next_element(()).await?;
330 let op = op.ok_or_else(|| de::Error::invalid_length(1, "a Closure Op"))?;
331 Ok(Closure { context, op })
332 }
333}