1use axum::{
7 Router,
8 body::Body,
9 http::{Response, StatusCode, header},
10 routing::get,
11};
12use bytes::Bytes;
13use std::net::SocketAddr;
14use std::sync::{Arc, Mutex};
15use tokio::net::TcpListener;
16use tokio::sync::{Notify, broadcast, mpsc};
17use tokio_stream::wrappers::ReceiverStream;
18
19pub struct ServerConfig {
22 pub host: String,
23 pub port: u16,
24 pub persist: bool,
25}
26
27impl Default for ServerConfig {
28 fn default() -> Self {
29 Self {
30 host: "127.0.0.1".into(),
31 port: 0,
32 persist: false,
33 }
34 }
35}
36
37struct State {
41 buffer: Mutex<Vec<Bytes>>,
43 live_tx: broadcast::Sender<Bytes>,
45 done: Notify,
47 persist: bool,
48 served: Arc<Notify>,
50 shutdown: Arc<Notify>,
52}
53
54pub struct ServerHandle {
58 state: Arc<State>,
59 pub addr: SocketAddr,
61}
62
63impl ServerHandle {
64 pub fn send(&self, chunk: Bytes) {
66 let mut buf = self.state.buffer.lock().unwrap();
67 buf.push(chunk.clone());
68 drop(buf);
69 let _ = self.state.live_tx.send(chunk);
70 }
71
72 pub fn finish(&self) {
78 self.state.done.notify_one();
81 }
84
85 pub async fn wait_served(&self) {
88 if !self.state.persist {
89 self.state.served.notified().await;
90 }
91 }
92}
93
94pub async fn serve<F>(cfg: ServerConfig, on_bind: F) -> ServerHandle
98where
99 F: FnOnce(SocketAddr),
100{
101 let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
102 .parse()
103 .expect("invalid bind address");
104
105 let listener = TcpListener::bind(addr)
106 .await
107 .expect("failed to bind server");
108
109 let bound = listener.local_addr().expect("no local addr");
110 on_bind(bound);
111
112 let (live_tx, _) = broadcast::channel::<Bytes>(256);
113 let shutdown = Arc::new(Notify::new());
114 let served = Arc::new(Notify::new());
115
116 let state = Arc::new(State {
117 buffer: Mutex::new(Vec::new()),
118 live_tx,
119 done: Notify::new(),
120 persist: cfg.persist,
121 served: Arc::clone(&served),
122 shutdown: Arc::clone(&shutdown),
123 });
124
125 let state_clone = Arc::clone(&state);
126 let shutdown_clone = Arc::clone(&shutdown);
127
128 tokio::spawn(async move {
130 let app = Router::new()
131 .route("/", get(handle_root))
132 .with_state(Arc::clone(&state_clone));
133
134 axum::serve(listener, app)
135 .with_graceful_shutdown(async move {
136 shutdown_clone.notified().await;
137 })
138 .await
139 .unwrap();
140 });
141
142 ServerHandle { state, addr: bound }
143}
144
145async fn handle_root(
148 axum::extract::State(state): axum::extract::State<Arc<State>>,
149) -> Response<Body> {
150 let (tx, rx) = mpsc::channel::<Result<Bytes, std::convert::Infallible>>(256);
152
153 let state_clone = Arc::clone(&state);
154 tokio::spawn(async move {
155 let buffered: Vec<Bytes> = state_clone.buffer.lock().unwrap().clone();
157 for chunk in buffered {
158 if tx.send(Ok(chunk)).await.is_err() {
159 return;
160 }
161 }
162
163 let mut live_rx = state_clone.live_tx.subscribe();
165 loop {
166 tokio::select! {
167 chunk = live_rx.recv() => {
168 match chunk {
169 Ok(bytes) => {
170 if tx.send(Ok(bytes)).await.is_err() {
171 return;
172 }
173 }
174 Err(broadcast::error::RecvError::Closed) => break,
175 Err(broadcast::error::RecvError::Lagged(_)) => continue,
176 }
177 }
178 _ = state_clone.done.notified() => {
179 while let Ok(bytes) = live_rx.try_recv() {
181 let _ = tx.send(Ok(bytes)).await;
182 }
183 break;
184 }
185 }
186 }
187
188 state_clone.served.notify_one();
191 if !state_clone.persist {
192 state_clone.shutdown.notify_one();
193 }
194 });
195
196 let stream = ReceiverStream::new(rx);
197 let body = Body::from_stream(stream);
198
199 Response::builder()
200 .status(StatusCode::OK)
201 .header(header::CONTENT_TYPE, "text/html; charset=utf-8")
202 .header(header::TRANSFER_ENCODING, "chunked")
203 .header(header::CACHE_CONTROL, "no-cache")
204 .body(body)
205 .unwrap()
206}