use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use std::task::{Context, Poll};
use crate::{
get_state, http,
http::server::{HttpBindingConfig, HttpServer, IncomingHttpRequest, WsBindingConfig},
logging::{error, info},
set_state, timer, Address, BuildError, LazyLoadBlob, Message, Request, SendError,
};
use futures_channel::oneshot;
use futures_util::task::{waker_ref, ArcWake};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
thread_local! {
static SPAWN_QUEUE: RefCell<Vec<Pin<Box<dyn Future<Output = ()>>>>> = RefCell::new(Vec::new());
pub static APP_CONTEXT: RefCell<AppContext> = RefCell::new(AppContext {
hidden_state: None,
executor: Executor::new(),
});
pub static RESPONSE_REGISTRY: RefCell<HashMap<String, Vec<u8>>> = RefCell::new(HashMap::new());
pub static CANCELLED_RESPONSES: RefCell<HashSet<String>> = RefCell::new(HashSet::new());
pub static APP_HELPERS: RefCell<AppHelpers> = RefCell::new(AppHelpers {
current_server: None,
current_message: None,
current_http_context: None,
});
}
#[derive(Clone)]
pub struct HttpRequestContext {
pub request: IncomingHttpRequest,
pub response_headers: HashMap<String, String>,
pub response_status: http::StatusCode,
}
pub struct AppContext {
pub hidden_state: Option<HiddenState>,
pub executor: Executor,
}
pub struct AppHelpers {
pub current_server: Option<*mut HttpServer>,
pub current_message: Option<Message>,
pub current_http_context: Option<HttpRequestContext>,
}
pub fn get_path() -> Option<String> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.and_then(|ctx| ctx.request.path().ok())
})
}
pub fn get_server() -> Option<&'static mut HttpServer> {
APP_HELPERS.with(|ctx| ctx.borrow().current_server.map(|ptr| unsafe { &mut *ptr }))
}
pub fn get_ws_channel_addr(channel_id: u32) -> Option<String> {
get_server().and_then(|server| server.get_ws_channel_addr(channel_id).cloned())
}
pub fn get_http_request() -> Option<IncomingHttpRequest> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.map(|ctx| ctx.request.clone())
})
}
pub fn get_http_method() -> Option<String> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.and_then(|ctx| ctx.request.method().ok())
.map(|m| m.to_string())
})
}
pub fn get_request_header(name: &str) -> Option<String> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.and_then(|ctx| {
let header_name = http::HeaderName::from_bytes(name.as_bytes()).ok()?;
ctx.request
.headers()
.get(&header_name)
.and_then(|value| value.to_str().ok())
.map(|s| s.to_string())
})
})
}
pub fn get_request_url() -> Option<String> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.and_then(|ctx| ctx.request.url().ok())
.map(|url| url.to_string())
})
}
pub fn set_response_headers(headers: HashMap<String, String>) {
APP_HELPERS.with(|helpers| {
if let Some(ctx) = &mut helpers.borrow_mut().current_http_context {
ctx.response_headers = headers;
}
})
}
pub fn add_response_header(key: String, value: String) {
APP_HELPERS.with(|helpers| {
if let Some(ctx) = &mut helpers.borrow_mut().current_http_context {
ctx.response_headers.insert(key, value);
}
})
}
pub fn set_response_status(status: http::StatusCode) {
APP_HELPERS.with(|helpers| {
if let Some(ctx) = &mut helpers.borrow_mut().current_http_context {
ctx.response_status = status;
}
})
}
pub fn clear_http_request_context() {
APP_HELPERS.with(|helpers| {
helpers.borrow_mut().current_http_context = None;
})
}
pub fn source() -> Address {
APP_HELPERS.with(|ctx| {
ctx.borrow()
.current_message
.as_ref()
.expect("No message in current context")
.source()
.clone()
})
}
pub fn get_query_params() -> Option<HashMap<String, String>> {
APP_HELPERS.with(|helpers| {
helpers
.borrow()
.current_http_context
.as_ref()
.map(|ctx| ctx.request.query_params().clone())
})
}
pub struct Executor {
tasks: Vec<Pin<Box<dyn Future<Output = ()>>>>,
}
struct ExecutorWakeFlag {
triggered: AtomicBool,
}
impl ExecutorWakeFlag {
fn new() -> Self {
Self {
triggered: AtomicBool::new(false),
}
}
fn take(&self) -> bool {
self.triggered.swap(false, Ordering::SeqCst)
}
}
impl ArcWake for ExecutorWakeFlag {
fn wake_by_ref(arc_self: &Arc<Self>) {
arc_self.triggered.store(true, Ordering::SeqCst);
}
}
pub struct JoinHandle<T> {
receiver: oneshot::Receiver<T>,
}
impl<T> Future for JoinHandle<T> {
type Output = Result<T, oneshot::Canceled>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let receiver = &mut self.get_mut().receiver;
Pin::new(receiver).poll(cx)
}
}
pub fn spawn<T>(fut: impl Future<Output = T> + 'static) -> JoinHandle<T>
where
T: 'static,
{
let (sender, receiver) = oneshot::channel();
SPAWN_QUEUE.with(|queue| {
queue.borrow_mut().push(Box::pin(async move {
let result = fut.await;
let _ = sender.send(result);
}));
});
JoinHandle { receiver }
}
impl Executor {
pub fn new() -> Self {
Self { tasks: Vec::new() }
}
pub fn poll_all_tasks(&mut self) {
let wake_flag = Arc::new(ExecutorWakeFlag::new());
loop {
SPAWN_QUEUE.with(|queue| {
self.tasks.append(&mut queue.borrow_mut());
});
let mut completed = Vec::new();
{
let waker = waker_ref(&wake_flag);
let mut ctx = Context::from_waker(&waker);
for i in 0..self.tasks.len() {
if let Poll::Ready(()) = self.tasks[i].as_mut().poll(&mut ctx) {
completed.push(i);
}
}
}
for idx in completed.into_iter().rev() {
let _ = self.tasks.remove(idx);
}
let has_new_tasks = SPAWN_QUEUE.with(|queue| !queue.borrow().is_empty());
let was_woken = wake_flag.take();
if !has_new_tasks && !was_woken {
break;
}
}
}
}
struct ResponseFuture {
correlation_id: String,
http_context: Option<HttpRequestContext>,
resolved: bool,
}
impl ResponseFuture {
fn new(correlation_id: String) -> Self {
let http_context =
APP_HELPERS.with(|helpers| helpers.borrow().current_http_context.clone());
Self {
correlation_id,
http_context,
resolved: false,
}
}
}
impl Future for ResponseFuture {
type Output = Vec<u8>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let maybe_bytes = RESPONSE_REGISTRY.with(|registry| {
let mut registry_mut = registry.borrow_mut();
registry_mut.remove(&this.correlation_id)
});
if let Some(bytes) = maybe_bytes {
this.resolved = true;
if let Some(ref context) = this.http_context {
APP_HELPERS.with(|helpers| {
helpers.borrow_mut().current_http_context = Some(context.clone());
});
}
Poll::Ready(bytes)
} else {
Poll::Pending
}
}
}
impl Drop for ResponseFuture {
fn drop(&mut self) {
if self.resolved {
return;
}
RESPONSE_REGISTRY.with(|registry| {
registry.borrow_mut().remove(&self.correlation_id);
});
CANCELLED_RESPONSES.with(|set| {
set.borrow_mut().insert(self.correlation_id.clone());
});
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Error)]
pub enum AppSendError {
#[error("SendError: {0}")]
SendError(SendError),
#[error("BuildError: {0}")]
BuildError(BuildError),
}
pub async fn sleep(sleep_ms: u64) -> Result<(), AppSendError> {
let request = Request::to(("our", "timer", "distro", "sys"))
.body(timer::TimerAction::SetTimer(sleep_ms))
.expects_response((sleep_ms / 1_000) + 1);
let correlation_id = Uuid::new_v4().to_string();
if let Err(e) = request.context(correlation_id.as_bytes().to_vec()).send() {
return Err(AppSendError::BuildError(e));
}
let _ = ResponseFuture::new(correlation_id).await;
return Ok(());
}
pub async fn send<R>(request: Request) -> Result<R, AppSendError>
where
R: serde::de::DeserializeOwned,
{
let request = if request.timeout.is_some() {
request
} else {
request.expects_response(30)
};
let correlation_id = Uuid::new_v4().to_string();
if let Err(e) = request.context(correlation_id.as_bytes().to_vec()).send() {
return Err(AppSendError::BuildError(e));
}
let response_bytes = ResponseFuture::new(correlation_id).await;
if let Ok(r) = serde_json::from_slice::<R>(&response_bytes) {
return Ok(r);
}
match serde_json::from_slice::<SendError>(&response_bytes) {
Ok(e) => Err(AppSendError::SendError(e)),
Err(err) => {
error!(
"Failed to deserialize response in send(): {} (payload: {:?})",
err, response_bytes
);
Err(AppSendError::BuildError(BuildError::NoBody))
}
}
}
pub async fn send_rmp<R>(request: Request) -> Result<R, AppSendError>
where
R: serde::de::DeserializeOwned,
{
let request = if request.timeout.is_some() {
request
} else {
request.expects_response(30)
};
let correlation_id = Uuid::new_v4().to_string();
if let Err(e) = request.context(correlation_id.as_bytes().to_vec()).send() {
return Err(AppSendError::BuildError(e));
}
let response_bytes = ResponseFuture::new(correlation_id).await;
if let Ok(r) = rmp_serde::from_slice::<R>(&response_bytes) {
return Ok(r);
}
match rmp_serde::from_slice::<SendError>(&response_bytes) {
Ok(e) => Err(AppSendError::SendError(e)),
Err(err) => {
error!(
"Failed to deserialize response in send_rmp(): {} (payload: {:?})",
err, response_bytes
);
Err(AppSendError::BuildError(BuildError::NoBody))
}
}
}
#[derive(Clone)]
pub enum SaveOptions {
Never,
EveryMessage,
EveryNMessage(u64),
EveryNSeconds(u64),
OnDiff,
}
pub struct HiddenState {
save_config: SaveOptions,
message_count: u64,
old_state: Option<Vec<u8>>, }
impl HiddenState {
pub fn new(save_config: SaveOptions) -> Self {
Self {
save_config,
message_count: 0,
old_state: None,
}
}
fn should_save_state(&mut self) -> bool {
match self.save_config {
SaveOptions::Never => false,
SaveOptions::EveryMessage => true,
SaveOptions::EveryNMessage(n) => {
self.message_count += 1;
if self.message_count >= n {
self.message_count = 0;
true
} else {
false
}
}
SaveOptions::EveryNSeconds(_) => false, SaveOptions::OnDiff => false, }
}
}
pub fn store_old_state<S>(state: &S)
where
S: serde::Serialize,
{
APP_CONTEXT.with(|ctx| {
let mut ctx_mut = ctx.borrow_mut();
if let Some(ref mut hidden_state) = ctx_mut.hidden_state {
if matches!(hidden_state.save_config, SaveOptions::OnDiff)
&& hidden_state.old_state.is_none()
{
if let Ok(s_bytes) = rmp_serde::to_vec(state) {
hidden_state.old_state = Some(s_bytes);
}
}
}
});
}
pub trait State {
fn new() -> Self;
}
pub fn initialize_state<S>() -> S
where
S: serde::de::DeserializeOwned + Default,
{
match get_state() {
Some(bytes) => match rmp_serde::from_slice::<S>(&bytes) {
Ok(state) => state,
Err(e) => {
panic!("error deserializing existing state: {e}. We're panicking because we don't want to nuke state by setting it to a new instance.");
}
},
None => {
info!("no existing state, creating new one");
S::default()
}
}
}
pub fn setup_server(
ui_config: Option<&HttpBindingConfig>,
ui_path: Option<String>,
endpoints: &[Binding],
) -> http::server::HttpServer {
let mut server = http::server::HttpServer::new(5);
if let Some(ui) = ui_config {
if let Err(e) = server.serve_ui(
&ui_path.unwrap_or_else(|| "ui".to_string()),
vec!["/"],
ui.clone(),
) {
panic!("failed to serve UI: {e}. Make sure that a ui folder is in /pkg");
}
}
let mut seen_paths = std::collections::HashSet::new();
for endpoint in endpoints.iter() {
let path = match endpoint {
Binding::Http { path, .. } => path,
Binding::Ws { path, .. } => path,
};
if !seen_paths.insert(path) {
panic!("duplicate path found: {}", path);
}
}
for endpoint in endpoints {
match endpoint {
Binding::Http { path, config } => {
server
.bind_http_path(path.to_string(), config.clone())
.expect("failed to serve API path");
}
Binding::Ws { path, config } => {
server
.bind_ws_path(path.to_string(), config.clone())
.expect("failed to bind WS path");
}
}
}
server
}
pub fn pretty_print_send_error(error: &SendError) {
let kind = &error.kind;
let target = &error.target;
let body = String::from_utf8(error.message.body().to_vec())
.map(|s| format!("\"{}\"", s))
.unwrap_or_else(|_| format!("{:?}", error.message.body()));
let context = error
.context
.as_ref()
.map(|bytes| String::from_utf8_lossy(bytes).into_owned());
error!(
"SendError {{
kind: {:?},
target: {},
body: {},
context: {}
}}",
kind,
target,
body,
context
.map(|s| format!("\"{}\"", s))
.unwrap_or("None".to_string())
);
}
pub fn no_init_fn<S>(_state: &mut S) {
}
pub fn no_ws_handler<S>(
_state: &mut S,
_server: &mut http::server::HttpServer,
_channel_id: u32,
_msg_type: http::server::WsMessageType,
_blob: LazyLoadBlob,
) {
}
pub fn no_http_api_call<S>(_state: &mut S, _req: ()) {
}
pub fn no_local_request<S>(_msg: &Message, _state: &mut S, _req: ()) {
}
pub fn no_remote_request<S>(_msg: &Message, _state: &mut S, _req: ()) {
}
#[derive(Clone, Debug)]
pub enum Binding {
Http {
path: &'static str,
config: HttpBindingConfig,
},
Ws {
path: &'static str,
config: WsBindingConfig,
},
}
pub fn maybe_save_state<S>(state: &S)
where
S: serde::Serialize,
{
APP_CONTEXT.with(|ctx| {
let mut ctx_mut = ctx.borrow_mut();
if let Some(ref mut hidden_state) = ctx_mut.hidden_state {
let should_save = if matches!(hidden_state.save_config, SaveOptions::OnDiff) {
if let Ok(current_bytes) = rmp_serde::to_vec(state) {
let state_changed = match &hidden_state.old_state {
Some(old_bytes) => old_bytes != ¤t_bytes,
None => true, };
if state_changed {
true
} else {
false
}
} else {
false
}
} else {
hidden_state.should_save_state()
};
if should_save {
if let Ok(s_bytes) = rmp_serde::to_vec(state) {
let _ = set_state(&s_bytes);
if matches!(hidden_state.save_config, SaveOptions::OnDiff) {
hidden_state.old_state = None;
}
}
}
}
});
}