Skip to main content

statefun/
function_registry.rs

1//! The function registry keeps a mapping from `FunctionType` to stateful functions.
2
3use std::collections::HashMap;
4
5use failure::format_err;
6use protobuf::well_known_types::Any;
7use protobuf::Message;
8
9use crate::{Context, Effects, FunctionType};
10
11/// Keeps a mapping from `FunctionType` to stateful functions. Use this together with a
12/// [Transport](crate::transport::Transport) to serve stateful functions.
13///
14/// Use `register_fn()` to register functions before handing the registry over to a `Transport` for
15/// serving.
16pub struct FunctionRegistry {
17    functions: HashMap<FunctionType, Box<dyn InvokableFunction + Send>>,
18}
19
20#[allow(clippy::new_without_default)]
21impl FunctionRegistry {
22    /// Creates a new empty `FunctionRegistry`.
23    pub fn new() -> FunctionRegistry {
24        FunctionRegistry {
25            functions: HashMap::new(),
26        }
27    }
28
29    /// Registers the given function under the `function_type`.
30    pub fn register_fn<I: Message, F: Fn(Context, I) -> Effects + Send + 'static>(
31        &mut self,
32        function_type: FunctionType,
33        function: F,
34    ) {
35        let callable_function = FnInvokableFunction {
36            function,
37            marker: ::std::marker::PhantomData,
38        };
39        self.functions
40            .insert(function_type, Box::new(callable_function));
41    }
42
43    /// Invokes the function that is registered for the given `FunctionType`. This will return
44    /// `Err` if no function is registered under the given type.
45    pub fn invoke(
46        &self,
47        target_function: FunctionType,
48        context: Context,
49        message: Any,
50    ) -> Result<Effects, failure::Error> {
51        let function = self.functions.get(&target_function);
52        match function {
53            Some(fun) => fun.invoke(context, message),
54            None => Err(format_err!(
55                "No function registered under {}",
56                target_function
57            )),
58        }
59    }
60}
61
62/// A function that can be invoked. This is used as trait objects in the `FunctionRegistry`.
63trait InvokableFunction {
64    fn invoke(&self, context: Context, message: Any) -> Result<Effects, failure::Error>;
65}
66
67/// An `InvokableFunction` that is backed by a `Fn`.
68struct FnInvokableFunction<I: Message, F: Fn(Context, I) -> Effects> {
69    function: F,
70    marker: ::std::marker::PhantomData<I>,
71}
72
73impl<I: Message, F: Fn(Context, I) -> Effects> InvokableFunction for FnInvokableFunction<I, F> {
74    fn invoke(&self, context: Context, message: Any) -> Result<Effects, failure::Error> {
75        let unpacked_argument: I = message.unpack()?.unwrap();
76        let effects = (self.function)(context, unpacked_argument);
77        Ok(effects)
78    }
79}
80
81#[cfg(test)]
82mod tests {
83    use crate::FunctionRegistry;
84    use crate::*;
85    use protobuf::well_known_types::StringValue;
86
87    #[test]
88    fn call_registered_function() -> Result<(), failure::Error> {
89        let state = HashMap::new();
90        let address = address_foo().into_proto();
91        let context = Context::new(&state, &address, &address);
92
93        let mut registry = FunctionRegistry::new();
94        registry.register_fn(function_type_foo(), |_context, _message: StringValue| {
95            Effects::new()
96        });
97
98        let packed_argument = Any::pack(&StringValue::new())?;
99        let _effects = registry.invoke(function_type_foo(), context, packed_argument)?;
100
101        Ok(())
102    }
103
104    #[test]
105    fn call_unknown_function() -> Result<(), failure::Error> {
106        let state = HashMap::new();
107        let address = address_foo().into_proto();
108        let context = Context::new(&state, &address, &address);
109
110        let registry = FunctionRegistry::new();
111
112        let packed_argument = Any::pack(&StringValue::new())?;
113        let result = registry.invoke(function_type_bar(), context, packed_argument);
114
115        assert!(result.is_err());
116
117        Ok(())
118    }
119
120    #[test]
121    fn call_correct_function() -> Result<(), failure::Error> {
122        let state = HashMap::new();
123
124        let mut registry = FunctionRegistry::new();
125        registry.register_fn(function_type_foo(), |context, _message: StringValue| {
126            let mut effects = Effects::new();
127
128            let mut message = StringValue::new();
129            message.set_value("function_foo".to_owned());
130            effects.send(context.self_address(), message);
131
132            effects
133        });
134
135        registry.register_fn(function_type_bar(), |context, _message: StringValue| {
136            let mut effects = Effects::new();
137
138            let mut message = StringValue::new();
139            message.set_value("function_bar".to_owned());
140            effects.send(context.self_address(), message);
141
142            effects
143        });
144
145        let address_foo = address_foo().into_proto();
146        let context = Context::new(&state, &address_foo, &address_foo);
147        let packed_argument = Any::pack(&StringValue::new())?;
148        let effects_foo = registry.invoke(function_type_foo(), context, packed_argument)?;
149        assert_eq!(
150            effects_foo.invocations[0]
151                .1
152                .unpack::<StringValue>()
153                .unwrap()
154                .unwrap()
155                .get_value(),
156            "function_foo",
157        );
158
159        let address_bar = address_bar().into_proto();
160        let context = Context::new(&state, &address_bar, &address_bar);
161        let packed_argument = Any::pack(&StringValue::new())?;
162        let effects_bar = registry.invoke(function_type_bar(), context, packed_argument)?;
163        assert_eq!(
164            effects_bar.invocations[0]
165                .1
166                .unpack::<StringValue>()
167                .unwrap()
168                .unwrap()
169                .get_value(),
170            "function_bar",
171        );
172
173        Ok(())
174    }
175
176    fn function_type_foo() -> FunctionType {
177        FunctionType::new("namespace", "foo")
178    }
179
180    fn function_type_bar() -> FunctionType {
181        FunctionType::new("namespace", "bar")
182    }
183
184    fn address_foo() -> Address {
185        Address::new(function_type_foo(), "doctor")
186    }
187
188    fn address_bar() -> Address {
189        Address::new(function_type_bar(), "doctor")
190    }
191}