1use crate::bindings::golem::durability::durability::{
16 begin_durable_function, current_durable_execution_state, end_durable_function,
17 observe_function_call, persist_typed_durable_function_invocation,
18 read_persisted_typed_durable_function_invocation, DurableExecutionState, DurableFunctionType,
19 OplogEntryVersion, OplogIndex, PersistedTypedDurableFunctionInvocation, PersistenceLevel,
20};
21use crate::value_and_type::{FromValueAndType, IntoValueAndType};
22use golem_wasm_rpc::golem_rpc_0_2_x::types::ValueAndType;
23use std::fmt::{Debug, Display};
24use std::marker::PhantomData;
25
26pub struct Durability<SOk, SErr> {
27 interface: &'static str,
28 function: &'static str,
29 function_type: DurableFunctionType,
30 begin_index: OplogIndex,
31 durable_execution_state: DurableExecutionState,
32 forced_commit: bool,
33 _sok: PhantomData<SOk>,
34 _serr: PhantomData<SErr>,
35}
36
37impl<SOk, SErr> Durability<SOk, SErr> {
38 pub fn new(
39 interface: &'static str,
40 function: &'static str,
41 function_type: DurableFunctionType,
42 ) -> Self {
43 observe_function_call(interface, function);
44
45 let begin_index = begin_durable_function(function_type);
46 let durable_execution_state = current_durable_execution_state();
47
48 Self {
49 interface,
50 function,
51 function_type,
52 begin_index,
53 durable_execution_state,
54 forced_commit: false,
55 _sok: PhantomData,
56 _serr: PhantomData,
57 }
58 }
59
60 pub fn enabled_forced_commit(&mut self) {
61 self.forced_commit = true;
62 }
63
64 pub fn is_live(&self) -> bool {
65 self.durable_execution_state.is_live
66 || matches!(
67 self.durable_execution_state.persistence_level,
68 PersistenceLevel::PersistNothing
69 )
70 }
71
72 pub fn persist<SIn, Ok, Err>(&self, input: SIn, result: Result<Ok, Err>) -> Result<Ok, Err>
73 where
74 Ok: Clone,
75 Err: From<SErr>,
76 SIn: Debug + IntoValueAndType,
77 SErr: Debug + for<'a> From<&'a Err>,
78 SOk: Debug + From<Ok>,
79 Result<SOk, SErr>: IntoValueAndType,
80 {
81 let serializable_result: Result<SOk, SErr> = result
82 .as_ref()
83 .map(|result| result.clone().into())
84 .map_err(|err| err.into());
85
86 self.persist_serializable(input, serializable_result);
87 result
88 }
89
90 pub fn persist_infallible<SIn, Ok>(&self, input: SIn, result: Ok) -> Ok
91 where
92 Ok: Clone,
93 SIn: Debug + IntoValueAndType,
94 SOk: Debug + From<Ok>,
95 SErr: Debug,
96 Result<SOk, SErr>: IntoValueAndType,
97 {
98 let serializable_result: Result<SOk, SErr> = Ok(result.clone().into());
99
100 self.persist_serializable(input, serializable_result);
101 result
102 }
103
104 pub fn persist_serializable<SIn>(&self, input: SIn, result: Result<SOk, SErr>)
105 where
106 SIn: Debug + IntoValueAndType,
107 Result<SOk, SErr>: IntoValueAndType,
108 {
109 let function_name = self.function_name();
110 if !matches!(
111 self.durable_execution_state.persistence_level,
112 PersistenceLevel::PersistNothing
113 ) {
114 persist_typed_durable_function_invocation(
115 &function_name,
116 &input.into_value_and_type(),
117 &result.into_value_and_type(),
118 self.function_type,
119 );
120 end_durable_function(self.function_type, self.begin_index, self.forced_commit);
121 }
122 }
123
124 pub fn replay_raw(&self) -> (ValueAndType, OplogEntryVersion) {
125 let oplog_entry = read_persisted_typed_durable_function_invocation();
126
127 let function_name = self.function_name();
128 Self::validate_oplog_entry(&oplog_entry, &function_name);
129
130 end_durable_function(self.function_type, self.begin_index, false);
131
132 (oplog_entry.response, oplog_entry.entry_version)
133 }
134
135 pub fn replay_serializable(&self) -> Result<SOk, SErr>
136 where
137 SOk: FromValueAndType,
138 SErr: FromValueAndType,
139 {
140 let (value_and_type, _) = self.replay_raw();
141 let result: Result<SOk, SErr> = FromValueAndType::from_value_and_type(value_and_type)
142 .unwrap_or_else(|err| panic!("Unexpected ImportedFunctionInvoked payload: {err}"));
143 result
144 }
145
146 pub fn replay<Ok, Err>(&self) -> Result<Ok, Err>
147 where
148 Ok: From<SOk>,
149 Err: From<SErr>,
150 SErr: Debug + FromValueAndType,
151 SOk: Debug + FromValueAndType,
152 {
153 Self::replay_serializable(self)
154 .map(|sok| sok.into())
155 .map_err(|serr| serr.into())
156 }
157
158 pub fn replay_infallible<Ok>(&self) -> Ok
159 where
160 Ok: From<SOk>,
161 SOk: FromValueAndType,
162 SErr: FromValueAndType + Display,
163 {
164 let result: Result<SOk, SErr> = self.replay_serializable();
165 result.map(|sok| sok.into()).unwrap_or_else(|err| {
166 panic!(
167 "Function {} previously failed with {}",
168 self.function_name(),
169 err
170 )
171 })
172 }
173
174 fn function_name(&self) -> String {
175 if self.interface.is_empty() {
176 self.function.to_string()
178 } else {
179 format!("{}::{}", self.interface, self.function)
180 }
181 }
182
183 fn validate_oplog_entry(
184 oplog_entry: &PersistedTypedDurableFunctionInvocation,
185 expected_function_name: &str,
186 ) {
187 if oplog_entry.function_name != expected_function_name {
188 panic!(
189 "Unexpected imported function call entry in oplog: expected {}, got {}",
190 expected_function_name, oplog_entry.function_name
191 );
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use crate::bindings::golem::durability::durability::DurableFunctionType;
199 use crate::value_and_type::type_builder::TypeNodeBuilder;
200 use crate::value_and_type::{FromValueAndType, IntoValue};
201 use golem_wasm_rpc::{NodeBuilder, WitValueExtractor};
202 use std::io::Error;
203
204 #[allow(dead_code)]
207 fn durability_interface_test() {
208 #[derive(Debug)]
209 enum CustomError {
210 Error1,
211 Error2,
212 }
213
214 impl From<&std::io::Error> for CustomError {
215 fn from(_value: &Error) -> Self {
216 Self::Error1
217 }
218 }
219
220 impl From<CustomError> for std::io::Error {
221 fn from(value: CustomError) -> Self {
222 Error::other(format!("{value:?}"))
223 }
224 }
225
226 impl IntoValue for CustomError {
227 fn add_to_builder<T: NodeBuilder>(self, builder: T) -> T::Result {
228 match self {
229 CustomError::Error1 => builder.enum_value(0),
230 CustomError::Error2 => builder.enum_value(1),
231 }
232 }
233
234 fn add_to_type_builder<T: TypeNodeBuilder>(builder: T) -> T::Result {
235 builder.r#enum(&["Error1", "Error2"])
236 }
237 }
238
239 impl FromValueAndType for CustomError {
240 fn from_extractor<'a, 'b>(
241 extractor: &'a impl WitValueExtractor<'a, 'b>,
242 ) -> Result<Self, String> {
243 match extractor.enum_value() {
244 Some(0) => Ok(CustomError::Error1),
245 Some(1) => Ok(CustomError::Error2),
246 _ => Err("Invalid enum value".to_string()),
247 }
248 }
249 }
250
251 fn durable_fn() -> Result<u64, std::io::Error> {
252 let durability = super::Durability::<u64, CustomError>::new(
253 "custom",
254 "random-number-generator",
255 DurableFunctionType::ReadLocal,
256 );
257 if durability.is_live() {
258 let result = Ok(1234);
259 durability.persist("input".to_string(), result)
260 } else {
261 durability.replay()
262 }
263 }
264 }
265}