1use axum::{
7 body::Body,
8 http::{header, Response, StatusCode},
9 routing::get,
10 Router,
11};
12use bytes::Bytes;
13use std::net::SocketAddr;
14use std::sync::{Arc, Mutex};
15use tokio::net::TcpListener;
16use tokio::sync::{broadcast, mpsc, Notify};
17use tokio_stream::wrappers::ReceiverStream;
18
19pub struct ServerConfig {
22 pub host: String,
23 pub port: u16,
24 pub persist: bool,
25}
26
27impl Default for ServerConfig {
28 fn default() -> Self {
29 Self {
30 host: "127.0.0.1".into(),
31 port: 0,
32 persist: false,
33 }
34 }
35}
36
37struct State {
41 buffer: Mutex<Vec<Bytes>>,
43 live_tx: broadcast::Sender<Bytes>,
45 done: Notify,
47 persist: bool,
48 served: Arc<Notify>,
50 shutdown: Arc<Notify>,
52}
53
54pub struct ServerHandle {
58 state: Arc<State>,
59 pub addr: SocketAddr,
61}
62
63impl ServerHandle {
64 pub fn send(&self, chunk: Bytes) {
66 let mut buf = self.state.buffer.lock().unwrap();
67 buf.push(chunk.clone());
68 drop(buf);
69 let _ = self.state.live_tx.send(chunk);
70 }
71
72 pub fn finish(&self) {
78 self.state.done.notify_one();
81 }
84
85 pub async fn wait_served(&self) {
88 if !self.state.persist {
89 self.state.served.notified().await;
90 }
91 }
92}
93
94pub async fn serve<F>(cfg: ServerConfig, on_bind: F) -> ServerHandle
98where
99 F: FnOnce(SocketAddr),
100{
101 let addr: SocketAddr = format!("{}:{}", cfg.host, cfg.port)
102 .parse()
103 .expect("invalid bind address");
104
105 let listener = TcpListener::bind(addr)
106 .await
107 .expect("failed to bind server");
108
109 let bound = listener.local_addr().expect("no local addr");
110 on_bind(bound);
111
112 let (live_tx, _) = broadcast::channel::<Bytes>(256);
113 let shutdown = Arc::new(Notify::new());
114 let served = Arc::new(Notify::new());
115
116 let state = Arc::new(State {
117 buffer: Mutex::new(Vec::new()),
118 live_tx,
119 done: Notify::new(),
120 persist: cfg.persist,
121 served: Arc::clone(&served),
122 shutdown: Arc::clone(&shutdown),
123 });
124
125 let state_clone = Arc::clone(&state);
126 let shutdown_clone = Arc::clone(&shutdown);
127
128 tokio::spawn(async move {
130 let app = Router::new()
131 .route("/", get(handle_root))
132 .with_state(Arc::clone(&state_clone));
133
134 axum::serve(listener, app)
135 .with_graceful_shutdown(async move {
136 shutdown_clone.notified().await;
137 })
138 .await
139 .unwrap();
140 });
141
142 ServerHandle {
143 state,
144 addr: bound,
145 }
146}
147
148async fn handle_root(
151 axum::extract::State(state): axum::extract::State<Arc<State>>,
152) -> Response<Body> {
153 let (tx, rx) = mpsc::channel::<Result<Bytes, std::convert::Infallible>>(256);
155
156 let state_clone = Arc::clone(&state);
157 tokio::spawn(async move {
158 let buffered: Vec<Bytes> = state_clone.buffer.lock().unwrap().clone();
160 for chunk in buffered {
161 if tx.send(Ok(chunk)).await.is_err() {
162 return;
163 }
164 }
165
166 let mut live_rx = state_clone.live_tx.subscribe();
168 loop {
169 tokio::select! {
170 chunk = live_rx.recv() => {
171 match chunk {
172 Ok(bytes) => {
173 if tx.send(Ok(bytes)).await.is_err() {
174 return;
175 }
176 }
177 Err(broadcast::error::RecvError::Closed) => break,
178 Err(broadcast::error::RecvError::Lagged(_)) => continue,
179 }
180 }
181 _ = state_clone.done.notified() => {
182 while let Ok(bytes) = live_rx.try_recv() {
184 let _ = tx.send(Ok(bytes)).await;
185 }
186 break;
187 }
188 }
189 }
190
191 state_clone.served.notify_one();
194 if !state_clone.persist {
195 state_clone.shutdown.notify_one();
196 }
197 });
198
199 let stream = ReceiverStream::new(rx);
200 let body = Body::from_stream(stream);
201
202 Response::builder()
203 .status(StatusCode::OK)
204 .header(header::CONTENT_TYPE, "text/html; charset=utf-8")
205 .header(header::TRANSFER_ENCODING, "chunked")
206 .header(header::CACHE_CONTROL, "no-cache")
207 .body(body)
208 .unwrap()
209}