devtools_core/
server.rs

1use crate::{Command, Watcher};
2pub use devtools_wire_format as wire;
3use devtools_wire_format::instrument;
4use devtools_wire_format::instrument::instrument_server::InstrumentServer;
5use devtools_wire_format::instrument::{instrument_server, InstrumentRequest};
6use devtools_wire_format::meta::metadata_server;
7use devtools_wire_format::meta::metadata_server::MetadataServer;
8use devtools_wire_format::sources::sources_server::SourcesServer;
9use devtools_wire_format::tauri::tauri_server;
10use devtools_wire_format::tauri::tauri_server::TauriServer;
11use futures::{FutureExt, TryStreamExt};
12use http::HeaderValue;
13use hyper::Body;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::sync::{Arc, Mutex};
17use std::task::{Context, Poll};
18use tokio::sync::mpsc;
19use tonic::body::BoxBody;
20use tonic::codegen::http::Method;
21use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
22use tonic::codegen::BoxStream;
23use tonic::{Request, Response, Status};
24use tonic_health::pb::health_server::{Health, HealthServer};
25use tonic_health::server::HealthReporter;
26use tonic_health::ServingStatus;
27use tower::Service;
28use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
29use tower_layer::Layer;
30
31/// Default maximum capacity for the channel of events sent from a
32/// [`Server`] to each subscribed client.
33///
34/// When this capacity is exhausted, the client is assumed to be inactive,
35/// and may be disconnected.
36const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;
37
38/// The `gRPC` server that exposes the instrumenting API
39pub struct Server {
40    router: tonic::transport::server::Router<
41        tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
42    >,
43    handle: ServerHandle,
44}
45
46/// A handle to a server that is allowed to modify its properties (such as CORS allowed origins)
47#[allow(clippy::module_name_repetitions)]
48#[derive(Clone)]
49pub struct ServerHandle {
50    allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
51}
52
53impl ServerHandle {
54    /// Allow the given origin in the instrumentation server CORS.
55    #[allow(clippy::missing_panics_doc)]
56    pub fn allow_origin(&self, origin: HeaderValue) {
57        self.allowed_origins.lock().unwrap().push(origin);
58    }
59}
60
61struct InstrumentService {
62    tx: mpsc::Sender<Command>,
63    health_reporter: HealthReporter,
64}
65
66#[derive(Clone)]
67struct DynamicCorsLayer {
68    allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
69}
70
71impl<S> Layer<S> for DynamicCorsLayer {
72    type Service = DynamicCors<S>;
73
74    fn layer(&self, service: S) -> Self::Service {
75        DynamicCors {
76            inner: service,
77            allowed_origins: self.allowed_origins.clone(),
78        }
79    }
80}
81
82#[derive(Debug, Clone)]
83struct DynamicCors<S> {
84    inner: S,
85    allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
86}
87
88type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
89
90impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
91where
92    S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
93    S::Future: Send + 'static,
94{
95    type Response = S::Response;
96    type Error = S::Error;
97    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
98
99    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100        self.inner.poll_ready(cx)
101    }
102
103    fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
104        let allowed_origins = self.allowed_origins.lock().unwrap().clone();
105        let cors = CorsLayer::new()
106            // allow `GET` and `POST` when accessing the resource
107            .allow_methods([Method::GET, Method::POST])
108            .allow_headers(AllowHeaders::any())
109            .allow_origin(if allowed_origins.iter().any(|o| o == "*") {
110                AllowOrigin::any()
111            } else {
112                AllowOrigin::list(allowed_origins)
113            });
114
115        Box::pin(cors.layer(self.inner.clone()).call(req))
116    }
117}
118
119impl Server {
120    #[allow(clippy::missing_panics_doc)]
121    pub fn new(
122        cmd_tx: mpsc::Sender<Command>,
123        mut health_reporter: HealthReporter,
124        health_service: HealthServer<impl Health>,
125        tauri_server: impl tauri_server::Tauri,
126        metadata_server: impl metadata_server::Metadata,
127        sources_server: impl wire::sources::sources_server::Sources,
128    ) -> Self {
129        health_reporter
130            .set_serving::<InstrumentServer<InstrumentService>>()
131            .now_or_never();
132
133        let allowed_origins = Arc::new(Mutex::new(
134            if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
135                vec![HeaderValue::from_str("*").unwrap()]
136            } else {
137                vec![
138                    HeaderValue::from_str("https://devtools.crabnebula.dev").unwrap(),
139                    HeaderValue::from_str("tauri://localhost").unwrap(),
140                    #[cfg(windows)]
141                    HeaderValue::from_str("http://tauri.localhost").unwrap(),
142                ]
143            },
144        ));
145
146        let router = tonic::transport::Server::builder()
147            .accept_http1(true)
148            .layer(DynamicCorsLayer {
149                allowed_origins: allowed_origins.clone(),
150            })
151            .add_service(tonic_web::enable(health_service))
152            .add_service(tonic_web::enable(InstrumentServer::new(
153                InstrumentService {
154                    tx: cmd_tx,
155                    health_reporter,
156                },
157            )))
158            .add_service(tonic_web::enable(TauriServer::new(tauri_server)))
159            .add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
160            .add_service(tonic_web::enable(SourcesServer::new(sources_server)));
161
162        Self {
163            router,
164            handle: ServerHandle { allowed_origins },
165        }
166    }
167
168    #[must_use]
169    pub fn handle(&self) -> ServerHandle {
170        self.handle.clone()
171    }
172
173    /// Consumes this [`Server`] and returns a future that will execute the server.
174    ///
175    /// # Errors
176    ///
177    /// This function fails if the address is already in use or if we fail to start the server.
178    pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
179        tracing::info!("Listening on {}", addr);
180
181        self.router.serve(addr).await?;
182
183        Ok(())
184    }
185}
186
187impl InstrumentService {
188    async fn set_status(&self, status: ServingStatus) {
189        let mut r = self.health_reporter.clone();
190        r.set_service_status("rs.devtools.instrument.Instrument", status)
191            .await;
192    }
193}
194
195#[tonic::async_trait]
196impl instrument_server::Instrument for InstrumentService {
197    type WatchUpdatesStream = BoxStream<instrument::Update>;
198
199    async fn watch_updates(
200        &self,
201        req: Request<InstrumentRequest>,
202    ) -> Result<Response<Self::WatchUpdatesStream>, Status> {
203        if let Some(addr) = req.remote_addr() {
204            tracing::debug!(client.addr = %addr, "starting a new watch");
205        } else {
206            tracing::debug!(client.addr = %"<unknown>", "starting a new watch");
207        }
208
209        // reserve capacity to message the aggregator
210        let Ok(permit) = self.tx.reserve().await else {
211            self.set_status(ServingStatus::NotServing).await;
212            return Err(Status::internal(
213                "cannot start new watch, aggregation task is not running",
214            ));
215        };
216
217        // create output channel and send tx to the aggregator for tracking
218        let (tx, rx) = mpsc::channel(DEFAULT_CLIENT_BUFFER_CAPACITY);
219
220        permit.send(Command::Instrument(Watcher { tx }));
221
222        tracing::debug!("watch started");
223
224        let stream = ReceiverStream::new(rx).or_else(|err| async move {
225            tracing::error!("Aggregator failed with error {err:?}");
226
227            // TODO set the health service status to NotServing here
228
229            Err(Status::internal("boom"))
230        });
231
232        Ok(Response::new(Box::pin(stream)))
233    }
234}
235
236#[cfg(test)]
237mod test {
238    use super::*;
239    use devtools_wire_format::instrument::instrument_server::Instrument;
240
241    #[tokio::test]
242    async fn subscription() {
243        let (health_reporter, _) = tonic_health::server::health_reporter();
244        let (cmd_tx, mut cmd_rx) = mpsc::channel(1);
245        let srv = InstrumentService {
246            tx: cmd_tx,
247            health_reporter,
248        };
249
250        let _stream = srv
251            .watch_updates(Request::new(InstrumentRequest {
252                log_filter: None,
253                span_filter: None,
254            }))
255            .await
256            .unwrap();
257
258        let cmd = cmd_rx.recv().await.unwrap();
259
260        assert!(matches!(cmd, Command::Instrument(_)));
261    }
262}