decthings_api/client/rpc/debug/
mod.rs

1mod request;
2mod response;
3
4use crate::{client::StateModification, tensor::OwnedDecthingsTensor};
5
6pub use request::*;
7pub use response::*;
8
9pub struct DebugRpc {
10    rpc: crate::client::DecthingsClientRpc,
11}
12
13impl DebugRpc {
14    pub(crate) fn new(rpc: crate::client::DecthingsClientRpc) -> Self {
15        Self { rpc }
16    }
17
18    pub async fn launch_debug_session(
19        &self,
20        params: LaunchDebugSessionParams<'_>,
21    ) -> Result<LaunchDebugSessionResult, crate::client::DecthingsRpcError<LaunchDebugSessionError>>
22    {
23        #[cfg(feature = "events")]
24        let subscribe_to_events = params.subscribe_to_events != Some(false);
25
26        #[cfg(feature = "events")]
27        let protocol = if subscribe_to_events {
28            crate::client::RpcProtocol::Ws
29        } else {
30            crate::client::RpcProtocol::Http
31        };
32
33        #[cfg(not(feature = "events"))]
34        let protocol = crate::client::RpcProtocol::Http;
35
36        let (tx, rx) = tokio::sync::oneshot::channel();
37        self.rpc
38            .raw_method_call::<_, _, &[u8]>(
39                "Debug",
40                "launchDebugSession",
41                params,
42                &[],
43                protocol,
44                move |x| {
45                    match x {
46                        Ok(val) => {
47                            let res: Result<
48                                super::Response<
49                                    response::LaunchDebugSessionResult,
50                                    response::LaunchDebugSessionError,
51                                >,
52                                crate::client::DecthingsRpcError<LaunchDebugSessionError>,
53                            > = serde_json::from_slice(&val.0).map_err(Into::into);
54                            match res {
55                                Ok(super::Response::Result(val)) => {
56                                    #[cfg(feature = "events")]
57                                    let debug_session_id = val.debug_session_id.clone();
58
59                                    tx.send(Ok(val)).ok();
60
61                                    #[cfg(feature = "events")]
62                                    if subscribe_to_events {
63                                        return StateModification {
64                                            add_events: vec![debug_session_id],
65                                            remove_events: vec![],
66                                        };
67                                    }
68                                }
69                                Ok(super::Response::Error(val)) => {
70                                    tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
71                                        .ok();
72                                }
73                                Err(e) => {
74                                    tx.send(Err(e)).ok();
75                                }
76                            }
77                        }
78                        Err(err) => {
79                            tx.send(Err(err.into())).ok();
80                        }
81                    }
82                    StateModification::empty()
83                },
84            )
85            .await;
86        rx.await.unwrap()
87    }
88
89    pub async fn get_debug_sessions(
90        &self,
91        params: GetDebugSessionsParams<'_, impl AsRef<str>>,
92    ) -> Result<GetDebugSessionsResult, crate::client::DecthingsRpcError<GetDebugSessionsError>>
93    {
94        let (tx, rx) = tokio::sync::oneshot::channel();
95        self.rpc
96            .raw_method_call::<_, _, &[u8]>(
97                "Debug",
98                "getDebugSessions",
99                params,
100                &[],
101                crate::client::RpcProtocol::Http,
102                |x| {
103                    tx.send(x).ok();
104                    StateModification::empty()
105                },
106            )
107            .await;
108        rx.await
109            .unwrap()
110            .map_err(crate::client::DecthingsRpcError::Request)
111            .and_then(|x| {
112                let res: super::Response<
113                    response::GetDebugSessionsResult,
114                    response::GetDebugSessionsError,
115                > = serde_json::from_slice(&x.0)?;
116                match res {
117                    super::Response::Result(val) => Ok(val),
118                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
119                }
120            })
121    }
122
123    pub async fn terminate_debug_session(
124        &self,
125        params: TerminateDebugSessionParams<'_>,
126    ) -> Result<
127        TerminateDebugSessionResult,
128        crate::client::DecthingsRpcError<TerminateDebugSessionError>,
129    > {
130        #[cfg(feature = "events")]
131        let debug_session_id_owned = params.debug_session_id.to_owned();
132
133        let (tx, rx) = tokio::sync::oneshot::channel();
134        self.rpc
135            .raw_method_call::<_, _, &[u8]>(
136                "Debug",
137                "terminateDebugSession",
138                params,
139                &[],
140                crate::client::RpcProtocol::Http,
141                move |x| {
142                    match x {
143                        Ok(val) => {
144                            let res: Result<
145                                super::Response<
146                                    response::TerminateDebugSessionResult,
147                                    response::TerminateDebugSessionError,
148                                >,
149                                crate::client::DecthingsRpcError<TerminateDebugSessionError>,
150                            > = serde_json::from_slice(&val.0).map_err(Into::into);
151                            match res {
152                                Ok(super::Response::Result(val)) => {
153                                    tx.send(Ok(val)).ok();
154
155                                    #[cfg(feature = "events")]
156                                    return StateModification {
157                                        add_events: vec![],
158                                        remove_events: vec![debug_session_id_owned],
159                                    };
160                                }
161                                Ok(super::Response::Error(val)) => {
162                                    tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
163                                        .ok();
164                                }
165                                Err(e) => {
166                                    tx.send(Err(e)).ok();
167                                }
168                            }
169                        }
170                        Err(err) => {
171                            tx.send(Err(err.into())).ok();
172                        }
173                    }
174                    StateModification::empty()
175                },
176            )
177            .await;
178        rx.await.unwrap()
179    }
180
181    pub async fn call_create_model_state<D>(
182        &self,
183        params: CallCreateModelStateParams<'_>,
184    ) -> Result<
185        CallCreateModelStateResult,
186        crate::client::DecthingsRpcError<CallCreateModelStateError>,
187    > {
188        let (tx, rx) = tokio::sync::oneshot::channel();
189        let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
190        self.rpc
191            .raw_method_call(
192                "Debug",
193                "callCreateModelState",
194                params,
195                serialized,
196                crate::client::RpcProtocol::Http,
197                |x| {
198                    tx.send(x).ok();
199                    StateModification::empty()
200                },
201            )
202            .await;
203        rx.await
204            .unwrap()
205            .map_err(crate::client::DecthingsRpcError::Request)
206            .and_then(|x| {
207                let res: super::Response<
208                    response::CallCreateModelStateResult,
209                    response::CallCreateModelStateError,
210                > = serde_json::from_slice(&x.0)?;
211                match res {
212                    super::Response::Result(val) => Ok(val),
213                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
214                }
215            })
216    }
217
218    pub async fn call_instantiate_model(
219        &self,
220        params: CallInstantiateModelParams<'_, impl AsRef<[u8]>>,
221    ) -> Result<
222        CallInstantiateModelResult,
223        crate::client::DecthingsRpcError<CallInstantiateModelError>,
224    > {
225        let (tx, rx) = tokio::sync::oneshot::channel();
226        let serialized = match &params.state_data {
227            StateDataProvider::Data { data } => *data,
228            _ => &[],
229        };
230        self.rpc
231            .raw_method_call(
232                "Debug",
233                "callInstantiateModel",
234                params,
235                serialized.iter().map(|x| &x.data).collect::<Vec<_>>(),
236                crate::client::RpcProtocol::Http,
237                |x| {
238                    tx.send(x).ok();
239                    StateModification::empty()
240                },
241            )
242            .await;
243        rx.await
244            .unwrap()
245            .map_err(crate::client::DecthingsRpcError::Request)
246            .and_then(|x| {
247                let res: super::Response<
248                    response::CallInstantiateModelResult,
249                    response::CallInstantiateModelError,
250                > = serde_json::from_slice(&x.0)?;
251                match res {
252                    super::Response::Result(val) => Ok(val),
253                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
254                }
255            })
256    }
257
258    pub async fn call_train<D>(
259        &self,
260        params: CallTrainParams<'_>,
261    ) -> Result<CallTrainResult, crate::client::DecthingsRpcError<CallTrainError>> {
262        let (tx, rx) = tokio::sync::oneshot::channel();
263        let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
264        self.rpc
265            .raw_method_call(
266                "Debug",
267                "callTrain",
268                params,
269                serialized,
270                crate::client::RpcProtocol::Http,
271                |x| {
272                    tx.send(x).ok();
273                    StateModification::empty()
274                },
275            )
276            .await;
277        rx.await
278            .unwrap()
279            .map_err(crate::client::DecthingsRpcError::Request)
280            .and_then(|x| {
281                let res: super::Response<response::CallTrainResult, response::CallTrainError> =
282                    serde_json::from_slice(&x.0)?;
283                match res {
284                    super::Response::Result(val) => Ok(val),
285                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
286                }
287            })
288    }
289
290    pub async fn get_training_status(
291        &self,
292        params: DebugGetTrainingStatusParams<'_>,
293    ) -> Result<
294        DebugGetTrainingStatusResult,
295        crate::client::DecthingsRpcError<DebugGetTrainingStatusError>,
296    > {
297        let (tx, rx) = tokio::sync::oneshot::channel();
298        self.rpc
299            .raw_method_call::<_, _, &[u8]>(
300                "Debug",
301                "getTrainingStatus",
302                params,
303                &[],
304                crate::client::RpcProtocol::Http,
305                |x| {
306                    tx.send(x).ok();
307                    StateModification::empty()
308                },
309            )
310            .await;
311        rx.await
312            .unwrap()
313            .map_err(crate::client::DecthingsRpcError::Request)
314            .and_then(|x| {
315                let res: super::Response<
316                    response::DebugGetTrainingStatusResult,
317                    response::DebugGetTrainingStatusError,
318                > = serde_json::from_slice(&x.0)?;
319                match res {
320                    super::Response::Result(val) => Ok(val),
321                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
322                }
323            })
324    }
325
326    pub async fn get_training_metrics(
327        &self,
328        params: DebugGetTrainingMetricsParams<'_>,
329    ) -> Result<
330        DebugGetTrainingMetricsResult,
331        crate::client::DecthingsRpcError<DebugGetTrainingMetricsError>,
332    > {
333        let (tx, rx) = tokio::sync::oneshot::channel();
334        self.rpc
335            .raw_method_call::<_, _, &[u8]>(
336                "Debug",
337                "getTrainingMetrics",
338                params,
339                &[],
340                crate::client::RpcProtocol::Http,
341                |x| {
342                    tx.send(x).ok();
343                    StateModification::empty()
344                },
345            )
346            .await;
347        rx.await
348            .unwrap()
349            .map_err(crate::client::DecthingsRpcError::Request)
350            .and_then(|x| {
351                let res: super::Response<
352                    response::DebugGetTrainingMetricsResult,
353                    response::DebugGetTrainingMetricsError,
354                > = serde_json::from_slice(&x.0)?;
355                match res {
356                    super::Response::Result(mut val) => {
357                        if val.metrics.iter().map(|x| x.entries.len()).sum::<usize>() != x.1.len() {
358                            return Err(crate::client::DecthingsClientError::InvalidMessage.into());
359                        }
360                        for (entry, data) in
361                            val.metrics.iter_mut().flat_map(|x| &mut x.entries).zip(x.1)
362                        {
363                            entry.data = OwnedDecthingsTensor::from_bytes(data)
364                                .map_err(|_| crate::client::DecthingsClientError::InvalidMessage)?;
365                        }
366                        Ok(val)
367                    }
368                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
369                }
370            })
371    }
372
373    pub async fn cancel_training_session(
374        &self,
375        params: DebugCancelTrainingSessionParams<'_>,
376    ) -> Result<
377        DebugCancelTrainingSessionResult,
378        crate::client::DecthingsRpcError<DebugCancelTrainingSessionError>,
379    > {
380        let (tx, rx) = tokio::sync::oneshot::channel();
381        self.rpc
382            .raw_method_call::<_, _, &[u8]>(
383                "Debug",
384                "cancelTrainingSession",
385                params,
386                &[],
387                crate::client::RpcProtocol::Http,
388                |x| {
389                    tx.send(x).ok();
390                    StateModification::empty()
391                },
392            )
393            .await;
394        rx.await
395            .unwrap()
396            .map_err(crate::client::DecthingsRpcError::Request)
397            .and_then(|x| {
398                let res: super::Response<
399                    response::DebugCancelTrainingSessionResult,
400                    response::DebugCancelTrainingSessionError,
401                > = serde_json::from_slice(&x.0)?;
402                match res {
403                    super::Response::Result(val) => Ok(val),
404                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
405                }
406            })
407    }
408
409    pub async fn call_evaluate(
410        &self,
411        params: CallEvaluateParams<'_>,
412    ) -> Result<CallEvaluateResult, crate::client::DecthingsRpcError<CallEvaluateError>> {
413        let (tx, rx) = tokio::sync::oneshot::channel();
414        let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
415        self.rpc
416            .raw_method_call(
417                "Debug",
418                "callEvaluate",
419                params,
420                serialized,
421                crate::client::RpcProtocol::Http,
422                |x| {
423                    tx.send(x).ok();
424                    StateModification::empty()
425                },
426            )
427            .await;
428        rx.await
429            .unwrap()
430            .map_err(crate::client::DecthingsRpcError::Request)
431            .and_then(|x| {
432                let res: super::Response<
433                    response::CallEvaluateResult,
434                    response::CallEvaluateError,
435                > = serde_json::from_slice(&x.0)?;
436                match res {
437                    super::Response::Result(mut val) => {
438                        if val.output.len() != x.1.len() {
439                            return Err(crate::client::DecthingsClientError::InvalidMessage.into());
440                        }
441                        for (entry, data) in val.output.iter_mut().zip(x.1) {
442                            entry.data = super::many_decthings_tensors_from_bytes(data)
443                                .map_err(|_| crate::client::DecthingsClientError::InvalidMessage)?;
444                        }
445                        Ok(val)
446                    }
447                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
448                }
449            })
450    }
451
452    pub async fn call_get_model_state(
453        &self,
454        params: CallGetModelStateParams<'_>,
455    ) -> Result<CallGetModelStateResult, crate::client::DecthingsRpcError<CallGetModelStateError>>
456    {
457        let (tx, rx) = tokio::sync::oneshot::channel();
458        self.rpc
459            .raw_method_call::<_, _, &[u8]>(
460                "Debug",
461                "callGetModelState",
462                params,
463                &[],
464                crate::client::RpcProtocol::Http,
465                |x| {
466                    tx.send(x).ok();
467                    StateModification::empty()
468                },
469            )
470            .await;
471        rx.await
472            .unwrap()
473            .map_err(crate::client::DecthingsRpcError::Request)
474            .and_then(|x| {
475                let res: super::Response<
476                    response::CallGetModelStateResult,
477                    response::CallGetModelStateError,
478                > = serde_json::from_slice(&x.0)?;
479                match res {
480                    super::Response::Result(val) => Ok(val),
481                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
482                }
483            })
484    }
485
486    pub async fn download_state_data(
487        &self,
488        params: DownloadStateDataParams<'_, impl AsRef<str>>,
489    ) -> Result<DownloadStateDataResult, crate::client::DecthingsRpcError<DownloadStateDataError>>
490    {
491        let (tx, rx) = tokio::sync::oneshot::channel();
492        self.rpc
493            .raw_method_call::<_, _, &[u8]>(
494                "Debug",
495                "downloadStateData",
496                params,
497                &[],
498                crate::client::RpcProtocol::Http,
499                |x| {
500                    tx.send(x).ok();
501                    StateModification::empty()
502                },
503            )
504            .await;
505        rx.await
506            .unwrap()
507            .map_err(crate::client::DecthingsRpcError::Request)
508            .and_then(|x| {
509                let res: super::Response<
510                    response::DownloadStateDataResult,
511                    response::DownloadStateDataError,
512                > = serde_json::from_slice(&x.0)?;
513                match res {
514                    super::Response::Result(val) => Ok(DownloadStateDataResult {
515                        data: val
516                            .data
517                            .into_iter()
518                            .zip(x.1)
519                            .map(|(key, data)| super::StateKeyData { key: key.key, data })
520                            .collect(),
521                    }),
522                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
523                }
524            })
525    }
526
527    pub async fn send_to_remote_inspector(
528        &self,
529        params: SendToRemoteInspectorParams<'_, impl AsRef<[u8]>>,
530    ) -> Result<
531        SendToRemoteInspectorResult,
532        crate::client::DecthingsRpcError<SendToRemoteInspectorError>,
533    > {
534        let (tx, rx) = tokio::sync::oneshot::channel();
535        self.rpc
536            .raw_method_call(
537                "Debug",
538                "sendToRemoteInspector",
539                &params,
540                [&params.data],
541                crate::client::RpcProtocol::Http,
542                |x| {
543                    tx.send(x).ok();
544                    StateModification::empty()
545                },
546            )
547            .await;
548        rx.await
549            .unwrap()
550            .map_err(crate::client::DecthingsRpcError::Request)
551            .and_then(|x| {
552                let res: super::Response<
553                    response::SendToRemoteInspectorResult,
554                    response::SendToRemoteInspectorError,
555                > = serde_json::from_slice(&x.0)?;
556                match res {
557                    super::Response::Result(val) => Ok(val),
558                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
559                }
560            })
561    }
562
563    #[cfg(feature = "events")]
564    pub async fn subscribe_to_events(
565        &self,
566        params: DebugSubscribeToEventsParams<'_>,
567    ) -> Result<
568        DebugSubscribeToEventsResult,
569        crate::client::DecthingsRpcError<DebugSubscribeToEventsError>,
570    > {
571        let (tx, rx) = tokio::sync::oneshot::channel();
572        let debug_session_id_owned = params.debug_session_id.to_owned();
573        self.rpc
574            .raw_method_call::<_, _, &[u8]>(
575                "Debug",
576                "subscribeToEvents",
577                params,
578                &[],
579                crate::client::RpcProtocol::Ws,
580                move |x| {
581                    match x {
582                        Ok(val) => {
583                            let res: Result<
584                                super::Response<
585                                    response::DebugSubscribeToEventsResult,
586                                    response::DebugSubscribeToEventsError,
587                                >,
588                                crate::client::DecthingsRpcError<DebugSubscribeToEventsError>,
589                            > = serde_json::from_slice(&val.0).map_err(Into::into);
590                            match res {
591                                Ok(super::Response::Result(val)) => {
592                                    tx.send(Ok(val)).ok();
593                                    return StateModification {
594                                        add_events: vec![debug_session_id_owned],
595                                        remove_events: vec![],
596                                    };
597                                }
598                                Ok(super::Response::Error(val)) => {
599                                    tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
600                                        .ok();
601                                }
602                                Err(e) => {
603                                    tx.send(Err(e)).ok();
604                                }
605                            }
606                        }
607                        Err(err) => {
608                            tx.send(Err(err.into())).ok();
609                        }
610                    }
611                    StateModification::empty()
612                },
613            )
614            .await;
615        rx.await.unwrap()
616    }
617
618    #[cfg(feature = "events")]
619    pub async fn unsubscribe_from_events(
620        &self,
621        params: DebugUnsubscribeFromEventsParams<'_>,
622    ) -> Result<
623        DebugUnsubscribeFromEventsResult,
624        crate::client::DecthingsRpcError<DebugUnsubscribeFromEventsError>,
625    > {
626        let (tx, rx) = tokio::sync::oneshot::channel();
627        let debug_session_id_owned = params.debug_session_id.to_owned();
628        let did_call = self
629            .rpc
630            .raw_method_call::<_, _, &[u8]>(
631                "Debug",
632                "unsubscribeFromEvents",
633                params,
634                &[],
635                crate::client::RpcProtocol::WsIfAvailableOtherwiseNone,
636                move |x| {
637                    match x {
638                        Ok(val) => {
639                            let res: Result<
640                                super::Response<
641                                    response::DebugUnsubscribeFromEventsResult,
642                                    response::DebugUnsubscribeFromEventsError,
643                                >,
644                                crate::client::DecthingsRpcError<DebugUnsubscribeFromEventsError>,
645                            > = serde_json::from_slice(&val.0).map_err(Into::into);
646                            match res {
647                                Ok(super::Response::Result(val)) => {
648                                    tx.send(Ok(val)).ok();
649                                    return StateModification {
650                                        add_events: vec![],
651                                        remove_events: vec![debug_session_id_owned],
652                                    };
653                                }
654                                Ok(super::Response::Error(val)) => {
655                                    tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
656                                        .ok();
657                                }
658                                Err(e) => {
659                                    tx.send(Err(e)).ok();
660                                }
661                            }
662                        }
663                        Err(err) => {
664                            tx.send(Err(err.into())).ok();
665                        }
666                    }
667                    StateModification::empty()
668                },
669            )
670            .await;
671        if !did_call {
672            return Err(crate::client::DecthingsRpcError::Rpc(
673                DebugUnsubscribeFromEventsError::NotSubscribed,
674            ));
675        }
676        rx.await.unwrap()
677    }
678}