azure_functions/
worker.rs

1use crate::{
2    backtrace::Backtrace,
3    codegen::{Function, InvokerFn},
4    context::Context,
5    logger,
6    registry::Registry,
7    rpc::{
8        client::FunctionRpcClient, status_result::Status, streaming_message::Content,
9        FunctionLoadRequest, FunctionLoadResponse, InvocationRequest, InvocationResponse,
10        StartStream, StatusResult, StreamingMessage, WorkerInitResponse, WorkerStatusRequest,
11        WorkerStatusResponse,
12    },
13};
14use futures::{channel::mpsc::unbounded, future::FutureExt, stream::StreamExt};
15use http::uri::Uri;
16use log::error;
17use std::{
18    cell::RefCell,
19    future::Future,
20    panic::{catch_unwind, set_hook, AssertUnwindSafe, PanicInfo},
21    pin::Pin,
22    task::Poll,
23};
24use tokio::future::poll_fn;
25use tokio_executor::threadpool::blocking;
26use tonic::Request;
27
28pub type Sender = futures::channel::mpsc::UnboundedSender<StreamingMessage>;
29
30struct ContextFuture<F> {
31    inner: F,
32    invocation_id: String,
33    function_id: String,
34    function_name: &'static str,
35    sender: Sender,
36}
37
38impl<F> ContextFuture<F> {
39    pub fn new(
40        inner: F,
41        invocation_id: String,
42        function_id: String,
43        function_name: &'static str,
44        sender: Sender,
45    ) -> Self {
46        ContextFuture {
47            inner,
48            invocation_id,
49            function_id,
50            function_name,
51            sender,
52        }
53    }
54}
55
56impl<F: Future<Output = InvocationResponse> + Unpin> Future for ContextFuture<F> {
57    type Output = ();
58
59    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
60        let _guard = Context::set(&self.invocation_id, &self.function_id, self.function_name);
61
62        let res = match catch_unwind(AssertUnwindSafe(|| self.inner.poll_unpin(cx))) {
63            Ok(p) => match p {
64                Poll::Ready(res) => res,
65                Poll::Pending => return Poll::Pending,
66            },
67            Err(_) => InvocationResponse {
68                invocation_id: self.invocation_id.clone(),
69                result: Some(StatusResult {
70                    status: Status::Failure as i32,
71                    result: "Azure Function panicked: see log for more information.".to_string(),
72                    ..Default::default()
73                }),
74                ..Default::default()
75            },
76        };
77
78        self.sender
79            .unbounded_send(StreamingMessage {
80                content: Some(Content::InvocationResponse(res)),
81                ..Default::default()
82            })
83            .expect("failed to send invocation response");
84
85        Poll::Ready(())
86    }
87}
88
89pub struct Worker;
90
91impl Worker {
92    pub fn run(host: &str, port: u16, worker_id: &str, mut registry: Registry<'static>) {
93        let host_uri: Uri = format!("http://{0}:{1}", host, port).parse().unwrap();
94        let (sender, receiver) = unbounded::<StreamingMessage>();
95
96        tokio::runtime::Runtime::new().unwrap().block_on(async {
97            let mut client = FunctionRpcClient::connect(host_uri)
98                .await
99                .map_err(|e| panic!("failed to connect to host: {}", e))
100                .unwrap();
101
102            // Start by sending a start stream message to the channel
103            // This will be sent to the host upon connection
104            sender
105                .unbounded_send(StreamingMessage {
106                    content: Some(Content::StartStream(StartStream {
107                        worker_id: worker_id.to_owned(),
108                    })),
109                    ..Default::default()
110                })
111                .unwrap();
112
113            let mut stream = client
114                .event_stream(Request::new(receiver))
115                .await
116                .map_err(|e| panic!("failed to start event stream: {}", e))
117                .unwrap()
118                .into_inner();
119
120            let init_req = stream
121                .next()
122                .await
123                .expect("expected a worker init request")
124                .map_err(|e| panic!("failed to read event stream response: {}", e))
125                .unwrap();
126
127            Worker::handle_worker_init_request(sender.clone(), init_req).await;
128
129            stream
130                .for_each(move |req| {
131                    Worker::handle_request(
132                        &mut registry,
133                        sender.clone(),
134                        req.expect("expected a request"),
135                    );
136                    futures::future::ready(())
137                })
138                .await;
139        });
140    }
141
142    async fn handle_worker_init_request(sender: Sender, req: StreamingMessage) {
143        match req.content {
144            Some(Content::WorkerInitRequest(req)) => {
145                println!(
146                    "Connected to Azure Functions host version {}.",
147                    req.host_version
148                );
149
150                // TODO: use the level requested by the Azure functions host
151                log::set_boxed_logger(Box::new(logger::Logger::new(
152                    log::Level::Info,
153                    sender.clone(),
154                )))
155                .expect("failed to set the global logger instance");
156
157                set_hook(Box::new(Worker::handle_panic));
158
159                log::set_max_level(log::LevelFilter::Trace);
160
161                sender
162                    .unbounded_send(StreamingMessage {
163                        content: Some(Content::WorkerInitResponse(WorkerInitResponse {
164                            worker_version: env!("CARGO_PKG_VERSION").to_owned(),
165                            result: Some(StatusResult {
166                                status: Status::Success as i32,
167                                ..Default::default()
168                            }),
169                            ..Default::default()
170                        })),
171                        ..Default::default()
172                    })
173                    .unwrap();
174            }
175            _ => panic!("expected a worker init request message from the host"),
176        };
177    }
178
179    fn handle_request(registry: &mut Registry<'static>, sender: Sender, req: StreamingMessage) {
180        match req.content {
181            Some(Content::FunctionLoadRequest(req)) => {
182                Worker::handle_function_load_request(registry, sender, req)
183            }
184            Some(Content::InvocationRequest(req)) => {
185                Worker::handle_invocation_request(registry, sender, req)
186            }
187            Some(Content::WorkerStatusRequest(req)) => {
188                Worker::handle_worker_status_request(sender, req)
189            }
190            Some(Content::FileChangeEventRequest(_)) => {}
191            Some(Content::InvocationCancel(_)) => {}
192            Some(Content::FunctionEnvironmentReloadRequest(_)) => {}
193            _ => panic!("unexpected message from host: {:?}.", req),
194        };
195    }
196
197    fn handle_function_load_request(
198        registry: &mut Registry<'static>,
199        sender: Sender,
200        req: FunctionLoadRequest,
201    ) {
202        let mut result = StatusResult::default();
203
204        match req.metadata.as_ref() {
205            Some(metadata) => {
206                if registry.register(&req.function_id, &metadata.name) {
207                    result.status = Status::Success as i32;
208                } else {
209                    result.status = Status::Failure as i32;
210                    result.result = format!("Function '{}' does not exist.", metadata.name);
211                }
212            }
213            None => {
214                result.status = Status::Failure as i32;
215                result.result = "Function load request metadata is missing.".to_string();
216            }
217        };
218
219        sender
220            .unbounded_send(StreamingMessage {
221                content: Some(Content::FunctionLoadResponse(FunctionLoadResponse {
222                    function_id: req.function_id,
223                    result: Some(result),
224                    ..Default::default()
225                })),
226                ..Default::default()
227            })
228            .expect("failed to send function load response");
229    }
230
231    fn handle_invocation_request(
232        registry: &Registry<'static>,
233        sender: Sender,
234        req: InvocationRequest,
235    ) {
236        if let Some(func) = registry.get(&req.function_id) {
237            Worker::invoke_function(func, sender, req);
238            return;
239        }
240
241        let error = format!("Function with id '{}' does not exist.", req.function_id);
242
243        sender
244            .unbounded_send(StreamingMessage {
245                content: Some(Content::InvocationResponse(InvocationResponse {
246                    invocation_id: req.invocation_id,
247                    result: Some(StatusResult {
248                        status: Status::Failure as i32,
249                        result: error,
250                        ..Default::default()
251                    }),
252                    ..Default::default()
253                })),
254                ..Default::default()
255            })
256            .expect("failed to send invocation response");
257    }
258
259    fn handle_worker_status_request(sender: Sender, _: WorkerStatusRequest) {
260        sender
261            .unbounded_send(StreamingMessage {
262                content: Some(Content::WorkerStatusResponse(WorkerStatusResponse {})),
263                ..Default::default()
264            })
265            .expect("failed to send worker status response");
266    }
267
268    fn invoke_function(func: &'static Function, sender: Sender, req: InvocationRequest) {
269        match func
270            .invoker
271            .as_ref()
272            .expect("function must have an invoker")
273            .invoker_fn
274        {
275            InvokerFn::Sync(invoker_fn) => {
276                // `poll_fn` takes FnMut and `blocking` takes FnOnce
277                // Wrap the request with a RefCell so we can move the request to the invoked function
278                let id = req.invocation_id.clone();
279                let func_id = req.function_id.clone();
280                let req = RefCell::new(Some(req));
281
282                tokio::spawn(ContextFuture::new(
283                    poll_fn(move |_| {
284                        blocking(|| {
285                            invoker_fn.expect("invoker must have a callback")(
286                                req.replace(None).expect("only a single call to invoker"),
287                            )
288                        })
289                    })
290                    .map(|r| r.expect("expected a response")),
291                    id,
292                    func_id,
293                    &func.name,
294                    sender,
295                ));
296            }
297            InvokerFn::Async(invoker_fn) => {
298                let id = req.invocation_id.clone();
299                let func_id = req.function_id.clone();
300
301                tokio::spawn(ContextFuture::new(
302                    invoker_fn.expect("invoker must have a callback")(req),
303                    id,
304                    func_id,
305                    &func.name,
306                    sender,
307                ));
308            }
309        };
310    }
311
312    fn handle_panic(info: &PanicInfo) {
313        let backtrace = Backtrace::new();
314        match info.location() {
315            Some(location) => {
316                error!(
317                    "Azure Function '{}' panicked with '{}', {}:{}:{}{}",
318                    crate::context::CURRENT.with(|c| c.borrow().function_name),
319                    info.payload()
320                        .downcast_ref::<&str>()
321                        .cloned()
322                        .unwrap_or_else(|| info
323                            .payload()
324                            .downcast_ref::<String>()
325                            .map(String::as_str)
326                            .unwrap_or("")),
327                    location.file(),
328                    location.line(),
329                    location.column(),
330                    backtrace
331                );
332            }
333            None => {
334                error!(
335                    "Azure Function '{}' panicked with '{}'{}",
336                    crate::context::CURRENT.with(|c| c.borrow().function_name),
337                    info.payload()
338                        .downcast_ref::<&str>()
339                        .cloned()
340                        .unwrap_or_else(|| info
341                            .payload()
342                            .downcast_ref::<String>()
343                            .map(String::as_str)
344                            .unwrap_or("")),
345                    backtrace
346                );
347            }
348        };
349    }
350}