Skip to main content

kcl_test_server/
lib.rs

1//! Executes KCL programs.
2//! The server reuses the same engine session for each KCL program it receives.
3use std::net::SocketAddr;
4use std::sync::Arc;
5use std::sync::atomic::AtomicUsize;
6use std::sync::atomic::Ordering;
7use std::time::Duration;
8
9use hyper::Body;
10use hyper::Error;
11use hyper::Response;
12use hyper::Server;
13use hyper::body::Bytes;
14use hyper::header::CONTENT_TYPE;
15use hyper::service::make_service_fn;
16use hyper::service::service_fn;
17use kcl_lib::ExecState;
18use kcl_lib::ExecutorContext;
19use kcl_lib::Program;
20use kcl_lib::test_server::RequestBody;
21use tokio::sync::mpsc;
22use tokio::sync::oneshot;
23use tokio::task::JoinHandle;
24use tokio::time::sleep;
25
26#[derive(Debug)]
27pub struct ServerArgs {
28    /// What port this server should listen on.
29    pub listen_on: SocketAddr,
30    /// How many connections to establish with the engine.
31    pub num_engine_conns: u8,
32    /// Where to find the engine.
33    /// If none, uses the prod engine.
34    /// This is useful for testing a local engine instance.
35    /// Overridden by the $ZOO_HOST environment variable.
36    pub engine_address: Option<String>,
37}
38
39impl ServerArgs {
40    pub fn parse(mut pargs: pico_args::Arguments) -> Result<Self, pico_args::Error> {
41        let mut args = ServerArgs {
42            listen_on: pargs
43                .opt_value_from_str("--listen-on")?
44                .unwrap_or("0.0.0.0:3333".parse().unwrap()),
45            num_engine_conns: pargs.opt_value_from_str("--num-engine-conns")?.unwrap_or(1),
46            engine_address: pargs.opt_value_from_str("--engine-address")?,
47        };
48        if let Ok(addr) = std::env::var("ZOO_HOST") {
49            println!("Overriding engine address via $ZOO_HOST");
50            args.engine_address = Some(addr);
51        }
52        println!("Config is {args:?}");
53        Ok(args)
54    }
55}
56
57/// Sent from the server to each worker.
58struct WorkerReq {
59    /// A KCL program, in UTF-8.
60    body: Bytes,
61    /// A channel to send the HTTP response back.
62    resp: oneshot::Sender<Response<Body>>,
63}
64
65/// Each worker has a connection to the engine, and accepts
66/// KCL programs. When it receives one (over the mpsc channel)
67/// it executes it and returns the result via a oneshot channel.
68fn start_worker(i: u8, engine_addr: Option<String>) -> mpsc::Sender<WorkerReq> {
69    println!("Starting worker {i}");
70    // Make a work queue for this worker.
71    let (tx, mut rx) = mpsc::channel(1);
72    tokio::task::spawn(async move {
73        let state = ExecutorContext::new_for_unit_test(engine_addr).await.unwrap();
74        println!("Worker {i} ready");
75        while let Some(req) = rx.recv().await {
76            let req: WorkerReq = req;
77            let resp = snapshot_endpoint(req.body, state.clone()).await;
78            if req.resp.send(resp).is_err() {
79                println!("\tWorker {i} exiting");
80            }
81        }
82        println!("\tWorker {i} exiting");
83    });
84    tx
85}
86
87struct ServerState {
88    workers: Vec<mpsc::Sender<WorkerReq>>,
89    req_num: AtomicUsize,
90}
91
92pub async fn start_server(args: ServerArgs) -> anyhow::Result<()> {
93    let ServerArgs {
94        listen_on,
95        num_engine_conns,
96        engine_address,
97    } = args;
98    let workers: Vec<_> = (0..num_engine_conns)
99        .map(|i| start_worker(i, engine_address.clone()))
100        .collect();
101    let state = Arc::new(ServerState {
102        workers,
103        req_num: 0.into(),
104    });
105    // In hyper, a `MakeService` is basically your server.
106    // It makes a `Service` for each connection, which manages the connection.
107    let make_service = make_service_fn(
108        // This closure is run for each connection.
109        move |_conn_info| {
110            // eprintln!("Connected to a client");
111            let state = state.clone();
112            async move {
113                // This is the `Service` which handles the connection.
114                // `service_fn` converts a function which returns a Response
115                // into a `Service`.
116                Ok::<_, Error>(service_fn(move |req| {
117                    // eprintln!("Received a request");
118                    let state = state.clone();
119                    async move { handle_request(req, state).await }
120                }))
121            }
122        },
123    );
124    let server = Server::bind(&listen_on).serve(make_service);
125    println!("Listening on {listen_on}");
126    println!("PID is {}", std::process::id());
127    if let Err(e) = server.await {
128        eprintln!("Server error: {e}");
129        return Err(e.into());
130    }
131    Ok(())
132}
133
134async fn handle_request(req: hyper::Request<Body>, state3: Arc<ServerState>) -> Result<Response<Body>, Error> {
135    let body = hyper::body::to_bytes(req.into_body()).await?;
136
137    // Round robin requests between each available worker.
138    let req_num = state3.req_num.fetch_add(1, Ordering::Relaxed);
139    let worker_id = req_num % state3.workers.len();
140    // println!("Sending request {req_num} to worker {worker_id}");
141    let worker = state3.workers[worker_id].clone();
142    let (tx, rx) = oneshot::channel();
143    let req_sent = worker.send(WorkerReq { body, resp: tx }).await;
144    req_sent.unwrap();
145    let resp = rx.await.unwrap();
146    Ok(resp)
147}
148
149/// Execute a KCL program, then respond with a PNG snapshot.
150/// KCL errors (from engine or the executor) respond with HTTP Bad Gateway.
151/// Malformed requests are HTTP Bad Request.
152/// Successful requests contain a PNG as the body.
153async fn snapshot_endpoint(body: Bytes, ctxt: ExecutorContext) -> Response<Body> {
154    let body = match serde_json::from_slice::<RequestBody>(body.as_ref()) {
155        Ok(bd) => bd,
156        Err(e) => return bad_request(format!("Invalid request JSON: {e}")),
157    };
158    let RequestBody { kcl_program, test_name } = body;
159
160    let program = match Program::parse_no_errs(&kcl_program) {
161        Ok(pr) => pr,
162        Err(e) => return bad_request(format!("Parse error: {e}")),
163    };
164
165    eprintln!("Executing {test_name}");
166    let mut exec_state = ExecState::new(&ctxt);
167    // This is a shitty source range, I don't know what else to use for it though.
168    // There's no actual KCL associated with this reset_scene call.
169    if let Err(e) = ctxt
170        .send_clear_scene(&mut exec_state, kcl_lib::SourceRange::default())
171        .await
172    {
173        return kcl_err(e);
174    }
175    // Let users know if the test is taking a long time.
176    let (done_tx, done_rx) = oneshot::channel::<()>();
177    let timer = time_until(done_rx);
178    if let Err(e) = ctxt.run(&program, &mut exec_state).await {
179        return kcl_err(e);
180    }
181    let snapshot = match ctxt.prepare_snapshot().await {
182        Ok(s) => s,
183        Err(e) => return kcl_err(e),
184    };
185    let _ = done_tx.send(());
186    timer.abort();
187    eprintln!("\tServing response");
188    let png_bytes = snapshot.contents.0;
189    let mut resp = Response::new(Body::from(png_bytes));
190    resp.headers_mut().insert(CONTENT_TYPE, "image/png".parse().unwrap());
191    resp
192}
193
194fn bad_request(msg: String) -> Response<Body> {
195    eprintln!("\tBad request");
196    let mut resp = Response::new(Body::from(msg));
197    *resp.status_mut() = hyper::StatusCode::BAD_REQUEST;
198    resp
199}
200
201fn bad_gateway(msg: String) -> Response<Body> {
202    eprintln!("\tBad gateway");
203    let mut resp = Response::new(Body::from(msg));
204    *resp.status_mut() = hyper::StatusCode::BAD_GATEWAY;
205    resp
206}
207
208fn kcl_err(err: impl std::fmt::Display) -> Response<Body> {
209    eprintln!("\tBad KCL");
210    bad_gateway(format!("{err}"))
211}
212
213fn time_until(done: oneshot::Receiver<()>) -> JoinHandle<()> {
214    tokio::task::spawn(async move {
215        let period = 10;
216        tokio::pin!(done);
217        for i in 1..=3 {
218            tokio::select! {
219                biased;
220                // If the test is done, no need for this timer anymore.
221                _ = &mut done => return,
222                _ = sleep(Duration::from_secs(period)) => {
223                    eprintln!("\tTest has taken {}s", period * i);
224                },
225            };
226        }
227    })
228}