use std::collections::HashMap;
use std::convert::Infallible;
use std::net::TcpListener as StdTcpListener;
use std::ops::Deref;
use std::process;
use std::sync::{mpsc, Arc};
use std::thread;
use aws_smithy_http_server::{
body::{Body, BoxBody},
routing::IntoMakeService,
};
use http::{Request, Response};
use hyper::server::conn::AddrIncoming;
use parking_lot::Mutex;
use pyo3::{prelude::*, types::IntoPyDict};
use signal_hook::{consts::*, iterator::Signals};
use socket2::Socket;
use tokio::{net::TcpListener, runtime};
use tokio_rustls::TlsAcceptor;
use tower::{util::BoxCloneService, ServiceBuilder};
use crate::{
context::{layer::AddPyContextLayer, PyContext},
tls::{listener::Listener as TlsListener, PyTlsConfig},
util::{error::rich_py_err, func_metadata},
PySocket,
};
#[pyclass]
#[derive(Debug, Clone)]
pub struct PyHandler {
pub func: PyObject,
pub args: usize,
pub is_coroutine: bool,
}
impl Deref for PyHandler {
type Target = PyObject;
fn deref(&self) -> &Self::Target {
&self.func
}
}
type Service = BoxCloneService<Request<Body>, Response<BoxBody>, Infallible>;
pub trait PyApp: Clone + pyo3::IntoPy<PyObject> {
fn workers(&self) -> &Mutex<Vec<PyObject>>;
fn context(&self) -> &Option<PyObject>;
fn handlers(&mut self) -> &mut HashMap<String, PyHandler>;
fn build_service(&mut self, event_loop: &pyo3::PyAny) -> pyo3::PyResult<Service>;
fn graceful_termination(&self, workers: &Mutex<Vec<PyObject>>) -> ! {
let workers = workers.lock();
for (idx, worker) in workers.iter().enumerate() {
let idx = idx + 1;
Python::with_gil(|py| {
let pid: isize = worker
.getattr(py, "pid")
.map(|pid| pid.extract(py).unwrap_or(-1))
.unwrap_or(-1);
tracing::debug!(idx, pid, "terminating worker");
match worker.call_method0(py, "terminate") {
Ok(_) => {}
Err(e) => {
tracing::error!(error = ?rich_py_err(e), idx, pid, "error terminating worker");
worker
.call_method0(py, "kill")
.map_err(|e| {
tracing::error!(
error = ?rich_py_err(e), idx, pid, "unable to kill kill worker"
);
})
.unwrap();
}
}
});
}
process::exit(0);
}
fn immediate_termination(&self, workers: &Mutex<Vec<PyObject>>) -> ! {
let workers = workers.lock();
for (idx, worker) in workers.iter().enumerate() {
let idx = idx + 1;
Python::with_gil(|py| {
let pid: isize = worker
.getattr(py, "pid")
.map(|pid| pid.extract(py).unwrap_or(-1))
.unwrap_or(-1);
tracing::debug!(idx, pid, "killing worker");
worker
.call_method0(py, "kill")
.map_err(|e| {
tracing::error!(error = ?rich_py_err(e), idx, pid, "unable to kill kill worker");
})
.unwrap();
});
}
process::exit(0);
}
fn block_on_rust_signals(&self) {
let mut signals =
Signals::new([SIGINT, SIGHUP, SIGQUIT, SIGTERM, SIGUSR1, SIGUSR2, SIGWINCH])
.expect("Unable to register signals");
for sig in signals.forever() {
match sig {
SIGINT => {
tracing::info!(
sig = %sig, "termination signal received, all workers will be immediately terminated"
);
self.immediate_termination(self.workers());
}
SIGTERM | SIGQUIT => {
tracing::info!(
sig = %sig, "termination signal received, all workers will be gracefully terminated"
);
self.graceful_termination(self.workers());
}
_ => {
tracing::debug!(sig = %sig, "signal is ignored by this application");
}
}
}
}
fn register_python_signals(&self, py: Python, event_loop: PyObject) -> PyResult<()> {
let locals = [("event_loop", event_loop)].into_py_dict(py);
py.run(
r#"
import asyncio
import logging
import functools
import signal
async def shutdown(sig, event_loop):
# reimport asyncio and logging to be sure they are available when
# this handler runs on signal catching.
import asyncio
import logging
logging.info(f"Caught signal {sig.name}, cancelling tasks registered on this loop")
tasks = [task for task in asyncio.all_tasks() if task is not
asyncio.current_task()]
list(map(lambda task: task.cancel(), tasks))
results = await asyncio.gather(*tasks, return_exceptions=True)
logging.debug(f"Finished awaiting cancelled tasks, results: {results}")
event_loop.stop()
event_loop.add_signal_handler(signal.SIGTERM,
functools.partial(asyncio.ensure_future, shutdown(signal.SIGTERM, event_loop)))
event_loop.add_signal_handler(signal.SIGINT,
functools.partial(asyncio.ensure_future, shutdown(signal.SIGINT, event_loop)))
"#,
None,
Some(locals),
)?;
Ok(())
}
fn start_hyper_worker(
&mut self,
py: Python,
socket: &PyCell<PySocket>,
event_loop: &PyAny,
service: Service,
worker_number: isize,
tls: Option<PyTlsConfig>,
) -> PyResult<()> {
let borrow = socket.try_borrow_mut()?;
let held_socket: &PySocket = &borrow;
let raw_socket = held_socket.get_socket()?;
self.register_python_signals(py, event_loop.to_object(py))?;
tracing::trace!("start the tokio runtime in a background task");
thread::spawn(move || {
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.thread_name(format!("smithy-rs-tokio[{worker_number}]"))
.build()
.expect("unable to start a new tokio runtime for this process");
rt.block_on(async move {
let addr = addr_incoming_from_socket(raw_socket);
if let Some(config) = tls {
let (acceptor, acceptor_rx) = tls_config_reloader(config);
let listener = TlsListener::new(acceptor, addr, acceptor_rx);
let server =
hyper::Server::builder(listener).serve(IntoMakeService::new(service));
tracing::trace!("started tls hyper server from shared socket");
if let Err(err) = server.await {
tracing::error!(error = ?err, "server error");
}
} else {
let server = hyper::Server::builder(addr).serve(IntoMakeService::new(service));
tracing::trace!("started hyper server from shared socket");
if let Err(err) = server.await {
tracing::error!(error = ?err, "server error");
}
}
});
});
tracing::trace!("run and block on the python event loop until a signal is received");
event_loop.call_method0("run_forever")?;
Ok(())
}
fn register_operation(&mut self, py: Python, name: &str, func: PyObject) -> PyResult<()> {
let func_metadata = func_metadata(py, &func)?;
let handler = PyHandler {
func,
is_coroutine: func_metadata.is_coroutine,
args: func_metadata.num_args,
};
tracing::info!(
name,
is_coroutine = handler.is_coroutine,
args = handler.args,
"registering handler function",
);
self.handlers().insert(name.to_string(), handler);
Ok(())
}
fn configure_python_event_loop<'py>(&self, py: Python<'py>) -> PyResult<&'py PyAny> {
let asyncio = py.import("asyncio")?;
match py.import("uvloop") {
Ok(uvloop) => {
uvloop.call_method0("install")?;
tracing::trace!("setting up uvloop for current process");
}
Err(_) => {
tracing::warn!("uvloop not found, using python standard event loop, which could have worse performance than uvloop");
}
}
let event_loop = asyncio.call_method0("new_event_loop")?;
asyncio.call_method1("set_event_loop", (event_loop,))?;
Ok(event_loop)
}
fn run_server(
&mut self,
py: Python,
address: Option<String>,
port: Option<i32>,
backlog: Option<i32>,
workers: Option<usize>,
tls: Option<PyTlsConfig>,
) -> PyResult<()> {
let mp = py.import("multiprocessing")?;
mp.call_method0("allow_connection_pickling")?;
#[cfg(target_os = "macos")]
mp.call_method(
"set_start_method",
("fork",),
Some(vec![("force", true)].into_py_dict(py)),
)?;
let address = address.unwrap_or_else(|| String::from("127.0.0.1"));
let port = port.unwrap_or(13734);
let socket = PySocket::new(address, port, backlog)?;
let mut active_workers = self.workers().lock();
for idx in 1..workers.unwrap_or_else(num_cpus::get) + 1 {
let sock = socket.try_clone()?;
let tls = tls.clone();
let process = mp.getattr("Process")?;
let handle = process.call1((
py.None(),
self.clone().into_py(py).getattr(py, "start_worker")?,
format!("smithy-rs-worker[{idx}]"),
(sock.into_py(py), idx, tls.into_py(py)),
))?;
handle.call_method0("start")?;
active_workers.push(handle.to_object(py));
}
drop(active_workers);
tracing::trace!("rust python server started successfully");
self.block_on_rust_signals();
Ok(())
}
fn run_lambda_handler(&mut self, py: Python) -> PyResult<()> {
use aws_smithy_http_server::routing::LambdaHandler;
let event_loop = self.configure_python_event_loop(py)?;
self.register_python_signals(py, event_loop.to_object(py))?;
let service = self.build_and_configure_service(py, event_loop)?;
tracing::trace!("start the tokio runtime in a background task");
thread::spawn(move || {
let rt = runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("unable to start a new tokio runtime for this process");
rt.block_on(async move {
let handler = LambdaHandler::new(service);
let lambda = lambda_http::run(handler);
tracing::debug!("starting lambda handler");
if let Err(err) = lambda.await {
tracing::error!(error = %err, "unable to start lambda handler");
}
});
});
tracing::trace!("run and block on the python event loop until a signal is received");
event_loop.call_method0("run_forever")?;
Ok(())
}
fn build_and_configure_service(
&mut self,
py: Python,
event_loop: &pyo3::PyAny,
) -> pyo3::PyResult<Service> {
let service = self.build_service(event_loop)?;
let context = PyContext::new(self.context().clone().unwrap_or_else(|| py.None()))?;
let service = ServiceBuilder::new()
.boxed_clone()
.layer(AddPyContextLayer::new(context))
.service(service);
Ok(service)
}
}
fn addr_incoming_from_socket(socket: Socket) -> AddrIncoming {
let std_listener: StdTcpListener = socket.into();
std_listener
.set_nonblocking(true)
.expect("unable to set `O_NONBLOCK=true` on `std::net::TcpListener`");
let listener = TcpListener::from_std(std_listener)
.expect("unable to create `tokio::net::TcpListener` from `std::net::TcpListener`");
AddrIncoming::from_listener(listener)
.expect("unable to create `AddrIncoming` from `TcpListener`")
}
fn tls_config_reloader(config: PyTlsConfig) -> (TlsAcceptor, mpsc::Receiver<TlsAcceptor>) {
let reload_dur = config.reload_duration();
let (tx, rx) = mpsc::channel();
let acceptor = TlsAcceptor::from(Arc::new(config.build().expect("invalid tls config")));
tokio::spawn(async move {
tracing::trace!(dur = ?reload_dur, "starting timer to reload tls config");
loop {
tokio::time::sleep(reload_dur).await;
tracing::trace!("reloading tls config");
match config.build() {
Ok(config) => {
let new_config = TlsAcceptor::from(Arc::new(config));
tx.send(new_config).expect("could not send new tls config")
}
Err(err) => {
tracing::error!(error = ?err, "could not reload tls config because it is invalid");
}
}
}
});
(acceptor, rx)
}