1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 websocket::{
11 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13 },
14 ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21#[cfg(feature = "artifact-graph")]
22use crate::execution::ArtifactCommand;
23use crate::{
24 engine::{AsyncTasks, EngineManager, EngineStats},
25 errors::{KclError, KclErrorDetails},
26 execution::{DefaultPlanes, IdGenerator},
27 SourceRange,
28};
29
30#[derive(Debug, PartialEq)]
31enum SocketHealth {
32 Active,
33 Inactive,
34}
35
36type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
37#[derive(Debug)]
38pub struct EngineConnection {
39 engine_req_tx: mpsc::Sender<ToEngineReq>,
40 shutdown_tx: mpsc::Sender<()>,
41 responses: ResponseInformation,
42 pending_errors: Arc<RwLock<Vec<String>>>,
43 #[allow(dead_code)]
44 tcp_read_handle: Arc<TcpReadHandle>,
45 socket_health: Arc<RwLock<SocketHealth>>,
46 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
47 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
48 #[cfg(feature = "artifact-graph")]
49 artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
50 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
51
52 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
54 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
56
57 stats: EngineStats,
58
59 async_tasks: AsyncTasks,
60
61 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
62}
63
64pub struct TcpRead {
65 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
66}
67
68pub enum WebSocketReadError {
71 Read(tokio_tungstenite::tungstenite::Error),
73 Deser(anyhow::Error),
75}
76
77impl From<anyhow::Error> for WebSocketReadError {
78 fn from(e: anyhow::Error) -> Self {
79 Self::Deser(e)
80 }
81}
82
83impl TcpRead {
84 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
85 let Some(msg) = self.stream.next().await else {
86 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
87 };
88 let msg = match msg {
89 Ok(msg) => msg,
90 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
91 return Err(WebSocketReadError::Read(e))
92 }
93 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
94 };
95 let msg: WebSocketResponse = match msg {
96 WsMsg::Text(text) => serde_json::from_str(&text)
97 .map_err(anyhow::Error::from)
98 .map_err(WebSocketReadError::from)?,
99 WsMsg::Binary(bin) => bson::from_slice(&bin)
100 .map_err(anyhow::Error::from)
101 .map_err(WebSocketReadError::from)?,
102 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
103 };
104 Ok(msg)
105 }
106}
107
108pub struct TcpReadHandle {
109 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
110}
111
112impl std::fmt::Debug for TcpReadHandle {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "TcpReadHandle")
115 }
116}
117
118impl Drop for TcpReadHandle {
119 fn drop(&mut self) {
120 self.handle.abort();
122 }
123}
124
125#[derive(Clone, Debug)]
127struct ResponseInformation {
128 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
130}
131
132impl ResponseInformation {
133 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
134 self.responses.write().await.insert(id, response);
135 }
136}
137
138struct ToEngineReq {
140 req: WebSocketRequest,
142 request_sent: oneshot::Sender<Result<()>>,
146}
147
148impl EngineConnection {
149 async fn start_write_actor(
151 mut tcp_write: WebSocketTcpWrite,
152 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
153 mut shutdown_rx: mpsc::Receiver<()>,
154 ) {
155 loop {
156 tokio::select! {
157 maybe_req = engine_req_rx.recv() => {
158 match maybe_req {
159 Some(ToEngineReq { req, request_sent }) => {
160 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
163 cmd: ModelingCmd::ImportFiles { .. },
164 cmd_id: _,
165 }) = &req
166 {
167 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
168 } else {
169 Self::inner_send_to_engine(req, &mut tcp_write).await
170 };
171
172 let _ = request_sent.send(res);
174 }
175 None => {
176 break;
179 }
180 }
181 },
182
183 _ = shutdown_rx.recv() => {
185 let _ = Self::inner_close_engine(&mut tcp_write).await;
186 return;
187 }
188 }
189 }
190
191 let _ = Self::inner_close_engine(&mut tcp_write).await;
194 }
195
196 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
198 tcp_write
199 .send(WsMsg::Close(None))
200 .await
201 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
202 Ok(())
203 }
204
205 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
207 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
208 tcp_write
209 .send(WsMsg::Text(msg))
210 .await
211 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
212 Ok(())
213 }
214
215 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
217 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
218 tcp_write
219 .send(WsMsg::Binary(msg))
220 .await
221 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
222 Ok(())
223 }
224
225 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
226 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
227 max_message_size: Some(usize::MAX),
229 max_frame_size: Some(usize::MAX),
230 ..Default::default()
231 };
232
233 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
234 ws,
235 tokio_tungstenite::tungstenite::protocol::Role::Client,
236 Some(wsconfig),
237 )
238 .await;
239
240 let (tcp_write, tcp_read) = ws_stream.split();
241 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
242 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
243 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
244
245 let mut tcp_read = TcpRead { stream: tcp_read };
246
247 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
248 let session_data2 = session_data.clone();
249 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
250 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
251 let pending_errors = Arc::new(RwLock::new(Vec::new()));
252 let pending_errors_clone = pending_errors.clone();
253 let response_information = ResponseInformation {
254 responses: Arc::new(RwLock::new(IndexMap::new())),
255 };
256 let response_information_cloned = response_information.clone();
257 let debug_info = Arc::new(RwLock::new(None));
258 let debug_info_cloned = debug_info.clone();
259
260 let socket_health_tcp_read = socket_health.clone();
261 let tcp_read_handle = tokio::spawn(async move {
262 loop {
264 match tcp_read.read().await {
265 Ok(ws_resp) => {
266 let id = ws_resp.request_id();
268 match &ws_resp {
269 WebSocketResponse::Success(SuccessWebSocketResponse {
270 resp: OkWebSocketResponseData::ModelingBatch { responses },
271 ..
272 }) => {
273 #[expect(
274 clippy::iter_over_hash_type,
275 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
276 )]
277 for (resp_id, batch_response) in responses {
278 let id: uuid::Uuid = (*resp_id).into();
279 match batch_response {
280 BatchResponse::Success { response } => {
281 response_information_cloned
284 .add(
285 id,
286 WebSocketResponse::Success(SuccessWebSocketResponse {
287 success: true,
288 request_id: Some(id),
289 resp: OkWebSocketResponseData::Modeling {
290 modeling_response: response.clone(),
291 },
292 }),
293 )
294 .await;
295 }
296 BatchResponse::Failure { errors } => {
297 response_information_cloned
298 .add(
299 id,
300 WebSocketResponse::Failure(FailureWebSocketResponse {
301 success: false,
302 request_id: Some(id),
303 errors: errors.clone(),
304 }),
305 )
306 .await;
307 }
308 }
309 }
310 }
311 WebSocketResponse::Success(SuccessWebSocketResponse {
312 resp: OkWebSocketResponseData::ModelingSessionData { session },
313 ..
314 }) => {
315 let mut sd = session_data2.write().await;
316 sd.replace(session.clone());
317 }
318 WebSocketResponse::Failure(FailureWebSocketResponse {
319 success: _,
320 request_id,
321 errors,
322 }) => {
323 if let Some(id) = request_id {
324 response_information_cloned
325 .add(
326 *id,
327 WebSocketResponse::Failure(FailureWebSocketResponse {
328 success: false,
329 request_id: *request_id,
330 errors: errors.clone(),
331 }),
332 )
333 .await;
334 } else {
335 let mut pe = pending_errors_clone.write().await;
337 for error in errors {
338 if !pe.contains(&error.message) {
339 pe.push(error.message.clone());
340 }
341 }
342 drop(pe);
343 }
344 }
345 WebSocketResponse::Success(SuccessWebSocketResponse {
346 resp: debug @ OkWebSocketResponseData::Debug { .. },
347 ..
348 }) => {
349 let mut handle = debug_info_cloned.write().await;
350 *handle = Some(debug.clone());
351 }
352 _ => {}
353 }
354
355 if let Some(id) = id {
356 response_information_cloned.add(id, ws_resp.clone()).await;
357 }
358 }
359 Err(e) => {
360 match &e {
361 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
362 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
363 }
364 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
365 return Err(e);
366 }
367 }
368 }
369 });
370
371 Ok(EngineConnection {
372 engine_req_tx,
373 shutdown_tx,
374 tcp_read_handle: Arc::new(TcpReadHandle {
375 handle: Arc::new(tcp_read_handle),
376 }),
377 responses: response_information,
378 pending_errors,
379 socket_health,
380 batch: Arc::new(RwLock::new(Vec::new())),
381 batch_end: Arc::new(RwLock::new(IndexMap::new())),
382 #[cfg(feature = "artifact-graph")]
383 artifact_commands: Arc::new(RwLock::new(Vec::new())),
384 ids_of_async_commands,
385 default_planes: Default::default(),
386 session_data,
387 stats: Default::default(),
388 async_tasks: AsyncTasks::new(),
389 debug_info,
390 })
391 }
392}
393
394#[async_trait::async_trait]
395impl EngineManager for EngineConnection {
396 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
397 self.batch.clone()
398 }
399
400 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
401 self.batch_end.clone()
402 }
403
404 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
405 self.responses.responses.clone()
406 }
407
408 #[cfg(feature = "artifact-graph")]
409 fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
410 self.artifact_commands.clone()
411 }
412
413 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
414 self.ids_of_async_commands.clone()
415 }
416
417 fn async_tasks(&self) -> AsyncTasks {
418 self.async_tasks.clone()
419 }
420
421 fn stats(&self) -> &EngineStats {
422 &self.stats
423 }
424
425 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
426 self.default_planes.clone()
427 }
428
429 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
430 self.debug_info.read().await.clone()
431 }
432
433 async fn fetch_debug(&self) -> Result<(), KclError> {
434 let (tx, rx) = oneshot::channel();
435
436 self.engine_req_tx
437 .send(ToEngineReq {
438 req: WebSocketRequest::Debug {},
439 request_sent: tx,
440 })
441 .await
442 .map_err(|e| KclError::Engine(KclErrorDetails::new(format!("Failed to send debug: {}", e), vec![])))?;
443
444 let _ = rx.await;
445 Ok(())
446 }
447
448 async fn clear_scene_post_hook(
449 &self,
450 id_generator: &mut IdGenerator,
451 source_range: SourceRange,
452 ) -> Result<(), KclError> {
453 let new_planes = self.new_default_planes(id_generator, source_range).await?;
455 *self.default_planes.write().await = Some(new_planes);
456
457 Ok(())
458 }
459
460 async fn inner_fire_modeling_cmd(
461 &self,
462 _id: uuid::Uuid,
463 source_range: SourceRange,
464 cmd: WebSocketRequest,
465 _id_to_source_range: HashMap<Uuid, SourceRange>,
466 ) -> Result<(), KclError> {
467 let (tx, rx) = oneshot::channel();
468
469 self.engine_req_tx
471 .send(ToEngineReq {
472 req: cmd.clone(),
473 request_sent: tx,
474 })
475 .await
476 .map_err(|e| {
477 KclError::Engine(KclErrorDetails::new(
478 format!("Failed to send modeling command: {}", e),
479 vec![source_range],
480 ))
481 })?;
482
483 rx.await
485 .map_err(|e| {
486 KclError::Engine(KclErrorDetails::new(
487 format!("could not send request to the engine actor: {e}"),
488 vec![source_range],
489 ))
490 })?
491 .map_err(|e| {
492 KclError::Engine(KclErrorDetails::new(
493 format!("could not send request to the engine: {e}"),
494 vec![source_range],
495 ))
496 })?;
497
498 Ok(())
499 }
500
501 async fn inner_send_modeling_cmd(
502 &self,
503 id: uuid::Uuid,
504 source_range: SourceRange,
505 cmd: WebSocketRequest,
506 id_to_source_range: HashMap<Uuid, SourceRange>,
507 ) -> Result<WebSocketResponse, KclError> {
508 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
509 .await?;
510
511 let current_time = std::time::Instant::now();
513 while current_time.elapsed().as_secs() < 60 {
514 let guard = self.socket_health.read().await;
515 if *guard == SocketHealth::Inactive {
516 let pe = self.pending_errors.read().await;
518 if !pe.is_empty() {
519 return Err(KclError::Engine(KclErrorDetails::new(
520 pe.join(", ").to_string(),
521 vec![source_range],
522 )));
523 } else {
524 return Err(KclError::Engine(KclErrorDetails::new(
525 "Modeling command failed: websocket closed early".to_string(),
526 vec![source_range],
527 )));
528 }
529 }
530
531 #[cfg(feature = "artifact-graph")]
532 {
533 if let Some(resp) = self.responses.responses.read().await.get(&id) {
535 return Ok(resp.clone());
536 }
537 }
538 #[cfg(not(feature = "artifact-graph"))]
539 {
540 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
541 return Ok(resp);
542 }
543 }
544 }
545
546 Err(KclError::Engine(KclErrorDetails::new(
547 format!("Modeling command timed out `{}`", id),
548 vec![source_range],
549 )))
550 }
551
552 async fn get_session_data(&self) -> Option<ModelingSessionData> {
553 self.session_data.read().await.clone()
554 }
555
556 async fn close(&self) {
557 let _ = self.shutdown_tx.send(()).await;
558 loop {
559 let guard = self.socket_health.read().await;
560 if *guard == SocketHealth::Inactive {
561 return;
562 }
563 }
564 }
565}