kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10    websocket::{
11        BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12        SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13    },
14    ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21use super::{EngineStats, ExecutionKind};
22use crate::{
23    engine::EngineManager,
24    errors::{KclError, KclErrorDetails},
25    execution::{ArtifactCommand, DefaultPlanes, IdGenerator},
26    SourceRange,
27};
28
29#[derive(Debug, PartialEq)]
30enum SocketHealth {
31    Active,
32    Inactive,
33}
34
35type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
36#[derive(Debug)]
37pub struct EngineConnection {
38    engine_req_tx: mpsc::Sender<ToEngineReq>,
39    shutdown_tx: mpsc::Sender<()>,
40    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
41    pending_errors: Arc<RwLock<Vec<String>>>,
42    #[allow(dead_code)]
43    tcp_read_handle: Arc<TcpReadHandle>,
44    socket_health: Arc<RwLock<SocketHealth>>,
45    batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
46    batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
47    artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
48
49    /// The default planes for the scene.
50    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
51    /// If the server sends session data, it'll be copied to here.
52    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
53
54    execution_kind: Arc<RwLock<ExecutionKind>>,
55    stats: EngineStats,
56}
57
58pub struct TcpRead {
59    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
60}
61
62/// Occurs when client couldn't read from the WebSocket to the engine.
63// #[derive(Debug)]
64pub enum WebSocketReadError {
65    /// Could not read a message due to WebSocket errors.
66    Read(tokio_tungstenite::tungstenite::Error),
67    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
68    Deser(anyhow::Error),
69}
70
71impl From<anyhow::Error> for WebSocketReadError {
72    fn from(e: anyhow::Error) -> Self {
73        Self::Deser(e)
74    }
75}
76
77impl TcpRead {
78    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
79        let Some(msg) = self.stream.next().await else {
80            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
81        };
82        let msg = match msg {
83            Ok(msg) => msg,
84            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
85                return Err(WebSocketReadError::Read(e))
86            }
87            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
88        };
89        let msg: WebSocketResponse = match msg {
90            WsMsg::Text(text) => serde_json::from_str(&text)
91                .map_err(anyhow::Error::from)
92                .map_err(WebSocketReadError::from)?,
93            WsMsg::Binary(bin) => bson::from_slice(&bin)
94                .map_err(anyhow::Error::from)
95                .map_err(WebSocketReadError::from)?,
96            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
97        };
98        Ok(msg)
99    }
100}
101
102pub struct TcpReadHandle {
103    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
104}
105
106impl std::fmt::Debug for TcpReadHandle {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        write!(f, "TcpReadHandle")
109    }
110}
111
112impl Drop for TcpReadHandle {
113    fn drop(&mut self) {
114        // Drop the read handle.
115        self.handle.abort();
116    }
117}
118
119/// Requests to send to the engine, and a way to await a response.
120struct ToEngineReq {
121    /// The request to send
122    req: WebSocketRequest,
123    /// If this resolves to Ok, the request was sent.
124    /// If this resolves to Err, the request could not be sent.
125    /// If this has not yet resolved, the request has not been sent yet.
126    request_sent: oneshot::Sender<Result<()>>,
127}
128
129impl EngineConnection {
130    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
131    async fn start_write_actor(
132        mut tcp_write: WebSocketTcpWrite,
133        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
134        mut shutdown_rx: mpsc::Receiver<()>,
135    ) {
136        loop {
137            tokio::select! {
138                maybe_req = engine_req_rx.recv() => {
139                    match maybe_req {
140                        Some(ToEngineReq { req, request_sent }) => {
141                            // Decide whether to send as binary or text,
142                            // then send to the engine.
143                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
144                                cmd: ModelingCmd::ImportFiles { .. },
145                                cmd_id: _,
146                            }) = &req
147                            {
148                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
149                            } else {
150                                Self::inner_send_to_engine(req, &mut tcp_write).await
151                            };
152
153                            // Let the caller know we’ve sent the request (ok or error).
154                            let _ = request_sent.send(res);
155                        }
156                        None => {
157                            // The engine_req_rx channel has closed, so no more requests.
158                            // We'll gracefully exit the loop and close the engine.
159                            break;
160                        }
161                    }
162                },
163
164                // If we get a shutdown signal, close the engine immediately and return.
165                _ = shutdown_rx.recv() => {
166                    let _ = Self::inner_close_engine(&mut tcp_write).await;
167                    return;
168                }
169            }
170        }
171
172        // If we exit the loop (e.g. engine_req_rx was closed),
173        // still gracefully close the engine before returning.
174        let _ = Self::inner_close_engine(&mut tcp_write).await;
175    }
176
177    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
178    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
179        tcp_write
180            .send(WsMsg::Close(None))
181            .await
182            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
183        Ok(())
184    }
185
186    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
187    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
188        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
189        tcp_write
190            .send(WsMsg::Text(msg))
191            .await
192            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
193        Ok(())
194    }
195
196    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
197    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
198        let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
199        tcp_write
200            .send(WsMsg::Binary(msg))
201            .await
202            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
203        Ok(())
204    }
205
206    pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
207        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
208            // 4294967296 bytes, which is around 4.2 GB.
209            max_message_size: Some(usize::MAX),
210            max_frame_size: Some(usize::MAX),
211            ..Default::default()
212        };
213
214        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
215            ws,
216            tokio_tungstenite::tungstenite::protocol::Role::Client,
217            Some(wsconfig),
218        )
219        .await;
220
221        let (tcp_write, tcp_read) = ws_stream.split();
222        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
223        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
224        tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
225
226        let mut tcp_read = TcpRead { stream: tcp_read };
227
228        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
229        let session_data2 = session_data.clone();
230        let responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>> = Arc::new(RwLock::new(IndexMap::new()));
231        let responses_clone = responses.clone();
232        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
233        let pending_errors = Arc::new(RwLock::new(Vec::new()));
234        let pending_errors_clone = pending_errors.clone();
235
236        let socket_health_tcp_read = socket_health.clone();
237        let tcp_read_handle = tokio::spawn(async move {
238            // Get Websocket messages from API server
239            loop {
240                match tcp_read.read().await {
241                    Ok(ws_resp) => {
242                        // If we got a batch response, add all the inner responses.
243                        let id = ws_resp.request_id();
244                        match &ws_resp {
245                            WebSocketResponse::Success(SuccessWebSocketResponse {
246                                resp: OkWebSocketResponseData::ModelingBatch { responses },
247                                ..
248                            }) =>
249                            {
250                                #[expect(
251                                    clippy::iter_over_hash_type,
252                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
253                                )]
254                                for (resp_id, batch_response) in responses {
255                                    let id: uuid::Uuid = (*resp_id).into();
256                                    match batch_response {
257                                        BatchResponse::Success { response } => {
258                                            responses_clone.write().await.insert(
259                                                id,
260                                                WebSocketResponse::Success(SuccessWebSocketResponse {
261                                                    success: true,
262                                                    request_id: Some(id),
263                                                    resp: OkWebSocketResponseData::Modeling {
264                                                        modeling_response: response.clone(),
265                                                    },
266                                                }),
267                                            );
268                                        }
269                                        BatchResponse::Failure { errors } => {
270                                            responses_clone.write().await.insert(
271                                                id,
272                                                WebSocketResponse::Failure(FailureWebSocketResponse {
273                                                    success: false,
274                                                    request_id: Some(id),
275                                                    errors: errors.clone(),
276                                                }),
277                                            );
278                                        }
279                                    }
280                                }
281                            }
282                            WebSocketResponse::Success(SuccessWebSocketResponse {
283                                resp: OkWebSocketResponseData::ModelingSessionData { session },
284                                ..
285                            }) => {
286                                let mut sd = session_data2.write().await;
287                                sd.replace(session.clone());
288                            }
289                            WebSocketResponse::Failure(FailureWebSocketResponse {
290                                success: _,
291                                request_id,
292                                errors,
293                            }) => {
294                                if let Some(id) = request_id {
295                                    responses_clone.write().await.insert(
296                                        *id,
297                                        WebSocketResponse::Failure(FailureWebSocketResponse {
298                                            success: false,
299                                            request_id: *request_id,
300                                            errors: errors.clone(),
301                                        }),
302                                    );
303                                } else {
304                                    // Add it to our pending errors.
305                                    let mut pe = pending_errors_clone.write().await;
306                                    for error in errors {
307                                        if !pe.contains(&error.message) {
308                                            pe.push(error.message.clone());
309                                        }
310                                    }
311                                    drop(pe);
312                                }
313                            }
314                            _ => {}
315                        }
316
317                        if let Some(id) = id {
318                            responses_clone.write().await.insert(id, ws_resp.clone());
319                        }
320                    }
321                    Err(e) => {
322                        match &e {
323                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
324                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
325                        }
326                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
327                        return Err(e);
328                    }
329                }
330            }
331        });
332
333        Ok(EngineConnection {
334            engine_req_tx,
335            shutdown_tx,
336            tcp_read_handle: Arc::new(TcpReadHandle {
337                handle: Arc::new(tcp_read_handle),
338            }),
339            responses,
340            pending_errors,
341            socket_health,
342            batch: Arc::new(RwLock::new(Vec::new())),
343            batch_end: Arc::new(RwLock::new(IndexMap::new())),
344            artifact_commands: Arc::new(RwLock::new(Vec::new())),
345            default_planes: Default::default(),
346            session_data,
347            execution_kind: Default::default(),
348            stats: Default::default(),
349        })
350    }
351}
352
353#[async_trait::async_trait]
354impl EngineManager for EngineConnection {
355    fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
356        self.batch.clone()
357    }
358
359    fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
360        self.batch_end.clone()
361    }
362
363    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
364        self.responses.clone()
365    }
366
367    fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
368        self.artifact_commands.clone()
369    }
370
371    async fn execution_kind(&self) -> ExecutionKind {
372        let guard = self.execution_kind.read().await;
373        *guard
374    }
375
376    async fn replace_execution_kind(&self, execution_kind: ExecutionKind) -> ExecutionKind {
377        let mut guard = self.execution_kind.write().await;
378        let original = *guard;
379        *guard = execution_kind;
380        original
381    }
382
383    fn stats(&self) -> &EngineStats {
384        &self.stats
385    }
386
387    fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
388        self.default_planes.clone()
389    }
390
391    async fn clear_scene_post_hook(
392        &self,
393        id_generator: &mut IdGenerator,
394        source_range: SourceRange,
395    ) -> Result<(), KclError> {
396        // Remake the default planes, since they would have been removed after the scene was cleared.
397        let new_planes = self.new_default_planes(id_generator, source_range).await?;
398        *self.default_planes.write().await = Some(new_planes);
399
400        Ok(())
401    }
402
403    async fn inner_send_modeling_cmd(
404        &self,
405        id: uuid::Uuid,
406        source_range: SourceRange,
407        cmd: WebSocketRequest,
408        _id_to_source_range: HashMap<Uuid, SourceRange>,
409    ) -> Result<WebSocketResponse, KclError> {
410        let (tx, rx) = oneshot::channel();
411
412        // Send the request to the engine, via the actor.
413        self.engine_req_tx
414            .send(ToEngineReq {
415                req: cmd.clone(),
416                request_sent: tx,
417            })
418            .await
419            .map_err(|e| {
420                KclError::Engine(KclErrorDetails {
421                    message: format!("Failed to send modeling command: {}", e),
422                    source_ranges: vec![source_range],
423                })
424            })?;
425
426        // Wait for the request to be sent.
427        rx.await
428            .map_err(|e| {
429                KclError::Engine(KclErrorDetails {
430                    message: format!("could not send request to the engine actor: {e}"),
431                    source_ranges: vec![source_range],
432                })
433            })?
434            .map_err(|e| {
435                KclError::Engine(KclErrorDetails {
436                    message: format!("could not send request to the engine: {e}"),
437                    source_ranges: vec![source_range],
438                })
439            })?;
440
441        // Wait for the response.
442        let current_time = std::time::Instant::now();
443        while current_time.elapsed().as_secs() < 60 {
444            let guard = self.socket_health.read().await;
445            if *guard == SocketHealth::Inactive {
446                // Check if we have any pending errors.
447                let pe = self.pending_errors.read().await;
448                if !pe.is_empty() {
449                    return Err(KclError::Engine(KclErrorDetails {
450                        message: pe.join(", ").to_string(),
451                        source_ranges: vec![source_range],
452                    }));
453                } else {
454                    return Err(KclError::Engine(KclErrorDetails {
455                        message: "Modeling command failed: websocket closed early".to_string(),
456                        source_ranges: vec![source_range],
457                    }));
458                }
459            }
460            // We pop off the responses to cleanup our mappings.
461            if let Some(resp) = self.responses.write().await.shift_remove(&id) {
462                return Ok(resp);
463            }
464        }
465
466        Err(KclError::Engine(KclErrorDetails {
467            message: format!("Modeling command timed out `{}`", id),
468            source_ranges: vec![source_range],
469        }))
470    }
471
472    async fn get_session_data(&self) -> Option<ModelingSessionData> {
473        self.session_data.read().await.clone()
474    }
475
476    async fn close(&self) {
477        let _ = self.shutdown_tx.send(()).await;
478        loop {
479            let guard = self.socket_health.read().await;
480            if *guard == SocketHealth::Inactive {
481                return;
482            }
483        }
484    }
485}