kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10    websocket::{
11        BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12        SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13    },
14    ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21#[cfg(feature = "artifact-graph")]
22use crate::execution::ArtifactCommand;
23use crate::{
24    engine::{AsyncTasks, EngineManager, EngineStats},
25    errors::{KclError, KclErrorDetails},
26    execution::{DefaultPlanes, IdGenerator},
27    SourceRange,
28};
29
30#[derive(Debug, PartialEq)]
31enum SocketHealth {
32    Active,
33    Inactive,
34}
35
36type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
37#[derive(Debug)]
38pub struct EngineConnection {
39    engine_req_tx: mpsc::Sender<ToEngineReq>,
40    shutdown_tx: mpsc::Sender<()>,
41    responses: ResponseInformation,
42    pending_errors: Arc<RwLock<Vec<String>>>,
43    #[allow(dead_code)]
44    tcp_read_handle: Arc<TcpReadHandle>,
45    socket_health: Arc<RwLock<SocketHealth>>,
46    batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
47    batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
48    #[cfg(feature = "artifact-graph")]
49    artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
50    ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
51
52    /// The default planes for the scene.
53    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
54    /// If the server sends session data, it'll be copied to here.
55    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
56
57    stats: EngineStats,
58
59    async_tasks: AsyncTasks,
60
61    debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
62}
63
64pub struct TcpRead {
65    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
66}
67
68/// Occurs when client couldn't read from the WebSocket to the engine.
69// #[derive(Debug)]
70#[allow(clippy::large_enum_variant)]
71pub enum WebSocketReadError {
72    /// Could not read a message due to WebSocket errors.
73    Read(tokio_tungstenite::tungstenite::Error),
74    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
75    Deser(anyhow::Error),
76}
77
78impl From<anyhow::Error> for WebSocketReadError {
79    fn from(e: anyhow::Error) -> Self {
80        Self::Deser(e)
81    }
82}
83
84impl TcpRead {
85    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
86        let Some(msg) = self.stream.next().await else {
87            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
88        };
89        let msg = match msg {
90            Ok(msg) => msg,
91            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
92                return Err(WebSocketReadError::Read(e))
93            }
94            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
95        };
96        let msg: WebSocketResponse = match msg {
97            WsMsg::Text(text) => serde_json::from_str(&text)
98                .map_err(anyhow::Error::from)
99                .map_err(WebSocketReadError::from)?,
100            WsMsg::Binary(bin) => bson::from_slice(&bin)
101                .map_err(anyhow::Error::from)
102                .map_err(WebSocketReadError::from)?,
103            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
104        };
105        Ok(msg)
106    }
107}
108
109pub struct TcpReadHandle {
110    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
111}
112
113impl std::fmt::Debug for TcpReadHandle {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        write!(f, "TcpReadHandle")
116    }
117}
118
119impl Drop for TcpReadHandle {
120    fn drop(&mut self) {
121        // Drop the read handle.
122        self.handle.abort();
123    }
124}
125
126/// Information about the responses from the engine.
127#[derive(Clone, Debug)]
128struct ResponseInformation {
129    /// The responses from the engine.
130    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
131}
132
133impl ResponseInformation {
134    pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
135        self.responses.write().await.insert(id, response);
136    }
137}
138
139/// Requests to send to the engine, and a way to await a response.
140struct ToEngineReq {
141    /// The request to send
142    req: WebSocketRequest,
143    /// If this resolves to Ok, the request was sent.
144    /// If this resolves to Err, the request could not be sent.
145    /// If this has not yet resolved, the request has not been sent yet.
146    request_sent: oneshot::Sender<Result<()>>,
147}
148
149impl EngineConnection {
150    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
151    async fn start_write_actor(
152        mut tcp_write: WebSocketTcpWrite,
153        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
154        mut shutdown_rx: mpsc::Receiver<()>,
155    ) {
156        loop {
157            tokio::select! {
158                maybe_req = engine_req_rx.recv() => {
159                    match maybe_req {
160                        Some(ToEngineReq { req, request_sent }) => {
161                            // Decide whether to send as binary or text,
162                            // then send to the engine.
163                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
164                                cmd: ModelingCmd::ImportFiles { .. },
165                                cmd_id: _,
166                            }) = &req
167                            {
168                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
169                            } else {
170                                Self::inner_send_to_engine(req, &mut tcp_write).await
171                            };
172
173                            // Let the caller know we’ve sent the request (ok or error).
174                            let _ = request_sent.send(res);
175                        }
176                        None => {
177                            // The engine_req_rx channel has closed, so no more requests.
178                            // We'll gracefully exit the loop and close the engine.
179                            break;
180                        }
181                    }
182                },
183
184                // If we get a shutdown signal, close the engine immediately and return.
185                _ = shutdown_rx.recv() => {
186                    let _ = Self::inner_close_engine(&mut tcp_write).await;
187                    return;
188                }
189            }
190        }
191
192        // If we exit the loop (e.g. engine_req_rx was closed),
193        // still gracefully close the engine before returning.
194        let _ = Self::inner_close_engine(&mut tcp_write).await;
195    }
196
197    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
198    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
199        tcp_write
200            .send(WsMsg::Close(None))
201            .await
202            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
203        Ok(())
204    }
205
206    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
207    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
208        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
209        tcp_write
210            .send(WsMsg::Text(msg.into()))
211            .await
212            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
213        Ok(())
214    }
215
216    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
217    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
218        let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
219        tcp_write
220            .send(WsMsg::Binary(msg.into()))
221            .await
222            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
223        Ok(())
224    }
225
226    pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
227        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
228            // 4294967296 bytes, which is around 4.2 GB.
229            .max_message_size(Some(usize::MAX))
230            .max_frame_size(Some(usize::MAX));
231
232        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
233            ws,
234            tokio_tungstenite::tungstenite::protocol::Role::Client,
235            Some(wsconfig),
236        )
237        .await;
238
239        let (tcp_write, tcp_read) = ws_stream.split();
240        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
241        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
242        tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
243
244        let mut tcp_read = TcpRead { stream: tcp_read };
245
246        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
247        let session_data2 = session_data.clone();
248        let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
249        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
250        let pending_errors = Arc::new(RwLock::new(Vec::new()));
251        let pending_errors_clone = pending_errors.clone();
252        let response_information = ResponseInformation {
253            responses: Arc::new(RwLock::new(IndexMap::new())),
254        };
255        let response_information_cloned = response_information.clone();
256        let debug_info = Arc::new(RwLock::new(None));
257        let debug_info_cloned = debug_info.clone();
258
259        let socket_health_tcp_read = socket_health.clone();
260        let tcp_read_handle = tokio::spawn(async move {
261            // Get Websocket messages from API server
262            loop {
263                match tcp_read.read().await {
264                    Ok(ws_resp) => {
265                        // If we got a batch response, add all the inner responses.
266                        let id = ws_resp.request_id();
267                        match &ws_resp {
268                            WebSocketResponse::Success(SuccessWebSocketResponse {
269                                resp: OkWebSocketResponseData::ModelingBatch { responses },
270                                ..
271                            }) => {
272                                #[expect(
273                                    clippy::iter_over_hash_type,
274                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
275                                )]
276                                for (resp_id, batch_response) in responses {
277                                    let id: uuid::Uuid = (*resp_id).into();
278                                    match batch_response {
279                                        BatchResponse::Success { response } => {
280                                            // If the id is in our ids of async commands, remove
281                                            // it.
282                                            response_information_cloned
283                                                .add(
284                                                    id,
285                                                    WebSocketResponse::Success(SuccessWebSocketResponse {
286                                                        success: true,
287                                                        request_id: Some(id),
288                                                        resp: OkWebSocketResponseData::Modeling {
289                                                            modeling_response: response.clone(),
290                                                        },
291                                                    }),
292                                                )
293                                                .await;
294                                        }
295                                        BatchResponse::Failure { errors } => {
296                                            response_information_cloned
297                                                .add(
298                                                    id,
299                                                    WebSocketResponse::Failure(FailureWebSocketResponse {
300                                                        success: false,
301                                                        request_id: Some(id),
302                                                        errors: errors.clone(),
303                                                    }),
304                                                )
305                                                .await;
306                                        }
307                                    }
308                                }
309                            }
310                            WebSocketResponse::Success(SuccessWebSocketResponse {
311                                resp: OkWebSocketResponseData::ModelingSessionData { session },
312                                ..
313                            }) => {
314                                let mut sd = session_data2.write().await;
315                                sd.replace(session.clone());
316                            }
317                            WebSocketResponse::Failure(FailureWebSocketResponse {
318                                success: _,
319                                request_id,
320                                errors,
321                            }) => {
322                                if let Some(id) = request_id {
323                                    response_information_cloned
324                                        .add(
325                                            *id,
326                                            WebSocketResponse::Failure(FailureWebSocketResponse {
327                                                success: false,
328                                                request_id: *request_id,
329                                                errors: errors.clone(),
330                                            }),
331                                        )
332                                        .await;
333                                } else {
334                                    // Add it to our pending errors.
335                                    let mut pe = pending_errors_clone.write().await;
336                                    for error in errors {
337                                        if !pe.contains(&error.message) {
338                                            pe.push(error.message.clone());
339                                        }
340                                    }
341                                    drop(pe);
342                                }
343                            }
344                            WebSocketResponse::Success(SuccessWebSocketResponse {
345                                resp: debug @ OkWebSocketResponseData::Debug { .. },
346                                ..
347                            }) => {
348                                let mut handle = debug_info_cloned.write().await;
349                                *handle = Some(debug.clone());
350                            }
351                            _ => {}
352                        }
353
354                        if let Some(id) = id {
355                            response_information_cloned.add(id, ws_resp.clone()).await;
356                        }
357                    }
358                    Err(e) => {
359                        match &e {
360                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
361                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
362                        }
363                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
364                        return Err(e);
365                    }
366                }
367            }
368        });
369
370        Ok(EngineConnection {
371            engine_req_tx,
372            shutdown_tx,
373            tcp_read_handle: Arc::new(TcpReadHandle {
374                handle: Arc::new(tcp_read_handle),
375            }),
376            responses: response_information,
377            pending_errors,
378            socket_health,
379            batch: Arc::new(RwLock::new(Vec::new())),
380            batch_end: Arc::new(RwLock::new(IndexMap::new())),
381            #[cfg(feature = "artifact-graph")]
382            artifact_commands: Arc::new(RwLock::new(Vec::new())),
383            ids_of_async_commands,
384            default_planes: Default::default(),
385            session_data,
386            stats: Default::default(),
387            async_tasks: AsyncTasks::new(),
388            debug_info,
389        })
390    }
391}
392
393#[async_trait::async_trait]
394impl EngineManager for EngineConnection {
395    fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
396        self.batch.clone()
397    }
398
399    fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
400        self.batch_end.clone()
401    }
402
403    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
404        self.responses.responses.clone()
405    }
406
407    #[cfg(feature = "artifact-graph")]
408    fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
409        self.artifact_commands.clone()
410    }
411
412    fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
413        self.ids_of_async_commands.clone()
414    }
415
416    fn async_tasks(&self) -> AsyncTasks {
417        self.async_tasks.clone()
418    }
419
420    fn stats(&self) -> &EngineStats {
421        &self.stats
422    }
423
424    fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
425        self.default_planes.clone()
426    }
427
428    async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
429        self.debug_info.read().await.clone()
430    }
431
432    async fn fetch_debug(&self) -> Result<(), KclError> {
433        let (tx, rx) = oneshot::channel();
434
435        self.engine_req_tx
436            .send(ToEngineReq {
437                req: WebSocketRequest::Debug {},
438                request_sent: tx,
439            })
440            .await
441            .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {}", e), vec![])))?;
442
443        let _ = rx.await;
444        Ok(())
445    }
446
447    async fn clear_scene_post_hook(
448        &self,
449        id_generator: &mut IdGenerator,
450        source_range: SourceRange,
451    ) -> Result<(), KclError> {
452        // Remake the default planes, since they would have been removed after the scene was cleared.
453        let new_planes = self.new_default_planes(id_generator, source_range).await?;
454        *self.default_planes.write().await = Some(new_planes);
455
456        Ok(())
457    }
458
459    async fn inner_fire_modeling_cmd(
460        &self,
461        _id: uuid::Uuid,
462        source_range: SourceRange,
463        cmd: WebSocketRequest,
464        _id_to_source_range: HashMap<Uuid, SourceRange>,
465    ) -> Result<(), KclError> {
466        let (tx, rx) = oneshot::channel();
467
468        // Send the request to the engine, via the actor.
469        self.engine_req_tx
470            .send(ToEngineReq {
471                req: cmd.clone(),
472                request_sent: tx,
473            })
474            .await
475            .map_err(|e| {
476                KclError::new_engine(KclErrorDetails::new(
477                    format!("Failed to send modeling command: {}", e),
478                    vec![source_range],
479                ))
480            })?;
481
482        // Wait for the request to be sent.
483        rx.await
484            .map_err(|e| {
485                KclError::new_engine(KclErrorDetails::new(
486                    format!("could not send request to the engine actor: {e}"),
487                    vec![source_range],
488                ))
489            })?
490            .map_err(|e| {
491                KclError::new_engine(KclErrorDetails::new(
492                    format!("could not send request to the engine: {e}"),
493                    vec![source_range],
494                ))
495            })?;
496
497        Ok(())
498    }
499
500    async fn inner_send_modeling_cmd(
501        &self,
502        id: uuid::Uuid,
503        source_range: SourceRange,
504        cmd: WebSocketRequest,
505        id_to_source_range: HashMap<Uuid, SourceRange>,
506    ) -> Result<WebSocketResponse, KclError> {
507        self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
508            .await?;
509
510        // Wait for the response.
511        let current_time = std::time::Instant::now();
512        while current_time.elapsed().as_secs() < 60 {
513            let guard = self.socket_health.read().await;
514            if *guard == SocketHealth::Inactive {
515                // Check if we have any pending errors.
516                let pe = self.pending_errors.read().await;
517                if !pe.is_empty() {
518                    return Err(KclError::new_engine(KclErrorDetails::new(
519                        pe.join(", ").to_string(),
520                        vec![source_range],
521                    )));
522                } else {
523                    return Err(KclError::new_engine(KclErrorDetails::new(
524                        "Modeling command failed: websocket closed early".to_string(),
525                        vec![source_range],
526                    )));
527                }
528            }
529
530            #[cfg(feature = "artifact-graph")]
531            {
532                // We cannot pop here or it will break the artifact graph.
533                if let Some(resp) = self.responses.responses.read().await.get(&id) {
534                    return Ok(resp.clone());
535                }
536            }
537            #[cfg(not(feature = "artifact-graph"))]
538            {
539                if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
540                    return Ok(resp);
541                }
542            }
543        }
544
545        Err(KclError::new_engine(KclErrorDetails::new(
546            format!("Modeling command timed out `{}`", id),
547            vec![source_range],
548        )))
549    }
550
551    async fn get_session_data(&self) -> Option<ModelingSessionData> {
552        self.session_data.read().await.clone()
553    }
554
555    async fn close(&self) {
556        let _ = self.shutdown_tx.send(()).await;
557        loop {
558            let guard = self.socket_health.read().await;
559            if *guard == SocketHealth::Inactive {
560                return;
561            }
562        }
563    }
564}