1use std::{
4    net::SocketAddr,
5    sync::{
6        atomic::{AtomicUsize, Ordering},
7        Arc,
8    },
9    time::Duration,
10};
11
12use hyper::{
13    body::Bytes,
14    header::CONTENT_TYPE,
15    service::{make_service_fn, service_fn},
16    Body, Error, Response, Server,
17};
18use kcl_lib::{test_server::RequestBody, ExecState, ExecutorContext, Program};
19use tokio::{
20    sync::{mpsc, oneshot},
21    task::JoinHandle,
22    time::sleep,
23};
24
25#[derive(Debug)]
26pub struct ServerArgs {
27    pub listen_on: SocketAddr,
29    pub num_engine_conns: u8,
31    pub engine_address: Option<String>,
36}
37
38impl ServerArgs {
39    pub fn parse(mut pargs: pico_args::Arguments) -> Result<Self, pico_args::Error> {
40        let mut args = ServerArgs {
41            listen_on: pargs
42                .opt_value_from_str("--listen-on")?
43                .unwrap_or("0.0.0.0:3333".parse().unwrap()),
44            num_engine_conns: pargs.opt_value_from_str("--num-engine-conns")?.unwrap_or(1),
45            engine_address: pargs.opt_value_from_str("--engine-address")?,
46        };
47        if let Ok(addr) = std::env::var("ZOO_HOST") {
48            println!("Overriding engine address via $ZOO_HOST");
49            args.engine_address = Some(addr);
50        }
51        println!("Config is {args:?}");
52        Ok(args)
53    }
54}
55
56struct WorkerReq {
58    body: Bytes,
60    resp: oneshot::Sender<Response<Body>>,
62}
63
64fn start_worker(i: u8, engine_addr: Option<String>) -> mpsc::Sender<WorkerReq> {
68    println!("Starting worker {i}");
69    let (tx, mut rx) = mpsc::channel(1);
71    tokio::task::spawn(async move {
72        let state = ExecutorContext::new_for_unit_test(engine_addr).await.unwrap();
73        println!("Worker {i} ready");
74        while let Some(req) = rx.recv().await {
75            let req: WorkerReq = req;
76            let resp = snapshot_endpoint(req.body, state.clone()).await;
77            if req.resp.send(resp).is_err() {
78                println!("\tWorker {i} exiting");
79            }
80        }
81        println!("\tWorker {i} exiting");
82    });
83    tx
84}
85
86struct ServerState {
87    workers: Vec<mpsc::Sender<WorkerReq>>,
88    req_num: AtomicUsize,
89}
90
91pub async fn start_server(args: ServerArgs) -> anyhow::Result<()> {
92    let ServerArgs {
93        listen_on,
94        num_engine_conns,
95        engine_address,
96    } = args;
97    let workers: Vec<_> = (0..num_engine_conns)
98        .map(|i| start_worker(i, engine_address.clone()))
99        .collect();
100    let state = Arc::new(ServerState {
101        workers,
102        req_num: 0.into(),
103    });
104    let make_service = make_service_fn(
107        move |_conn_info| {
109            let state = state.clone();
111            async move {
112                Ok::<_, Error>(service_fn(move |req| {
116                    let state = state.clone();
118                    async move { handle_request(req, state).await }
119                }))
120            }
121        },
122    );
123    let server = Server::bind(&listen_on).serve(make_service);
124    println!("Listening on {listen_on}");
125    println!("PID is {}", std::process::id());
126    if let Err(e) = server.await {
127        eprintln!("Server error: {e}");
128        return Err(e.into());
129    }
130    Ok(())
131}
132
133async fn handle_request(req: hyper::Request<Body>, state3: Arc<ServerState>) -> Result<Response<Body>, Error> {
134    let body = hyper::body::to_bytes(req.into_body()).await?;
135
136    let req_num = state3.req_num.fetch_add(1, Ordering::Relaxed);
138    let worker_id = req_num % state3.workers.len();
139    let worker = state3.workers[worker_id].clone();
141    let (tx, rx) = oneshot::channel();
142    let req_sent = worker.send(WorkerReq { body, resp: tx }).await;
143    req_sent.unwrap();
144    let resp = rx.await.unwrap();
145    Ok(resp)
146}
147
148async fn snapshot_endpoint(body: Bytes, ctxt: ExecutorContext) -> Response<Body> {
153    let body = match serde_json::from_slice::<RequestBody>(body.as_ref()) {
154        Ok(bd) => bd,
155        Err(e) => return bad_request(format!("Invalid request JSON: {e}")),
156    };
157    let RequestBody { kcl_program, test_name } = body;
158
159    let program = match Program::parse_no_errs(&kcl_program) {
160        Ok(pr) => pr,
161        Err(e) => return bad_request(format!("Parse error: {e}")),
162    };
163
164    eprintln!("Executing {test_name}");
165    let mut exec_state = ExecState::new(&ctxt);
166    if let Err(e) = ctxt
169        .send_clear_scene(&mut exec_state, kcl_lib::SourceRange::default())
170        .await
171    {
172        return kcl_err(e);
173    }
174    let (done_tx, done_rx) = oneshot::channel::<()>();
176    let timer = time_until(done_rx);
177    if let Err(e) = ctxt.run(&program, &mut exec_state).await {
178        return kcl_err(e);
179    }
180    let snapshot = match ctxt.prepare_snapshot().await {
181        Ok(s) => s,
182        Err(e) => return kcl_err(e),
183    };
184    let _ = done_tx.send(());
185    timer.abort();
186    eprintln!("\tServing response");
187    let png_bytes = snapshot.contents.0;
188    let mut resp = Response::new(Body::from(png_bytes));
189    resp.headers_mut().insert(CONTENT_TYPE, "image/png".parse().unwrap());
190    resp
191}
192
193fn bad_request(msg: String) -> Response<Body> {
194    eprintln!("\tBad request");
195    let mut resp = Response::new(Body::from(msg));
196    *resp.status_mut() = hyper::StatusCode::BAD_REQUEST;
197    resp
198}
199
200fn bad_gateway(msg: String) -> Response<Body> {
201    eprintln!("\tBad gateway");
202    let mut resp = Response::new(Body::from(msg));
203    *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
204    resp
205}
206
207fn kcl_err(err: impl std::fmt::Display) -> Response<Body> {
208    eprintln!("\tBad KCL");
209    bad_gateway(format!("{err}"))
210}
211
212fn time_until(done: oneshot::Receiver<()>) -> JoinHandle<()> {
213    tokio::task::spawn(async move {
214        let period = 10;
215        tokio::pin!(done);
216        for i in 1..=3 {
217            tokio::select! {
218                biased;
219                _ = &mut done => return,
221                _ = sleep(Duration::from_secs(period)) => {
222                    eprintln!("\tTest has taken {}s", period * i);
223                },
224            };
225        }
226    })
227}