Skip to main content

tierkreis_runtime/operations/
mod.rs

1//! Implements operations as processes each acting asynchronously on a stream
2//! of inputs to produce a stream of outputs.
3use self::eval::run_eval;
4use self::function::run_fn;
5use self::graph::GraphOperation;
6use self::r#box::run_box;
7use crate::util::JoinHandleWithDrop;
8use crate::workers::EscapeHatch;
9use crate::Runtime;
10use anyhow::{anyhow, bail};
11use futures::future::{self, AbortHandle, BoxFuture};
12use futures::stream::BoxStream;
13use futures::{Future, FutureExt, Stream, StreamExt};
14use std::collections::HashSet;
15use std::pin::Pin;
16use std::{collections::HashMap, sync::Arc};
17use thiserror::Error;
18use tierkreis_core::graph::{Edge, Graph, GraphBuilder, Node, Value};
19use tierkreis_core::prelude::TryInto;
20use tierkreis_core::symbol::{Label, Location, SymbolError};
21use tierkreis_proto::messages::{Callback, Completed, GraphTrace, Status};
22use tokio::sync::mpsc;
23use tokio::sync::watch;
24use tokio_stream::wrappers::UnboundedReceiverStream;
25use tracing::Instrument;
26
27pub(crate) mod r#box;
28pub(crate) mod eval;
29pub(crate) mod function;
30pub(crate) mod graph;
31pub(crate) mod variant;
32
33pub use graph::checkpoint_client::CheckpointClient;
34
35/// A stream of inputs (one per wire) arriving at an operation,
36/// plus a notification after the last input (wire) has arrived.
37pub struct OperationInputs(BoxStream<'static, Input>);
38
39impl Stream for OperationInputs {
40    type Item = Input;
41
42    fn poll_next(
43        mut self: std::pin::Pin<&mut Self>,
44        cx: &mut std::task::Context<'_>,
45    ) -> std::task::Poll<Option<Self::Item>> {
46        Pin::new(&mut self.as_mut().0).poll_next(cx)
47    }
48}
49
50/// Stream of outputs generated by a running operation.
51///
52/// When this stream is dropped the operation will be cancelled.
53pub struct OperationOutputs {
54    stream: BoxStream<'static, Output>,
55    abort_input: AbortHandle,
56}
57
58impl OperationOutputs {
59    /// Views the operation as a process
60    pub fn into_task(self) -> TaskHandle {
61        TaskHandle::new(self)
62    }
63}
64
65impl Drop for OperationOutputs {
66    fn drop(&mut self) {
67        self.abort_input.abort()
68    }
69}
70
71impl Stream for OperationOutputs {
72    type Item = Output;
73
74    fn poll_next(
75        mut self: Pin<&mut Self>,
76        cx: &mut std::task::Context<'_>,
77    ) -> std::task::Poll<Option<Self::Item>> {
78        Pin::new(&mut self.as_mut().stream).poll_next(cx)
79    }
80}
81
82/// A single operation that can be started (executed asynchronously) at runtime,
83/// given some inputs and a context including a stream of outputs. Does not
84/// necessarily have to be related to a specific node in a Tierkreis graph,
85/// although generally it is.
86pub struct RuntimeOperation {
87    start: Box<
88        dyn FnOnce(OperationContext, OperationInputs) -> BoxFuture<'static, anyhow::Result<()>>
89            + Send,
90    >,
91}
92
93impl RuntimeOperation {
94    fn new<F, FF>(f: F) -> Self
95    where
96        F: FnOnce(OperationContext, OperationInputs) -> FF + Send + 'static,
97        FF: Future<Output = anyhow::Result<()>> + Send + 'static,
98    {
99        Self {
100            start: Box::new(|ctx, inputs| f(ctx, inputs).boxed()),
101        }
102    }
103
104    /// Creates a RuntimeOperation that just produces a constant value (given).
105    pub fn new_const(value: Value) -> RuntimeOperation {
106        operation_const(value)
107    }
108
109    /// Creates a RuntimeOperation that runs a graph (found in a box) at some location,
110    /// i.e. perhaps remotely.
111    /// Note: if `loc` is [Location::local] then this is largely equivalent to [Self::new_graph].
112    pub fn new_box(loc: Location, graph: Graph) -> RuntimeOperation {
113        RuntimeOperation::new(|ctx, inputs| run_box(loc, graph, ctx, inputs))
114    }
115
116    /// Creates a RuntimeOperation that runs a graph (in the local runtime, interleaved
117    /// with other nodes/etc. executing). Inputs to the operation will become outputs of
118    /// the graph Input node; the operation outputs are those received by the graph Output node.    
119    pub fn new_graph(graph: Graph) -> RuntimeOperation {
120        RuntimeOperation::new(move |ctx, inputs| GraphOperation::new(graph, ctx, inputs).run())
121    }
122
123    /// Creates a RuntimeOperation that runs a `match` operation, i.e. expects
124    /// to be given a variant value and a Graph for each *possible* variant, and
125    /// runs the appropriate one of those Graphs.
126    pub(crate) fn new_match() -> RuntimeOperation {
127        RuntimeOperation::new(variant::run_match)
128    }
129
130    /// Creates a RuntimeOperation that tags a value to make a [Value::Variant]
131    pub(crate) fn new_tag(tag: Label) -> RuntimeOperation {
132        RuntimeOperation::new(move |ctx, inputs| variant::run_tag(tag, ctx, inputs))
133    }
134
135    /// Create a new runtime operation from a simple function.
136    ///
137    /// Do not use this for blocking or expensive CPU bound functions.
138    pub fn new_fn_simple<F>(f: F) -> Self
139    where
140        F: FnOnce(HashMap<Label, Value>, OperationContext) -> anyhow::Result<HashMap<Label, Value>>
141            + Send
142            + 'static,
143    {
144        let f = |inputs, ctx| futures::future::ready(f(inputs, ctx));
145        RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
146    }
147
148    /// Create a new runtime operation from an async function.
149    ///
150    /// Do not use this for blocking or expensive CPU bound functions.
151    pub fn new_fn_async<F, FF>(f: F) -> Self
152    where
153        F: FnOnce(HashMap<Label, Value>, OperationContext) -> FF + Send + 'static,
154        FF: Future<Output = anyhow::Result<HashMap<Label, Value>>> + Send + 'static,
155    {
156        let f = |inputs, ctx| {
157            let span = tracing::Span::current();
158            let handle = tokio::spawn(f(inputs, ctx).instrument(span));
159            JoinHandleWithDrop::from(handle).map(|r| r.unwrap_or_else(|e| Err(e.into())))
160        };
161        RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
162    }
163
164    /// Create a new runtime operation from a blocking function.
165    ///
166    /// The function is run via [`tokio::task::spawn_blocking`].
167    pub fn new_fn_blocking<F>(f: F) -> Self
168    where
169        F: FnOnce(HashMap<Label, Value>, OperationContext) -> anyhow::Result<HashMap<Label, Value>>
170            + Send
171            + 'static,
172    {
173        let f = |inputs, ctx| {
174            let span = tracing::Span::current();
175            let handle = tokio::task::spawn_blocking(move || span.in_scope(|| f(inputs, ctx)));
176            JoinHandleWithDrop::from(handle).map(|r| r.unwrap_or_else(|e| Err(e.into())))
177        };
178        RuntimeOperation::new(move |ctx, inputs| run_fn(f, ctx, inputs))
179    }
180
181    /// Starts the operation i.e. so it will execute asynchronously (using `tokio::spawn`).
182    /// Returns the stream of outputs.
183    pub fn run<S>(
184        self,
185        runtime: Runtime,
186        callback: Callback,
187        escape: EscapeHatch,
188        inputs: S,
189        stack_trace: GraphTrace,
190        checkpoint_client: Option<CheckpointClient>,
191    ) -> OperationOutputs
192    where
193        S: Stream<Item = Input> + Send + 'static,
194    {
195        let inputs = inputs.chain(futures::stream::pending());
196        let (inputs, abort_handle) = futures::stream::abortable(inputs);
197
198        let (output_tx, output_rx) = mpsc::unbounded_channel();
199
200        let context = OperationContext {
201            output: output_tx.clone(),
202            callback,
203            escape,
204            runtime,
205            graph_trace: stack_trace,
206            checkpoint_client,
207        };
208
209        let span = tracing::Span::current();
210
211        // TODO: Use the join handle to react to panics
212        tokio::spawn(
213            async move {
214                let result = (self.start)(context, OperationInputs(inputs.boxed())).await;
215
216                let _ = match result {
217                    Ok(()) => output_tx.send(Output::Success),
218                    Err(error) => output_tx.send(Output::Failure { error }),
219                };
220            }
221            .instrument(span),
222        );
223
224        OperationOutputs {
225            stream: UnboundedReceiverStream::new(output_rx).boxed(),
226            abort_input: abort_handle,
227        }
228    }
229
230    /// Like [`RuntimeOperation::run`] but the inputs are available immediately.
231    pub fn run_simple<I>(
232        self,
233        runtime: Runtime,
234        callback: Callback,
235        escape: EscapeHatch,
236        inputs: I,
237        stack_trace: GraphTrace,
238        checkpoint_client: Option<CheckpointClient>,
239    ) -> OperationOutputs
240    where
241        I: IntoIterator<Item = (Label, Value)>,
242    {
243        let inputs: HashMap<_, _> = inputs.into_iter().collect();
244
245        let inputs_stream = futures::stream::iter(inputs)
246            .map(|(port, value)| Input::Input { port, value })
247            .chain(tokio_stream::once(Input::Complete));
248
249        self.run(
250            runtime,
251            callback,
252            escape,
253            inputs_stream,
254            stack_trace,
255            checkpoint_client,
256        )
257    }
258}
259
260/// A task-level view of a process, allowing only to poll or block for completion
261/// and to request the task to abort, rather than a stream of results.
262#[derive(Clone)]
263pub struct TaskHandle {
264    status: watch::Receiver<Status>,
265    abort: mpsc::UnboundedSender<()>,
266}
267
268impl std::fmt::Debug for TaskHandle {
269    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
270        f.debug_struct("TaskHandle").finish()
271    }
272}
273
274impl TaskHandle {
275    fn new(mut output_stream: OperationOutputs) -> Self {
276        let (status_tx, status_rx) = watch::channel(Status::Running);
277        let (abort_tx, mut abort_rx) = mpsc::unbounded_channel();
278
279        tokio::spawn(async move {
280            let mut outputs = HashMap::new();
281
282            loop {
283                let output = tokio::select! { biased;
284                    _ = abort_rx.recv() => break,
285                    msg = output_stream.next() => {
286                        match msg {
287                            Some(msg) => msg,
288                            None => break,
289                        }
290                    }
291                };
292
293                match output {
294                    Output::Output { port, value } => {
295                        outputs.insert(port, value);
296                    }
297                    Output::Success => {
298                        let _ = status_tx.send(Status::Completed(Ok(Arc::new(outputs))));
299
300                        return;
301                    }
302                    Output::Failure { error } => {
303                        let _ = status_tx.send(Status::Completed(Err(Arc::new(error))));
304                        return;
305                    }
306                }
307            }
308
309            let error = anyhow!("task was cancelled");
310            let _ = status_tx.send(Status::Completed(Err(Arc::new(error))));
311        });
312
313        Self {
314            status: status_rx,
315            abort: abort_tx,
316        }
317    }
318
319    /// Polls the current status of the process.
320    /// (Note this copies the entire result map, if there is one.)
321    pub fn status(&self) -> Status {
322        self.status.borrow().clone()
323    }
324
325    /// Cancels a running task even when there are other copies of the task handle.
326    pub fn cancel(&self) {
327        let _ = self.abort.send(());
328    }
329
330    /// Waits for the task to complete and returns the result (outputs or error)
331    pub async fn complete(&mut self) -> Completed {
332        loop {
333            if let Status::Completed(result) = self.status() {
334                match result {
335                    Ok(_) => tracing::debug!("complete"),
336                    Err(_) => tracing::warn!("complete with error"),
337                }
338                return result;
339            }
340            tracing::debug!("still running");
341
342            let event = self.status.changed().await;
343
344            if event.is_err() {
345                unreachable!("watch was closed before setting a completed status");
346            }
347        }
348    }
349}
350
351/// An event or signal sent to a process about its inputs
352#[derive(Debug)]
353pub enum Input {
354    /// A value is available at an input port.
355    #[allow(missing_docs)]
356    Input { port: Label, value: Value },
357    /// Values for all input ports have been sent.
358    Complete,
359}
360
361/// An event or signal sent from a process about its outputs
362#[derive(Debug)]
363pub enum Output {
364    /// A value is available at an output port.
365    #[allow(missing_docs)]
366    Output { port: Label, value: Value },
367    /// The operation has finished processing successfully.
368    Success,
369    /// The operation has finished processing with an error.
370    #[allow(missing_docs)]
371    Failure { error: anyhow::Error },
372}
373
374impl Output {
375    /// Adds a line to the error trace that will be displayed if this is an [Output::Failure]
376    pub fn context<C>(self, context: C) -> Self
377    where
378        C: std::fmt::Display + Send + Sync + 'static,
379    {
380        match self {
381            Self::Output { port, value } => Self::Output { port, value },
382            Self::Success => Self::Success,
383            Self::Failure { error } => {
384                let error = error.context(context);
385                Self::Failure { error }
386            }
387        }
388    }
389}
390
391/// The context available to an actor running an operation.
392#[derive(Clone)]
393pub struct OperationContext {
394    output: mpsc::UnboundedSender<Output>,
395    /// Runtime within which the operation executes; allows access to features
396    /// outside the particular graph being run, such as connected workers.
397    pub runtime: Runtime,
398    /// FunctionWorkers (executing run_function requests) can use this to make run_graph
399    /// calls (also infer_type and other RuntimeWorker operations, but not run_function).
400    /// The Callback will route requests back up the tree of runtimes to the
401    /// closest explicitly-enclosing Scope specified by the user's Graph.
402    pub callback: Callback,
403    /// Runtimes use this to run any function(-name)s they don't recognize themselves;
404    /// it will route requests back up to the original root  of the client request.
405    pub escape: EscapeHatch,
406    /// The trace of evaluations by which we came to be executing this operation (node *or* graph)
407    pub graph_trace: GraphTrace,
408    /// The checkpointing client for the current job.
409    pub checkpoint_client: Option<CheckpointClient>,
410}
411
412impl OperationContext {
413    /// Notify the graph actor that a value is available at one of the node's output ports.
414    pub fn set_output(&self, port: impl Into<Label>, value: Value) {
415        let _ = self.output.send(Output::Output {
416            port: port.into(),
417            value,
418        });
419    }
420
421    // check if the graph is the outermost graph and has a checkpoint client
422    fn outer_graph_checkpoint(&mut self) -> Option<&mut CheckpointClient> {
423        if self.graph_trace == GraphTrace::Root {
424            self.checkpoint_client.as_mut()
425        } else {
426            None
427        }
428    }
429}
430
431pub(crate) fn operation_eval() -> RuntimeOperation {
432    RuntimeOperation::new(run_eval)
433}
434
435pub(crate) fn operation_const(value: Value) -> RuntimeOperation {
436    RuntimeOperation::new_fn_simple(move |_inputs, _context| {
437        let mut outputs = HashMap::new();
438        outputs.insert(Label::value(), value);
439        Ok(outputs)
440    })
441}
442
443pub(crate) fn operation_id() -> RuntimeOperation {
444    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
445        let mut outputs = HashMap::new();
446        outputs.insert(Label::value(), take_input(&mut inputs, Label::value())?);
447        Ok(outputs)
448    })
449}
450
451pub(crate) fn operation_sleep() -> RuntimeOperation {
452    RuntimeOperation::new_fn_async(|mut inputs, _context| async move {
453        let delay = validate_float_input(&mut inputs, "delay_secs")?;
454        tokio::time::sleep(tokio::time::Duration::from_secs_f64(delay)).await;
455        let mut outputs = HashMap::new();
456        outputs.insert(Label::value(), take_input(&mut inputs, Label::value())?);
457        Ok(outputs)
458    })
459}
460
461pub(crate) fn operation_copy() -> RuntimeOperation {
462    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
463        let value = take_input(&mut inputs, Label::value())?;
464        let mut outputs = HashMap::new();
465        outputs.insert(TryInto::try_into("value_0")?, value.clone());
466        outputs.insert(TryInto::try_into("value_1")?, value);
467        Ok(outputs)
468    })
469}
470
471pub(crate) fn operation_discard() -> RuntimeOperation {
472    RuntimeOperation::new_fn_simple(|_inputs, _context| Ok(HashMap::new()))
473}
474
475pub(crate) fn operation_equality() -> RuntimeOperation {
476    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
477        let val0 = take_input(&mut inputs, "value_0")?;
478        let val1 = take_input(&mut inputs, "value_1")?;
479        let mut outputs = HashMap::new();
480        outputs.insert(TryInto::try_into("result")?, Value::Bool(val0 == val1));
481        Ok(outputs)
482    })
483}
484
485pub(crate) fn operation_not_equality() -> RuntimeOperation {
486    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
487        let val0 = take_input(&mut inputs, "value_0")?;
488        let val1 = take_input(&mut inputs, "value_1")?;
489        let mut outputs = HashMap::new();
490        outputs.insert(TryInto::try_into("result")?, Value::Bool(val0 != val1));
491        Ok(outputs)
492    })
493}
494
495pub(crate) fn operation_switch() -> RuntimeOperation {
496    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
497        let predicate = validate_bool_input(&mut inputs, "pred")?;
498
499        let branch_true = take_input(&mut inputs, "if_true")?;
500        let branch_false = take_input(&mut inputs, "if_false")?;
501
502        let result = if predicate { branch_true } else { branch_false };
503
504        let mut outputs = HashMap::new();
505        outputs.insert(Label::value(), result);
506        Ok(outputs)
507    })
508}
509
510pub(crate) fn operation_make_pair() -> RuntimeOperation {
511    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
512        let first = take_input(&mut inputs, "first")?;
513        let second = take_input(&mut inputs, "second")?;
514
515        let mut outputs = HashMap::new();
516        outputs.insert(
517            TryInto::try_into("pair")?,
518            Value::Pair(Box::new((first, second))),
519        );
520        Ok(outputs)
521    })
522}
523
524pub(crate) fn operation_unpack_pair() -> RuntimeOperation {
525    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
526        let (first, second) = validate_input(&mut inputs, "pair", |x| match x {
527            Value::Pair(pair) => Some((pair.0, pair.1)),
528            _ => None,
529        })?;
530
531        let mut outputs = HashMap::new();
532        outputs.insert(TryInto::try_into("first")?, first);
533        outputs.insert(TryInto::try_into("second")?, second);
534        Ok(outputs)
535    })
536}
537
538pub(crate) fn operation_push() -> RuntimeOperation {
539    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
540        let vec_l: Label = TryInto::try_into("vec")?;
541        let mut vec = match take_input(&mut inputs, vec_l)? {
542            Value::Vec(vec) => vec,
543            _ => bail!("Push function expected vector input."),
544        };
545
546        let item = take_input(&mut inputs, "item")?;
547
548        vec.push(item);
549        let mut outputs = HashMap::new();
550        outputs.insert(vec_l, Value::Vec(vec));
551        Ok(outputs)
552    })
553}
554
555pub(crate) fn operation_pop() -> RuntimeOperation {
556    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
557        let vec_l: Label = TryInto::try_into("vec")?;
558        let mut vec = match take_input(&mut inputs, vec_l)? {
559            Value::Vec(vec) => {
560                if vec.is_empty() {
561                    Err(RuntimeError::EmptyVector)
562                } else {
563                    Ok(vec)
564                }
565            }
566            _ => Err(RuntimeError::InvalidInput(vec_l)),
567        }?;
568
569        let item = vec.pop().unwrap();
570        let mut outputs = HashMap::new();
571        outputs.insert(TryInto::try_into("item")?, item);
572        outputs.insert(vec_l, Value::Vec(vec));
573
574        Ok(outputs)
575    })
576}
577
578pub(crate) fn operation_loop() -> RuntimeOperation {
579    RuntimeOperation::new_fn_async(|mut inputs, context| async move {
580        let _ = &context;
581        let body = validate_graph_input(&mut inputs, "body")?;
582        let mut value = take_input(&mut inputs, Label::value())?;
583
584        let node_trace = context
585            .graph_trace
586            .as_node_trace()
587            .map_err(|_| anyhow!("loop function expected stack trace to correspond to a node"))?;
588        for iteration in 1.. {
589            // Run body
590            let graph_trace = node_trace.clone().loop_iter(iteration);
591
592            let body_output = RuntimeOperation::new_graph(body.clone())
593                .run_simple(
594                    context.runtime.clone(),
595                    context.callback.clone(),
596                    context.escape.clone(),
597                    [(Label::value(), value)],
598                    graph_trace,
599                    context.checkpoint_client.clone(),
600                )
601                .into_task()
602                .complete()
603                .await
604                .map_err(|err| {
605                    // Formatting with debug gets us more/deeper context than anyhow! alone
606                    let e = anyhow!(format!("{:?}", err.as_ref()));
607                    e.context(format!("loop body (iteration {})", iteration))
608                })?;
609
610            let body_output = body_output.as_ref();
611            if let Some(Value::Variant(label, b)) = body_output.get(&Label::value()) {
612                value = *b.clone();
613                if label == &Label::continue_() {
614                    continue;
615                } else if label == &Label::break_() {
616                    break;
617                }
618            };
619            // Both expected variants will have jumped elsewhere
620            bail!(
621                "loop node expected body to output a Variant (break | continue) on port 'value' (iteration {})",
622                iteration
623            )
624        }
625        Ok(HashMap::from([(Label::value(), value)]))
626    })
627}
628
629pub(crate) fn operation_sequence() -> RuntimeOperation {
630    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
631        let first = validate_graph_input(&mut inputs, "first")?;
632
633        let second = validate_graph_input(&mut inputs, "second")?;
634
635        let g3 = {
636            let mut builder = GraphBuilder::new();
637            let [input, output] = Graph::boundary();
638            let inputs: Vec<Edge> = first.node_outputs(input).cloned().collect();
639            let second_input_ports: HashSet<_> =
640                second.node_outputs(input).map(|e| e.source.port).collect();
641            // only wire up shared ports
642            let middle: Vec<Edge> = first
643                .node_inputs(output)
644                .filter(|e| second_input_ports.contains(&e.target.port))
645                .cloned()
646                .collect();
647            let outputs: Vec<Edge> = second.node_inputs(output).cloned().collect();
648
649            let b1 = builder.add_node(Node::local_box(first))?;
650            let b2 = builder.add_node(Node::local_box(second))?;
651
652            for input_edge in inputs {
653                builder.add_edge(
654                    (input, input_edge.source.port),
655                    (b1, input_edge.source.port),
656                    input_edge.edge_type,
657                )?;
658            }
659            for seq_edge in middle {
660                builder.add_edge(
661                    (b1, seq_edge.target.port),
662                    (b2, seq_edge.target.port),
663                    seq_edge.edge_type,
664                )?;
665            }
666
667            for output_edge in outputs {
668                builder.add_edge(
669                    (b2, output_edge.target.port),
670                    (output, output_edge.target.port),
671                    output_edge.edge_type,
672                )?;
673            }
674            builder.build()?
675        };
676
677        let mut outputs = HashMap::new();
678        outputs.insert(TryInto::try_into("sequenced")?, Value::Graph(g3));
679
680        Ok(outputs)
681    })
682}
683
684pub(crate) fn operation_parallel() -> RuntimeOperation {
685    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
686        let left = validate_graph_input(&mut inputs, "left")?;
687
688        let right = validate_graph_input(&mut inputs, "right")?;
689
690        let g3 = {
691            let mut builder = GraphBuilder::new();
692            let [input, output] = Graph::boundary();
693
694            let inputs_left: Vec<Edge> = left.node_outputs(input).cloned().collect();
695            let inputs_right: Vec<Edge> = right.node_outputs(input).cloned().collect();
696            let outputs_left: Vec<Edge> = left.node_inputs(output).cloned().collect();
697            let outputs_right: Vec<Edge> = right.node_inputs(output).cloned().collect();
698
699            let b_left = builder.add_node(Node::local_box(left))?;
700            let b_right = builder.add_node(Node::local_box(right))?;
701
702            for left_input_edge in inputs_left {
703                builder.add_edge(
704                    (input, left_input_edge.source.port),
705                    (b_left, left_input_edge.source.port),
706                    left_input_edge.edge_type,
707                )?;
708            }
709
710            for right_input_edge in inputs_right {
711                builder.add_edge(
712                    (input, right_input_edge.source.port),
713                    (b_right, right_input_edge.source.port),
714                    right_input_edge.edge_type,
715                )?;
716            }
717
718            for left_output_edge in outputs_left {
719                builder.add_edge(
720                    (b_left, left_output_edge.target.port),
721                    (output, left_output_edge.target.port),
722                    left_output_edge.edge_type,
723                )?;
724            }
725
726            for right_output_edge in outputs_right {
727                builder.add_edge(
728                    (b_right, right_output_edge.target.port),
729                    (output, right_output_edge.target.port),
730                    right_output_edge.edge_type,
731                )?;
732            }
733
734            builder.build()?
735        };
736
737        let outputs = HashMap::from([(Label::value(), Value::Graph(g3))]);
738
739        Ok(outputs)
740    })
741}
742
743pub(crate) fn operation_make_struct() -> RuntimeOperation {
744    RuntimeOperation::new_fn_simple(|inputs, _context| {
745        let struc = Value::Struct(inputs);
746        let mut outputs = HashMap::new();
747        outputs.insert(TryInto::try_into("struct")?, struc);
748        Ok(outputs)
749    })
750}
751
752pub(crate) fn operation_unpack_struct() -> RuntimeOperation {
753    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
754        let outputs = validate_input(&mut inputs, "struct", |x| match x {
755            Value::Struct(fields) => Some(fields),
756            _ => None,
757        })?;
758        Ok(outputs)
759    })
760}
761
762pub(crate) fn operation_insert_key() -> RuntimeOperation {
763    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
764        let map_l: Label = TryInto::try_into("map")?;
765        let mut map = validate_input(&mut inputs, map_l, |x| match x {
766            Value::Map(map) => Some(map),
767            _ => None,
768        })?;
769
770        let key = take_input(&mut inputs, "key")?;
771        let val = take_input(&mut inputs, "val")?;
772
773        map.insert(key, val);
774        let mut outputs = HashMap::new();
775        outputs.insert(map_l, Value::Map(map));
776        Ok(outputs)
777    })
778}
779
780pub(crate) fn operation_remove_key() -> RuntimeOperation {
781    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
782        let map_l: Label = TryInto::try_into("map")?;
783        let mut map = validate_input(&mut inputs, map_l, |x| match x {
784            Value::Map(map) => Some(map),
785            _ => None,
786        })?;
787
788        let key = take_input(&mut inputs, "key")?;
789        let val = map.remove(&key).ok_or(RuntimeError::KeyNotFound(key))?;
790        let mut outputs = HashMap::new();
791        outputs.insert(map_l, Value::Map(map));
792        outputs.insert(TryInto::try_into("val")?, val);
793        Ok(outputs)
794    })
795}
796
797// numeric operations
798
799pub(crate) fn binary_int_operation_with_error<F>(f: F) -> RuntimeOperation
800where
801    F: FnOnce(i64, i64) -> anyhow::Result<i64> + Sync + Send + 'static,
802{
803    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
804        let a = validate_int_input(&mut inputs, "a")?;
805        let b = validate_int_input(&mut inputs, "b")?;
806
807        let result = f(a, b)?;
808        let mut outputs = HashMap::new();
809        outputs.insert(Label::value(), Value::Int(result));
810        Ok(outputs)
811    })
812}
813
814pub(crate) fn binary_int_operation<F>(f: F) -> RuntimeOperation
815where
816    F: FnOnce(i64, i64) -> i64 + Sync + Send + 'static,
817{
818    binary_int_operation_with_error(|a, b| Ok(f(a, b)))
819}
820
821pub(crate) fn binary_flt_operation<F>(f: F) -> RuntimeOperation
822where
823    F: FnOnce(f64, f64) -> f64 + Sync + Send + 'static,
824{
825    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
826        let a = validate_float_input(&mut inputs, "a")?;
827        let b = validate_float_input(&mut inputs, "b")?;
828
829        let mut outputs = HashMap::new();
830        outputs.insert(Label::value(), Value::Float(f(a, b)));
831        Ok(outputs)
832    })
833}
834
835pub(crate) fn binary_int_comparison<F>(f: F) -> RuntimeOperation
836where
837    F: FnOnce(i64, i64) -> bool + Sync + Send + 'static,
838{
839    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
840        let a = validate_int_input(&mut inputs, "a")?;
841        let b = validate_int_input(&mut inputs, "b")?;
842
843        let mut outputs = HashMap::new();
844        outputs.insert(Label::value(), Value::Bool(f(a, b)));
845        Ok(outputs)
846    })
847}
848
849pub(crate) fn binary_flt_comparison<F>(f: F) -> RuntimeOperation
850where
851    F: FnOnce(f64, f64) -> bool + Sync + Send + 'static,
852{
853    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
854        let a = validate_float_input(&mut inputs, "a")?;
855        let b = validate_float_input(&mut inputs, "b")?;
856
857        let mut outputs = HashMap::new();
858        outputs.insert(Label::value(), Value::Bool(f(a, b)));
859        Ok(outputs)
860    })
861}
862
863pub(crate) fn binary_bool_operation<F>(f: F) -> RuntimeOperation
864where
865    F: FnOnce(bool, bool) -> bool + Sync + Send + 'static,
866{
867    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
868        let a = validate_bool_input(&mut inputs, "a")?;
869        let b = validate_bool_input(&mut inputs, "b")?;
870
871        let mut outputs = HashMap::new();
872        outputs.insert(Label::value(), Value::Bool(f(a, b)));
873        Ok(outputs)
874    })
875}
876
877pub(crate) fn operation_int_to_float() -> RuntimeOperation {
878    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
879        let int = validate_int_input(&mut inputs, "int")?;
880
881        let mut outputs = HashMap::new();
882        outputs.insert(Label::value(), Value::Float(int as f64));
883        Ok(outputs)
884    })
885}
886
887pub(crate) fn operation_float_to_int() -> RuntimeOperation {
888    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
889        let flt = validate_float_input(&mut inputs, "float")?;
890
891        let mut outputs = HashMap::new();
892        outputs.insert(Label::value(), Value::Int(flt as i64));
893        Ok(outputs)
894    })
895}
896
897pub(crate) fn operation_partial() -> RuntimeOperation {
898    RuntimeOperation::new_fn_simple(|mut inputs, _context| {
899        let thunk = validate_graph_input(&mut inputs, Label::thunk())?;
900
901        let g3 = {
902            let mut builder = GraphBuilder::new();
903            let [input, output] = Graph::boundary();
904
905            let input_edges: Vec<Edge> = thunk.node_outputs(input).cloned().collect();
906            let output_edges: Vec<Edge> = thunk.node_inputs(output).cloned().collect();
907            let b1 = builder.add_node(Node::local_box(thunk))?;
908
909            // Note that if values are provided to 'partial' that do NOT
910            // correspond to inputs to the thunk, this code will drop them here
911            // (at closure-creation time), whereas the old code would have
912            // passed them into the closure (which I think would have ignored them)
913            for input_edge in input_edges {
914                let port = input_edge.source.port;
915                let source = match inputs.remove(&port) {
916                    Some(value) => {
917                        let new_const = builder.add_node(Node::Const(value))?;
918                        (new_const, Label::value())
919                    }
920                    None => (input, port),
921                };
922                builder.add_edge(source, (b1, port), input_edge.edge_type)?;
923            }
924            //So: are there any ports in inputs.keys() that do not have
925            // an edge in input_edges? If so, maybe raise an error?
926
927            for out_edge in output_edges {
928                builder.add_edge(
929                    (b1, out_edge.target.port),
930                    (output, out_edge.target.port),
931                    out_edge.edge_type,
932                )?;
933            }
934
935            builder.build()?
936        };
937
938        let mut outputs = HashMap::new();
939        outputs.insert(Label::value(), Value::Graph(g3));
940
941        Ok(outputs)
942    })
943}
944
945pub(crate) fn operation_map() -> RuntimeOperation {
946    RuntimeOperation::new_fn_async(|mut inputs, context| async move {
947        let _ = &context;
948        let thunk = validate_graph_input(&mut inputs, Label::thunk())?;
949
950        let list = validate_input(&mut inputs, Label::value(), |x| match x {
951            Value::Vec(a) => Some(a),
952            _ => None,
953        })?;
954
955        let mut tasks = Vec::new();
956
957        let node_trace = context
958            .graph_trace
959            .as_node_trace()
960            .map_err(|_| anyhow!("map function expected stack trace to correspond to a node"))?;
961        for (idx, x) in list.into_iter().enumerate() {
962            let thunk_clone = thunk.clone();
963            let runtime = context.runtime.clone();
964            let callback = context.callback.clone();
965            let escape = context.escape.clone();
966            let checkpoint = context.checkpoint_client.clone();
967            let graph_trace = node_trace.clone().list_elem(idx as u32);
968
969            let span = tracing::Span::current();
970
971            let t = tokio::spawn(
972                async move {
973                    let value: Value = RuntimeOperation::new_graph(thunk_clone)
974                        .run_simple(
975                            runtime,
976                            callback,
977                            escape,
978                            [(Label::value(), x)],
979                            graph_trace,
980                            checkpoint,
981                        )
982                        .into_task()
983                        .complete()
984                        .await
985                        .map_err(|err| {
986                            let e = anyhow!(format!("{:?}", err.as_ref()));
987                            e.context("map body".to_string())
988                        })?
989                        .get(&Label::value())
990                        .ok_or_else(|| anyhow!("map thunk should output on value port"))?
991                        .clone();
992                    Ok::<Value, anyhow::Error>(value)
993                }
994                .instrument(span),
995            );
996            tasks.push(t);
997        }
998        // .collect();
999        let x = future::join_all(tasks).await;
1000        let y: Result<Vec<_>, _> = x.into_iter().collect();
1001        let z: Result<Vec<Value>, _> = y?.into_iter().collect();
1002
1003        let outputs = HashMap::from([(Label::value(), Value::Vec(z?))]);
1004
1005        Ok(outputs)
1006    })
1007}
1008
1009fn take_input<E>(
1010    inputs: &mut HashMap<Label, Value>,
1011    port: impl TryInto<Label, Error = E>,
1012) -> anyhow::Result<Value>
1013where
1014    E: Into<SymbolError>,
1015{
1016    let port = TryInto::try_into(port).map_err(|e| e.into())?;
1017    inputs
1018        .remove(&port)
1019        .ok_or_else(|| RuntimeError::MissingInput(port).into())
1020}
1021
1022fn validate_input<E, T>(
1023    inputs: &mut HashMap<Label, Value>,
1024    port: impl TryInto<Label, Error = E>,
1025    validation: impl FnOnce(Value) -> Option<T>,
1026) -> anyhow::Result<T>
1027where
1028    E: Into<SymbolError>,
1029{
1030    let port = TryInto::try_into(port).map_err(|e| e.into())?;
1031    let input = inputs
1032        .remove(&port)
1033        .ok_or(RuntimeError::MissingInput(port))?;
1034    match validation(input) {
1035        Some(v) => Ok(v),
1036        None => Err(anyhow!(RuntimeError::InvalidInput(port))),
1037    }
1038}
1039
1040fn validate_int_input<E>(
1041    inputs: &mut HashMap<Label, Value>,
1042    port: impl TryInto<Label, Error = E>,
1043) -> anyhow::Result<i64>
1044where
1045    E: Into<SymbolError>,
1046{
1047    validate_input(inputs, port, |x| match x {
1048        Value::Int(a) => Some(a),
1049        _ => None,
1050    })
1051}
1052
1053fn validate_float_input<E>(
1054    inputs: &mut HashMap<Label, Value>,
1055    port: impl TryInto<Label, Error = E>,
1056) -> anyhow::Result<f64>
1057where
1058    E: Into<SymbolError>,
1059{
1060    validate_input(inputs, port, |x| match x {
1061        Value::Float(a) => Some(a),
1062        _ => None,
1063    })
1064}
1065
1066fn validate_bool_input<E>(
1067    inputs: &mut HashMap<Label, Value>,
1068    port: impl TryInto<Label, Error = E>,
1069) -> anyhow::Result<bool>
1070where
1071    E: Into<SymbolError>,
1072{
1073    validate_input(inputs, port, |x| match x {
1074        Value::Bool(a) => Some(a),
1075        _ => None,
1076    })
1077}
1078
1079fn validate_graph_input<E>(
1080    inputs: &mut HashMap<Label, Value>,
1081    port: impl TryInto<Label, Error = E>,
1082) -> anyhow::Result<Graph>
1083where
1084    E: Into<SymbolError>,
1085{
1086    validate_input(inputs, port, |x| match x {
1087        Value::Graph(a) => Some(a),
1088        _ => None,
1089    })
1090}
1091
1092#[derive(Debug, Error)]
1093enum RuntimeError {
1094    #[error("Missing input on port {0}.")]
1095    MissingInput(Label),
1096    #[error("Invalid input on port {0}.")]
1097    InvalidInput(Label),
1098    #[error("Vector is empty.")]
1099    EmptyVector,
1100    #[error("Key not found in map.")]
1101    KeyNotFound(Value),
1102}