Skip to main content

kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use anyhow::Result;
8use anyhow::anyhow;
9use futures::SinkExt;
10use futures::StreamExt;
11use indexmap::IndexMap;
12use kcmc::ModelingCmd;
13use kcmc::websocket::BatchResponse;
14use kcmc::websocket::FailureWebSocketResponse;
15use kcmc::websocket::ModelingCmdReq;
16use kcmc::websocket::ModelingSessionData;
17use kcmc::websocket::OkWebSocketResponseData;
18use kcmc::websocket::SuccessWebSocketResponse;
19use kcmc::websocket::WebSocketRequest;
20use kcmc::websocket::WebSocketResponse;
21use kittycad_modeling_cmds::{self as kcmc};
22use tokio::sync::RwLock;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio_tungstenite::tungstenite::Message as WsMsg;
26use uuid::Uuid;
27
28use crate::SourceRange;
29use crate::engine::AsyncTasks;
30use crate::engine::EngineManager;
31use crate::engine::EngineStats;
32use crate::errors::KclError;
33use crate::errors::KclErrorDetails;
34use crate::execution::DefaultPlanes;
35use crate::execution::IdGenerator;
36use crate::log::logln;
37
38#[derive(Debug, PartialEq)]
39enum SocketHealth {
40    Active,
41    Inactive,
42}
43
44type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
45#[derive(Debug)]
46pub struct EngineConnection {
47    engine_req_tx: mpsc::Sender<ToEngineReq>,
48    shutdown_tx: mpsc::Sender<()>,
49    responses: ResponseInformation,
50    pending_errors: Arc<RwLock<Vec<String>>>,
51    #[allow(dead_code)]
52    tcp_read_handle: Arc<TcpReadHandle>,
53    socket_health: Arc<RwLock<SocketHealth>>,
54    batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
55    batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
56    ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
57
58    /// The default planes for the scene.
59    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
60    /// If the server sends session data, it'll be copied to here.
61    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
62
63    stats: EngineStats,
64
65    async_tasks: AsyncTasks,
66
67    debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
68}
69
70pub struct TcpRead {
71    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
72}
73
74/// Occurs when client couldn't read from the WebSocket to the engine.
75// #[derive(Debug)]
76#[allow(clippy::large_enum_variant)]
77pub enum WebSocketReadError {
78    /// Could not read a message due to WebSocket errors.
79    Read(tokio_tungstenite::tungstenite::Error),
80    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
81    Deser(anyhow::Error),
82}
83
84impl From<anyhow::Error> for WebSocketReadError {
85    fn from(e: anyhow::Error) -> Self {
86        Self::Deser(e)
87    }
88}
89
90impl TcpRead {
91    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
92        let Some(msg) = self.stream.next().await else {
93            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
94        };
95        let msg = match msg {
96            Ok(msg) => msg,
97            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
98                return Err(WebSocketReadError::Read(e));
99            }
100            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
101        };
102        let msg: WebSocketResponse = match msg {
103            WsMsg::Text(text) => serde_json::from_str(&text)
104                .map_err(anyhow::Error::from)
105                .map_err(WebSocketReadError::from)?,
106            WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
107                .map_err(anyhow::Error::from)
108                .map_err(WebSocketReadError::from)?,
109            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
110        };
111        Ok(msg)
112    }
113}
114
115pub struct TcpReadHandle {
116    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
117}
118
119impl std::fmt::Debug for TcpReadHandle {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(f, "TcpReadHandle")
122    }
123}
124
125impl Drop for TcpReadHandle {
126    fn drop(&mut self) {
127        // Drop the read handle.
128        self.handle.abort();
129    }
130}
131
132/// Information about the responses from the engine.
133#[derive(Clone, Debug)]
134struct ResponseInformation {
135    /// The responses from the engine.
136    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
137}
138
139impl ResponseInformation {
140    pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
141        self.responses.write().await.insert(id, response);
142    }
143}
144
145/// Requests to send to the engine, and a way to await a response.
146struct ToEngineReq {
147    /// The request to send
148    req: WebSocketRequest,
149    /// If this resolves to Ok, the request was sent.
150    /// If this resolves to Err, the request could not be sent.
151    /// If this has not yet resolved, the request has not been sent yet.
152    request_sent: oneshot::Sender<Result<()>>,
153}
154
155impl EngineConnection {
156    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
157    async fn start_write_actor(
158        mut tcp_write: WebSocketTcpWrite,
159        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
160        mut shutdown_rx: mpsc::Receiver<()>,
161    ) {
162        loop {
163            tokio::select! {
164                maybe_req = engine_req_rx.recv() => {
165                    match maybe_req {
166                        Some(ToEngineReq { req, request_sent }) => {
167                            // Decide whether to send as binary or text,
168                            // then send to the engine.
169                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
170                                cmd: ModelingCmd::ImportFiles { .. },
171                                cmd_id: _,
172                            }) = &req
173                            {
174                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
175                            } else {
176                                Self::inner_send_to_engine(req, &mut tcp_write).await
177                            };
178
179                            // Let the caller know we’ve sent the request (ok or error).
180                            let _ = request_sent.send(res);
181                        }
182                        None => {
183                            // The engine_req_rx channel has closed, so no more requests.
184                            // We'll gracefully exit the loop and close the engine.
185                            break;
186                        }
187                    }
188                },
189
190                // If we get a shutdown signal, close the engine immediately and return.
191                _ = shutdown_rx.recv() => {
192                    let _ = Self::inner_close_engine(&mut tcp_write).await;
193                    return;
194                }
195            }
196        }
197
198        // If we exit the loop (e.g. engine_req_rx was closed),
199        // still gracefully close the engine before returning.
200        let _ = Self::inner_close_engine(&mut tcp_write).await;
201    }
202
203    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
204    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
205        tcp_write
206            .send(WsMsg::Close(None))
207            .await
208            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
209        Ok(())
210    }
211
212    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
213    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
214        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
215        tcp_write
216            .send(WsMsg::Text(msg.into()))
217            .await
218            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
219        Ok(())
220    }
221
222    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
223    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
224        let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
225        tcp_write
226            .send(WsMsg::Binary(msg.into()))
227            .await
228            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
229        Ok(())
230    }
231
232    pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
233        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
234            // 4294967296 bytes, which is around 4.2 GB.
235            .max_message_size(Some(usize::MAX))
236            .max_frame_size(Some(usize::MAX));
237
238        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
239            ws,
240            tokio_tungstenite::tungstenite::protocol::Role::Client,
241            Some(wsconfig),
242        )
243        .await;
244
245        let (tcp_write, tcp_read) = ws_stream.split();
246        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
247        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
248        tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
249
250        let mut tcp_read = TcpRead { stream: tcp_read };
251
252        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
253        let session_data2 = session_data.clone();
254        let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
255        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
256        let pending_errors = Arc::new(RwLock::new(Vec::new()));
257        let pending_errors_clone = pending_errors.clone();
258        let response_information = ResponseInformation {
259            responses: Arc::new(RwLock::new(IndexMap::new())),
260        };
261        let response_information_cloned = response_information.clone();
262        let debug_info = Arc::new(RwLock::new(None));
263        let debug_info_cloned = debug_info.clone();
264
265        let socket_health_tcp_read = socket_health.clone();
266        let tcp_read_handle = tokio::spawn(async move {
267            // Get Websocket messages from API server
268            loop {
269                match tcp_read.read().await {
270                    Ok(ws_resp) => {
271                        // If we got a batch response, add all the inner responses.
272                        let id = ws_resp.request_id();
273                        match &ws_resp {
274                            WebSocketResponse::Success(SuccessWebSocketResponse {
275                                resp: OkWebSocketResponseData::ModelingBatch { responses },
276                                ..
277                            }) => {
278                                #[expect(
279                                    clippy::iter_over_hash_type,
280                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
281                                )]
282                                for (resp_id, batch_response) in responses {
283                                    let id: uuid::Uuid = (*resp_id).into();
284                                    match batch_response {
285                                        BatchResponse::Success { response } => {
286                                            // If the id is in our ids of async commands, remove
287                                            // it.
288                                            response_information_cloned
289                                                .add(
290                                                    id,
291                                                    WebSocketResponse::Success(SuccessWebSocketResponse {
292                                                        success: true,
293                                                        request_id: Some(id),
294                                                        resp: OkWebSocketResponseData::Modeling {
295                                                            modeling_response: response.clone(),
296                                                        },
297                                                    }),
298                                                )
299                                                .await;
300                                        }
301                                        BatchResponse::Failure { errors } => {
302                                            response_information_cloned
303                                                .add(
304                                                    id,
305                                                    WebSocketResponse::Failure(FailureWebSocketResponse {
306                                                        success: false,
307                                                        request_id: Some(id),
308                                                        errors: errors.clone(),
309                                                    }),
310                                                )
311                                                .await;
312                                        }
313                                    }
314                                }
315                            }
316                            WebSocketResponse::Success(SuccessWebSocketResponse {
317                                resp: OkWebSocketResponseData::ModelingSessionData { session },
318                                ..
319                            }) => {
320                                let mut sd = session_data2.write().await;
321                                sd.replace(session.clone());
322                                logln!("API Call ID: {}", session.api_call_id);
323                            }
324                            WebSocketResponse::Failure(FailureWebSocketResponse {
325                                success: _,
326                                request_id,
327                                errors,
328                            }) => {
329                                if let Some(id) = request_id {
330                                    response_information_cloned
331                                        .add(
332                                            *id,
333                                            WebSocketResponse::Failure(FailureWebSocketResponse {
334                                                success: false,
335                                                request_id: *request_id,
336                                                errors: errors.clone(),
337                                            }),
338                                        )
339                                        .await;
340                                } else {
341                                    // Add it to our pending errors.
342                                    let mut pe = pending_errors_clone.write().await;
343                                    for error in errors {
344                                        if !pe.contains(&error.message) {
345                                            pe.push(error.message.clone());
346                                        }
347                                    }
348                                    drop(pe);
349                                }
350                            }
351                            WebSocketResponse::Success(SuccessWebSocketResponse {
352                                resp: debug @ OkWebSocketResponseData::Debug { .. },
353                                ..
354                            }) => {
355                                let mut handle = debug_info_cloned.write().await;
356                                *handle = Some(debug.clone());
357                            }
358                            _ => {}
359                        }
360
361                        if let Some(id) = id {
362                            response_information_cloned.add(id, ws_resp.clone()).await;
363                        }
364                    }
365                    Err(e) => {
366                        match &e {
367                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
368                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
369                        }
370                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
371                        return Err(e);
372                    }
373                }
374            }
375        });
376
377        Ok(EngineConnection {
378            engine_req_tx,
379            shutdown_tx,
380            tcp_read_handle: Arc::new(TcpReadHandle {
381                handle: Arc::new(tcp_read_handle),
382            }),
383            responses: response_information,
384            pending_errors,
385            socket_health,
386            batch: Arc::new(RwLock::new(Vec::new())),
387            batch_end: Arc::new(RwLock::new(IndexMap::new())),
388            ids_of_async_commands,
389            default_planes: Default::default(),
390            session_data,
391            stats: Default::default(),
392            async_tasks: AsyncTasks::new(),
393            debug_info,
394        })
395    }
396}
397
398#[async_trait::async_trait]
399impl EngineManager for EngineConnection {
400    fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
401        self.batch.clone()
402    }
403
404    fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
405        self.batch_end.clone()
406    }
407
408    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
409        self.responses.responses.clone()
410    }
411
412    fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
413        self.ids_of_async_commands.clone()
414    }
415
416    fn async_tasks(&self) -> AsyncTasks {
417        self.async_tasks.clone()
418    }
419
420    fn stats(&self) -> &EngineStats {
421        &self.stats
422    }
423
424    fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
425        self.default_planes.clone()
426    }
427
428    async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
429        self.debug_info.read().await.clone()
430    }
431
432    async fn fetch_debug(&self) -> Result<(), KclError> {
433        let (tx, rx) = oneshot::channel();
434
435        self.engine_req_tx
436            .send(ToEngineReq {
437                req: WebSocketRequest::Debug {},
438                request_sent: tx,
439            })
440            .await
441            .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
442
443        let _ = rx.await;
444        Ok(())
445    }
446
447    async fn clear_scene_post_hook(
448        &self,
449        id_generator: &mut IdGenerator,
450        source_range: SourceRange,
451    ) -> Result<(), KclError> {
452        // Remake the default planes, since they would have been removed after the scene was cleared.
453        let new_planes = self.new_default_planes(id_generator, source_range).await?;
454        *self.default_planes.write().await = Some(new_planes);
455
456        Ok(())
457    }
458
459    async fn inner_fire_modeling_cmd(
460        &self,
461        _id: uuid::Uuid,
462        source_range: SourceRange,
463        cmd: WebSocketRequest,
464        _id_to_source_range: HashMap<Uuid, SourceRange>,
465    ) -> Result<(), KclError> {
466        let (tx, rx) = oneshot::channel();
467
468        // Send the request to the engine, via the actor.
469        self.engine_req_tx
470            .send(ToEngineReq {
471                req: cmd.clone(),
472                request_sent: tx,
473            })
474            .await
475            .map_err(|e| {
476                KclError::new_engine(KclErrorDetails::new(
477                    format!("Failed to send modeling command: {e}"),
478                    vec![source_range],
479                ))
480            })?;
481
482        // Wait for the request to be sent.
483        rx.await
484            .map_err(|e| {
485                KclError::new_engine(KclErrorDetails::new(
486                    format!("could not send request to the engine actor: {e}"),
487                    vec![source_range],
488                ))
489            })?
490            .map_err(|e| {
491                KclError::new_engine(KclErrorDetails::new(
492                    format!("could not send request to the engine: {e}"),
493                    vec![source_range],
494                ))
495            })?;
496
497        Ok(())
498    }
499
500    async fn inner_send_modeling_cmd(
501        &self,
502        id: uuid::Uuid,
503        source_range: SourceRange,
504        cmd: WebSocketRequest,
505        id_to_source_range: HashMap<Uuid, SourceRange>,
506    ) -> Result<WebSocketResponse, KclError> {
507        self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
508            .await?;
509
510        // Wait for the response.
511        let response_timeout = 300;
512        let current_time = std::time::Instant::now();
513        while current_time.elapsed().as_secs() < response_timeout {
514            let guard = self.socket_health.read().await;
515            if *guard == SocketHealth::Inactive {
516                // Check if we have any pending errors.
517                let pe = self.pending_errors.read().await;
518                if !pe.is_empty() {
519                    return Err(KclError::new_engine(KclErrorDetails::new(
520                        pe.join(", "),
521                        vec![source_range],
522                    )));
523                } else {
524                    return Err(KclError::new_engine_hangup(KclErrorDetails::new(
525                        "Modeling command failed: websocket closed early".to_string(),
526                        vec![source_range],
527                    )));
528                }
529            }
530
531            #[cfg(feature = "artifact-graph")]
532            {
533                // We cannot pop here or it will break the artifact graph.
534                if let Some(resp) = self.responses.responses.read().await.get(&id) {
535                    return Ok(resp.clone());
536                }
537            }
538            #[cfg(not(feature = "artifact-graph"))]
539            {
540                if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
541                    return Ok(resp);
542                }
543            }
544        }
545
546        Err(KclError::new_engine(KclErrorDetails::new(
547            format!("Modeling command timed out `{id}`"),
548            vec![source_range],
549        )))
550    }
551
552    async fn get_session_data(&self) -> Option<ModelingSessionData> {
553        self.session_data.read().await.clone()
554    }
555
556    async fn close(&self) {
557        let _ = self.shutdown_tx.send(()).await;
558        loop {
559            let guard = self.socket_health.read().await;
560            if *guard == SocketHealth::Inactive {
561                return;
562            }
563        }
564    }
565}