1use std::{
4 net::SocketAddr,
5 sync::{
6 atomic::{AtomicUsize, Ordering},
7 Arc,
8 },
9 time::Duration,
10};
11
12use hyper::{
13 body::Bytes,
14 header::CONTENT_TYPE,
15 service::{make_service_fn, service_fn},
16 Body, Error, Response, Server,
17};
18use kcl_lib::{test_server::RequestBody, ExecState, ExecutorContext, Program, UnitLength};
19use tokio::{
20 sync::{mpsc, oneshot},
21 task::JoinHandle,
22 time::sleep,
23};
24
25#[derive(Debug)]
26pub struct ServerArgs {
27 pub listen_on: SocketAddr,
29 pub num_engine_conns: u8,
31 pub engine_address: Option<String>,
36}
37
38impl ServerArgs {
39 pub fn parse(mut pargs: pico_args::Arguments) -> Result<Self, pico_args::Error> {
40 let mut args = ServerArgs {
41 listen_on: pargs
42 .opt_value_from_str("--listen-on")?
43 .unwrap_or("0.0.0.0:3333".parse().unwrap()),
44 num_engine_conns: pargs.opt_value_from_str("--num-engine-conns")?.unwrap_or(1),
45 engine_address: pargs.opt_value_from_str("--engine-address")?,
46 };
47 if let Ok(addr) = std::env::var("ZOO_HOST") {
48 println!("Overriding engine address via $ZOO_HOST");
49 args.engine_address = Some(addr);
50 }
51 println!("Config is {args:?}");
52 Ok(args)
53 }
54}
55
56struct WorkerReq {
58 body: Bytes,
60 resp: oneshot::Sender<Response<Body>>,
62}
63
64fn start_worker(i: u8, engine_addr: Option<String>) -> mpsc::Sender<WorkerReq> {
68 println!("Starting worker {i}");
69 let (tx, mut rx) = mpsc::channel(1);
71 tokio::task::spawn(async move {
72 let state = ExecutorContext::new_for_unit_test(UnitLength::Mm, engine_addr)
73 .await
74 .unwrap();
75 println!("Worker {i} ready");
76 while let Some(req) = rx.recv().await {
77 let req: WorkerReq = req;
78 let resp = snapshot_endpoint(req.body, state.clone()).await;
79 if req.resp.send(resp).is_err() {
80 println!("\tWorker {i} exiting");
81 }
82 }
83 println!("\tWorker {i} exiting");
84 });
85 tx
86}
87
88struct ServerState {
89 workers: Vec<mpsc::Sender<WorkerReq>>,
90 req_num: AtomicUsize,
91}
92
93pub async fn start_server(args: ServerArgs) -> anyhow::Result<()> {
94 let ServerArgs {
95 listen_on,
96 num_engine_conns,
97 engine_address,
98 } = args;
99 let workers: Vec<_> = (0..num_engine_conns)
100 .map(|i| start_worker(i, engine_address.clone()))
101 .collect();
102 let state = Arc::new(ServerState {
103 workers,
104 req_num: 0.into(),
105 });
106 let make_service = make_service_fn(
109 move |_conn_info| {
111 let state = state.clone();
113 async move {
114 Ok::<_, Error>(service_fn(move |req| {
118 let state = state.clone();
120 async move { handle_request(req, state).await }
121 }))
122 }
123 },
124 );
125 let server = Server::bind(&listen_on).serve(make_service);
126 println!("Listening on {listen_on}");
127 println!("PID is {}", std::process::id());
128 if let Err(e) = server.await {
129 eprintln!("Server error: {e}");
130 return Err(e.into());
131 }
132 Ok(())
133}
134
135async fn handle_request(req: hyper::Request<Body>, state3: Arc<ServerState>) -> Result<Response<Body>, Error> {
136 let body = hyper::body::to_bytes(req.into_body()).await?;
137
138 let req_num = state3.req_num.fetch_add(1, Ordering::Relaxed);
140 let worker_id = req_num % state3.workers.len();
141 let worker = state3.workers[worker_id].clone();
143 let (tx, rx) = oneshot::channel();
144 let req_sent = worker.send(WorkerReq { body, resp: tx }).await;
145 req_sent.unwrap();
146 let resp = rx.await.unwrap();
147 Ok(resp)
148}
149
150async fn snapshot_endpoint(body: Bytes, ctxt: ExecutorContext) -> Response<Body> {
155 let body = match serde_json::from_slice::<RequestBody>(body.as_ref()) {
156 Ok(bd) => bd,
157 Err(e) => return bad_request(format!("Invalid request JSON: {e}")),
158 };
159 let RequestBody { kcl_program, test_name } = body;
160
161 let program = match Program::parse_no_errs(&kcl_program) {
162 Ok(pr) => pr,
163 Err(e) => return bad_request(format!("Parse error: {e}")),
164 };
165
166 eprintln!("Executing {test_name}");
167 let mut exec_state = ExecState::new(&ctxt);
168 if let Err(e) = ctxt
171 .send_clear_scene(&mut exec_state, kcl_lib::SourceRange::default())
172 .await
173 {
174 return kcl_err(e);
175 }
176 let (done_tx, done_rx) = oneshot::channel::<()>();
178 let timer = time_until(done_rx);
179 if let Err(e) = ctxt.run(&program, &mut exec_state).await {
180 return kcl_err(e);
181 }
182 let snapshot = match ctxt.prepare_snapshot().await {
183 Ok(s) => s,
184 Err(e) => return kcl_err(e),
185 };
186 let _ = done_tx.send(());
187 timer.abort();
188 eprintln!("\tServing response");
189 let png_bytes = snapshot.contents.0;
190 let mut resp = Response::new(Body::from(png_bytes));
191 resp.headers_mut().insert(CONTENT_TYPE, "image/png".parse().unwrap());
192 resp
193}
194
195fn bad_request(msg: String) -> Response<Body> {
196 eprintln!("\tBad request");
197 let mut resp = Response::new(Body::from(msg));
198 *resp.status_mut() = hyper::StatusCode::BAD_REQUEST;
199 resp
200}
201
202fn bad_gateway(msg: String) -> Response<Body> {
203 eprintln!("\tBad gateway");
204 let mut resp = Response::new(Body::from(msg));
205 *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
206 resp
207}
208
209fn kcl_err(err: impl std::fmt::Display) -> Response<Body> {
210 eprintln!("\tBad KCL");
211 bad_gateway(format!("{err}"))
212}
213
214fn time_until(done: oneshot::Receiver<()>) -> JoinHandle<()> {
215 tokio::task::spawn(async move {
216 let period = 10;
217 tokio::pin!(done);
218 for i in 1..=3 {
219 tokio::select! {
220 biased;
221 _ = &mut done => return,
223 _ = sleep(Duration::from_secs(period)) => {
224 eprintln!("\tTest has taken {}s", period * i);
225 },
226 };
227 }
228 })
229}