1use std::{
4 net::SocketAddr,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc,
8 },
9 time::Duration,
10};
11
12use hyper::{
13 body::Bytes,
14 header::CONTENT_TYPE,
15 service::{make_service_fn, service_fn},
16 Body, Error, Response, Server,
17};
18use kcl_lib::{test_server::RequestBody, ExecState, ExecutorContext, Program};
19use tokio::{
20 sync::{mpsc, oneshot},
21 task::JoinHandle,
22 time::sleep,
23};
24
25#[derive(Debug)]
26pub struct ServerArgs {
27 pub listen_on: SocketAddr,
29 pub num_engine_conns: u8,
31 pub engine_address: Option<String>,
36}
37
38impl ServerArgs {
39 pub fn parse(mut pargs: pico_args::Arguments) -> Result<Self, pico_args::Error> {
40 let mut args = ServerArgs {
41 listen_on: pargs
42 .opt_value_from_str("--listen-on")?
43 .unwrap_or("0.0.0.0:3333".parse().unwrap()),
44 num_engine_conns: pargs.opt_value_from_str("--num-engine-conns")?.unwrap_or(1),
45 engine_address: pargs.opt_value_from_str("--engine-address")?,
46 };
47 if let Ok(addr) = std::env::var("ZOO_HOST") {
48 println!("Overriding engine address via $ZOO_HOST");
49 args.engine_address = Some(addr);
50 }
51 println!("Config is {args:?}");
52 Ok(args)
53 }
54}
55
56struct WorkerReq {
58 body: Bytes,
60 resp: oneshot::Sender<Response<Body>>,
62}
63
64fn start_worker(i: u8, engine_addr: Option<String>) -> mpsc::Sender<WorkerReq> {
68 println!("Starting worker {i}");
69 let (tx, mut rx) = mpsc::channel(1);
71 tokio::task::spawn(async move {
72 let state = ExecutorContext::new_for_unit_test(engine_addr).await.unwrap();
73 println!("Worker {i} ready");
74 while let Some(req) = rx.recv().await {
75 let req: WorkerReq = req;
76 let resp = snapshot_endpoint(req.body, state.clone()).await;
77 if req.resp.send(resp).is_err() {
78 println!("\tWorker {i} exiting");
79 }
80 }
81 println!("\tWorker {i} exiting");
82 });
83 tx
84}
85
86struct ServerState {
87 workers: Vec<mpsc::Sender<WorkerReq>>,
88 req_num: AtomicUsize,
89}
90
91pub async fn start_server(args: ServerArgs) -> anyhow::Result<()> {
92 let ServerArgs {
93 listen_on,
94 num_engine_conns,
95 engine_address,
96 } = args;
97 let workers: Vec<_> = (0..num_engine_conns)
98 .map(|i| start_worker(i, engine_address.clone()))
99 .collect();
100 let state = Arc::new(ServerState {
101 workers,
102 req_num: 0.into(),
103 });
104 let make_service = make_service_fn(
107 move |_conn_info| {
109 let state = state.clone();
111 async move {
112 Ok::<_, Error>(service_fn(move |req| {
116 let state = state.clone();
118 async move { handle_request(req, state).await }
119 }))
120 }
121 },
122 );
123 let server = Server::bind(&listen_on).serve(make_service);
124 println!("Listening on {listen_on}");
125 println!("PID is {}", std::process::id());
126 if let Err(e) = server.await {
127 eprintln!("Server error: {e}");
128 return Err(e.into());
129 }
130 Ok(())
131}
132
133async fn handle_request(req: hyper::Request<Body>, state3: Arc<ServerState>) -> Result<Response<Body>, Error> {
134 let body = hyper::body::to_bytes(req.into_body()).await?;
135
136 let req_num = state3.req_num.fetch_add(1, Ordering::Relaxed);
138 let worker_id = req_num % state3.workers.len();
139 let worker = state3.workers[worker_id].clone();
141 let (tx, rx) = oneshot::channel();
142 let req_sent = worker.send(WorkerReq { body, resp: tx }).await;
143 req_sent.unwrap();
144 let resp = rx.await.unwrap();
145 Ok(resp)
146}
147
148async fn snapshot_endpoint(body: Bytes, ctxt: ExecutorContext) -> Response<Body> {
153 let body = match serde_json::from_slice::<RequestBody>(body.as_ref()) {
154 Ok(bd) => bd,
155 Err(e) => return bad_request(format!("Invalid request JSON: {e}")),
156 };
157 let RequestBody { kcl_program, test_name } = body;
158
159 let program = match Program::parse_no_errs(&kcl_program) {
160 Ok(pr) => pr,
161 Err(e) => return bad_request(format!("Parse error: {e}")),
162 };
163
164 eprintln!("Executing {test_name}");
165 let mut exec_state = ExecState::new(&ctxt);
166 if let Err(e) = ctxt
169 .send_clear_scene(&mut exec_state, kcl_lib::SourceRange::default())
170 .await
171 {
172 return kcl_err(e);
173 }
174 let (done_tx, done_rx) = oneshot::channel::<()>();
176 let timer = time_until(done_rx);
177 if let Err(e) = ctxt.run(&program, &mut exec_state).await {
178 return kcl_err(e);
179 }
180 let snapshot = match ctxt.prepare_snapshot().await {
181 Ok(s) => s,
182 Err(e) => return kcl_err(e),
183 };
184 let _ = done_tx.send(());
185 timer.abort();
186 eprintln!("\tServing response");
187 let png_bytes = snapshot.contents.0;
188 let mut resp = Response::new(Body::from(png_bytes));
189 resp.headers_mut().insert(CONTENT_TYPE, "image/png".parse().unwrap());
190 resp
191}
192
193fn bad_request(msg: String) -> Response<Body> {
194 eprintln!("\tBad request");
195 let mut resp = Response::new(Body::from(msg));
196 *resp.status_mut() = hyper::StatusCode::BAD_REQUEST;
197 resp
198}
199
200fn bad_gateway(msg: String) -> Response<Body> {
201 eprintln!("\tBad gateway");
202 let mut resp = Response::new(Body::from(msg));
203 *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
204 resp
205}
206
207fn kcl_err(err: impl std::fmt::Display) -> Response<Body> {
208 eprintln!("\tBad KCL");
209 bad_gateway(format!("{err}"))
210}
211
212fn time_until(done: oneshot::Receiver<()>) -> JoinHandle<()> {
213 tokio::task::spawn(async move {
214 let period = 10;
215 tokio::pin!(done);
216 for i in 1..=3 {
217 tokio::select! {
218 biased;
219 _ = &mut done => return,
221 _ = sleep(Duration::from_secs(period)) => {
222 eprintln!("\tTest has taken {}s", period * i);
223 },
224 };
225 }
226 })
227}