1use std::{collections::HashMap, sync::Arc};
5
6use anyhow::{anyhow, Result};
7use futures::{SinkExt, StreamExt};
8use indexmap::IndexMap;
9use kcmc::{
10 websocket::{
11 BatchResponse, FailureWebSocketResponse, ModelingCmdReq, ModelingSessionData, OkWebSocketResponseData,
12 SuccessWebSocketResponse, WebSocketRequest, WebSocketResponse,
13 },
14 ModelingCmd,
15};
16use kittycad_modeling_cmds::{self as kcmc};
17use tokio::sync::{mpsc, oneshot, RwLock};
18use tokio_tungstenite::tungstenite::Message as WsMsg;
19use uuid::Uuid;
20
21#[cfg(feature = "artifact-graph")]
22use crate::execution::ArtifactCommand;
23use crate::{
24 engine::{AsyncTasks, EngineManager, EngineStats},
25 errors::{KclError, KclErrorDetails},
26 execution::{DefaultPlanes, IdGenerator},
27 SourceRange,
28};
29
30#[derive(Debug, PartialEq)]
31enum SocketHealth {
32 Active,
33 Inactive,
34}
35
36type WebSocketTcpWrite = futures::stream::SplitSink<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>, WsMsg>;
37#[derive(Debug)]
38pub struct EngineConnection {
39 engine_req_tx: mpsc::Sender<ToEngineReq>,
40 shutdown_tx: mpsc::Sender<()>,
41 responses: ResponseInformation,
42 pending_errors: Arc<RwLock<Vec<String>>>,
43 #[allow(dead_code)]
44 tcp_read_handle: Arc<TcpReadHandle>,
45 socket_health: Arc<RwLock<SocketHealth>>,
46 batch: Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>>,
47 batch_end: Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>>,
48 #[cfg(feature = "artifact-graph")]
49 artifact_commands: Arc<RwLock<Vec<ArtifactCommand>>>,
50 ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>>,
51
52 default_planes: Arc<RwLock<Option<DefaultPlanes>>>,
54 session_data: Arc<RwLock<Option<ModelingSessionData>>>,
56
57 stats: EngineStats,
58
59 async_tasks: AsyncTasks,
60
61 debug_info: Arc<RwLock<Option<OkWebSocketResponseData>>>,
62}
63
64pub struct TcpRead {
65 stream: futures::stream::SplitStream<tokio_tungstenite::WebSocketStream<reqwest::Upgraded>>,
66}
67
68pub enum WebSocketReadError {
71 Read(tokio_tungstenite::tungstenite::Error),
73 Deser(anyhow::Error),
75}
76
77impl From<anyhow::Error> for WebSocketReadError {
78 fn from(e: anyhow::Error) -> Self {
79 Self::Deser(e)
80 }
81}
82
83impl TcpRead {
84 pub async fn read(&mut self) -> std::result::Result<WebSocketResponse, WebSocketReadError> {
85 let Some(msg) = self.stream.next().await else {
86 return Err(anyhow::anyhow!("Failed to read from WebSocket").into());
87 };
88 let msg = match msg {
89 Ok(msg) => msg,
90 Err(e) if matches!(e, tokio_tungstenite::tungstenite::Error::Protocol(_)) => {
91 return Err(WebSocketReadError::Read(e))
92 }
93 Err(e) => return Err(anyhow::anyhow!("Error reading from engine's WebSocket: {e}").into()),
94 };
95 let msg: WebSocketResponse = match msg {
96 WsMsg::Text(text) => serde_json::from_str(&text)
97 .map_err(anyhow::Error::from)
98 .map_err(WebSocketReadError::from)?,
99 WsMsg::Binary(bin) => bson::from_slice(&bin)
100 .map_err(anyhow::Error::from)
101 .map_err(WebSocketReadError::from)?,
102 other => return Err(anyhow::anyhow!("Unexpected WebSocket message from engine API: {other}").into()),
103 };
104 Ok(msg)
105 }
106}
107
108pub struct TcpReadHandle {
109 handle: Arc<tokio::task::JoinHandle<Result<(), WebSocketReadError>>>,
110}
111
112impl std::fmt::Debug for TcpReadHandle {
113 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 write!(f, "TcpReadHandle")
115 }
116}
117
118impl Drop for TcpReadHandle {
119 fn drop(&mut self) {
120 self.handle.abort();
122 }
123}
124
125#[derive(Clone, Debug)]
127struct ResponseInformation {
128 responses: Arc<RwLock<IndexMap<uuid::Uuid, WebSocketResponse>>>,
130}
131
132impl ResponseInformation {
133 pub async fn add(&self, id: Uuid, response: WebSocketResponse) {
134 self.responses.write().await.insert(id, response);
135 }
136}
137
138struct ToEngineReq {
140 req: WebSocketRequest,
142 request_sent: oneshot::Sender<Result<()>>,
146}
147
148impl EngineConnection {
149 async fn start_write_actor(
151 mut tcp_write: WebSocketTcpWrite,
152 mut engine_req_rx: mpsc::Receiver<ToEngineReq>,
153 mut shutdown_rx: mpsc::Receiver<()>,
154 ) {
155 loop {
156 tokio::select! {
157 maybe_req = engine_req_rx.recv() => {
158 match maybe_req {
159 Some(ToEngineReq { req, request_sent }) => {
160 let res = if let WebSocketRequest::ModelingCmdReq(ModelingCmdReq {
163 cmd: ModelingCmd::ImportFiles { .. },
164 cmd_id: _,
165 }) = &req
166 {
167 Self::inner_send_to_engine_binary(req, &mut tcp_write).await
168 } else {
169 Self::inner_send_to_engine(req, &mut tcp_write).await
170 };
171
172 let _ = request_sent.send(res);
174 }
175 None => {
176 break;
179 }
180 }
181 },
182
183 _ = shutdown_rx.recv() => {
185 let _ = Self::inner_close_engine(&mut tcp_write).await;
186 return;
187 }
188 }
189 }
190
191 let _ = Self::inner_close_engine(&mut tcp_write).await;
194 }
195
196 async fn inner_close_engine(tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
198 tcp_write
199 .send(WsMsg::Close(None))
200 .await
201 .map_err(|e| anyhow!("could not send close over websocket: {e}"))?;
202 Ok(())
203 }
204
205 async fn inner_send_to_engine(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
207 let msg = serde_json::to_string(&request).map_err(|e| anyhow!("could not serialize json: {e}"))?;
208 tcp_write
209 .send(WsMsg::Text(msg))
210 .await
211 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
212 Ok(())
213 }
214
215 async fn inner_send_to_engine_binary(request: WebSocketRequest, tcp_write: &mut WebSocketTcpWrite) -> Result<()> {
217 let msg = bson::to_vec(&request).map_err(|e| anyhow!("could not serialize bson: {e}"))?;
218 tcp_write
219 .send(WsMsg::Binary(msg))
220 .await
221 .map_err(|e| anyhow!("could not send json over websocket: {e}"))?;
222 Ok(())
223 }
224
225 pub async fn new(ws: reqwest::Upgraded) -> Result<EngineConnection> {
226 let wsconfig = tokio_tungstenite::tungstenite::protocol::WebSocketConfig {
227 max_message_size: Some(usize::MAX),
229 max_frame_size: Some(usize::MAX),
230 ..Default::default()
231 };
232
233 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
234 ws,
235 tokio_tungstenite::tungstenite::protocol::Role::Client,
236 Some(wsconfig),
237 )
238 .await;
239
240 let (tcp_write, tcp_read) = ws_stream.split();
241 let (engine_req_tx, engine_req_rx) = mpsc::channel(10);
242 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
243 tokio::task::spawn(Self::start_write_actor(tcp_write, engine_req_rx, shutdown_rx));
244
245 let mut tcp_read = TcpRead { stream: tcp_read };
246
247 let session_data: Arc<RwLock<Option<ModelingSessionData>>> = Arc::new(RwLock::new(None));
248 let session_data2 = session_data.clone();
249 let ids_of_async_commands: Arc<RwLock<IndexMap<Uuid, SourceRange>>> = Arc::new(RwLock::new(IndexMap::new()));
250 let socket_health = Arc::new(RwLock::new(SocketHealth::Active));
251 let pending_errors = Arc::new(RwLock::new(Vec::new()));
252 let pending_errors_clone = pending_errors.clone();
253 let response_information = ResponseInformation {
254 responses: Arc::new(RwLock::new(IndexMap::new())),
255 };
256 let response_information_cloned = response_information.clone();
257 let debug_info = Arc::new(RwLock::new(None));
258 let debug_info_cloned = debug_info.clone();
259
260 let socket_health_tcp_read = socket_health.clone();
261 let tcp_read_handle = tokio::spawn(async move {
262 loop {
264 match tcp_read.read().await {
265 Ok(ws_resp) => {
266 let id = ws_resp.request_id();
268 match &ws_resp {
269 WebSocketResponse::Success(SuccessWebSocketResponse {
270 resp: OkWebSocketResponseData::ModelingBatch { responses },
271 ..
272 }) => {
273 #[expect(
274 clippy::iter_over_hash_type,
275 reason = "modeling command uses a HashMap and keys are random, so we don't really have a choice"
276 )]
277 for (resp_id, batch_response) in responses {
278 let id: uuid::Uuid = (*resp_id).into();
279 match batch_response {
280 BatchResponse::Success { response } => {
281 response_information_cloned
284 .add(
285 id,
286 WebSocketResponse::Success(SuccessWebSocketResponse {
287 success: true,
288 request_id: Some(id),
289 resp: OkWebSocketResponseData::Modeling {
290 modeling_response: response.clone(),
291 },
292 }),
293 )
294 .await;
295 }
296 BatchResponse::Failure { errors } => {
297 response_information_cloned
298 .add(
299 id,
300 WebSocketResponse::Failure(FailureWebSocketResponse {
301 success: false,
302 request_id: Some(id),
303 errors: errors.clone(),
304 }),
305 )
306 .await;
307 }
308 }
309 }
310 }
311 WebSocketResponse::Success(SuccessWebSocketResponse {
312 resp: OkWebSocketResponseData::ModelingSessionData { session },
313 ..
314 }) => {
315 let mut sd = session_data2.write().await;
316 sd.replace(session.clone());
317 }
318 WebSocketResponse::Failure(FailureWebSocketResponse {
319 success: _,
320 request_id,
321 errors,
322 }) => {
323 if let Some(id) = request_id {
324 response_information_cloned
325 .add(
326 *id,
327 WebSocketResponse::Failure(FailureWebSocketResponse {
328 success: false,
329 request_id: *request_id,
330 errors: errors.clone(),
331 }),
332 )
333 .await;
334 } else {
335 let mut pe = pending_errors_clone.write().await;
337 for error in errors {
338 if !pe.contains(&error.message) {
339 pe.push(error.message.clone());
340 }
341 }
342 drop(pe);
343 }
344 }
345 WebSocketResponse::Success(SuccessWebSocketResponse {
346 resp: debug @ OkWebSocketResponseData::Debug { .. },
347 ..
348 }) => {
349 let mut handle = debug_info_cloned.write().await;
350 *handle = Some(debug.clone());
351 }
352 _ => {}
353 }
354
355 if let Some(id) = id {
356 response_information_cloned.add(id, ws_resp.clone()).await;
357 }
358 }
359 Err(e) => {
360 match &e {
361 WebSocketReadError::Read(e) => crate::logln!("could not read from WS: {:?}", e),
362 WebSocketReadError::Deser(e) => crate::logln!("could not deserialize msg from WS: {:?}", e),
363 }
364 *socket_health_tcp_read.write().await = SocketHealth::Inactive;
365 return Err(e);
366 }
367 }
368 }
369 });
370
371 Ok(EngineConnection {
372 engine_req_tx,
373 shutdown_tx,
374 tcp_read_handle: Arc::new(TcpReadHandle {
375 handle: Arc::new(tcp_read_handle),
376 }),
377 responses: response_information,
378 pending_errors,
379 socket_health,
380 batch: Arc::new(RwLock::new(Vec::new())),
381 batch_end: Arc::new(RwLock::new(IndexMap::new())),
382 #[cfg(feature = "artifact-graph")]
383 artifact_commands: Arc::new(RwLock::new(Vec::new())),
384 ids_of_async_commands,
385 default_planes: Default::default(),
386 session_data,
387 stats: Default::default(),
388 async_tasks: AsyncTasks::new(),
389 debug_info,
390 })
391 }
392}
393
394#[async_trait::async_trait]
395impl EngineManager for EngineConnection {
396 fn batch(&self) -> Arc<RwLock<Vec<(WebSocketRequest, SourceRange)>>> {
397 self.batch.clone()
398 }
399
400 fn batch_end(&self) -> Arc<RwLock<IndexMap<uuid::Uuid, (WebSocketRequest, SourceRange)>>> {
401 self.batch_end.clone()
402 }
403
404 fn responses(&self) -> Arc<RwLock<IndexMap<Uuid, WebSocketResponse>>> {
405 self.responses.responses.clone()
406 }
407
408 #[cfg(feature = "artifact-graph")]
409 fn artifact_commands(&self) -> Arc<RwLock<Vec<ArtifactCommand>>> {
410 self.artifact_commands.clone()
411 }
412
413 fn ids_of_async_commands(&self) -> Arc<RwLock<IndexMap<Uuid, SourceRange>>> {
414 self.ids_of_async_commands.clone()
415 }
416
417 fn async_tasks(&self) -> AsyncTasks {
418 self.async_tasks.clone()
419 }
420
421 fn stats(&self) -> &EngineStats {
422 &self.stats
423 }
424
425 fn get_default_planes(&self) -> Arc<RwLock<Option<DefaultPlanes>>> {
426 self.default_planes.clone()
427 }
428
429 async fn get_debug(&self) -> Option<OkWebSocketResponseData> {
430 self.debug_info.read().await.clone()
431 }
432
433 async fn fetch_debug(&self) -> Result<(), KclError> {
434 let (tx, rx) = oneshot::channel();
435
436 self.engine_req_tx
437 .send(ToEngineReq {
438 req: WebSocketRequest::Debug {},
439 request_sent: tx,
440 })
441 .await
442 .map_err(|e| {
443 KclError::Engine(KclErrorDetails {
444 message: format!("Failed to send debug: {}", e),
445 source_ranges: vec![],
446 })
447 })?;
448
449 let _ = rx.await;
450 Ok(())
451 }
452
453 async fn clear_scene_post_hook(
454 &self,
455 id_generator: &mut IdGenerator,
456 source_range: SourceRange,
457 ) -> Result<(), KclError> {
458 let new_planes = self.new_default_planes(id_generator, source_range).await?;
460 *self.default_planes.write().await = Some(new_planes);
461
462 Ok(())
463 }
464
465 async fn inner_fire_modeling_cmd(
466 &self,
467 _id: uuid::Uuid,
468 source_range: SourceRange,
469 cmd: WebSocketRequest,
470 _id_to_source_range: HashMap<Uuid, SourceRange>,
471 ) -> Result<(), KclError> {
472 let (tx, rx) = oneshot::channel();
473
474 self.engine_req_tx
476 .send(ToEngineReq {
477 req: cmd.clone(),
478 request_sent: tx,
479 })
480 .await
481 .map_err(|e| {
482 KclError::Engine(KclErrorDetails {
483 message: format!("Failed to send modeling command: {}", e),
484 source_ranges: vec![source_range],
485 })
486 })?;
487
488 rx.await
490 .map_err(|e| {
491 KclError::Engine(KclErrorDetails {
492 message: format!("could not send request to the engine actor: {e}"),
493 source_ranges: vec![source_range],
494 })
495 })?
496 .map_err(|e| {
497 KclError::Engine(KclErrorDetails {
498 message: format!("could not send request to the engine: {e}"),
499 source_ranges: vec![source_range],
500 })
501 })?;
502
503 Ok(())
504 }
505
506 async fn inner_send_modeling_cmd(
507 &self,
508 id: uuid::Uuid,
509 source_range: SourceRange,
510 cmd: WebSocketRequest,
511 id_to_source_range: HashMap<Uuid, SourceRange>,
512 ) -> Result<WebSocketResponse, KclError> {
513 self.inner_fire_modeling_cmd(id, source_range, cmd, id_to_source_range)
514 .await?;
515
516 let current_time = std::time::Instant::now();
518 while current_time.elapsed().as_secs() < 60 {
519 let guard = self.socket_health.read().await;
520 if *guard == SocketHealth::Inactive {
521 let pe = self.pending_errors.read().await;
523 if !pe.is_empty() {
524 return Err(KclError::Engine(KclErrorDetails {
525 message: pe.join(", ").to_string(),
526 source_ranges: vec![source_range],
527 }));
528 } else {
529 return Err(KclError::Engine(KclErrorDetails {
530 message: "Modeling command failed: websocket closed early".to_string(),
531 source_ranges: vec![source_range],
532 }));
533 }
534 }
535
536 #[cfg(feature = "artifact-graph")]
537 {
538 if let Some(resp) = self.responses.responses.read().await.get(&id) {
540 return Ok(resp.clone());
541 }
542 }
543 #[cfg(not(feature = "artifact-graph"))]
544 {
545 if let Some(resp) = self.responses.responses.write().await.shift_remove(&id) {
546 return Ok(resp);
547 }
548 }
549 }
550
551 Err(KclError::Engine(KclErrorDetails {
552 message: format!("Modeling command timed out `{}`", id),
553 source_ranges: vec![source_range],
554 }))
555 }
556
557 async fn get_session_data(&self) -> Option<ModelingSessionData> {
558 self.session_data.read().await.clone()
559 }
560
561 async fn close(&self) {
562 let _ = self.shutdown_tx.send(()).await;
563 loop {
564 let guard = self.socket_health.read().await;
565 if *guard == SocketHealth::Inactive {
566 return;
567 }
568 }
569 }
570}