1use crate::repl_state::ReplState;
16use async_trait::async_trait;
17use golem_wasm_ast::analysis::AnalysedType;
18use golem_wasm_rpc::ValueAndType;
19use rib::{
20 ComponentDependencyKey, EvaluatedFnArgs, EvaluatedFqFn, EvaluatedWorkerName, InstructionId,
21 RibComponentFunctionInvoke, RibFunctionInvokeResult,
22};
23use std::sync::Arc;
24use uuid::Uuid;
25
26#[async_trait]
27pub trait WorkerFunctionInvoke {
28 async fn invoke(
29 &self,
30 component_id: Uuid,
31 component_name: &str,
32 worker_name: Option<String>,
33 function_name: &str,
34 args: Vec<ValueAndType>,
35 return_type: Option<AnalysedType>,
36 ) -> anyhow::Result<Option<ValueAndType>>;
37}
38
39pub(crate) struct ReplRibFunctionInvoke {
45 repl_state: Arc<ReplState>,
46}
47
48impl ReplRibFunctionInvoke {
49 pub fn new(repl_state: Arc<ReplState>) -> Self {
50 Self { repl_state }
51 }
52
53 fn get_cached_result(&self, instruction_id: &InstructionId) -> Option<Option<ValueAndType>> {
54 if instruction_id.index > self.repl_state.last_executed_instruction().index {
58 None
59 } else {
60 self.repl_state.invocation_results().get(instruction_id)
61 }
62 }
63}
64
65#[async_trait]
66impl RibComponentFunctionInvoke for ReplRibFunctionInvoke {
67 async fn invoke(
68 &self,
69 component_dependency: ComponentDependencyKey,
70 instruction_id: &InstructionId,
71 worker_name: Option<EvaluatedWorkerName>,
72 function_name: EvaluatedFqFn,
73 args: EvaluatedFnArgs,
74 return_type: Option<AnalysedType>,
75 ) -> RibFunctionInvokeResult {
76 match self.get_cached_result(instruction_id) {
77 Some(result) => Ok(result),
78 None => {
79 let rib_invocation_result = self
80 .repl_state
81 .worker_function_invoke()
82 .invoke(
83 component_dependency.component_id,
84 component_dependency.component_name.as_str(),
85 worker_name.map(|x| x.0),
86 function_name.0.as_str(),
87 args.0,
88 return_type,
89 )
90 .await;
91
92 match rib_invocation_result {
93 Ok(result) => {
94 self.repl_state
95 .update_cache(instruction_id.clone(), result.clone());
96
97 Ok(result)
98 }
99 Err(err) => Err(err.into()),
100 }
101 }
102 }
103 }
104}