1use crate::{
2 backtrace::Backtrace,
3 codegen::{Function, InvokerFn},
4 context::Context,
5 logger,
6 registry::Registry,
7 rpc::{
8 client::FunctionRpcClient, status_result::Status, streaming_message::Content,
9 FunctionLoadRequest, FunctionLoadResponse, InvocationRequest, InvocationResponse,
10 StartStream, StatusResult, StreamingMessage, WorkerInitResponse, WorkerStatusRequest,
11 WorkerStatusResponse,
12 },
13};
14use futures::{channel::mpsc::unbounded, future::FutureExt, stream::StreamExt};
15use http::uri::Uri;
16use log::error;
17use std::{
18 cell::RefCell,
19 future::Future,
20 panic::{catch_unwind, set_hook, AssertUnwindSafe, PanicInfo},
21 pin::Pin,
22 task::Poll,
23};
24use tokio::future::poll_fn;
25use tokio_executor::threadpool::blocking;
26use tonic::Request;
27
28pub type Sender = futures::channel::mpsc::UnboundedSender<StreamingMessage>;
29
30struct ContextFuture<F> {
31 inner: F,
32 invocation_id: String,
33 function_id: String,
34 function_name: &'static str,
35 sender: Sender,
36}
37
38impl<F> ContextFuture<F> {
39 pub fn new(
40 inner: F,
41 invocation_id: String,
42 function_id: String,
43 function_name: &'static str,
44 sender: Sender,
45 ) -> Self {
46 ContextFuture {
47 inner,
48 invocation_id,
49 function_id,
50 function_name,
51 sender,
52 }
53 }
54}
55
56impl<F: Future<Output = InvocationResponse> + Unpin> Future for ContextFuture<F> {
57 type Output = ();
58
59 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
60 let _guard = Context::set(&self.invocation_id, &self.function_id, self.function_name);
61
62 let res = match catch_unwind(AssertUnwindSafe(|| self.inner.poll_unpin(cx))) {
63 Ok(p) => match p {
64 Poll::Ready(res) => res,
65 Poll::Pending => return Poll::Pending,
66 },
67 Err(_) => InvocationResponse {
68 invocation_id: self.invocation_id.clone(),
69 result: Some(StatusResult {
70 status: Status::Failure as i32,
71 result: "Azure Function panicked: see log for more information.".to_string(),
72 ..Default::default()
73 }),
74 ..Default::default()
75 },
76 };
77
78 self.sender
79 .unbounded_send(StreamingMessage {
80 content: Some(Content::InvocationResponse(res)),
81 ..Default::default()
82 })
83 .expect("failed to send invocation response");
84
85 Poll::Ready(())
86 }
87}
88
89pub struct Worker;
90
91impl Worker {
92 pub fn run(host: &str, port: u16, worker_id: &str, mut registry: Registry<'static>) {
93 let host_uri: Uri = format!("http://{0}:{1}", host, port).parse().unwrap();
94 let (sender, receiver) = unbounded::<StreamingMessage>();
95
96 tokio::runtime::Runtime::new().unwrap().block_on(async {
97 let mut client = FunctionRpcClient::connect(host_uri)
98 .await
99 .map_err(|e| panic!("failed to connect to host: {}", e))
100 .unwrap();
101
102 sender
105 .unbounded_send(StreamingMessage {
106 content: Some(Content::StartStream(StartStream {
107 worker_id: worker_id.to_owned(),
108 })),
109 ..Default::default()
110 })
111 .unwrap();
112
113 let mut stream = client
114 .event_stream(Request::new(receiver))
115 .await
116 .map_err(|e| panic!("failed to start event stream: {}", e))
117 .unwrap()
118 .into_inner();
119
120 let init_req = stream
121 .next()
122 .await
123 .expect("expected a worker init request")
124 .map_err(|e| panic!("failed to read event stream response: {}", e))
125 .unwrap();
126
127 Worker::handle_worker_init_request(sender.clone(), init_req).await;
128
129 stream
130 .for_each(move |req| {
131 Worker::handle_request(
132 &mut registry,
133 sender.clone(),
134 req.expect("expected a request"),
135 );
136 futures::future::ready(())
137 })
138 .await;
139 });
140 }
141
142 async fn handle_worker_init_request(sender: Sender, req: StreamingMessage) {
143 match req.content {
144 Some(Content::WorkerInitRequest(req)) => {
145 println!(
146 "Connected to Azure Functions host version {}.",
147 req.host_version
148 );
149
150 log::set_boxed_logger(Box::new(logger::Logger::new(
152 log::Level::Info,
153 sender.clone(),
154 )))
155 .expect("failed to set the global logger instance");
156
157 set_hook(Box::new(Worker::handle_panic));
158
159 log::set_max_level(log::LevelFilter::Trace);
160
161 sender
162 .unbounded_send(StreamingMessage {
163 content: Some(Content::WorkerInitResponse(WorkerInitResponse {
164 worker_version: env!("CARGO_PKG_VERSION").to_owned(),
165 result: Some(StatusResult {
166 status: Status::Success as i32,
167 ..Default::default()
168 }),
169 ..Default::default()
170 })),
171 ..Default::default()
172 })
173 .unwrap();
174 }
175 _ => panic!("expected a worker init request message from the host"),
176 };
177 }
178
179 fn handle_request(registry: &mut Registry<'static>, sender: Sender, req: StreamingMessage) {
180 match req.content {
181 Some(Content::FunctionLoadRequest(req)) => {
182 Worker::handle_function_load_request(registry, sender, req)
183 }
184 Some(Content::InvocationRequest(req)) => {
185 Worker::handle_invocation_request(registry, sender, req)
186 }
187 Some(Content::WorkerStatusRequest(req)) => {
188 Worker::handle_worker_status_request(sender, req)
189 }
190 Some(Content::FileChangeEventRequest(_)) => {}
191 Some(Content::InvocationCancel(_)) => {}
192 Some(Content::FunctionEnvironmentReloadRequest(_)) => {}
193 _ => panic!("unexpected message from host: {:?}.", req),
194 };
195 }
196
197 fn handle_function_load_request(
198 registry: &mut Registry<'static>,
199 sender: Sender,
200 req: FunctionLoadRequest,
201 ) {
202 let mut result = StatusResult::default();
203
204 match req.metadata.as_ref() {
205 Some(metadata) => {
206 if registry.register(&req.function_id, &metadata.name) {
207 result.status = Status::Success as i32;
208 } else {
209 result.status = Status::Failure as i32;
210 result.result = format!("Function '{}' does not exist.", metadata.name);
211 }
212 }
213 None => {
214 result.status = Status::Failure as i32;
215 result.result = "Function load request metadata is missing.".to_string();
216 }
217 };
218
219 sender
220 .unbounded_send(StreamingMessage {
221 content: Some(Content::FunctionLoadResponse(FunctionLoadResponse {
222 function_id: req.function_id,
223 result: Some(result),
224 ..Default::default()
225 })),
226 ..Default::default()
227 })
228 .expect("failed to send function load response");
229 }
230
231 fn handle_invocation_request(
232 registry: &Registry<'static>,
233 sender: Sender,
234 req: InvocationRequest,
235 ) {
236 if let Some(func) = registry.get(&req.function_id) {
237 Worker::invoke_function(func, sender, req);
238 return;
239 }
240
241 let error = format!("Function with id '{}' does not exist.", req.function_id);
242
243 sender
244 .unbounded_send(StreamingMessage {
245 content: Some(Content::InvocationResponse(InvocationResponse {
246 invocation_id: req.invocation_id,
247 result: Some(StatusResult {
248 status: Status::Failure as i32,
249 result: error,
250 ..Default::default()
251 }),
252 ..Default::default()
253 })),
254 ..Default::default()
255 })
256 .expect("failed to send invocation response");
257 }
258
259 fn handle_worker_status_request(sender: Sender, _: WorkerStatusRequest) {
260 sender
261 .unbounded_send(StreamingMessage {
262 content: Some(Content::WorkerStatusResponse(WorkerStatusResponse {})),
263 ..Default::default()
264 })
265 .expect("failed to send worker status response");
266 }
267
268 fn invoke_function(func: &'static Function, sender: Sender, req: InvocationRequest) {
269 match func
270 .invoker
271 .as_ref()
272 .expect("function must have an invoker")
273 .invoker_fn
274 {
275 InvokerFn::Sync(invoker_fn) => {
276 let id = req.invocation_id.clone();
279 let func_id = req.function_id.clone();
280 let req = RefCell::new(Some(req));
281
282 tokio::spawn(ContextFuture::new(
283 poll_fn(move |_| {
284 blocking(|| {
285 invoker_fn.expect("invoker must have a callback")(
286 req.replace(None).expect("only a single call to invoker"),
287 )
288 })
289 })
290 .map(|r| r.expect("expected a response")),
291 id,
292 func_id,
293 &func.name,
294 sender,
295 ));
296 }
297 InvokerFn::Async(invoker_fn) => {
298 let id = req.invocation_id.clone();
299 let func_id = req.function_id.clone();
300
301 tokio::spawn(ContextFuture::new(
302 invoker_fn.expect("invoker must have a callback")(req),
303 id,
304 func_id,
305 &func.name,
306 sender,
307 ));
308 }
309 };
310 }
311
312 fn handle_panic(info: &PanicInfo) {
313 let backtrace = Backtrace::new();
314 match info.location() {
315 Some(location) => {
316 error!(
317 "Azure Function '{}' panicked with '{}', {}:{}:{}{}",
318 crate::context::CURRENT.with(|c| c.borrow().function_name),
319 info.payload()
320 .downcast_ref::<&str>()
321 .cloned()
322 .unwrap_or_else(|| info
323 .payload()
324 .downcast_ref::<String>()
325 .map(String::as_str)
326 .unwrap_or("")),
327 location.file(),
328 location.line(),
329 location.column(),
330 backtrace
331 );
332 }
333 None => {
334 error!(
335 "Azure Function '{}' panicked with '{}'{}",
336 crate::context::CURRENT.with(|c| c.borrow().function_name),
337 info.payload()
338 .downcast_ref::<&str>()
339 .cloned()
340 .unwrap_or_else(|| info
341 .payload()
342 .downcast_ref::<String>()
343 .map(String::as_str)
344 .unwrap_or("")),
345 backtrace
346 );
347 }
348 };
349 }
350}