1#![forbid(unsafe_code)]
10#![warn(missing_docs)]
11#![warn(clippy::large_futures)]
12#![warn(rustdoc::bare_urls)]
13
14use std::collections::HashMap;
15use std::net::{IpAddr, Ipv4Addr, SocketAddr, SocketAddrV4};
16use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
17use std::sync::Arc;
18use std::time::{Duration, SystemTime, UNIX_EPOCH};
19
20use atomic_destructor::{AtomicDestroyer, AtomicDestructor};
21use bytes::Bytes;
22use http_body_util::combinators::BoxBody;
23use http_body_util::{BodyExt, Full};
24use hyper::body::Incoming;
25use hyper::server::conn::http1;
26use hyper::service::service_fn;
27use hyper::{Method, Request, Response, StatusCode};
28use hyper_util::rt::TokioIo;
29use nostr::prelude::{BoxedFuture, SignerBackend};
30use nostr::{Event, NostrSigner, PublicKey, SignerError, UnsignedEvent};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize, Serializer};
33use serde_json::{json, Value};
34use tokio::net::{TcpListener, TcpStream};
35use tokio::sync::oneshot::Sender;
36use tokio::sync::{oneshot, Mutex, Notify};
37use tokio::time;
38use uuid::Uuid;
39
40mod error;
41pub mod prelude;
42
43pub use self::error::Error;
44
45const HTML: &str = include_str!("../index.html");
46const JS: &str = include_str!("../proxy.js");
47const CSS: &str = include_str!("../style.css");
48
49type PendingResponseMap = HashMap<Uuid, Sender<Result<Value, String>>>;
50
51#[derive(Debug, Deserialize)]
52struct Message {
53 id: Uuid,
54 error: Option<String>,
55 result: Option<Value>,
56}
57
58impl Message {
59 fn into_result(self) -> Result<Value, String> {
60 if let Some(error) = self.error {
61 Err(error)
62 } else {
63 Ok(self.result.unwrap_or(Value::Null))
64 }
65 }
66}
67
68#[derive(Debug, Clone, Copy)]
69enum RequestMethod {
70 GetPublicKey,
71 SignEvent,
72 Nip04Encrypt,
73 Nip04Decrypt,
74 Nip44Encrypt,
75 Nip44Decrypt,
76}
77
78impl RequestMethod {
79 fn as_str(&self) -> &str {
80 match self {
81 Self::GetPublicKey => "get_public_key",
82 Self::SignEvent => "sign_event",
83 Self::Nip04Encrypt => "nip04_encrypt",
84 Self::Nip04Decrypt => "nip04_decrypt",
85 Self::Nip44Encrypt => "nip44_encrypt",
86 Self::Nip44Decrypt => "nip44_decrypt",
87 }
88 }
89}
90
91impl Serialize for RequestMethod {
92 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
93 where
94 S: Serializer,
95 {
96 serializer.serialize_str(self.as_str())
97 }
98}
99
100#[derive(Debug, Clone, Serialize)]
101struct RequestData {
102 id: Uuid,
103 method: RequestMethod,
104 params: Value,
105}
106
107impl RequestData {
108 #[inline]
109 fn new(method: RequestMethod, params: Value) -> Self {
110 Self {
111 id: Uuid::new_v4(),
112 method,
113 params,
114 }
115 }
116}
117
118#[derive(Serialize)]
119struct Requests<'a> {
120 requests: &'a [RequestData],
121}
122
123impl<'a> Requests<'a> {
124 #[inline]
125 fn new(requests: &'a [RequestData]) -> Self {
126 Self { requests }
127 }
128
129 #[inline]
130 fn len(&self) -> usize {
131 self.requests.len()
132 }
133}
134
135#[derive(Serialize)]
137struct CryptoParams<'a> {
138 public_key: &'a PublicKey,
139 content: &'a str,
140}
141
142impl<'a> CryptoParams<'a> {
143 #[inline]
144 fn new(public_key: &'a PublicKey, content: &'a str) -> Self {
145 Self {
146 public_key,
147 content,
148 }
149 }
150}
151
152#[derive(Debug)]
153struct ProxyState {
154 pub outgoing_requests: Mutex<Vec<RequestData>>,
156 pub pending_responses: Mutex<PendingResponseMap>,
158 pub last_pending_request: Arc<AtomicU64>,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct BrowserSignerProxyOptions {
165 pub timeout: Duration,
167 pub addr: SocketAddr,
169}
170
171#[derive(Debug, Clone)]
172struct InnerBrowserSignerProxy {
173 options: BrowserSignerProxyOptions,
175 state: Arc<ProxyState>,
177 shutdown: Arc<Notify>,
179 is_shutdown: Arc<AtomicBool>,
181 is_started: Arc<AtomicBool>,
183}
184
185impl AtomicDestroyer for InnerBrowserSignerProxy {
186 fn on_destroy(&self) {
187 self.shutdown();
188 }
189}
190
191impl InnerBrowserSignerProxy {
192 #[inline]
193 fn is_shutdown(&self) -> bool {
194 self.is_shutdown.load(Ordering::SeqCst)
195 }
196
197 fn shutdown(&self) {
198 self.is_shutdown.store(true, Ordering::SeqCst);
200
201 self.shutdown.notify_one();
203 self.shutdown.notify_waiters();
204 }
205}
206
207#[derive(Debug, Clone)]
211pub struct BrowserSignerProxy {
212 inner: AtomicDestructor<InnerBrowserSignerProxy>,
213}
214
215impl Default for BrowserSignerProxyOptions {
216 fn default() -> Self {
217 Self {
218 timeout: Duration::from_secs(30),
219 addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 7400)),
221 }
222 }
223}
224
225impl BrowserSignerProxyOptions {
226 pub const fn timeout(mut self, timeout: Duration) -> Self {
228 self.timeout = timeout;
229 self
230 }
231
232 pub const fn ip_addr(mut self, new_ip: IpAddr) -> Self {
234 self.addr = SocketAddr::new(new_ip, self.addr.port());
235 self
236 }
237
238 pub const fn port(mut self, new_port: u16) -> Self {
240 self.addr = SocketAddr::new(self.addr.ip(), new_port);
241 self
242 }
243}
244
245impl BrowserSignerProxy {
246 pub fn new(options: BrowserSignerProxyOptions) -> Self {
248 let state = ProxyState {
249 outgoing_requests: Mutex::new(Vec::new()),
250 pending_responses: Mutex::new(HashMap::new()),
251 last_pending_request: Arc::new(AtomicU64::new(0)),
252 };
253
254 Self {
255 inner: AtomicDestructor::new(InnerBrowserSignerProxy {
256 options,
257 state: Arc::new(state),
258 shutdown: Arc::new(Notify::new()),
259 is_shutdown: Arc::new(AtomicBool::new(false)),
260 is_started: Arc::new(AtomicBool::new(false)),
261 }),
262 }
263 }
264
265 #[inline]
267 pub fn is_started(&self) -> bool {
268 self.inner.is_started.load(Ordering::SeqCst)
269 }
270
271 #[inline]
274 pub fn is_session_active(&self) -> bool {
275 current_time() - self.inner.state.last_pending_request.load(Ordering::SeqCst) < 2
276 }
277
278 #[inline]
280 pub fn url(&self) -> String {
281 format!("http://{}", self.inner.options.addr)
282 }
283
284 pub async fn start(&self) -> Result<(), Error> {
288 if self.inner.is_shutdown() {
290 return Err(Error::Shutdown);
291 }
292
293 let is_started: bool = self.inner.is_started.swap(true, Ordering::SeqCst);
295
296 if is_started {
298 return Ok(());
299 }
300
301 let listener: TcpListener = match TcpListener::bind(self.inner.options.addr).await {
302 Ok(listener) => listener,
303 Err(e) => {
304 self.inner.is_started.store(false, Ordering::SeqCst);
306
307 return Err(Error::from(e));
309 }
310 };
311
312 let addr: SocketAddr = self.inner.options.addr;
313 let state: Arc<ProxyState> = self.inner.state.clone();
314 let shutdown: Arc<Notify> = self.inner.shutdown.clone();
315
316 tokio::spawn(async move {
317 tracing::info!("Starting proxy server on {addr}");
318
319 loop {
320 tokio::select! {
321 res = listener.accept() => {
322 let stream: TcpStream = match res {
323 Ok((stream, ..)) => stream,
324 Err(e) => {
325 tracing::error!("Failed to accept connection: {}", e);
326 continue;
327 }
328 };
329
330 let io: TokioIo<TcpStream> = TokioIo::new(stream);
331 let state: Arc<ProxyState> = state.clone();
332 let shutdown: Arc<Notify> = shutdown.clone();
333
334 tokio::spawn(async move {
335 let service = service_fn(move |req| {
336 handle_request(req, state.clone())
337 });
338
339 tokio::select! {
340 res = http1::Builder::new().serve_connection(io, service) => {
341 if let Err(e) = res {
342 tracing::error!("Error serving connection: {e}");
343 }
344 }
345 _ = shutdown.notified() => {
346 tracing::debug!("Closing connection, proxy server is shutting down.");
347 }
348 }
349 });
350 },
351 _ = shutdown.notified() => {
352 break;
353 }
354 }
355 }
356
357 tracing::info!("Shutting down proxy server.");
358 });
359
360 Ok(())
361 }
362
363 #[inline]
364 async fn store_pending_response(&self, id: Uuid, tx: Sender<Result<Value, String>>) {
365 let mut pending_responses = self.inner.state.pending_responses.lock().await;
366 pending_responses.insert(id, tx);
367 }
368
369 #[inline]
370 async fn store_outgoing_request(&self, request: RequestData) {
371 let mut outgoing_requests = self.inner.state.outgoing_requests.lock().await;
372 outgoing_requests.push(request);
373 }
374
375 async fn request<T>(&self, method: RequestMethod, params: Value) -> Result<T, Error>
376 where
377 T: DeserializeOwned,
378 {
379 self.start().await?;
381
382 let request: RequestData = RequestData::new(method, params);
384
385 let (tx, rx) = oneshot::channel();
387
388 self.store_pending_response(request.id, tx).await;
390
391 self.store_outgoing_request(request).await;
393
394 match time::timeout(self.inner.options.timeout, rx)
396 .await
397 .map_err(|_| Error::Timeout)??
398 {
399 Ok(res) => Ok(serde_json::from_value(res)?),
400 Err(error) => Err(Error::Generic(error)),
401 }
402 }
403
404 #[inline]
405 async fn _get_public_key(&self) -> Result<PublicKey, Error> {
406 self.request(RequestMethod::GetPublicKey, json!({})).await
407 }
408
409 #[inline]
410 async fn _sign_event(&self, event: UnsignedEvent) -> Result<Event, Error> {
411 let event: Event = self
412 .request(RequestMethod::SignEvent, serde_json::to_value(event)?)
413 .await?;
414 event.verify()?;
415 Ok(event)
416 }
417
418 #[inline]
419 async fn _nip04_encrypt(&self, public_key: &PublicKey, content: &str) -> Result<String, Error> {
420 let params = CryptoParams::new(public_key, content);
421 self.request(RequestMethod::Nip04Encrypt, serde_json::to_value(params)?)
422 .await
423 }
424
425 #[inline]
426 async fn _nip04_decrypt(&self, public_key: &PublicKey, content: &str) -> Result<String, Error> {
427 let params = CryptoParams::new(public_key, content);
428 self.request(RequestMethod::Nip04Decrypt, serde_json::to_value(params)?)
429 .await
430 }
431
432 #[inline]
433 async fn _nip44_encrypt(&self, public_key: &PublicKey, content: &str) -> Result<String, Error> {
434 let params = CryptoParams::new(public_key, content);
435 self.request(RequestMethod::Nip44Encrypt, serde_json::to_value(params)?)
436 .await
437 }
438
439 #[inline]
440 async fn _nip44_decrypt(&self, public_key: &PublicKey, content: &str) -> Result<String, Error> {
441 let params = CryptoParams::new(public_key, content);
442 self.request(RequestMethod::Nip44Decrypt, serde_json::to_value(params)?)
443 .await
444 }
445}
446
447impl NostrSigner for BrowserSignerProxy {
448 fn backend(&self) -> SignerBackend {
449 SignerBackend::BrowserExtension
450 }
451
452 #[inline]
453 fn get_public_key(&self) -> BoxedFuture<Result<PublicKey, SignerError>> {
454 Box::pin(async move { self._get_public_key().await.map_err(SignerError::backend) })
455 }
456
457 #[inline]
458 fn sign_event(&self, unsigned: UnsignedEvent) -> BoxedFuture<Result<Event, SignerError>> {
459 Box::pin(async move {
460 self._sign_event(unsigned)
461 .await
462 .map_err(SignerError::backend)
463 })
464 }
465
466 #[inline]
467 fn nip04_encrypt<'a>(
468 &'a self,
469 public_key: &'a PublicKey,
470 content: &'a str,
471 ) -> BoxedFuture<'a, Result<String, SignerError>> {
472 Box::pin(async move {
473 self._nip04_encrypt(public_key, content)
474 .await
475 .map_err(SignerError::backend)
476 })
477 }
478
479 #[inline]
480 fn nip04_decrypt<'a>(
481 &'a self,
482 public_key: &'a PublicKey,
483 encrypted_content: &'a str,
484 ) -> BoxedFuture<'a, Result<String, SignerError>> {
485 Box::pin(async move {
486 self._nip04_decrypt(public_key, encrypted_content)
487 .await
488 .map_err(SignerError::backend)
489 })
490 }
491
492 #[inline]
493 fn nip44_encrypt<'a>(
494 &'a self,
495 public_key: &'a PublicKey,
496 content: &'a str,
497 ) -> BoxedFuture<'a, Result<String, SignerError>> {
498 Box::pin(async move {
499 self._nip44_encrypt(public_key, content)
500 .await
501 .map_err(SignerError::backend)
502 })
503 }
504
505 #[inline]
506 fn nip44_decrypt<'a>(
507 &'a self,
508 public_key: &'a PublicKey,
509 payload: &'a str,
510 ) -> BoxedFuture<'a, Result<String, SignerError>> {
511 Box::pin(async move {
512 self._nip44_decrypt(public_key, payload)
513 .await
514 .map_err(SignerError::backend)
515 })
516 }
517}
518
519async fn handle_request(
520 req: Request<Incoming>,
521 state: Arc<ProxyState>,
522) -> Result<Response<BoxBody<Bytes, Error>>, Error> {
523 match (req.method(), req.uri().path()) {
524 (&Method::GET, "/") => Ok(Response::builder()
526 .header("Content-Type", "text/html")
527 .body(full(HTML))?),
528 (&Method::GET, "/style.css") => Ok(Response::builder()
530 .header("Content-Type", "text/css")
531 .body(full(CSS))?),
532 (&Method::GET, "/proxy.js") => Ok(Response::builder()
534 .header("Content-Type", "application/javascript")
535 .body(full(JS))?),
536 (&Method::GET, "/api/pending") => {
538 state
539 .last_pending_request
540 .store(current_time(), Ordering::SeqCst);
541
542 let mut outgoing = state.outgoing_requests.lock().await;
543
544 let requests: Requests<'_> = Requests::new(&outgoing);
545 let json: String = serde_json::to_string(&requests)?;
546
547 tracing::debug!("Sending {} pending requests to browser", requests.len());
548
549 outgoing.clear();
551
552 Ok(Response::builder()
553 .header("Content-Type", "application/json")
554 .header("Access-Control-Allow-Origin", "*")
555 .body(full(json))?)
556 }
557 (&Method::POST, "/api/response") => {
559 let body_bytes: Bytes = match req.into_body().collect().await {
561 Ok(collected) => collected.to_bytes(),
562 Err(e) => {
563 tracing::error!("Failed to read body: {e}");
564 let response = Response::builder()
565 .status(StatusCode::BAD_REQUEST)
566 .body(full("Failed to read body"))?;
567 return Ok(response);
568 }
569 };
570
571 let message: Message = match serde_json::from_slice(&body_bytes) {
573 Ok(json) => json,
574 Err(_) => {
575 let response = Response::builder()
576 .status(StatusCode::BAD_REQUEST)
577 .body(full("Invalid JSON"))?;
578 return Ok(response);
579 }
580 };
581
582 tracing::debug!("Received response from browser: {message:?}");
583
584 let id: Uuid = message.id;
585 let mut pending = state.pending_responses.lock().await;
586
587 match pending.remove(&id) {
588 Some(sender) => {
589 let _ = sender.send(message.into_result());
590 tracing::info!("Forwarded response for request {id}");
591 }
592 None => tracing::warn!("No pending request found for {id}"),
593 }
594
595 let response = Response::builder()
596 .header("Access-Control-Allow-Origin", "*")
597 .body(full("OK"))?;
598 Ok(response)
599 }
600 (&Method::OPTIONS, _) => {
601 let response = Response::builder()
603 .header("Access-Control-Allow-Origin", "*")
604 .header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
605 .header("Access-Control-Allow-Headers", "Content-Type")
606 .body(full(""))?;
607 Ok(response)
608 }
609 _ => {
611 let response = Response::builder()
612 .status(StatusCode::NOT_FOUND)
613 .body(full("Not Found"))?;
614 Ok(response)
615 }
616 }
617}
618
619#[inline]
620fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, Error> {
621 Full::new(chunk.into())
622 .map_err(|never| match never {})
623 .boxed()
624}
625
626#[inline]
629fn current_time() -> u64 {
630 SystemTime::now()
631 .duration_since(UNIX_EPOCH)
632 .map(|d| d.as_secs())
633 .unwrap_or_default()
634}