Skip to main content

kcl_lib/engine/
conn.rs

1//! Functions for setting up our WebSocket and WebRTC connections for communications with the
2//! engine.
3
4use std::collections::HashMap;
5use std::sync::Arc;
6use std::time::Duration;
7
8use anyhow::Result;
9use anyhow::anyhow;
10use futures::SinkExt;
11use futures::StreamExt;
12use indexmap::IndexMap;
13use kcmc::ModelingCmd;
14use kcmc::websocket::BatchResponse;
15use kcmc::websocket::FailureWebSocketResponse;
16use kcmc::websocket::ModelingCmdReq;
17use kcmc::websocket::ModelingSessionData;
18use kcmc::websocket::OkWebSocketResponseData;
19use kcmc::websocket::SuccessWebSocketResponse;
20use kcmc::websocket::WebSocketRequest;
21use kcmc::websocket::WebSocketResponse;
22use kittycad_modeling_cmds::{self as kcmc};
23use tokio::sync::RwLock;
24use tokio::sync::mpsc;
25use tokio::sync::oneshot;
26use tokio_tungstenite::tungstenite::Message as WsMsg;
27use uuid::Uuid;
28
29use crate::SourceRange;
30use crate::engine::AsyncTasks;
31use crate::engine::EngineBatchContext;
32use crate::engine::EngineManager;
33use crate::engine::EngineStats;
34use crate::errors::KclError;
35use crate::errors::KclErrorDetails;
36use crate::execution::DefaultPlanes;
37use crate::execution::IdGenerator;
38use crate::log::logln;
39
40#[derive(Debug, PartialEq)]
41enum SocketHealth {
42    Active,
43    Inactive,
44}
45
46type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
47#[derive(Debug)]
48pub struct EngineConnection {
49    engine_req_tx: mpsc::Sender<ToEngineReq>,
50    shutdown_tx: mpsc::Sender<()>,
51    responses: ResponseInformation,
52    pending_errors: Arc<RwLock<Vec<String>>>,
53    #[allow(dead_code)]
54    tcp_read_handle: Arc<TcpReadHandle>,
55    socket_health: Arc<RwLock<SocketHealth>>,
56    ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
57
58    /// The default planes for the scene.
59    default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
60    /// If the server sends session data, it'll be copied to here.
61    session_data: Arc<RwLock<Option<ModelingSessionData>>>,
62
63    stats: EngineStats,
64
65    async_tasks: AsyncTasks,
66
67    debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
68}
69
70pub struct TcpRead {
71    stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
72}
73
74/// Occurs when client couldn't read from the WebSocket to the engine.
75// #[derive(Debug)]
76#[allow(clippy::large_enum_variant)]
77pub enum WebSocketReadError {
78    /// Could not read a message due to WebSocket errors.
79    Read(tokio_tungstenite::tungstenite::Error),
80    /// WebSocket message didn't contain a valid message that the KCL Executor could parse.
81    Deser(anyhow::Error),
82}
83
84impl From<anyhow::Error> for WebSocketReadError {
85    fn from(e: anyhow::Error) -> Self {
86        Self::Deser(e)
87    }
88}
89
90impl TcpRead {
91    pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
92        let Some(msg) = self.stream.next().await else {
93            return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
94        };
95        let msg = match msg {
96            Ok(msg) => msg,
97            Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
98                return Err(WebSocketReadError::Read(e));
99            }
100            Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
101        };
102        let msg: WebSocketResponse = match msg {
103            WsMsg::Text(text) => serde_json::from_str(&text)
104                .map_err(anyhow::Error::from)
105                .map_err(WebSocketReadError::from)?,
106            WsMsg::Binary(bin) => rmp_serde::from_slice(&bin)
107                .map_err(anyhow::Error::from)
108                .map_err(WebSocketReadError::from)?,
109            other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
110        };
111        Ok(msg)
112    }
113}
114
115pub struct TcpReadHandle {
116    handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
117}
118
119impl std::fmt::Debug for TcpReadHandle {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(f, "TcpReadHandle")
122    }
123}
124
125impl Drop for TcpReadHandle {
126    fn drop(&mut self) {
127        // Drop the read handle.
128        self.handle.abort();
129    }
130}
131
132/// Information about the responses from the engine.
133#[derive(Clone, Debug)]
134struct ResponseInformation {
135    /// The responses from the engine.
136    responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
137}
138
139impl ResponseInformation {
140    pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
141        self.responses.write().await.insert(id, response);
142    }
143}
144
145/// Requests to send to the engine, and a way to await a response.
146struct ToEngineReq {
147    /// The request to send
148    req: WebSocketRequest,
149    /// If this resolves to Ok, the request was sent.
150    /// If this resolves to Err, the request could not be sent.
151    /// If this has not yet resolved, the request has not been sent yet.
152    request_sent: oneshot::Sender<Result<()>>,
153}
154
155impl EngineConnection {
156    /// Start waiting for incoming engine requests, and send each one over the WebSocket to the engine.
157    /// If heartbeats is Some(N), sends a heartbeat to keep the WebSocket active, every N seconds.
158    /// If None, no heartbeats will be sent.
159    async fn start_write_actor(
160        mut tcp_write: WebSocketTcpWrite,
161        mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
162        mut shutdown_rx: mpsc::Receiver<()>,
163        heartbeats: Option<u64>,
164    ) {
165        let heartbeats = heartbeats.unwrap_or_default();
166        let send_heartbeats = heartbeats != 0;
167        let period_seconds = if heartbeats == 0 { 5 * 60 } else { heartbeats };
168        let period = Duration::from_secs(period_seconds);
169        let mut heartbeats_stream = tokio::time::interval(period);
170
171        loop {
172            tokio::select! {
173                maybe_req = engine_req_rx.recv() => {
174                    match maybe_req {
175                        Some(ToEngineReq { req, request_sent }) => {
176                            // Decide whether to send as binary or text,
177                            // then send to the engine.
178                            let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
179                                cmd: ModelingCmd::ImportFiles { .. },
180                                cmd_id: _,
181                            }) = &req
182                            {
183                                Self::inner_send_to_engine_binary(req, &mut tcp_write).await
184                            } else {
185                                Self::inner_send_to_engine(req, &mut tcp_write).await
186                            };
187
188                            // Let the caller know we’ve sent the request (ok or error).
189                            let _ = request_sent.send(res);
190                        }
191                        None => {
192                            // The engine_req_rx channel has closed, so no more requests.
193                            // We'll gracefully exit the loop and close the engine.
194                            break;
195                        }
196                    }
197                },
198
199                // If we get a shutdown signal, close the engine immediately and return.
200                _ = shutdown_rx.recv() => {
201                    let _ = Self::inner_close_engine(&mut tcp_write).await;
202                    return;
203                }
204
205                // Send heartbeats periodically.
206                _ = heartbeats_stream.tick(), if send_heartbeats => {
207                    // Send a heartbeat.
208                    let res = Self::inner_send_to_engine(WebSocketRequest::Ping {}, &mut tcp_write).await;
209                    // We don't really care if a heartbeat fails, we'll just try again soon.
210                    let _ = res;
211                }
212            }
213        }
214
215        // If we exit the loop (e.g. engine_req_rx was closed),
216        // still gracefully close the engine before returning.
217        let _ = Self::inner_close_engine(&mut tcp_write).await;
218    }
219
220    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
221    async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
222        tcp_write
223            .send(WsMsg::Close(None))
224            .await
225            .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
226        Ok(())
227    }
228
229    /// Send the given `request` to the engine via the WebSocket connection `tcp_write`.
230    async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
231        let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
232        tcp_write
233            .send(WsMsg::Text(msg.into()))
234            .await
235            .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
236        Ok(())
237    }
238
239    /// Send the given `request` to the engine via the WebSocket connection `tcp_write` as binary.
240    async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
241        let msg = rmp_serde::to_vec_named(&request).map_err(|e| anyhow!("could not serialize msgpack: {e}"))?;
242        tcp_write
243            .send(WsMsg::Binary(msg.into()))
244            .await
245            .map_err(|e| anyhow!("could not send MsgPack over websocket: {e}"))?;
246        Ok(())
247    }
248
249    pub async fn new(ws: reqwest::Upgraded, heartbeats: Option<u64>) -> Result<EngineConnection> {
250        let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig::default()
251            // 4294967296 bytes, which is around 4.2 GB.
252            .max_message_size(Some(usize::MAX))
253            .max_frame_size(Some(usize::MAX));
254
255        let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
256            ws,
257            tokio_tungstenite::tungstenite::protocol::Role::Client,
258            Some(wsconfig),
259        )
260        .await;
261
262        let (tcp_write, tcp_read) = ws_stream.split();
263        let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
264        let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
265        tokio::task::spawn(Self::start_write_actor(
266            tcp_write,
267            engine_req_rx,
268            shutdown_rx,
269            heartbeats,
270        ));
271
272        let mut tcp_read = TcpRead { stream: tcp_read };
273
274        let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
275        let session_data2 = session_data.clone();
276        let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
277        let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
278        let pending_errors = Arc::new(RwLock::new(Vec::new()));
279        let pending_errors_clone = pending_errors.clone();
280        let response_information = ResponseInformation {
281            responses: Arc::new(RwLock::new(IndexMap::new())),
282        };
283        let response_information_cloned = response_information.clone();
284        let debug_info = Arc::new(RwLock::new(None));
285        let debug_info_cloned = debug_info.clone();
286
287        let socket_health_tcp_read = socket_health.clone();
288        let tcp_read_handle = tokio::spawn(async move {
289            // Get Websocket messages from API server
290            loop {
291                match tcp_read.read().await {
292                    Ok(ws_resp) => {
293                        // If we got a batch response, add all the inner responses.
294                        let id = ws_resp.request_id();
295                        match &ws_resp {
296                            WebSocketResponse::Success(SuccessWebSocketResponse {
297                                resp: OkWebSocketResponseData::ModelingBatch { responses },
298                                ..
299                            }) => {
300                                #[expect(
301                                    clippy::iter_over_hash_type,
302                                    reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
303                                )]
304                                for (resp_id, batch_response) in responses {
305                                    let id: uuid::Uuid = (*resp_id).into();
306                                    match batch_response {
307                                        BatchResponse::Success { response } => {
308                                            // If the id is in our ids of async commands, remove
309                                            // it.
310                                            response_information_cloned
311                                                .add(
312                                                    id,
313                                                    WebSocketResponse::Success(SuccessWebSocketResponse {
314                                                        success: true,
315                                                        request_id: Some(id),
316                                                        resp: OkWebSocketResponseData::Modeling {
317                                                            modeling_response: response.clone(),
318                                                        },
319                                                    }),
320                                                )
321                                                .await;
322                                        }
323                                        BatchResponse::Failure { errors } => {
324                                            response_information_cloned
325                                                .add(
326                                                    id,
327                                                    WebSocketResponse::Failure(FailureWebSocketResponse {
328                                                        success: false,
329                                                        request_id: Some(id),
330                                                        errors: errors.clone(),
331                                                    }),
332                                                )
333                                                .await;
334                                        }
335                                    }
336                                }
337                            }
338                            WebSocketResponse::Success(SuccessWebSocketResponse {
339                                resp: OkWebSocketResponseData::ModelingSessionData { session },
340                                ..
341                            }) => {
342                                let mut sd = session_data2.write().await;
343                                sd.replace(session.clone());
344                                logln!("API Call ID: {}", session.api_call_id);
345                            }
346                            WebSocketResponse::Failure(FailureWebSocketResponse {
347                                success: _,
348                                request_id,
349                                errors,
350                            }) => {
351                                if let Some(id) = request_id {
352                                    response_information_cloned
353                                        .add(
354                                            *id,
355                                            WebSocketResponse::Failure(FailureWebSocketResponse {
356                                                success: false,
357                                                request_id: *request_id,
358                                                errors: errors.clone(),
359                                            }),
360                                        )
361                                        .await;
362                                } else {
363                                    // Add it to our pending errors.
364                                    let mut pe = pending_errors_clone.write().await;
365                                    for error in errors {
366                                        if !pe.contains(&error.message) {
367                                            pe.push(error.message.clone());
368                                        }
369                                    }
370                                    drop(pe);
371                                }
372                            }
373                            WebSocketResponse::Success(SuccessWebSocketResponse {
374                                resp: debug @ OkWebSocketResponseData::Debug { .. },
375                                ..
376                            }) => {
377                                let mut handle = debug_info_cloned.write().await;
378                                *handle = Some(debug.clone());
379                            }
380                            _ => {}
381                        }
382
383                        if let Some(id) = id {
384                            response_information_cloned.add(id, ws_resp.clone()).await;
385                        }
386                    }
387                    Err(e) => {
388                        match &e {
389                            WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
390                            WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
391                        }
392                        *socket_health_tcp_read.write().await = SocketHealth::Inactive;
393                        return Err(e);
394                    }
395                }
396            }
397        });
398
399        Ok(EngineConnection {
400            engine_req_tx,
401            shutdown_tx,
402            tcp_read_handle: Arc::new(TcpReadHandle {
403                handle: Arc::new(tcp_read_handle),
404            }),
405            responses: response_information,
406            pending_errors,
407            socket_health,
408            ids_of_async_commands,
409            default_planes: Default::default(),
410            session_data,
411            stats: Default::default(),
412            async_tasks: AsyncTasks::new(),
413            debug_info,
414        })
415    }
416}
417
418#[async_trait::async_trait]
419impl EngineManager for EngineConnection {
420    fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
421        self.responses.responses.clone()
422    }
423
424    fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
425        self.ids_of_async_commands.clone()
426    }
427
428    fn async_tasks(&self) -> AsyncTasks {
429        self.async_tasks.clone()
430    }
431
432    fn stats(&self) -> &EngineStats {
433        &self.stats
434    }
435
436    fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
437        self.default_planes.clone()
438    }
439
440    async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
441        self.debug_info.read().await.clone()
442    }
443
444    async fn fetch_debug(&self) -> Result<(), KclError> {
445        let (tx, rx) = oneshot::channel();
446
447        self.engine_req_tx
448            .send(ToEngineReq {
449                req: WebSocketRequest::Debug {},
450                request_sent: tx,
451            })
452            .await
453            .map_err(|e| KclError::new_engine(KclErrorDetails::new(format!("Failed to send debug: {e}"), vec![])))?;
454
455        let _ = rx.await;
456        Ok(())
457    }
458
459    async fn clear_scene_post_hook(
460        &self,
461        batch_context: &EngineBatchContext,
462        id_generator: &mut IdGenerator,
463        source_range: SourceRange,
464    ) -> Result<(), KclError> {
465        // Remake the default planes, since they would have been removed after the scene was cleared.
466        let new_planes = self
467            .new_default_planes(batch_context, id_generator, source_range)
468            .await?;
469        *self.default_planes.write().await = Some(new_planes);
470
471        Ok(())
472    }
473
474    async fn inner_fire_modeling_cmd(
475        &self,
476        _id: uuid::Uuid,
477        source_range: SourceRange,
478        cmd: WebSocketRequest,
479        _id_to_source_range: HashMap<Uuid, SourceRange>,
480    ) -> Result<(), KclError> {
481        let (tx, rx) = oneshot::channel();
482
483        // Send the request to the engine, via the actor.
484        self.engine_req_tx
485            .send(ToEngineReq {
486                req: cmd.clone(),
487                request_sent: tx,
488            })
489            .await
490            .map_err(|e| {
491                KclError::new_engine(KclErrorDetails::new(
492                    format!("Failed to send modeling command: {e}"),
493                    vec![source_range],
494                ))
495            })?;
496
497        // Wait for the request to be sent.
498        rx.await
499            .map_err(|e| {
500                KclError::new_engine_hangup(
501                    KclErrorDetails::new(
502                        format!("could not send request to the engine actor: {e}"),
503                        vec![source_range],
504                    ),
505                    None,
506                )
507            })?
508            .map_err(|e| {
509                KclError::new_engine_hangup(
510                    KclErrorDetails::new(format!("could not send request to the engine: {e}"), vec![source_range]),
511                    None,
512                )
513            })?;
514
515        Ok(())
516    }
517
518    async fn inner_send_modeling_cmd(
519        &self,
520        id: uuid::Uuid,
521        source_range: SourceRange,
522        cmd: WebSocketRequest,
523        id_to_source_range: HashMap<Uuid, SourceRange>,
524    ) -> Result<WebSocketResponse, KclError> {
525        self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
526            .await?;
527
528        // Wait for the response.
529        let response_timeout = 300;
530        let current_time = std::time::Instant::now();
531        while current_time.elapsed().as_secs() < response_timeout {
532            let guard = self.socket_health.read().await;
533            if *guard == SocketHealth::Inactive {
534                // Get the API call ID from session data if available
535                let session_data = self.session_data.read().await;
536                let api_call_id = session_data.as_ref().map(|session| session.api_call_id.to_string());
537                let api_call_id_msg = if let Some(ref id) = api_call_id {
538                    format!(" (API call ID: {})", id)
539                } else {
540                    String::new()
541                };
542
543                // Check if we have any pending errors.
544                let pe = self.pending_errors.read().await;
545                if !pe.is_empty() {
546                    return Err(KclError::new_engine(KclErrorDetails::new(
547                        format!("{}{}", pe.join(", "), api_call_id_msg),
548                        vec![source_range],
549                    )));
550                } else {
551                    return Err(KclError::new_engine_hangup(
552                        KclErrorDetails::new(
553                            format!("Modeling command failed: websocket closed early{}", api_call_id_msg),
554                            vec![source_range],
555                        ),
556                        api_call_id,
557                    ));
558                }
559            }
560
561            // We cannot pop here or it will break the artifact graph.
562            if let Some(resp) = self.responses.responses.read().await.get(&id) {
563                return Ok(resp.clone());
564            }
565        }
566
567        // Get the API call ID from session data if available for timeout error
568        let session_data = self.session_data.read().await;
569        let api_call_id_msg = if let Some(session) = session_data.as_ref() {
570            format!(" (API call ID: {})", session.api_call_id)
571        } else {
572            String::new()
573        };
574
575        Err(KclError::new_engine(KclErrorDetails::new(
576            format!("Modeling command timed out `{id}`{}", api_call_id_msg),
577            vec![source_range],
578        )))
579    }
580
581    async fn get_session_data(&self) -> Option<ModelingSessionData> {
582        self.session_data.read().await.clone()
583    }
584
585    async fn close(&self) {
586        let _ = self.shutdown_tx.send(()).await;
587        loop {
588            let guard = self.socket_health.read().await;
589            if *guard == SocketHealth::Inactive {
590                return;
591            }
592        }
593    }
594}