use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use futures_util::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::sync::{Mutex, broadcast, mpsc};
use tokio::task::JoinHandle;
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::accept_hdr_async;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
use tokio_tungstenite::tungstenite::http::{Response as HttpResponse, StatusCode};
use crate::codec::{FrameDecoder, encode_frame};
use crate::launcher::{self, BrowserOptions, Launched};
use crate::transport::Child;
use crate::{Error, Result};
const FRAME_CHANNEL_CAPACITY: usize = 4096;
pub struct BrowserServer {
endpoint: String,
port: u16,
token: String,
child: Mutex<Option<Child>>,
profile_dir: PathBuf,
profile_is_temp: bool,
accept_task: JoinHandle<()>,
}
impl BrowserServer {
pub async fn launch(opts: BrowserOptions) -> Result<Self> {
Self::launch_on(opts, "127.0.0.1", 0, None).await
}
pub async fn launch_on(
opts: BrowserOptions,
host: &str,
port: u16,
ws_path: Option<&str>,
) -> Result<Self> {
let Launched {
child,
writer,
reader,
profile_dir,
profile_is_temp,
} = launcher::launch(&opts).await?;
let (pipe_tx, pipe_rx) = mpsc::unbounded_channel::<Vec<u8>>();
tokio::spawn(pipe_writer(writer, pipe_rx));
let (frame_tx, _) = broadcast::channel::<Vec<u8>>(FRAME_CHANNEL_CAPACITY);
tokio::spawn(pipe_reader(reader, frame_tx.clone()));
let listener = TcpListener::bind((host, port))
.await
.map_err(|e| Error::Transport(format!("ws 服务端绑定 {host}:{port} 失败: {e}")))?;
let addr = listener
.local_addr()
.map_err(|e| Error::Transport(format!("读取 ws 监听地址失败: {e}")))?;
let token = ws_path
.map(|s| s.trim_start_matches('/').to_string())
.unwrap_or_else(random_token);
let want_path = format!("/{token}");
let accept_task = tokio::spawn(accept_loop(listener, want_path, pipe_tx, frame_tx));
let endpoint = format!("ws://{}:{}/{}", addr.ip(), addr.port(), token);
tracing::info!(%endpoint, "BrowserServer 已就绪");
Ok(Self {
endpoint,
port: addr.port(),
token,
child: Mutex::new(Some(child)),
profile_dir,
profile_is_temp,
accept_task,
})
}
pub fn ws_endpoint(&self) -> &str {
&self.endpoint
}
pub fn port(&self) -> u16 {
self.port
}
pub fn token(&self) -> &str {
&self.token
}
pub async fn stop(&self) -> Result<()> {
self.accept_task.abort();
if let Some(mut child) = self.child.lock().await.take() {
let _ = child.kill().await;
}
if self.profile_is_temp {
let _ = tokio::fs::remove_dir_all(&self.profile_dir).await;
}
Ok(())
}
}
impl Drop for BrowserServer {
fn drop(&mut self) {
self.accept_task.abort();
if let Ok(mut g) = self.child.try_lock()
&& let Some(mut c) = g.take()
{
let _ = c.start_kill();
}
if self.profile_is_temp {
let _ = std::fs::remove_dir_all(&self.profile_dir);
}
}
}
fn random_token() -> String {
static COUNTER: AtomicU64 = AtomicU64::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let n = COUNTER.fetch_add(1, Ordering::Relaxed);
format!("{:x}{:x}{:x}", nanos, std::process::id(), n)
}
async fn pipe_writer<W>(mut writer: W, mut rx: mpsc::UnboundedReceiver<Vec<u8>>)
where
W: AsyncWrite + Unpin,
{
while let Some(json) = rx.recv().await {
let frame = encode_frame(&json);
if writer.write_all(&frame).await.is_err() {
break;
}
}
}
async fn pipe_reader<R>(mut reader: R, frame_tx: broadcast::Sender<Vec<u8>>)
where
R: AsyncRead + Unpin,
{
let mut decoder = FrameDecoder::new();
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = match reader.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(_) => break,
};
decoder.push(&buf[..n]);
while let Some(frame) = decoder.next_frame() {
let _ = frame_tx.send(frame);
}
}
}
#[allow(clippy::result_large_err)]
async fn accept_loop(
listener: TcpListener,
want_path: String,
pipe_tx: mpsc::UnboundedSender<Vec<u8>>,
frame_tx: broadcast::Sender<Vec<u8>>,
) {
let active = Arc::new(AtomicBool::new(false));
loop {
let (stream, _addr) = match listener.accept().await {
Ok(x) => x,
Err(e) => {
tracing::warn!(error = %e, "ws accept 失败,服务端监听退出");
break;
}
};
let pipe_tx = pipe_tx.clone();
let frame_tx = frame_tx.clone();
let want = want_path.clone();
let active = active.clone();
tokio::spawn(async move {
let check = move |req: &Request,
resp: Response|
-> std::result::Result<Response, ErrorResponse> {
if req.uri().path() == want {
Ok(resp)
} else {
let err = HttpResponse::builder()
.status(StatusCode::FORBIDDEN)
.body(Some("invalid juggler ws path".to_string()))
.expect("build error response");
Err(err)
}
};
let ws = match accept_hdr_async(stream, check).await {
Ok(ws) => ws,
Err(e) => {
tracing::debug!(error = %e, "ws 握手失败(可能 token 不匹配)");
return;
}
};
if active
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_err()
{
tracing::warn!("已有活动 ws 客户端,拒绝新连接(单实例)");
let (mut sink, _stream) = ws.split();
let _ = sink.close().await;
return;
}
let frame_rx = frame_tx.subscribe();
bridge_client(ws, pipe_tx, frame_rx).await;
active.store(false, Ordering::Release);
});
}
}
async fn bridge_client<S>(
ws: WebSocketStream<S>,
pipe_tx: mpsc::UnboundedSender<Vec<u8>>,
mut frame_rx: broadcast::Receiver<Vec<u8>>,
) where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (mut sink, mut stream) = ws.split();
let to_ws = tokio::spawn(async move {
loop {
match frame_rx.recv().await {
Ok(json) => {
let text = match String::from_utf8(json) {
Ok(s) => s,
Err(_) => continue,
};
if sink.send(Message::text(text)).await.is_err() {
break;
}
}
Err(broadcast::error::RecvError::Lagged(_)) => continue,
Err(broadcast::error::RecvError::Closed) => break,
}
}
let _ = sink.close().await;
});
while let Some(item) = stream.next().await {
match item {
Ok(Message::Text(t)) => {
if pipe_tx.send(t.as_bytes().to_vec()).is_err() {
break;
}
}
Ok(Message::Binary(b)) => {
if pipe_tx.send(b.to_vec()).is_err() {
break;
}
}
Ok(Message::Close(_)) => break,
Ok(_) => {}
Err(_) => break,
}
}
to_ws.abort();
}