kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10    websocket::{
11        BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12        SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13    },
14    ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21use super::ExecutionKind;
22use crate::{
23    engine::EngineManager,
24    errors::{KclError, KclErrorDetails},
25    execution::{ArtifactCommand, DefaultPlanes, IdGenerator},
26    SourceRange,
27};
28
29#[derive(Debug, PartialEq)]
30enum SocketHealth {
31    Active,
32    Inactive,
33}
34
35type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
36#[derive(Debug)]
37pub struct EngineConnection {
38    engine_req_tx: mpsc::Sender<ToEngineReq>,
39    shutdown_tx: mpsc::Sender<()>,
40    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
41    pending_errors: Arc<RwLock<Vec<String>>>,
42    #[allow(dead_code)]
43    tcp_read_handle: Arc<TcpReadHandle>,
44    socket_health: Arc<RwLock<SocketHealth>>,
45    batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
46    batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
47    artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
48
49    /// The default planes for the scene.
50    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
51    /// If the server sends session data, it'll be copied to here.
52    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
53
54    execution_kind: Arc<RwLock<ExecutionKind>>,
55}
56
57pub struct TcpRead {
58    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
59}
60
61/// Occurs when client couldn't read from the WebSocket to the engine.
62// #[derive(Debug)]
63pub enum WebSocketReadError {
64    /// Could not read a message due to WebSocket errors.
65    Read(tokio_tungstenite::tungstenite::Error),
66    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
67    Deser(anyhow::Error),
68}
69
70impl From<anyhow::Error> for WebSocketReadError {
71    fn from(e: anyhow::Error) -> Self {
72        Self::Deser(e)
73    }
74}
75
76impl TcpRead {
77    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
78        let Some(msg) = self.stream.next().await else {
79            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
80        };
81        let msg = match msg {
82            Ok(msg) => msg,
83            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
84                return Err(WebSocketReadError::Read(e))
85            }
86            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
87        };
88        let msg: WebSocketResponse = match msg {
89            WsMsg::Text(text) => serde_json::from_str(&text)
90                .map_err(anyhow::Error::from)
91                .map_err(WebSocketReadError::from)?,
92            WsMsg::Binary(bin) => bson::from_slice(&bin)
93                .map_err(anyhow::Error::from)
94                .map_err(WebSocketReadError::from)?,
95            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
96        };
97        Ok(msg)
98    }
99}
100
101pub struct TcpReadHandle {
102    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
103}
104
105impl std::fmt::Debug for TcpReadHandle {
106    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107        write!(f, "TcpReadHandle")
108    }
109}
110
111impl Drop for TcpReadHandle {
112    fn drop(&mut self) {
113        // Drop the read handle.
114        self.handle.abort();
115    }
116}
117
118/// Requests to send to the engine, and a way to await a response.
119struct ToEngineReq {
120    /// The request to send
121    req: WebSocketRequest,
122    /// If this resolves to Ok, the request was sent.
123    /// If this resolves to Err, the request could not be sent.
124    /// If this has not yet resolved, the request has not been sent yet.
125    request_sent: oneshot::Sender<Result<()>>,
126}
127
128impl EngineConnection {
129    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
130    async fn start_write_actor(
131        mut tcp_write: WebSocketTcpWrite,
132        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
133        mut shutdown_rx: mpsc::Receiver<()>,
134    ) {
135        loop {
136            tokio::select! {
137                maybe_req = engine_req_rx.recv() => {
138                    match maybe_req {
139                        Some(ToEngineReq { req, request_sent }) => {
140                            // Decide whether to send as binary or text,
141                            // then send to the engine.
142                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
143                                cmd: ModelingCmd::ImportFiles { .. },
144                                cmd_id: _,
145                            }) = &req
146                            {
147                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
148                            } else {
149                                Self::inner_send_to_engine(req, &mut tcp_write).await
150                            };
151
152                            // Let the caller know we’ve sent the request (ok or error).
153                            let _ = request_sent.send(res);
154                        }
155                        None => {
156                            // The engine_req_rx channel has closed, so no more requests.
157                            // We'll gracefully exit the loop and close the engine.
158                            break;
159                        }
160                    }
161                },
162
163                // If we get a shutdown signal, close the engine immediately and return.
164                _ = shutdown_rx.recv() => {
165                    let _ = Self::inner_close_engine(&mut tcp_write).await;
166                    return;
167                }
168            }
169        }
170
171        // If we exit the loop (e.g. engine_req_rx was closed),
172        // still gracefully close the engine before returning.
173        let _ = Self::inner_close_engine(&mut tcp_write).await;
174    }
175
176    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
177    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
178        tcp_write
179            .send(WsMsg::Close(None))
180            .await
181            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
182        Ok(())
183    }
184
185    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
186    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
187        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
188        tcp_write
189            .send(WsMsg::Text(msg))
190            .await
191            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
192        Ok(())
193    }
194
195    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
196    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
197        let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
198        tcp_write
199            .send(WsMsg::Binary(msg))
200            .await
201            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
202        Ok(())
203    }
204
205    pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
206        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
207            // 4294967296 bytes, which is around 4.2 GB.
208            max_message_size: Some(usize::MAX),
209            max_frame_size: Some(usize::MAX),
210            ..Default::default()
211        };
212
213        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
214            ws,
215            tokio_tungstenite::tungstenite::protocol::Role::Client,
216            Some(wsconfig),
217        )
218        .await;
219
220        let (tcp_write, tcp_read) = ws_stream.split();
221        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
222        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
223        tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
224
225        let mut tcp_read = TcpRead { stream: tcp_read };
226
227        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
228        let session_data2 = session_data.clone();
229        let responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>> = Arc::new(RwLock::new(IndexMap::new()));
230        let responses_clone = responses.clone();
231        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
232        let pending_errors = Arc::new(RwLock::new(Vec::new()));
233        let pending_errors_clone = pending_errors.clone();
234
235        let socket_health_tcp_read = socket_health.clone();
236        let tcp_read_handle = tokio::spawn(async move {
237            // Get Websocket messages from API server
238            loop {
239                match tcp_read.read().await {
240                    Ok(ws_resp) => {
241                        // If we got a batch response, add all the inner responses.
242                        let id = ws_resp.request_id();
243                        match &ws_resp {
244                            WebSocketResponse::Success(SuccessWebSocketResponse {
245                                resp: OkWebSocketResponseData::ModelingBatch { responses },
246                                ..
247                            }) =>
248                            {
249                                #[expect(
250                                    clippy::iter_over_hash_type,
251                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
252                                )]
253                                for (resp_id, batch_response) in responses {
254                                    let id: uuid::Uuid = (*resp_id).into();
255                                    match batch_response {
256                                        BatchResponse::Success { response } => {
257                                            responses_clone.write().await.insert(
258                                                id,
259                                                WebSocketResponse::Success(SuccessWebSocketResponse {
260                                                    success: true,
261                                                    request_id: Some(id),
262                                                    resp: OkWebSocketResponseData::Modeling {
263                                                        modeling_response: response.clone(),
264                                                    },
265                                                }),
266                                            );
267                                        }
268                                        BatchResponse::Failure { errors } => {
269                                            responses_clone.write().await.insert(
270                                                id,
271                                                WebSocketResponse::Failure(FailureWebSocketResponse {
272                                                    success: false,
273                                                    request_id: Some(id),
274                                                    errors: errors.clone(),
275                                                }),
276                                            );
277                                        }
278                                    }
279                                }
280                            }
281                            WebSocketResponse::Success(SuccessWebSocketResponse {
282                                resp: OkWebSocketResponseData::ModelingSessionData { session },
283                                ..
284                            }) => {
285                                let mut sd = session_data2.write().await;
286                                sd.replace(session.clone());
287                            }
288                            WebSocketResponse::Failure(FailureWebSocketResponse {
289                                success: _,
290                                request_id,
291                                errors,
292                            }) => {
293                                if let Some(id) = request_id {
294                                    responses_clone.write().await.insert(
295                                        *id,
296                                        WebSocketResponse::Failure(FailureWebSocketResponse {
297                                            success: false,
298                                            request_id: *request_id,
299                                            errors: errors.clone(),
300                                        }),
301                                    );
302                                } else {
303                                    // Add it to our pending errors.
304                                    let mut pe = pending_errors_clone.write().await;
305                                    for error in errors {
306                                        if !pe.contains(&error.message) {
307                                            pe.push(error.message.clone());
308                                        }
309                                    }
310                                    drop(pe);
311                                }
312                            }
313                            _ => {}
314                        }
315
316                        if let Some(id) = id {
317                            responses_clone.write().await.insert(id, ws_resp.clone());
318                        }
319                    }
320                    Err(e) => {
321                        match &e {
322                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
323                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
324                        }
325                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
326                        return Err(e);
327                    }
328                }
329            }
330        });
331
332        Ok(EngineConnection {
333            engine_req_tx,
334            shutdown_tx,
335            tcp_read_handle: Arc::new(TcpReadHandle {
336                handle: Arc::new(tcp_read_handle),
337            }),
338            responses,
339            pending_errors,
340            socket_health,
341            batch: Arc::new(RwLock::new(Vec::new())),
342            batch_end: Arc::new(RwLock::new(IndexMap::new())),
343            artifact_commands: Arc::new(RwLock::new(Vec::new())),
344            default_planes: Default::default(),
345            session_data,
346            execution_kind: Default::default(),
347        })
348    }
349}
350
351#[async_trait::async_trait]
352impl EngineManager for EngineConnection {
353    fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
354        self.batch.clone()
355    }
356
357    fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
358        self.batch_end.clone()
359    }
360
361    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
362        self.responses.clone()
363    }
364
365    fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
366        self.artifact_commands.clone()
367    }
368
369    async fn execution_kind(&self) -> ExecutionKind {
370        let guard = self.execution_kind.read().await;
371        *guard
372    }
373
374    async fn replace_execution_kind(&self, execution_kind: ExecutionKind) -> ExecutionKind {
375        let mut guard = self.execution_kind.write().await;
376        let original = *guard;
377        *guard = execution_kind;
378        original
379    }
380
381    async fn default_planes(
382        &self,
383        id_generator: &mut IdGenerator,
384        source_range: SourceRange,
385    ) -> Result<DefaultPlanes, KclError> {
386        {
387            let opt = self.default_planes.read().await.as_ref().cloned();
388            if let Some(planes) = opt {
389                return Ok(planes);
390            }
391        } // drop the read lock
392
393        let new_planes = self.new_default_planes(id_generator, source_range).await?;
394        *self.default_planes.write().await = Some(new_planes.clone());
395
396        Ok(new_planes)
397    }
398
399    async fn clear_scene_post_hook(
400        &self,
401        id_generator: &mut IdGenerator,
402        source_range: SourceRange,
403    ) -> Result<(), KclError> {
404        // Remake the default planes, since they would have been removed after the scene was cleared.
405        let new_planes = self.new_default_planes(id_generator, source_range).await?;
406        *self.default_planes.write().await = Some(new_planes);
407
408        Ok(())
409    }
410
411    async fn inner_send_modeling_cmd(
412        &self,
413        id: uuid::Uuid,
414        source_range: SourceRange,
415        cmd: WebSocketRequest,
416        _id_to_source_range: HashMap<Uuid, SourceRange>,
417    ) -> Result<WebSocketResponse, KclError> {
418        let (tx, rx) = oneshot::channel();
419
420        // Send the request to the engine, via the actor.
421        self.engine_req_tx
422            .send(ToEngineReq {
423                req: cmd.clone(),
424                request_sent: tx,
425            })
426            .await
427            .map_err(|e| {
428                KclError::Engine(KclErrorDetails {
429                    message: format!("Failed to send modeling command: {}", e),
430                    source_ranges: vec![source_range],
431                })
432            })?;
433
434        // Wait for the request to be sent.
435        rx.await
436            .map_err(|e| {
437                KclError::Engine(KclErrorDetails {
438                    message: format!("could not send request to the engine actor: {e}"),
439                    source_ranges: vec![source_range],
440                })
441            })?
442            .map_err(|e| {
443                KclError::Engine(KclErrorDetails {
444                    message: format!("could not send request to the engine: {e}"),
445                    source_ranges: vec![source_range],
446                })
447            })?;
448
449        // Wait for the response.
450        let current_time = std::time::Instant::now();
451        while current_time.elapsed().as_secs() < 60 {
452            let guard = self.socket_health.read().await;
453            if *guard == SocketHealth::Inactive {
454                // Check if we have any pending errors.
455                let pe = self.pending_errors.read().await;
456                if !pe.is_empty() {
457                    return Err(KclError::Engine(KclErrorDetails {
458                        message: pe.join(", ").to_string(),
459                        source_ranges: vec![source_range],
460                    }));
461                } else {
462                    return Err(KclError::Engine(KclErrorDetails {
463                        message: "Modeling command failed: websocket closed early".to_string(),
464                        source_ranges: vec![source_range],
465                    }));
466                }
467            }
468            // We pop off the responses to cleanup our mappings.
469            if let Some(resp) = self.responses.write().await.shift_remove(&id) {
470                return Ok(resp);
471            }
472        }
473
474        Err(KclError::Engine(KclErrorDetails {
475            message: format!("Modeling command timed out `{}`", id),
476            source_ranges: vec![source_range],
477        }))
478    }
479
480    async fn get_session_data(&self) -> Option<ModelingSessionData> {
481        self.session_data.read().await.clone()
482    }
483
484    async fn close(&self) {
485        let _ = self.shutdown_tx.send(()).await;
486        loop {
487            let guard = self.socket_health.read().await;
488            if *guard == SocketHealth::Inactive {
489                return;
490            }
491        }
492    }
493}