1mod request;
2mod response;
3
4use crate::{client::StateModification, tensor::OwnedDecthingsTensor};
5
6pub use request::*;
7pub use response::*;
8
9pub struct DebugRpc {
10 rpc: crate::client::DecthingsClientRpc,
11}
12
13impl DebugRpc {
14 pub(crate) fn new(rpc: crate::client::DecthingsClientRpc) -> Self {
15 Self { rpc }
16 }
17
18 pub async fn launch_debug_session(
19 &self,
20 params: LaunchDebugSessionParams<'_>,
21 ) -> Result<LaunchDebugSessionResult, crate::client::DecthingsRpcError<LaunchDebugSessionError>>
22 {
23 #[cfg(feature = "events")]
24 let subscribe_to_events = params.subscribe_to_events != Some(false);
25
26 #[cfg(feature = "events")]
27 let protocol = if subscribe_to_events {
28 crate::client::RpcProtocol::Ws
29 } else {
30 crate::client::RpcProtocol::Http
31 };
32
33 #[cfg(not(feature = "events"))]
34 let protocol = crate::client::RpcProtocol::Http;
35
36 let (tx, rx) = tokio::sync::oneshot::channel();
37 self.rpc
38 .raw_method_call::<_, _, &[u8]>(
39 "Debug",
40 "launchDebugSession",
41 params,
42 &[],
43 protocol,
44 move |x| {
45 match x {
46 Ok(val) => {
47 let res: Result<
48 super::Response<
49 response::LaunchDebugSessionResult,
50 response::LaunchDebugSessionError,
51 >,
52 crate::client::DecthingsRpcError<LaunchDebugSessionError>,
53 > = serde_json::from_slice(&val.0).map_err(Into::into);
54 match res {
55 Ok(super::Response::Result(val)) => {
56 #[cfg(feature = "events")]
57 let debug_session_id = val.debug_session_id.clone();
58
59 tx.send(Ok(val)).ok();
60
61 #[cfg(feature = "events")]
62 if subscribe_to_events {
63 return StateModification {
64 add_events: vec![debug_session_id],
65 remove_events: vec![],
66 };
67 }
68 }
69 Ok(super::Response::Error(val)) => {
70 tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
71 .ok();
72 }
73 Err(e) => {
74 tx.send(Err(e)).ok();
75 }
76 }
77 }
78 Err(err) => {
79 tx.send(Err(err.into())).ok();
80 }
81 }
82 StateModification::empty()
83 },
84 )
85 .await;
86 rx.await.unwrap()
87 }
88
89 pub async fn get_debug_sessions(
90 &self,
91 params: GetDebugSessionsParams<'_, impl AsRef<str>>,
92 ) -> Result<GetDebugSessionsResult, crate::client::DecthingsRpcError<GetDebugSessionsError>>
93 {
94 let (tx, rx) = tokio::sync::oneshot::channel();
95 self.rpc
96 .raw_method_call::<_, _, &[u8]>(
97 "Debug",
98 "getDebugSessions",
99 params,
100 &[],
101 crate::client::RpcProtocol::Http,
102 |x| {
103 tx.send(x).ok();
104 StateModification::empty()
105 },
106 )
107 .await;
108 rx.await
109 .unwrap()
110 .map_err(crate::client::DecthingsRpcError::Request)
111 .and_then(|x| {
112 let res: super::Response<
113 response::GetDebugSessionsResult,
114 response::GetDebugSessionsError,
115 > = serde_json::from_slice(&x.0)?;
116 match res {
117 super::Response::Result(val) => Ok(val),
118 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
119 }
120 })
121 }
122
123 pub async fn terminate_debug_session(
124 &self,
125 params: TerminateDebugSessionParams<'_>,
126 ) -> Result<
127 TerminateDebugSessionResult,
128 crate::client::DecthingsRpcError<TerminateDebugSessionError>,
129 > {
130 #[cfg(feature = "events")]
131 let debug_session_id_owned = params.debug_session_id.to_owned();
132
133 let (tx, rx) = tokio::sync::oneshot::channel();
134 self.rpc
135 .raw_method_call::<_, _, &[u8]>(
136 "Debug",
137 "terminateDebugSession",
138 params,
139 &[],
140 crate::client::RpcProtocol::Http,
141 move |x| {
142 match x {
143 Ok(val) => {
144 let res: Result<
145 super::Response<
146 response::TerminateDebugSessionResult,
147 response::TerminateDebugSessionError,
148 >,
149 crate::client::DecthingsRpcError<TerminateDebugSessionError>,
150 > = serde_json::from_slice(&val.0).map_err(Into::into);
151 match res {
152 Ok(super::Response::Result(val)) => {
153 tx.send(Ok(val)).ok();
154
155 #[cfg(feature = "events")]
156 return StateModification {
157 add_events: vec![],
158 remove_events: vec![debug_session_id_owned],
159 };
160 }
161 Ok(super::Response::Error(val)) => {
162 tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
163 .ok();
164 }
165 Err(e) => {
166 tx.send(Err(e)).ok();
167 }
168 }
169 }
170 Err(err) => {
171 tx.send(Err(err.into())).ok();
172 }
173 }
174 StateModification::empty()
175 },
176 )
177 .await;
178 rx.await.unwrap()
179 }
180
181 pub async fn call_create_model_state<D>(
182 &self,
183 params: CallCreateModelStateParams<'_>,
184 ) -> Result<
185 CallCreateModelStateResult,
186 crate::client::DecthingsRpcError<CallCreateModelStateError>,
187 > {
188 let (tx, rx) = tokio::sync::oneshot::channel();
189 let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
190 self.rpc
191 .raw_method_call(
192 "Debug",
193 "callCreateModelState",
194 params,
195 serialized,
196 crate::client::RpcProtocol::Http,
197 |x| {
198 tx.send(x).ok();
199 StateModification::empty()
200 },
201 )
202 .await;
203 rx.await
204 .unwrap()
205 .map_err(crate::client::DecthingsRpcError::Request)
206 .and_then(|x| {
207 let res: super::Response<
208 response::CallCreateModelStateResult,
209 response::CallCreateModelStateError,
210 > = serde_json::from_slice(&x.0)?;
211 match res {
212 super::Response::Result(val) => Ok(val),
213 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
214 }
215 })
216 }
217
218 pub async fn call_instantiate_model(
219 &self,
220 params: CallInstantiateModelParams<'_, impl AsRef<[u8]>>,
221 ) -> Result<
222 CallInstantiateModelResult,
223 crate::client::DecthingsRpcError<CallInstantiateModelError>,
224 > {
225 let (tx, rx) = tokio::sync::oneshot::channel();
226 let serialized = match ¶ms.state_data {
227 StateDataProvider::Data { data } => *data,
228 _ => &[],
229 };
230 self.rpc
231 .raw_method_call(
232 "Debug",
233 "callInstantiateModel",
234 params,
235 serialized.iter().map(|x| &x.data).collect::<Vec<_>>(),
236 crate::client::RpcProtocol::Http,
237 |x| {
238 tx.send(x).ok();
239 StateModification::empty()
240 },
241 )
242 .await;
243 rx.await
244 .unwrap()
245 .map_err(crate::client::DecthingsRpcError::Request)
246 .and_then(|x| {
247 let res: super::Response<
248 response::CallInstantiateModelResult,
249 response::CallInstantiateModelError,
250 > = serde_json::from_slice(&x.0)?;
251 match res {
252 super::Response::Result(val) => Ok(val),
253 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
254 }
255 })
256 }
257
258 pub async fn call_train<D>(
259 &self,
260 params: CallTrainParams<'_>,
261 ) -> Result<CallTrainResult, crate::client::DecthingsRpcError<CallTrainError>> {
262 let (tx, rx) = tokio::sync::oneshot::channel();
263 let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
264 self.rpc
265 .raw_method_call(
266 "Debug",
267 "callTrain",
268 params,
269 serialized,
270 crate::client::RpcProtocol::Http,
271 |x| {
272 tx.send(x).ok();
273 StateModification::empty()
274 },
275 )
276 .await;
277 rx.await
278 .unwrap()
279 .map_err(crate::client::DecthingsRpcError::Request)
280 .and_then(|x| {
281 let res: super::Response<response::CallTrainResult, response::CallTrainError> =
282 serde_json::from_slice(&x.0)?;
283 match res {
284 super::Response::Result(val) => Ok(val),
285 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
286 }
287 })
288 }
289
290 pub async fn get_training_status(
291 &self,
292 params: DebugGetTrainingStatusParams<'_>,
293 ) -> Result<
294 DebugGetTrainingStatusResult,
295 crate::client::DecthingsRpcError<DebugGetTrainingStatusError>,
296 > {
297 let (tx, rx) = tokio::sync::oneshot::channel();
298 self.rpc
299 .raw_method_call::<_, _, &[u8]>(
300 "Debug",
301 "getTrainingStatus",
302 params,
303 &[],
304 crate::client::RpcProtocol::Http,
305 |x| {
306 tx.send(x).ok();
307 StateModification::empty()
308 },
309 )
310 .await;
311 rx.await
312 .unwrap()
313 .map_err(crate::client::DecthingsRpcError::Request)
314 .and_then(|x| {
315 let res: super::Response<
316 response::DebugGetTrainingStatusResult,
317 response::DebugGetTrainingStatusError,
318 > = serde_json::from_slice(&x.0)?;
319 match res {
320 super::Response::Result(val) => Ok(val),
321 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
322 }
323 })
324 }
325
326 pub async fn get_training_metrics(
327 &self,
328 params: DebugGetTrainingMetricsParams<'_>,
329 ) -> Result<
330 DebugGetTrainingMetricsResult,
331 crate::client::DecthingsRpcError<DebugGetTrainingMetricsError>,
332 > {
333 let (tx, rx) = tokio::sync::oneshot::channel();
334 self.rpc
335 .raw_method_call::<_, _, &[u8]>(
336 "Debug",
337 "getTrainingMetrics",
338 params,
339 &[],
340 crate::client::RpcProtocol::Http,
341 |x| {
342 tx.send(x).ok();
343 StateModification::empty()
344 },
345 )
346 .await;
347 rx.await
348 .unwrap()
349 .map_err(crate::client::DecthingsRpcError::Request)
350 .and_then(|x| {
351 let res: super::Response<
352 response::DebugGetTrainingMetricsResult,
353 response::DebugGetTrainingMetricsError,
354 > = serde_json::from_slice(&x.0)?;
355 match res {
356 super::Response::Result(mut val) => {
357 if val.metrics.iter().map(|x| x.entries.len()).sum::<usize>() != x.1.len() {
358 return Err(crate::client::DecthingsClientError::InvalidMessage.into());
359 }
360 for (entry, data) in
361 val.metrics.iter_mut().flat_map(|x| &mut x.entries).zip(x.1)
362 {
363 entry.data = OwnedDecthingsTensor::from_bytes(data)
364 .map_err(|_| crate::client::DecthingsClientError::InvalidMessage)?;
365 }
366 Ok(val)
367 }
368 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
369 }
370 })
371 }
372
373 pub async fn cancel_training_session(
374 &self,
375 params: DebugCancelTrainingSessionParams<'_>,
376 ) -> Result<
377 DebugCancelTrainingSessionResult,
378 crate::client::DecthingsRpcError<DebugCancelTrainingSessionError>,
379 > {
380 let (tx, rx) = tokio::sync::oneshot::channel();
381 self.rpc
382 .raw_method_call::<_, _, &[u8]>(
383 "Debug",
384 "cancelTrainingSession",
385 params,
386 &[],
387 crate::client::RpcProtocol::Http,
388 |x| {
389 tx.send(x).ok();
390 StateModification::empty()
391 },
392 )
393 .await;
394 rx.await
395 .unwrap()
396 .map_err(crate::client::DecthingsRpcError::Request)
397 .and_then(|x| {
398 let res: super::Response<
399 response::DebugCancelTrainingSessionResult,
400 response::DebugCancelTrainingSessionError,
401 > = serde_json::from_slice(&x.0)?;
402 match res {
403 super::Response::Result(val) => Ok(val),
404 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
405 }
406 })
407 }
408
409 pub async fn call_evaluate(
410 &self,
411 params: CallEvaluateParams<'_>,
412 ) -> Result<CallEvaluateResult, crate::client::DecthingsRpcError<CallEvaluateError>> {
413 let (tx, rx) = tokio::sync::oneshot::channel();
414 let serialized = crate::client::serialize_parameter_provider_list(params.params.iter());
415 self.rpc
416 .raw_method_call(
417 "Debug",
418 "callEvaluate",
419 params,
420 serialized,
421 crate::client::RpcProtocol::Http,
422 |x| {
423 tx.send(x).ok();
424 StateModification::empty()
425 },
426 )
427 .await;
428 rx.await
429 .unwrap()
430 .map_err(crate::client::DecthingsRpcError::Request)
431 .and_then(|x| {
432 let res: super::Response<
433 response::CallEvaluateResult,
434 response::CallEvaluateError,
435 > = serde_json::from_slice(&x.0)?;
436 match res {
437 super::Response::Result(mut val) => {
438 if val.output.len() != x.1.len() {
439 return Err(crate::client::DecthingsClientError::InvalidMessage.into());
440 }
441 for (entry, data) in val.output.iter_mut().zip(x.1) {
442 entry.data = super::many_decthings_tensors_from_bytes(data)
443 .map_err(|_| crate::client::DecthingsClientError::InvalidMessage)?;
444 }
445 Ok(val)
446 }
447 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
448 }
449 })
450 }
451
452 pub async fn call_get_model_state(
453 &self,
454 params: CallGetModelStateParams<'_>,
455 ) -> Result<CallGetModelStateResult, crate::client::DecthingsRpcError<CallGetModelStateError>>
456 {
457 let (tx, rx) = tokio::sync::oneshot::channel();
458 self.rpc
459 .raw_method_call::<_, _, &[u8]>(
460 "Debug",
461 "callGetModelState",
462 params,
463 &[],
464 crate::client::RpcProtocol::Http,
465 |x| {
466 tx.send(x).ok();
467 StateModification::empty()
468 },
469 )
470 .await;
471 rx.await
472 .unwrap()
473 .map_err(crate::client::DecthingsRpcError::Request)
474 .and_then(|x| {
475 let res: super::Response<
476 response::CallGetModelStateResult,
477 response::CallGetModelStateError,
478 > = serde_json::from_slice(&x.0)?;
479 match res {
480 super::Response::Result(val) => Ok(val),
481 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
482 }
483 })
484 }
485
486 pub async fn download_state_data(
487 &self,
488 params: DownloadStateDataParams<'_, impl AsRef<str>>,
489 ) -> Result<DownloadStateDataResult, crate::client::DecthingsRpcError<DownloadStateDataError>>
490 {
491 let (tx, rx) = tokio::sync::oneshot::channel();
492 self.rpc
493 .raw_method_call::<_, _, &[u8]>(
494 "Debug",
495 "downloadStateData",
496 params,
497 &[],
498 crate::client::RpcProtocol::Http,
499 |x| {
500 tx.send(x).ok();
501 StateModification::empty()
502 },
503 )
504 .await;
505 rx.await
506 .unwrap()
507 .map_err(crate::client::DecthingsRpcError::Request)
508 .and_then(|x| {
509 let res: super::Response<
510 response::DownloadStateDataResult,
511 response::DownloadStateDataError,
512 > = serde_json::from_slice(&x.0)?;
513 match res {
514 super::Response::Result(val) => Ok(DownloadStateDataResult {
515 data: val
516 .data
517 .into_iter()
518 .zip(x.1)
519 .map(|(key, data)| super::StateKeyData { key: key.key, data })
520 .collect(),
521 }),
522 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
523 }
524 })
525 }
526
527 pub async fn send_to_remote_inspector(
528 &self,
529 params: SendToRemoteInspectorParams<'_, impl AsRef<[u8]>>,
530 ) -> Result<
531 SendToRemoteInspectorResult,
532 crate::client::DecthingsRpcError<SendToRemoteInspectorError>,
533 > {
534 let (tx, rx) = tokio::sync::oneshot::channel();
535 self.rpc
536 .raw_method_call(
537 "Debug",
538 "sendToRemoteInspector",
539 ¶ms,
540 [¶ms.data],
541 crate::client::RpcProtocol::Http,
542 |x| {
543 tx.send(x).ok();
544 StateModification::empty()
545 },
546 )
547 .await;
548 rx.await
549 .unwrap()
550 .map_err(crate::client::DecthingsRpcError::Request)
551 .and_then(|x| {
552 let res: super::Response<
553 response::SendToRemoteInspectorResult,
554 response::SendToRemoteInspectorError,
555 > = serde_json::from_slice(&x.0)?;
556 match res {
557 super::Response::Result(val) => Ok(val),
558 super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
559 }
560 })
561 }
562
563 #[cfg(feature = "events")]
564 pub async fn subscribe_to_events(
565 &self,
566 params: DebugSubscribeToEventsParams<'_>,
567 ) -> Result<
568 DebugSubscribeToEventsResult,
569 crate::client::DecthingsRpcError<DebugSubscribeToEventsError>,
570 > {
571 let (tx, rx) = tokio::sync::oneshot::channel();
572 let debug_session_id_owned = params.debug_session_id.to_owned();
573 self.rpc
574 .raw_method_call::<_, _, &[u8]>(
575 "Debug",
576 "subscribeToEvents",
577 params,
578 &[],
579 crate::client::RpcProtocol::Ws,
580 move |x| {
581 match x {
582 Ok(val) => {
583 let res: Result<
584 super::Response<
585 response::DebugSubscribeToEventsResult,
586 response::DebugSubscribeToEventsError,
587 >,
588 crate::client::DecthingsRpcError<DebugSubscribeToEventsError>,
589 > = serde_json::from_slice(&val.0).map_err(Into::into);
590 match res {
591 Ok(super::Response::Result(val)) => {
592 tx.send(Ok(val)).ok();
593 return StateModification {
594 add_events: vec![debug_session_id_owned],
595 remove_events: vec![],
596 };
597 }
598 Ok(super::Response::Error(val)) => {
599 tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
600 .ok();
601 }
602 Err(e) => {
603 tx.send(Err(e)).ok();
604 }
605 }
606 }
607 Err(err) => {
608 tx.send(Err(err.into())).ok();
609 }
610 }
611 StateModification::empty()
612 },
613 )
614 .await;
615 rx.await.unwrap()
616 }
617
618 #[cfg(feature = "events")]
619 pub async fn unsubscribe_from_events(
620 &self,
621 params: DebugUnsubscribeFromEventsParams<'_>,
622 ) -> Result<
623 DebugUnsubscribeFromEventsResult,
624 crate::client::DecthingsRpcError<DebugUnsubscribeFromEventsError>,
625 > {
626 let (tx, rx) = tokio::sync::oneshot::channel();
627 let debug_session_id_owned = params.debug_session_id.to_owned();
628 let did_call = self
629 .rpc
630 .raw_method_call::<_, _, &[u8]>(
631 "Debug",
632 "unsubscribeFromEvents",
633 params,
634 &[],
635 crate::client::RpcProtocol::WsIfAvailableOtherwiseNone,
636 move |x| {
637 match x {
638 Ok(val) => {
639 let res: Result<
640 super::Response<
641 response::DebugUnsubscribeFromEventsResult,
642 response::DebugUnsubscribeFromEventsError,
643 >,
644 crate::client::DecthingsRpcError<DebugUnsubscribeFromEventsError>,
645 > = serde_json::from_slice(&val.0).map_err(Into::into);
646 match res {
647 Ok(super::Response::Result(val)) => {
648 tx.send(Ok(val)).ok();
649 return StateModification {
650 add_events: vec![],
651 remove_events: vec![debug_session_id_owned],
652 };
653 }
654 Ok(super::Response::Error(val)) => {
655 tx.send(Err(crate::client::DecthingsRpcError::Rpc(val)))
656 .ok();
657 }
658 Err(e) => {
659 tx.send(Err(e)).ok();
660 }
661 }
662 }
663 Err(err) => {
664 tx.send(Err(err.into())).ok();
665 }
666 }
667 StateModification::empty()
668 },
669 )
670 .await;
671 if !did_call {
672 return Err(crate::client::DecthingsRpcError::Rpc(
673 DebugUnsubscribeFromEventsError::NotSubscribed,
674 ));
675 }
676 rx.await.unwrap()
677 }
678}