use std::error::Error;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Mutex, OnceLock};
use tokio::sync::{mpsc, oneshot};
pub use bmux_plugin_sdk::prompt::{
PromptEvent, PromptField, PromptFormField, PromptFormFieldKind, PromptFormSection,
PromptFormValue, PromptOption, PromptPolicy, PromptRequest, PromptResponse, PromptValidation,
PromptValue, PromptWidth,
};
#[derive(Debug)]
pub struct PromptHostRequest {
pub request: PromptRequest,
pub response_tx: oneshot::Sender<PromptResponse>,
pub event_tx: Option<mpsc::UnboundedSender<PromptEvent>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PromptSubmitError {
HostUnavailable,
HostDisconnected,
}
impl fmt::Display for PromptSubmitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HostUnavailable => f.write_str("prompt host unavailable"),
Self::HostDisconnected => f.write_str("prompt host disconnected"),
}
}
}
impl Error for PromptSubmitError {}
#[derive(Clone)]
struct PromptHostRegistration {
id: u64,
sender: mpsc::UnboundedSender<PromptHostRequest>,
}
static HOST_REGISTRY: OnceLock<Mutex<Option<PromptHostRegistration>>> = OnceLock::new();
static HOST_REGISTRATION_SEQUENCE: AtomicU64 = AtomicU64::new(1);
fn host_registry() -> &'static Mutex<Option<PromptHostRegistration>> {
HOST_REGISTRY.get_or_init(|| Mutex::new(None))
}
#[derive(Debug)]
pub struct PromptHostGuard {
id: u64,
}
impl Drop for PromptHostGuard {
fn drop(&mut self) {
if let Ok(mut slot) = host_registry().lock()
&& slot
.as_ref()
.is_some_and(|registration| registration.id == self.id)
{
*slot = None;
}
}
}
pub fn register_host(sender: mpsc::UnboundedSender<PromptHostRequest>) -> PromptHostGuard {
let id = HOST_REGISTRATION_SEQUENCE.fetch_add(1, Ordering::Relaxed);
if let Ok(mut slot) = host_registry().lock() {
*slot = Some(PromptHostRegistration { id, sender });
}
PromptHostGuard { id }
}
pub fn submit(
request: PromptRequest,
) -> std::result::Result<oneshot::Receiver<PromptResponse>, PromptSubmitError> {
submit_inner(request, None)
}
pub fn submit_with_events(
request: PromptRequest,
) -> std::result::Result<
(
oneshot::Receiver<PromptResponse>,
mpsc::UnboundedReceiver<PromptEvent>,
),
PromptSubmitError,
> {
let (event_tx, event_rx) = mpsc::unbounded_channel();
let response_rx = submit_inner(request, Some(event_tx))?;
Ok((response_rx, event_rx))
}
fn submit_inner(
request: PromptRequest,
event_tx: Option<mpsc::UnboundedSender<PromptEvent>>,
) -> std::result::Result<oneshot::Receiver<PromptResponse>, PromptSubmitError> {
let guard = host_registry()
.lock()
.map_err(|_| PromptSubmitError::HostDisconnected)?;
let sender = guard
.as_ref()
.map(|registration| registration.sender.clone())
.ok_or(PromptSubmitError::HostUnavailable)?;
drop(guard);
let (response_tx, response_rx) = oneshot::channel();
sender
.send(PromptHostRequest {
request,
response_tx,
event_tx,
})
.map_err(|_| PromptSubmitError::HostDisconnected)?;
Ok(response_rx)
}
pub async fn request(
request: PromptRequest,
) -> std::result::Result<PromptResponse, PromptSubmitError> {
let response_rx = submit(request)?;
response_rx
.await
.map_err(|_| PromptSubmitError::HostDisconnected)
}
pub fn request_with_events(
request: PromptRequest,
) -> std::result::Result<
(
oneshot::Receiver<PromptResponse>,
mpsc::UnboundedReceiver<PromptEvent>,
),
PromptSubmitError,
> {
submit_with_events(request)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[serial_test::serial]
async fn request_fails_when_no_host_is_registered() {
let response = request(PromptRequest::confirm("missing host")).await;
assert_eq!(response, Err(PromptSubmitError::HostUnavailable));
}
#[tokio::test]
#[serial_test::serial]
async fn request_routes_through_registered_host() {
let (tx, mut rx) = mpsc::unbounded_channel();
let _guard = register_host(tx);
let client_task =
tokio::spawn(async { request(PromptRequest::confirm("quit session?")).await });
let host_request = rx.recv().await.expect("host should receive request");
assert_eq!(host_request.request.title, "quit session?");
host_request
.response_tx
.send(PromptResponse::Cancelled)
.expect("host should send response");
let response = client_task
.await
.expect("request task should complete")
.expect("request should resolve");
assert_eq!(response, PromptResponse::Cancelled);
}
#[tokio::test]
#[serial_test::serial]
async fn dropping_host_guard_unregisters_the_host() {
let (tx, _rx) = mpsc::unbounded_channel();
let guard = register_host(tx);
drop(guard);
let receiver = submit(PromptRequest::confirm("hello"));
assert!(matches!(receiver, Err(PromptSubmitError::HostUnavailable)));
}
}