1use crate::{Command, Watcher};
2pub use devtools_wire_format as wire;
3use devtools_wire_format::instrument;
4use devtools_wire_format::instrument::instrument_server::InstrumentServer;
5use devtools_wire_format::instrument::{instrument_server, InstrumentRequest};
6use devtools_wire_format::meta::metadata_server;
7use devtools_wire_format::meta::metadata_server::MetadataServer;
8use devtools_wire_format::sources::sources_server::SourcesServer;
9use devtools_wire_format::tauri::tauri_server;
10use devtools_wire_format::tauri::tauri_server::TauriServer;
11use futures::{FutureExt, TryStreamExt};
12use http::HeaderValue;
13use hyper::Body;
14use std::net::SocketAddr;
15use std::pin::Pin;
16use std::sync::{Arc, Mutex};
17use std::task::{Context, Poll};
18use tokio::sync::mpsc;
19use tonic::body::BoxBody;
20use tonic::codegen::http::Method;
21use tonic::codegen::tokio_stream::wrappers::ReceiverStream;
22use tonic::codegen::BoxStream;
23use tonic::{Request, Response, Status};
24use tonic_health::pb::health_server::{Health, HealthServer};
25use tonic_health::server::HealthReporter;
26use tonic_health::ServingStatus;
27use tower::Service;
28use tower_http::cors::{AllowHeaders, AllowOrigin, CorsLayer};
29use tower_layer::Layer;
30
31const DEFAULT_CLIENT_BUFFER_CAPACITY: usize = 1024 * 4;
37
38pub struct Server {
40 router: tonic::transport::server::Router<
41 tower_layer::Stack<DynamicCorsLayer, tower_layer::Identity>,
42 >,
43 handle: ServerHandle,
44}
45
46#[allow(clippy::module_name_repetitions)]
48#[derive(Clone)]
49pub struct ServerHandle {
50 allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
51}
52
53impl ServerHandle {
54 #[allow(clippy::missing_panics_doc)]
56 pub fn allow_origin(&self, origin: HeaderValue) {
57 self.allowed_origins.lock().unwrap().push(origin);
58 }
59}
60
61struct InstrumentService {
62 tx: mpsc::Sender<Command>,
63 health_reporter: HealthReporter,
64}
65
66#[derive(Clone)]
67struct DynamicCorsLayer {
68 allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
69}
70
71impl<S> Layer<S> for DynamicCorsLayer {
72 type Service = DynamicCors<S>;
73
74 fn layer(&self, service: S) -> Self::Service {
75 DynamicCors {
76 inner: service,
77 allowed_origins: self.allowed_origins.clone(),
78 }
79 }
80}
81
82#[derive(Debug, Clone)]
83struct DynamicCors<S> {
84 inner: S,
85 allowed_origins: Arc<Mutex<Vec<HeaderValue>>>,
86}
87
88type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
89
90impl<S> Service<hyper::Request<Body>> for DynamicCors<S>
91where
92 S: Service<hyper::Request<Body>, Response = hyper::Response<BoxBody>> + Clone + Send + 'static,
93 S::Future: Send + 'static,
94{
95 type Response = S::Response;
96 type Error = S::Error;
97 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
98
99 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
100 self.inner.poll_ready(cx)
101 }
102
103 fn call(&mut self, req: hyper::Request<Body>) -> Self::Future {
104 let allowed_origins = self.allowed_origins.lock().unwrap().clone();
105 let cors = CorsLayer::new()
106 .allow_methods([Method::GET, Method::POST])
108 .allow_headers(AllowHeaders::any())
109 .allow_origin(if allowed_origins.iter().any(|o| o == "*") {
110 AllowOrigin::any()
111 } else {
112 AllowOrigin::list(allowed_origins)
113 });
114
115 Box::pin(cors.layer(self.inner.clone()).call(req))
116 }
117}
118
119impl Server {
120 #[allow(clippy::missing_panics_doc)]
121 pub fn new(
122 cmd_tx: mpsc::Sender<Command>,
123 mut health_reporter: HealthReporter,
124 health_service: HealthServer<impl Health>,
125 tauri_server: impl tauri_server::Tauri,
126 metadata_server: impl metadata_server::Metadata,
127 sources_server: impl wire::sources::sources_server::Sources,
128 ) -> Self {
129 health_reporter
130 .set_serving::<InstrumentServer<InstrumentService>>()
131 .now_or_never();
132
133 let allowed_origins = Arc::new(Mutex::new(
134 if option_env!("__DEVTOOLS_LOCAL_DEVELOPMENT").is_some() {
135 vec![HeaderValue::from_str("*").unwrap()]
136 } else {
137 vec![
138 HeaderValue::from_str("https://devtools.crabnebula.dev").unwrap(),
139 HeaderValue::from_str("tauri://localhost").unwrap(),
140 #[cfg(windows)]
141 HeaderValue::from_str("http://tauri.localhost").unwrap(),
142 ]
143 },
144 ));
145
146 let router = tonic::transport::Server::builder()
147 .accept_http1(true)
148 .layer(DynamicCorsLayer {
149 allowed_origins: allowed_origins.clone(),
150 })
151 .add_service(tonic_web::enable(health_service))
152 .add_service(tonic_web::enable(InstrumentServer::new(
153 InstrumentService {
154 tx: cmd_tx,
155 health_reporter,
156 },
157 )))
158 .add_service(tonic_web::enable(TauriServer::new(tauri_server)))
159 .add_service(tonic_web::enable(MetadataServer::new(metadata_server)))
160 .add_service(tonic_web::enable(SourcesServer::new(sources_server)));
161
162 Self {
163 router,
164 handle: ServerHandle { allowed_origins },
165 }
166 }
167
168 #[must_use]
169 pub fn handle(&self) -> ServerHandle {
170 self.handle.clone()
171 }
172
173 pub async fn run(self, addr: SocketAddr) -> crate::Result<()> {
179 tracing::info!("Listening on {}", addr);
180
181 self.router.serve(addr).await?;
182
183 Ok(())
184 }
185}
186
187impl InstrumentService {
188 async fn set_status(&self, status: ServingStatus) {
189 let mut r = self.health_reporter.clone();
190 r.set_service_status("rs.devtools.instrument.Instrument", status)
191 .await;
192 }
193}
194
195#[tonic::async_trait]
196impl instrument_server::Instrument for InstrumentService {
197 type WatchUpdatesStream = BoxStream<instrument::Update>;
198
199 async fn watch_updates(
200 &self,
201 req: Request<InstrumentRequest>,
202 ) -> Result<Response<Self::WatchUpdatesStream>, Status> {
203 if let Some(addr) = req.remote_addr() {
204 tracing::debug!(client.addr = %addr, "starting a new watch");
205 } else {
206 tracing::debug!(client.addr = %"<unknown>", "starting a new watch");
207 }
208
209 let Ok(permit) = self.tx.reserve().await else {
211 self.set_status(ServingStatus::NotServing).await;
212 return Err(Status::internal(
213 "cannot start new watch, aggregation task is not running",
214 ));
215 };
216
217 let (tx, rx) = mpsc::channel(DEFAULT_CLIENT_BUFFER_CAPACITY);
219
220 permit.send(Command::Instrument(Watcher { tx }));
221
222 tracing::debug!("watch started");
223
224 let stream = ReceiverStream::new(rx).or_else(|err| async move {
225 tracing::error!("Aggregator failed with error {err:?}");
226
227 Err(Status::internal("boom"))
230 });
231
232 Ok(Response::new(Box::pin(stream)))
233 }
234}
235
236#[cfg(test)]
237mod test {
238 use super::*;
239 use devtools_wire_format::instrument::instrument_server::Instrument;
240
241 #[tokio::test]
242 async fn subscription() {
243 let (health_reporter, _) = tonic_health::server::health_reporter();
244 let (cmd_tx, mut cmd_rx) = mpsc::channel(1);
245 let srv = InstrumentService {
246 tx: cmd_tx,
247 health_reporter,
248 };
249
250 let _stream = srv
251 .watch_updates(Request::new(InstrumentRequest {
252 log_filter: None,
253 span_filter: None,
254 }))
255 .await
256 .unwrap();
257
258 let cmd = cmd_rx.recv().await.unwrap();
259
260 assert!(matches!(cmd, Command::Instrument(_)));
261 }
262}