use std::future::Future;
use std::pin::Pin;
use crate::app::subscription::BoxedSubscription;
use crate::overlay::Overlay;
use tokio_util::sync::CancellationToken;
#[derive(Default)]
pub struct Command<M> {
actions: Vec<CommandAction<M>>,
}
pub use crate::error::BoxedError;
pub type AsyncFallibleResult<M> = Result<Option<M>, BoxedError>;
pub(crate) enum CommandAction<M> {
Message(M),
Batch(Vec<M>),
Quit,
Callback(Box<dyn FnOnce() -> Option<M> + Send + 'static>),
Async(Pin<Box<dyn Future<Output = Option<M>> + Send + 'static>>),
AsyncFallible(Pin<Box<dyn Future<Output = AsyncFallibleResult<M>> + Send + 'static>>),
PushOverlay(Box<dyn Overlay<M> + Send>),
PopOverlay,
RequestCancelToken(Box<dyn FnOnce(CancellationToken) -> M + Send + 'static>),
Subscribe(BoxedSubscription<M>),
}
impl<M> CommandAction<M> {
#[cfg(feature = "tracing")]
pub(crate) fn kind_name(&self) -> &'static str {
match self {
CommandAction::Message(_) => "message",
CommandAction::Batch(_) => "batch",
CommandAction::Quit => "quit",
CommandAction::Callback(_) => "callback",
CommandAction::Async(_) => "async",
CommandAction::AsyncFallible(_) => "async_fallible",
CommandAction::PushOverlay(_) => "push_overlay",
CommandAction::PopOverlay => "pop_overlay",
CommandAction::RequestCancelToken(_) => "request_cancel_token",
CommandAction::Subscribe(_) => "subscribe",
}
}
}
impl<M> Command<M> {
pub fn none() -> Self {
Self {
actions: Vec::new(),
}
}
pub fn is_none(&self) -> bool {
self.actions.is_empty()
}
pub fn is_quit(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::Quit))
}
pub fn is_message(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::Message(_)))
}
pub fn is_batch(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::Batch(_)))
}
pub fn is_async(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::Async(_) | CommandAction::AsyncFallible(_)))
}
pub fn is_overlay_push(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::PushOverlay(_)))
}
pub fn is_overlay_pop(&self) -> bool {
self.actions
.iter()
.any(|a| matches!(a, CommandAction::PopOverlay))
}
pub fn action_count(&self) -> usize {
self.actions.len()
}
pub fn message(msg: M) -> Self {
Self {
actions: vec![CommandAction::Message(msg)],
}
}
pub fn batch(messages: impl IntoIterator<Item = M>) -> Self {
let msgs: Vec<M> = messages.into_iter().collect();
if msgs.is_empty() {
Self::none()
} else {
Self {
actions: vec![CommandAction::Batch(msgs)],
}
}
}
pub fn quit() -> Self {
Self {
actions: vec![CommandAction::Quit],
}
}
pub fn perform<F>(f: F) -> Self
where
F: FnOnce() -> Option<M> + Send + 'static,
{
Self {
actions: vec![CommandAction::Callback(Box::new(f))],
}
}
pub fn perform_async<Fut>(future: Fut) -> Self
where
Fut: Future<Output = Option<M>> + Send + 'static,
{
Self {
actions: vec![CommandAction::Async(Box::pin(future))],
}
}
pub fn future<Fut>(future: Fut) -> Self
where
Fut: Future<Output = Option<M>> + Send + 'static,
{
Self::perform_async(future)
}
pub fn spawn<Fut>(future: Fut) -> Self
where
Fut: Future<Output = ()> + Send + 'static,
{
Self::perform_async(async move {
future.await;
None
})
}
pub fn perform_async_fallible<Fut, T, E, F>(future: Fut, on_result: F) -> Self
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
F: FnOnce(Result<T, E>) -> M + Send + 'static,
M: Send + 'static,
{
Self {
actions: vec![CommandAction::Async(Box::pin(async move {
let result = future.await;
Some(on_result(result))
}))],
}
}
pub fn try_perform_async<Fut, T, E, F>(future: Fut, on_success: F) -> Self
where
Fut: Future<Output = Result<T, E>> + Send + 'static,
E: std::error::Error + Send + Sync + 'static,
F: FnOnce(T) -> Option<M> + Send + 'static,
M: Send + 'static,
{
Self {
actions: vec![CommandAction::AsyncFallible(Box::pin(async move {
match future.await {
Ok(value) => Ok(on_success(value)),
Err(e) => Err(Box::new(e) as BoxedError),
}
}))],
}
}
pub fn push_overlay(overlay: impl Overlay<M> + 'static) -> Self {
Self {
actions: vec![CommandAction::PushOverlay(Box::new(overlay))],
}
}
pub fn pop_overlay() -> Self {
Self {
actions: vec![CommandAction::PopOverlay],
}
}
pub fn request_cancel_token<F>(f: F) -> Self
where
F: FnOnce(CancellationToken) -> M + Send + 'static,
{
Self {
actions: vec![CommandAction::RequestCancelToken(Box::new(f))],
}
}
pub fn subscribe(subscription: BoxedSubscription<M>) -> Self
where
M: Send + 'static,
{
Self {
actions: vec![CommandAction::Subscribe(subscription)],
}
}
#[cfg(feature = "serialization")]
pub fn save_state<S: serde::Serialize>(
state: &S,
path: impl Into<std::path::PathBuf>,
) -> Command<M>
where
M: Send + 'static,
{
let json = match serde_json::to_string(state) {
Ok(json) => json,
Err(_) => return Command::none(),
};
let path = path.into();
Command::try_perform_async(async move { tokio::fs::write(path, json).await }, |_| None)
}
pub fn combine(commands: impl IntoIterator<Item = Command<M>>) -> Self {
let mut actions = Vec::new();
for cmd in commands {
actions.extend(cmd.actions);
}
Self { actions }
}
pub fn and(mut self, other: Command<M>) -> Self {
self.actions.extend(other.actions);
self
}
pub(crate) fn into_actions(self) -> Vec<CommandAction<M>> {
self.actions
}
pub fn map<N, F>(self, f: F) -> Command<N>
where
F: Fn(M) -> N + Clone + Send + 'static,
M: Send + 'static,
N: Send + 'static,
{
let actions = self
.actions
.into_iter()
.filter_map(|action| match action {
CommandAction::Message(m) => Some(CommandAction::Message(f(m))),
CommandAction::Batch(msgs) => Some(CommandAction::Batch(
msgs.into_iter().map(|m| f.clone()(m)).collect(),
)),
CommandAction::Quit => Some(CommandAction::Quit),
CommandAction::Callback(cb) => {
let f = f.clone();
Some(CommandAction::Callback(Box::new(move || cb().map(&f))))
}
CommandAction::Async(fut) => {
let f = f.clone();
Some(CommandAction::Async(Box::pin(
async move { fut.await.map(&f) },
)))
}
CommandAction::AsyncFallible(fut) => {
let f = f.clone();
Some(CommandAction::AsyncFallible(Box::pin(async move {
fut.await.map(|opt| opt.map(&f))
})))
}
CommandAction::PushOverlay(_) => None,
CommandAction::PopOverlay => Some(CommandAction::PopOverlay),
CommandAction::RequestCancelToken(cb) => {
let f = f.clone();
Some(CommandAction::RequestCancelToken(Box::new(move |token| {
f(cb(token))
})))
}
CommandAction::Subscribe(_) => None,
})
.collect();
Command { actions }
}
}
impl<M> std::fmt::Debug for Command<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Command")
.field("action_count", &self.actions.len())
.finish()
}
}
pub type BoxedFuture<M> = Pin<Box<dyn Future<Output = Option<M>> + Send + 'static>>;
pub type BoxedFallibleFuture<M> =
Pin<Box<dyn Future<Output = AsyncFallibleResult<M>> + Send + 'static>>;
pub(crate) type CancelTokenCallback<M> = Box<dyn FnOnce(CancellationToken) -> M + Send + 'static>;
pub struct CommandHandler<M> {
core: super::command_core::CommandHandlerCore<M>,
pending_futures: Vec<BoxedFuture<M>>,
pending_fallible_futures: Vec<BoxedFallibleFuture<M>>,
pending_cancel_token_requests: Vec<CancelTokenCallback<M>>,
}
impl<M: Send + 'static> CommandHandler<M> {
pub fn new() -> Self {
Self {
core: super::command_core::CommandHandlerCore::new(),
pending_futures: Vec::new(),
pending_fallible_futures: Vec::new(),
pending_cancel_token_requests: Vec::new(),
}
}
pub fn execute(&mut self, command: Command<M>) {
for action in command.into_actions() {
#[cfg(feature = "tracing")]
tracing::debug!(action = action.kind_name(), "executing command action");
if let Some(async_action) = self.core.execute_action(action) {
match async_action {
CommandAction::Async(fut) => {
self.pending_futures.push(fut);
}
CommandAction::AsyncFallible(fut) => {
self.pending_fallible_futures.push(fut);
}
CommandAction::RequestCancelToken(cb) => {
self.pending_cancel_token_requests.push(cb);
}
_ => unreachable!("execute_action only returns async or cancel-token actions"),
}
}
}
}
pub fn spawn_pending(
&mut self,
msg_tx: tokio::sync::mpsc::Sender<M>,
err_tx: tokio::sync::mpsc::Sender<BoxedError>,
cancel: tokio_util::sync::CancellationToken,
) {
#[cfg(feature = "tracing")]
{
let regular = self.pending_futures.len();
let fallible = self.pending_fallible_futures.len();
if regular > 0 || fallible > 0 {
tracing::debug!(regular, fallible, "spawning async command tasks");
}
}
for fut in self.pending_futures.drain(..) {
let tx = msg_tx.clone();
let cancel = cancel.clone();
tokio::spawn(async move {
tokio::select! {
result = fut => {
if let Some(msg) = result {
let _ = tx.send(msg).await;
}
}
_ = cancel.cancelled() => {
}
}
});
}
for fut in self.pending_fallible_futures.drain(..) {
let msg_tx = msg_tx.clone();
let err_tx = err_tx.clone();
let cancel = cancel.clone();
tokio::spawn(async move {
tokio::select! {
result = fut => {
match result {
Ok(Some(msg)) => {
let _ = msg_tx.send(msg).await;
}
Ok(None) => {
}
Err(e) => {
let _ = err_tx.send(e).await;
}
}
}
_ = cancel.cancelled() => {
}
}
});
}
}
pub fn take_messages(&mut self) -> Vec<M> {
self.core.take_messages()
}
pub fn take_overlay_pushes(&mut self) -> Vec<Box<dyn Overlay<M> + Send>> {
self.core.take_overlay_pushes()
}
pub fn take_overlay_pops(&mut self) -> usize {
self.core.take_overlay_pops()
}
pub(crate) fn take_subscriptions(&mut self) -> Vec<BoxedSubscription<M>> {
self.core.take_subscriptions()
}
pub(crate) fn take_cancel_token_requests(&mut self) -> Vec<CancelTokenCallback<M>> {
std::mem::take(&mut self.pending_cancel_token_requests)
}
pub fn has_pending_futures(&self) -> bool {
!self.pending_futures.is_empty()
}
pub fn pending_future_count(&self) -> usize {
self.pending_futures.len()
}
pub fn should_quit(&self) -> bool {
self.core.should_quit()
}
pub fn reset_quit(&mut self) {
self.core.reset_quit()
}
}
impl<M: Send + 'static> Default for CommandHandler<M> {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests;