Skip to main content

kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use anyhow::Result;
8use anyhow::anyhow;
9use futures::SinkExt;
10use futures::StreamExt;
11use indexmap::IndexMap;
12use kcmc::ModelingCmd;
13use kcmc::websocket::BatchResponse;
14use kcmc::websocket::FailureWebSocketResponse;
15use kcmc::websocket::ModelingCmdReq;
16use kcmc::websocket::ModelingSessionData;
17use kcmc::websocket::OkWebSocketResponseData;
18use kcmc::websocket::SuccessWebSocketResponse;
19use kcmc::websocket::WebSocketRequest;
20use kcmc::websocket::WebSocketResponse;
21use kittycad_modeling_cmds::{self as kcmc};
22use tokio::sync::RwLock;
23use tokio::sync::mpsc;
24use tokio::sync::oneshot;
25use tokio_tungstenite::tungstenite::Message as WsMsg;
26use uuid::Uuid;
27
28use crate::SourceRange;
29use crate::engine::AsyncTasks;
30use crate::engine::EngineBatchContext;
31use crate::engine::EngineManager;
32use crate::engine::EngineStats;
33use crate::errors::KclError;
34use crate::errors::KclErrorDetails;
35use crate::execution::DefaultPlanes;
36use crate::execution::IdGenerator;
37use crate::log::logln;
38
39#[derive(Debug, PartialEq)]
40enum SocketHealth {
41    Active,
42    Inactive,
43}
44
45type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
46#[derive(Debug)]
47pub struct EngineConnection {
48    engine_req_tx: mpsc::Sender<ToEngineReq>,
49    shutdown_tx: mpsc::Sender<()>,
50    responses: ResponseInformation,
51    pending_errors: Arc<RwLock<Vec<String>>>,
52    #[allow(dead_code)]
53    tcp_read_handle: Arc<TcpReadHandle>,
54    socket_health: Arc<RwLock<SocketHealth>>,
55    ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
56
57    /// The default planes for the scene.
58    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
59    /// If the server sends session data, it'll be copied to here.
60    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
61
62    stats: EngineStats,
63
64    async_tasks: AsyncTasks,
65
66    debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
67}
68
69pub struct TcpRead {
70    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
71}
72
73/// Occurs when client couldn't read from the WebSocket to the engine.
74// #[derive(Debug)]
75#[allow(clippy::large_enum_variant)]
76pub enum WebSocketReadError {
77    /// Could not read a message due to WebSocket errors.
78    Read(tokio_tungstenite::tungstenite::Error),
79    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
80    Deser(anyhow::Error),
81}
82
83impl From<anyhow::Error> for WebSocketReadError {
84    fn from(e: anyhow::Error) -> Self {
85        Self::Deser(e)
86    }
87}
88
89impl TcpRead {
90    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
91        let Some(msg) = self.stream.next().await else {
92            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
93        };
94        let msg = match msg {
95            Ok(msg) => msg,
96            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
97                return Err(WebSocketReadError::Read(e));
98            }
99            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
100        };
101        let msg: WebSocketResponse = match msg {
102            WsMsg::Text(text) => serde_json::from_str(&text)
103                .map_err(anyhow::Error::from)
104                .map_err(WebSocketReadError::from)?,
105            WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
106                .map_err(anyhow::Error::from)
107                .map_err(WebSocketReadError::from)?,
108            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
109        };
110        Ok(msg)
111    }
112}
113
114pub struct TcpReadHandle {
115    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
116}
117
118impl std::fmt::Debug for TcpReadHandle {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        write!(f, "TcpReadHandle")
121    }
122}
123
124impl Drop for TcpReadHandle {
125    fn drop(&mut self) {
126        // Drop the read handle.
127        self.handle.abort();
128    }
129}
130
131/// Information about the responses from the engine.
132#[derive(Clone, Debug)]
133struct ResponseInformation {
134    /// The responses from the engine.
135    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
136}
137
138impl ResponseInformation {
139    pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
140        self.responses.write().await.insert(id, response);
141    }
142}
143
144/// Requests to send to the engine, and a way to await a response.
145struct ToEngineReq {
146    /// The request to send
147    req: WebSocketRequest,
148    /// If this resolves to Ok, the request was sent.
149    /// If this resolves to Err, the request could not be sent.
150    /// If this has not yet resolved, the request has not been sent yet.
151    request_sent: oneshot::Sender<Result<()>>,
152}
153
154impl EngineConnection {
155    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
156    async fn start_write_actor(
157        mut tcp_write: WebSocketTcpWrite,
158        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
159        mut shutdown_rx: mpsc::Receiver<()>,
160    ) {
161        loop {
162            tokio::select! {
163                maybe_req = engine_req_rx.recv() => {
164                    match maybe_req {
165                        Some(ToEngineReq { req, request_sent }) => {
166                            // Decide whether to send as binary or text,
167                            // then send to the engine.
168                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
169                                cmd: ModelingCmd::ImportFiles { .. },
170                                cmd_id: _,
171                            }) = &req
172                            {
173                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
174                            } else {
175                                Self::inner_send_to_engine(req, &mut tcp_write).await
176                            };
177
178                            // Let the caller know we’ve sent the request (ok or error).
179                            let _ = request_sent.send(res);
180                        }
181                        None => {
182                            // The engine_req_rx channel has closed, so no more requests.
183                            // We'll gracefully exit the loop and close the engine.
184                            break;
185                        }
186                    }
187                },
188
189                // If we get a shutdown signal, close the engine immediately and return.
190                _ = shutdown_rx.recv() => {
191                    let _ = Self::inner_close_engine(&mut tcp_write).await;
192                    return;
193                }
194            }
195        }
196
197        // If we exit the loop (e.g. engine_req_rx was closed),
198        // still gracefully close the engine before returning.
199        let _ = Self::inner_close_engine(&mut tcp_write).await;
200    }
201
202    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
203    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
204        tcp_write
205            .send(WsMsg::Close(None))
206            .await
207            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
208        Ok(())
209    }
210
211    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
212    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
213        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
214        tcp_write
215            .send(WsMsg::Text(msg.into()))
216            .await
217            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
218        Ok(())
219    }
220
221    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
222    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
223        let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
224        tcp_write
225            .send(WsMsg::Binary(msg.into()))
226            .await
227            .map_err(|e| anyhow!("could not send MsgPack over websocket: {e}"))?;
228        Ok(())
229    }
230
231    pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
232        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
233            // 4294967296 bytes, which is around 4.2 GB.
234            .max_message_size(Some(usize::MAX))
235            .max_frame_size(Some(usize::MAX));
236
237        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
238            ws,
239            tokio_tungstenite::tungstenite::protocol::Role::Client,
240            Some(wsconfig),
241        )
242        .await;
243
244        let (tcp_write, tcp_read) = ws_stream.split();
245        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
246        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
247        tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
248
249        let mut tcp_read = TcpRead { stream: tcp_read };
250
251        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
252        let session_data2 = session_data.clone();
253        let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
254        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
255        let pending_errors = Arc::new(RwLock::new(Vec::new()));
256        let pending_errors_clone = pending_errors.clone();
257        let response_information = ResponseInformation {
258            responses: Arc::new(RwLock::new(IndexMap::new())),
259        };
260        let response_information_cloned = response_information.clone();
261        let debug_info = Arc::new(RwLock::new(None));
262        let debug_info_cloned = debug_info.clone();
263
264        let socket_health_tcp_read = socket_health.clone();
265        let tcp_read_handle = tokio::spawn(async move {
266            // Get Websocket messages from API server
267            loop {
268                match tcp_read.read().await {
269                    Ok(ws_resp) => {
270                        // If we got a batch response, add all the inner responses.
271                        let id = ws_resp.request_id();
272                        match &ws_resp {
273                            WebSocketResponse::Success(SuccessWebSocketResponse {
274                                resp: OkWebSocketResponseData::ModelingBatch { responses },
275                                ..
276                            }) => {
277                                #[expect(
278                                    clippy::iter_over_hash_type,
279                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
280                                )]
281                                for (resp_id, batch_response) in responses {
282                                    let id: uuid::Uuid = (*resp_id).into();
283                                    match batch_response {
284                                        BatchResponse::Success { response } => {
285                                            // If the id is in our ids of async commands, remove
286                                            // it.
287                                            response_information_cloned
288                                                .add(
289                                                    id,
290                                                    WebSocketResponse::Success(SuccessWebSocketResponse {
291                                                        success: true,
292                                                        request_id: Some(id),
293                                                        resp: OkWebSocketResponseData::Modeling {
294                                                            modeling_response: response.clone(),
295                                                        },
296                                                    }),
297                                                )
298                                                .await;
299                                        }
300                                        BatchResponse::Failure { errors } => {
301                                            response_information_cloned
302                                                .add(
303                                                    id,
304                                                    WebSocketResponse::Failure(FailureWebSocketResponse {
305                                                        success: false,
306                                                        request_id: Some(id),
307                                                        errors: errors.clone(),
308                                                    }),
309                                                )
310                                                .await;
311                                        }
312                                    }
313                                }
314                            }
315                            WebSocketResponse::Success(SuccessWebSocketResponse {
316                                resp: OkWebSocketResponseData::ModelingSessionData { session },
317                                ..
318                            }) => {
319                                let mut sd = session_data2.write().await;
320                                sd.replace(session.clone());
321                                logln!("API Call ID: {}", session.api_call_id);
322                            }
323                            WebSocketResponse::Failure(FailureWebSocketResponse {
324                                success: _,
325                                request_id,
326                                errors,
327                            }) => {
328                                if let Some(id) = request_id {
329                                    response_information_cloned
330                                        .add(
331                                            *id,
332                                            WebSocketResponse::Failure(FailureWebSocketResponse {
333                                                success: false,
334                                                request_id: *request_id,
335                                                errors: errors.clone(),
336                                            }),
337                                        )
338                                        .await;
339                                } else {
340                                    // Add it to our pending errors.
341                                    let mut pe = pending_errors_clone.write().await;
342                                    for error in errors {
343                                        if !pe.contains(&error.message) {
344                                            pe.push(error.message.clone());
345                                        }
346                                    }
347                                    drop(pe);
348                                }
349                            }
350                            WebSocketResponse::Success(SuccessWebSocketResponse {
351                                resp: debug @ OkWebSocketResponseData::Debug { .. },
352                                ..
353                            }) => {
354                                let mut handle = debug_info_cloned.write().await;
355                                *handle = Some(debug.clone());
356                            }
357                            _ => {}
358                        }
359
360                        if let Some(id) = id {
361                            response_information_cloned.add(id, ws_resp.clone()).await;
362                        }
363                    }
364                    Err(e) => {
365                        match &e {
366                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
367                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
368                        }
369                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
370                        return Err(e);
371                    }
372                }
373            }
374        });
375
376        Ok(EngineConnection {
377            engine_req_tx,
378            shutdown_tx,
379            tcp_read_handle: Arc::new(TcpReadHandle {
380                handle: Arc::new(tcp_read_handle),
381            }),
382            responses: response_information,
383            pending_errors,
384            socket_health,
385            ids_of_async_commands,
386            default_planes: Default::default(),
387            session_data,
388            stats: Default::default(),
389            async_tasks: AsyncTasks::new(),
390            debug_info,
391        })
392    }
393}
394
395#[async_trait::async_trait]
396impl EngineManager for EngineConnection {
397    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
398        self.responses.responses.clone()
399    }
400
401    fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
402        self.ids_of_async_commands.clone()
403    }
404
405    fn async_tasks(&self) -> AsyncTasks {
406        self.async_tasks.clone()
407    }
408
409    fn stats(&self) -> &EngineStats {
410        &self.stats
411    }
412
413    fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
414        self.default_planes.clone()
415    }
416
417    async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
418        self.debug_info.read().await.clone()
419    }
420
421    async fn fetch_debug(&self) -> Result<(), KclError> {
422        let (tx, rx) = oneshot::channel();
423
424        self.engine_req_tx
425            .send(ToEngineReq {
426                req: WebSocketRequest::Debug {},
427                request_sent: tx,
428            })
429            .await
430            .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
431
432        let _ = rx.await;
433        Ok(())
434    }
435
436    async fn clear_scene_post_hook(
437        &self,
438        batch_context: &EngineBatchContext,
439        id_generator: &mut IdGenerator,
440        source_range: SourceRange,
441    ) -> Result<(), KclError> {
442        // Remake the default planes, since they would have been removed after the scene was cleared.
443        let new_planes = self
444            .new_default_planes(batch_context, id_generator, source_range)
445            .await?;
446        *self.default_planes.write().await = Some(new_planes);
447
448        Ok(())
449    }
450
451    async fn inner_fire_modeling_cmd(
452        &self,
453        _id: uuid::Uuid,
454        source_range: SourceRange,
455        cmd: WebSocketRequest,
456        _id_to_source_range: HashMap<Uuid, SourceRange>,
457    ) -> Result<(), KclError> {
458        let (tx, rx) = oneshot::channel();
459
460        // Send the request to the engine, via the actor.
461        self.engine_req_tx
462            .send(ToEngineReq {
463                req: cmd.clone(),
464                request_sent: tx,
465            })
466            .await
467            .map_err(|e| {
468                KclError::new_engine(KclErrorDetails::new(
469                    format!("Failed to send modeling command: {e}"),
470                    vec![source_range],
471                ))
472            })?;
473
474        // Wait for the request to be sent.
475        rx.await
476            .map_err(|e| {
477                KclError::new_engine_hangup(
478                    KclErrorDetails::new(
479                        format!("could not send request to the engine actor: {e}"),
480                        vec![source_range],
481                    ),
482                    None,
483                )
484            })?
485            .map_err(|e| {
486                KclError::new_engine_hangup(
487                    KclErrorDetails::new(format!("could not send request to the engine: {e}"), vec![source_range]),
488                    None,
489                )
490            })?;
491
492        Ok(())
493    }
494
495    async fn inner_send_modeling_cmd(
496        &self,
497        id: uuid::Uuid,
498        source_range: SourceRange,
499        cmd: WebSocketRequest,
500        id_to_source_range: HashMap<Uuid, SourceRange>,
501    ) -> Result<WebSocketResponse, KclError> {
502        self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
503            .await?;
504
505        // Wait for the response.
506        let response_timeout = 300;
507        let current_time = std::time::Instant::now();
508        while current_time.elapsed().as_secs() < response_timeout {
509            let guard = self.socket_health.read().await;
510            if *guard == SocketHealth::Inactive {
511                // Get the API call ID from session data if available
512                let session_data = self.session_data.read().await;
513                let api_call_id = session_data.as_ref().map(|session| session.api_call_id.to_string());
514                let api_call_id_msg = if let Some(ref id) = api_call_id {
515                    format!(" (API call ID: {})", id)
516                } else {
517                    String::new()
518                };
519
520                // Check if we have any pending errors.
521                let pe = self.pending_errors.read().await;
522                if !pe.is_empty() {
523                    return Err(KclError::new_engine(KclErrorDetails::new(
524                        format!("{}{}", pe.join(", "), api_call_id_msg),
525                        vec![source_range],
526                    )));
527                } else {
528                    return Err(KclError::new_engine_hangup(
529                        KclErrorDetails::new(
530                            format!("Modeling command failed: websocket closed early{}", api_call_id_msg),
531                            vec![source_range],
532                        ),
533                        api_call_id,
534                    ));
535                }
536            }
537
538            #[cfg(feature = "artifact-graph")]
539            {
540                // We cannot pop here or it will break the artifact graph.
541                if let Some(resp) = self.responses.responses.read().await.get(&id) {
542                    return Ok(resp.clone());
543                }
544            }
545            #[cfg(not(feature = "artifact-graph"))]
546            {
547                if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
548                    return Ok(resp);
549                }
550            }
551        }
552
553        // Get the API call ID from session data if available for timeout error
554        let session_data = self.session_data.read().await;
555        let api_call_id_msg = if let Some(session) = session_data.as_ref() {
556            format!(" (API call ID: {})", session.api_call_id)
557        } else {
558            String::new()
559        };
560
561        Err(KclError::new_engine(KclErrorDetails::new(
562            format!("Modeling command timed out `{id}`{}", api_call_id_msg),
563            vec![source_range],
564        )))
565    }
566
567    async fn get_session_data(&self) -> Option<ModelingSessionData> {
568        self.session_data.read().await.clone()
569    }
570
571    async fn close(&self) {
572        let _ = self.shutdown_tx.send(()).await;
573        loop {
574            let guard = self.socket_health.read().await;
575            if *guard == SocketHealth::Inactive {
576                return;
577            }
578        }
579    }
580}