use crate::{Command, Watcher};
pub use devtools_wire_format as wire;
use devtools_wire_format::instrument;
use devtools_wire_format::instrument::instrument_server::InstrumentServer;
use devtools_wire_format::instrument::{instrument_server, InstrumentRequest};
use devtools_wire_format::meta::metadata_server;
use devtools_wire_format::meta::metadata_server::MetadataServer;
use devtools_wire_format::sources::sources_server::SourcesServer;
use devtools_wire_format::tauri::tauri_server;
use devtools_wire_format::tauri::tauri_server::TauriServer;
use futures::{FutureExt, TryStreamExt};
use http::HeaderValue;
use hyper::Body;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tonic::body::BoxBody;
use tonic::codegen::http::Method;
use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
use tonic::codegen::BoxStream;
use tonic::{Request, Response, Status};
use tonic_health::pb::health_server::{Health, HealthServer};
use tonic_health::server::HealthReporter;
use tonic_health::ServingStatus;
use tower::Service;
use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
use tower_layer::Layer;
const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;
pub struct Server {
router: tonic::transport::server::Router<
tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
>,
handle: ServerHandle,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Clone)]
pub struct ServerHandle {
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}
impl ServerHandle {
#[allow(clippy::missing_panics_doc)]
pub fn allow_origin(&self, origin: impl Into<AllowOrigin>) {
self.allowed_origins.lock().unwrap().push(origin.into());
}
}
struct InstrumentService {
tx: mpsc::Sender<Command>,
health_reporter: HealthReporter,
}
#[derive(Clone)]
struct DynamicCorsLayer {
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}
impl<S> Layer<S> for DynamicCorsLayer {
type Service = DynamicCors<S>;
fn layer(&self, service: S) -> Self::Service {
DynamicCors {
inner: service,
allowed_origins: self.allowed_origins.clone(),
}
}
}
#[derive(Debug, Clone)]
struct DynamicCors<S> {
inner: S,
allowed_origins: Arc<Mutex<Vec<AllowOrigin>>>,
}
type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
where
S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
let mut cors = CorsLayer::new()
.allow_methods([Method::GET, Method::POST])
.allow_headers(AllowHeaders::any());
for origin in &*self.allowed_origins.lock().unwrap() {
cors = cors.allow_origin(origin.clone());
}
Box::pin(cors.layer(self.inner.clone()).call(req))
}
}
impl Server {
#[allow(clippy::missing_panics_doc)]
pub fn new(
cmd_tx: mpsc::Sender<Command>,
mut health_reporter: HealthReporter,
health_service: HealthServer<impl Health>,
tauri_server: impl tauri_server::Tauri,
metadata_server: impl metadata_server::Metadata,
sources_server: impl wire::sources::sources_server::Sources,
) -> Self {
health_reporter
.set_serving::<InstrumentServer<InstrumentService>>()
.now_or_never();
let allowed_origins =
Arc::new(Mutex::new(vec![
if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
AllowOrigin::from(tower_http::cors::Any)
} else {
HeaderValue::from_str("https://devtools.crabnebula.dev")
.unwrap()
.into()
},
]));
let router = tonic::transport::Server::builder()
.accept_http1(true)
.layer(DynamicCorsLayer {
allowed_origins: allowed_origins.clone(),
})
.add_service(tonic_web::enable(health_service))
.add_service(tonic_web::enable(InstrumentServer::new(
InstrumentService {
tx: cmd_tx,
health_reporter,
},
)))
.add_service(tonic_web::enable(TauriServer::new(tauri_server)))
.add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
.add_service(tonic_web::enable(SourcesServer::new(sources_server)));
Self {
router,
handle: ServerHandle { allowed_origins },
}
}
#[must_use]
pub fn handle(&self) -> ServerHandle {
self.handle.clone()
}
pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
tracing::info!("Listening on {}", addr);
self.router.serve(addr).await?;
Ok(())
}
}
impl InstrumentService {
async fn set_status(&self, status: ServingStatus) {
let mut r = self.health_reporter.clone();
r.set_service_status("rs.devtools.instrument.Instrument", status)
.await;
}
}
#[tonic::async_trait]
impl instrument_server::Instrument for InstrumentService {
type WatchUpdatesStream = BoxStream<instrument::Update>;
async fn watch_updates(
&self,
req: Request<InstrumentRequest>,
) -> Result<Response<Self::WatchUpdatesStream>, Status> {
if let Some(addr) = req.remote_addr() {
tracing::debug!(client.addr = %addr, "starting a new watch");
} else {
tracing::debug!(client.addr = %"<unknown>", "starting a new watch");
}
let Ok(permit) = self.tx.reserve().await else {
self.set_status(ServingStatus::NotServing).await;
return Err(Status::internal(
"cannot start new watch, aggregation task is not running",
));
};
let (tx, rx) = mpsc::channel(DEFAULT_CLIENT_BUFFER_CAPACITY);
permit.send(Command::Instrument(Watcher { tx }));
tracing::debug!("watch started");
let stream = ReceiverStream::new(rx).or_else(|err| async move {
tracing::error!("Aggregator failed with error {err:?}");
Err(Status::internal("boom"))
});
Ok(Response::new(Box::pin(stream)))
}
}
#[cfg(test)]
mod test {
use super::*;
use devtools_wire_format::instrument::instrument_server::Instrument;
#[tokio::test]
async fn subscription() {
let (health_reporter, _) = tonic_health::server::health_reporter();
let (cmd_tx, mut cmd_rx) = mpsc::channel(1);
let srv = InstrumentService {
tx: cmd_tx,
health_reporter,
};
let _stream = srv
.watch_updates(Request::new(InstrumentRequest {
log_filter: None,
span_filter: None,
}))
.await
.unwrap();
let cmd = cmd_rx.recv().await.unwrap();
assert!(matches!(cmd, Command::Instrument(_)));
}
}