use crate::{
backtrace::Backtrace,
codegen::{Function, InvokerFn},
context::Context,
logger,
registry::Registry,
rpc::{
client::FunctionRpcClient, status_result::Status, streaming_message::Content,
FunctionLoadRequest, FunctionLoadResponse, InvocationRequest, InvocationResponse,
StartStream, StatusResult, StreamingMessage, WorkerInitResponse, WorkerStatusRequest,
WorkerStatusResponse,
},
};
use futures::{channel::mpsc::unbounded, future::FutureExt, stream::StreamExt};
use http::uri::Uri;
use log::error;
use std::{
cell::RefCell,
future::Future,
panic::{catch_unwind, set_hook, AssertUnwindSafe, PanicInfo},
pin::Pin,
task::Poll,
};
use tokio::future::poll_fn;
use tokio_executor::threadpool::blocking;
use tonic::Request;
pub type Sender = futures::channel::mpsc::UnboundedSender<StreamingMessage>;
struct ContextFuture<F> {
inner: F,
invocation_id: String,
function_id: String,
function_name: &'static str,
sender: Sender,
}
impl<F> ContextFuture<F> {
pub fn new(
inner: F,
invocation_id: String,
function_id: String,
function_name: &'static str,
sender: Sender,
) -> Self {
ContextFuture {
inner,
invocation_id,
function_id,
function_name,
sender,
}
}
}
impl<F: Future<Output = InvocationResponse> + Unpin> Future for ContextFuture<F> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context) -> Poll<Self::Output> {
let _guard = Context::set(&self.invocation_id, &self.function_id, self.function_name);
let res = match catch_unwind(AssertUnwindSafe(|| self.inner.poll_unpin(cx))) {
Ok(p) => match p {
Poll::Ready(res) => res,
Poll::Pending => return Poll::Pending,
},
Err(_) => InvocationResponse {
invocation_id: self.invocation_id.clone(),
result: Some(StatusResult {
status: Status::Failure as i32,
result: "Azure Function panicked: see log for more information.".to_string(),
..Default::default()
}),
..Default::default()
},
};
self.sender
.unbounded_send(StreamingMessage {
content: Some(Content::InvocationResponse(res)),
..Default::default()
})
.expect("failed to send invocation response");
Poll::Ready(())
}
}
pub struct Worker;
impl Worker {
pub fn run(host: &str, port: u16, worker_id: &str, mut registry: Registry<'static>) {
let host_uri: Uri = format!("http://{0}:{1}", host, port).parse().unwrap();
let (sender, receiver) = unbounded::<StreamingMessage>();
tokio::runtime::Runtime::new().unwrap().block_on(async {
let mut client = FunctionRpcClient::connect(host_uri)
.await
.map_err(|e| panic!("failed to connect to host: {}", e))
.unwrap();
sender
.unbounded_send(StreamingMessage {
content: Some(Content::StartStream(StartStream {
worker_id: worker_id.to_owned(),
})),
..Default::default()
})
.unwrap();
let mut stream = client
.event_stream(Request::new(receiver))
.await
.map_err(|e| panic!("failed to start event stream: {}", e))
.unwrap()
.into_inner();
let init_req = stream
.next()
.await
.expect("expected a worker init request")
.map_err(|e| panic!("failed to read event stream response: {}", e))
.unwrap();
Worker::handle_worker_init_request(sender.clone(), init_req).await;
stream
.for_each(move |req| {
Worker::handle_request(
&mut registry,
sender.clone(),
req.expect("expected a request"),
);
futures::future::ready(())
})
.await;
});
}
async fn handle_worker_init_request(sender: Sender, req: StreamingMessage) {
match req.content {
Some(Content::WorkerInitRequest(req)) => {
println!(
"Connected to Azure Functions host version {}.",
req.host_version
);
log::set_boxed_logger(Box::new(logger::Logger::new(
log::Level::Info,
sender.clone(),
)))
.expect("failed to set the global logger instance");
set_hook(Box::new(Worker::handle_panic));
log::set_max_level(log::LevelFilter::Trace);
sender
.unbounded_send(StreamingMessage {
content: Some(Content::WorkerInitResponse(WorkerInitResponse {
worker_version: env!("CARGO_PKG_VERSION").to_owned(),
result: Some(StatusResult {
status: Status::Success as i32,
..Default::default()
}),
..Default::default()
})),
..Default::default()
})
.unwrap();
}
_ => panic!("expected a worker init request message from the host"),
};
}
fn handle_request(registry: &mut Registry<'static>, sender: Sender, req: StreamingMessage) {
match req.content {
Some(Content::FunctionLoadRequest(req)) => {
Worker::handle_function_load_request(registry, sender, req)
}
Some(Content::InvocationRequest(req)) => {
Worker::handle_invocation_request(registry, sender, req)
}
Some(Content::WorkerStatusRequest(req)) => {
Worker::handle_worker_status_request(sender, req)
}
Some(Content::FileChangeEventRequest(_)) => {}
Some(Content::InvocationCancel(_)) => {}
Some(Content::FunctionEnvironmentReloadRequest(_)) => {}
_ => panic!("unexpected message from host: {:?}.", req),
};
}
fn handle_function_load_request(
registry: &mut Registry<'static>,
sender: Sender,
req: FunctionLoadRequest,
) {
let mut result = StatusResult::default();
match req.metadata.as_ref() {
Some(metadata) => {
if registry.register(&req.function_id, &metadata.name) {
result.status = Status::Success as i32;
} else {
result.status = Status::Failure as i32;
result.result = format!("Function '{}' does not exist.", metadata.name);
}
}
None => {
result.status = Status::Failure as i32;
result.result = "Function load request metadata is missing.".to_string();
}
};
sender
.unbounded_send(StreamingMessage {
content: Some(Content::FunctionLoadResponse(FunctionLoadResponse {
function_id: req.function_id,
result: Some(result),
..Default::default()
})),
..Default::default()
})
.expect("failed to send function load response");
}
fn handle_invocation_request(
registry: &Registry<'static>,
sender: Sender,
req: InvocationRequest,
) {
if let Some(func) = registry.get(&req.function_id) {
Worker::invoke_function(func, sender, req);
return;
}
let error = format!("Function with id '{}' does not exist.", req.function_id);
sender
.unbounded_send(StreamingMessage {
content: Some(Content::InvocationResponse(InvocationResponse {
invocation_id: req.invocation_id,
result: Some(StatusResult {
status: Status::Failure as i32,
result: error,
..Default::default()
}),
..Default::default()
})),
..Default::default()
})
.expect("failed to send invocation response");
}
fn handle_worker_status_request(sender: Sender, _: WorkerStatusRequest) {
sender
.unbounded_send(StreamingMessage {
content: Some(Content::WorkerStatusResponse(WorkerStatusResponse {})),
..Default::default()
})
.expect("failed to send worker status response");
}
fn invoke_function(func: &'static Function, sender: Sender, req: InvocationRequest) {
match func
.invoker
.as_ref()
.expect("function must have an invoker")
.invoker_fn
{
InvokerFn::Sync(invoker_fn) => {
let id = req.invocation_id.clone();
let func_id = req.function_id.clone();
let req = RefCell::new(Some(req));
tokio::spawn(ContextFuture::new(
poll_fn(move |_| {
blocking(|| {
invoker_fn.expect("invoker must have a callback")(
req.replace(None).expect("only a single call to invoker"),
)
})
})
.map(|r| r.expect("expected a response")),
id,
func_id,
&func.name,
sender,
));
}
InvokerFn::Async(invoker_fn) => {
let id = req.invocation_id.clone();
let func_id = req.function_id.clone();
tokio::spawn(ContextFuture::new(
invoker_fn.expect("invoker must have a callback")(req),
id,
func_id,
&func.name,
sender,
));
}
};
}
fn handle_panic(info: &PanicInfo) {
let backtrace = Backtrace::new();
match info.location() {
Some(location) => {
error!(
"Azure Function '{}' panicked with '{}', {}:{}:{}{}",
crate::context::CURRENT.with(|c| c.borrow().function_name),
info.payload()
.downcast_ref::<&str>()
.cloned()
.unwrap_or_else(|| info
.payload()
.downcast_ref::<String>()
.map(String::as_str)
.unwrap_or("")),
location.file(),
location.line(),
location.column(),
backtrace
);
}
None => {
error!(
"Azure Function '{}' panicked with '{}'{}",
crate::context::CURRENT.with(|c| c.borrow().function_name),
info.payload()
.downcast_ref::<&str>()
.cloned()
.unwrap_or_else(|| info
.payload()
.downcast_ref::<String>()
.map(String::as_str)
.unwrap_or("")),
backtrace
);
}
};
}
}