use std::sync::Arc;
use tokio::sync::OnceCell;
use tokio_util::sync::{CancellationToken, DropGuard};
pub struct ProxyHandle {
pub url: String,
_shutdown: DropGuard,
}
pub struct ProxySpawner {
cell: OnceCell<Arc<ProxyHandle>>,
handle: Option<tokio::runtime::Handle>,
config_builder_fn:
Box<dyn Fn() -> objectiveai_mcp_proxy::ConfigBuilder + Send + Sync + 'static>,
}
impl ProxySpawner {
pub fn new<F>(config_builder_fn: F) -> Self
where
F: Fn() -> objectiveai_mcp_proxy::ConfigBuilder + Send + Sync + 'static,
{
Self {
cell: OnceCell::new(),
handle: None,
config_builder_fn: Box::new(config_builder_fn),
}
}
pub fn new_with_handle<F>(handle: tokio::runtime::Handle, config_builder_fn: F) -> Self
where
F: Fn() -> objectiveai_mcp_proxy::ConfigBuilder + Send + Sync + 'static,
{
Self {
cell: OnceCell::new(),
handle: Some(handle),
config_builder_fn: Box::new(config_builder_fn),
}
}
pub async fn get(&self) -> std::io::Result<Arc<ProxyHandle>> {
self.cell
.get_or_try_init(|| async {
let mut builder = (self.config_builder_fn)();
builder.address = Some("127.0.0.1".into());
builder.port = Some(0);
builder.suppress_output = Some(true);
let config = builder.build();
let cancel = CancellationToken::new();
let token = cancel.clone();
let (addr_tx, addr_rx) =
tokio::sync::oneshot::channel::<std::io::Result<std::net::SocketAddr>>();
let task = async move {
match objectiveai_mcp_proxy::setup(config).await {
Ok((listener, router)) => {
let addr = listener.local_addr();
let send_result = match addr {
Ok(a) => addr_tx.send(Ok(a)),
Err(e) => {
let _ = addr_tx.send(Err(e));
return;
}
};
if send_result.is_err() {
return;
}
let _ = axum::serve(listener, router)
.with_graceful_shutdown(token.cancelled_owned())
.await;
}
Err(e) => {
let _ = addr_tx.send(Err(e));
}
}
};
match &self.handle {
Some(h) => {
h.spawn(task);
}
None => {
tokio::spawn(task);
}
}
let addr = addr_rx
.await
.map_err(|_| std::io::Error::other("proxy task dropped before reporting addr"))??;
Ok(Arc::new(ProxyHandle {
url: format!("http://{addr}"),
_shutdown: cancel.drop_guard(),
}))
})
.await
.map(Arc::clone)
}
}