decthings_api/client/rpc/dataset/
mod.rs

1mod request;
2mod response;
3
4use std::collections::HashSet;
5
6use crate::{client::StateModification, tensor::OwnedDecthingsTensor};
7
8pub use request::*;
9pub use response::*;
10
11/// *data* has one element per dataset key, where each of these elements contains a list of
12/// DecthingsTensor. We first assert that the keys are valid (more than one and unique), and that
13/// the same number of elements are to be added to each key. Then, the data is serialized by first
14/// sorting the data by key (this allows for optimization in the server). Then, the data is
15/// serialized using *.serialize()* on the DecthingsTensor. The data is serialized in order by
16/// sorted key, and grouped by element. That is, one element for each key is serialized and added
17/// after each other, before moving on to the next element. Each element is returned as a separate
18/// vec.
19fn serialize_add_dataset_data(data: &[DataToAddForKey<'_>]) -> Result<Vec<Vec<u8>>, String> {
20    if data.is_empty() {
21        return Err(
22            "Failed to serialize data: Got zero keys, but a dataset always has at least one key."
23                .to_string(),
24        );
25    }
26
27    let num_entries = data[0].data.len();
28    for x in data.iter() {
29        if x.data.len() != num_entries {
30            return Err(format!(
31                    "Failed to serialize data: All keys must contain the same amount of data. Key {} had {num_entries} elements, but key {} had {} elements.",
32                    data[0].key,
33                    x.key,
34                    x.data.len()
35                ));
36        }
37    }
38
39    let mut sorted_keys: Vec<_> = data.iter().map(|x| x.key).collect();
40    sorted_keys.sort();
41
42    {
43        let mut uniq = HashSet::new();
44        if !sorted_keys.iter().all(|x| uniq.insert(x)) {
45            return Err(format!(
46                "Failed to serialize data: Got duplicate keys. Keys were: {:?}",
47                data.iter().map(|x| x.key).collect::<Vec<_>>()
48            ));
49        }
50    }
51
52    let mut res = Vec::with_capacity(num_entries * sorted_keys.len());
53
54    for i in 0..num_entries {
55        for &key in &sorted_keys {
56            let element = &data.iter().find(|x| x.key == key).unwrap().data[i];
57            res.push(element.serialize());
58        }
59    }
60
61    Ok(res)
62}
63
64pub struct DatasetRpc {
65    rpc: crate::client::DecthingsClientRpc,
66}
67
68impl DatasetRpc {
69    pub(crate) fn new(rpc: crate::client::DecthingsClientRpc) -> Self {
70        Self { rpc }
71    }
72
73    pub async fn create_dataset(
74        &self,
75        params: CreateDatasetParams<'_>,
76    ) -> Result<CreateDatasetResult, crate::client::DecthingsRpcError<CreateDatasetError>> {
77        let (tx, rx) = tokio::sync::oneshot::channel();
78        self.rpc
79            .raw_method_call::<_, _, &[u8]>(
80                "Dataset",
81                "createDataset",
82                params,
83                &[],
84                crate::client::RpcProtocol::Http,
85                |x| {
86                    tx.send(x).ok();
87                    StateModification::empty()
88                },
89            )
90            .await;
91        rx.await
92            .unwrap()
93            .map_err(crate::client::DecthingsRpcError::Request)
94            .and_then(|x| {
95                let res: super::Response<CreateDatasetResult, CreateDatasetError> =
96                    serde_json::from_slice(&x.0)?;
97                match res {
98                    super::Response::Result(val) => Ok(val),
99                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
100                }
101            })
102    }
103
104    pub async fn update_dataset(
105        &self,
106        params: UpdateDatasetParams<'_>,
107    ) -> Result<UpdateDatasetResult, crate::client::DecthingsRpcError<UpdateDatasetError>> {
108        let (tx, rx) = tokio::sync::oneshot::channel();
109        self.rpc
110            .raw_method_call::<_, _, &[u8]>(
111                "Dataset",
112                "updateDataset",
113                params,
114                &[],
115                crate::client::RpcProtocol::Http,
116                |x| {
117                    tx.send(x).ok();
118                    StateModification::empty()
119                },
120            )
121            .await;
122        rx.await
123            .unwrap()
124            .map_err(crate::client::DecthingsRpcError::Request)
125            .and_then(|x| {
126                let res: super::Response<UpdateDatasetResult, UpdateDatasetError> =
127                    serde_json::from_slice(&x.0)?;
128                match res {
129                    super::Response::Result(val) => Ok(val),
130                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
131                }
132            })
133    }
134
135    pub async fn delete_dataset(
136        &self,
137        params: DeleteDatasetParams<'_>,
138    ) -> Result<DeleteDatasetResult, crate::client::DecthingsRpcError<DeleteDatasetError>> {
139        let (tx, rx) = tokio::sync::oneshot::channel();
140        self.rpc
141            .raw_method_call::<_, _, &[u8]>(
142                "Dataset",
143                "deleteDataset",
144                params,
145                &[],
146                crate::client::RpcProtocol::Http,
147                |x| {
148                    tx.send(x).ok();
149                    StateModification::empty()
150                },
151            )
152            .await;
153        rx.await
154            .unwrap()
155            .map_err(crate::client::DecthingsRpcError::Request)
156            .and_then(|x| {
157                let res: super::Response<DeleteDatasetResult, DeleteDatasetError> =
158                    serde_json::from_slice(&x.0)?;
159                match res {
160                    super::Response::Result(val) => Ok(val),
161                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
162                }
163            })
164    }
165
166    pub async fn get_datasets(
167        &self,
168        params: GetDatasetsParams<'_, impl AsRef<str>>,
169    ) -> Result<GetDatasetsResult, crate::client::DecthingsRpcError<GetDatasetsError>> {
170        let (tx, rx) = tokio::sync::oneshot::channel();
171        self.rpc
172            .raw_method_call::<_, _, &[u8]>(
173                "Dataset",
174                "getDatasets",
175                params,
176                &[],
177                crate::client::RpcProtocol::Http,
178                |x| {
179                    tx.send(x).ok();
180                    StateModification::empty()
181                },
182            )
183            .await;
184        rx.await
185            .unwrap()
186            .map_err(crate::client::DecthingsRpcError::Request)
187            .and_then(|x| {
188                let res: super::Response<GetDatasetsResult, GetDatasetsError> =
189                    serde_json::from_slice(&x.0)?;
190                match res {
191                    super::Response::Result(val) => Ok(val),
192                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
193                }
194            })
195    }
196
197    pub async fn add_entries(
198        &self,
199        params: AddEntriesParams<'_>,
200    ) -> Result<AddEntriesResult, crate::client::DecthingsRpcError<AddEntriesError>> {
201        let (tx, rx) = tokio::sync::oneshot::channel();
202        let serialized = serialize_add_dataset_data(&params.keys).map_err(|e| {
203            crate::client::DecthingsRpcError::Rpc(AddEntriesError::InvalidParameter {
204                parameter_name: "params.keys".to_string(),
205                reason: e,
206            })
207        })?;
208        self.rpc
209            .raw_method_call(
210                "Dataset",
211                "addEntries",
212                params,
213                serialized,
214                crate::client::RpcProtocol::Http,
215                |x| {
216                    tx.send(x).ok();
217                    StateModification::empty()
218                },
219            )
220            .await;
221        rx.await
222            .unwrap()
223            .map_err(crate::client::DecthingsRpcError::Request)
224            .and_then(|x| {
225                let res: super::Response<AddEntriesResult, AddEntriesError> =
226                    serde_json::from_slice(&x.0)?;
227                match res {
228                    super::Response::Result(val) => Ok(val),
229                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
230                }
231            })
232    }
233
234    pub async fn add_entries_to_needs_review(
235        &self,
236        params: AddEntriesToNeedsReviewParams<'_>,
237    ) -> Result<
238        AddEntriesToNeedsReviewResult,
239        crate::client::DecthingsRpcError<AddEntriesToNeedsReviewError>,
240    > {
241        let (tx, rx) = tokio::sync::oneshot::channel();
242        let serialized = serialize_add_dataset_data(&params.keys).map_err(|e| {
243            crate::client::DecthingsRpcError::Rpc(AddEntriesToNeedsReviewError::InvalidParameter {
244                parameter_name: "params.keys".to_string(),
245                reason: e,
246            })
247        })?;
248        self.rpc
249            .raw_method_call(
250                "Dataset",
251                "addEntriesToNeedsReview",
252                params,
253                serialized,
254                crate::client::RpcProtocol::Http,
255                |x| {
256                    tx.send(x).ok();
257                    StateModification::empty()
258                },
259            )
260            .await;
261        rx.await
262            .unwrap()
263            .map_err(crate::client::DecthingsRpcError::Request)
264            .and_then(|x| {
265                let res: super::Response<
266                    AddEntriesToNeedsReviewResult,
267                    AddEntriesToNeedsReviewError,
268                > = serde_json::from_slice(&x.0)?;
269                match res {
270                    super::Response::Result(val) => Ok(val),
271                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
272                }
273            })
274    }
275
276    pub async fn finalize_needs_review_entries(
277        &self,
278        params: FinalizeNeedsReviewEntriesParams<'_>,
279    ) -> Result<
280        FinalizeNeedsReviewEntriesResult,
281        crate::client::DecthingsRpcError<FinalizeNeedsReviewEntriesError>,
282    > {
283        let (tx, rx) = tokio::sync::oneshot::channel();
284        let serialized = serialize_add_dataset_data(&params.keys).map_err(|e| {
285            crate::client::DecthingsRpcError::Rpc(
286                FinalizeNeedsReviewEntriesError::InvalidParameter {
287                    parameter_name: "params.keys".to_string(),
288                    reason: e,
289                },
290            )
291        })?;
292        if params.indexes.len() != params.keys[0].data.len() {
293            return Err(crate::client::DecthingsRpcError::Rpc(
294                FinalizeNeedsReviewEntriesError::InvalidParameter {
295                    parameter_name: "params.keys".to_string(),
296                    reason: format!(
297                        "The number of indexes to remove must equal the number of elements to add. Attempted to remove {} indexes and add {} elements.",
298                        params.indexes.len(),
299                        params.keys[0].data.len()
300                    )
301                },
302            ));
303        }
304        self.rpc
305            .raw_method_call(
306                "Dataset",
307                "finalizeNeedsReviewEntries",
308                params,
309                serialized,
310                crate::client::RpcProtocol::Http,
311                |x| {
312                    tx.send(x).ok();
313                    StateModification::empty()
314                },
315            )
316            .await;
317        rx.await
318            .unwrap()
319            .map_err(crate::client::DecthingsRpcError::Request)
320            .and_then(|x| {
321                let res: super::Response<
322                    FinalizeNeedsReviewEntriesResult,
323                    FinalizeNeedsReviewEntriesError,
324                > = serde_json::from_slice(&x.0)?;
325                match res {
326                    super::Response::Result(val) => Ok(val),
327                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
328                }
329            })
330    }
331
332    pub async fn get_entries(
333        &self,
334        params: GetEntriesParams<'_>,
335    ) -> Result<GetEntriesResult, crate::client::DecthingsRpcError<GetEntriesError>> {
336        let (tx, rx) = tokio::sync::oneshot::channel();
337        self.rpc
338            .raw_method_call::<_, _, &[u8]>(
339                "Dataset",
340                "getEntries",
341                params,
342                &[],
343                crate::client::RpcProtocol::Http,
344                |x| {
345                    tx.send(x).ok();
346                    StateModification::empty()
347                },
348            )
349            .await;
350        rx.await
351            .unwrap()
352            .map_err(crate::client::DecthingsRpcError::Request)
353            .and_then(|mut x| {
354                let inner_res: super::Response<InnerGetEntriesResult, GetEntriesError> =
355                    serde_json::from_slice(&x.0)?;
356                match inner_res {
357                    super::Response::Result(val) => {
358                        if x.1.len() != val.indexes.len() * val.keys.len() {
359                            return Err(crate::client::DecthingsClientError::InvalidMessage.into());
360                        }
361                        let mut res = GetEntriesResult {
362                            keys: val
363                                .keys
364                                .into_iter()
365                                .map(|x| KeyData {
366                                    name: x,
367                                    data: Vec::with_capacity(val.indexes.len()),
368                                })
369                                .collect(),
370                        };
371
372                        for index in val.indexes {
373                            for key in res.keys.iter_mut() {
374                                key.data.push(FetchedEntry {
375                                    index,
376                                    data: OwnedDecthingsTensor::from_bytes(x.1.remove(0)).map_err(
377                                        |_| crate::client::DecthingsClientError::InvalidMessage,
378                                    )?,
379                                });
380                            }
381                        }
382                        Ok(res)
383                    }
384                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
385                }
386            })
387    }
388
389    pub async fn get_needs_review_entries(
390        &self,
391        params: GetNeedsReviewEntriesParams<'_>,
392    ) -> Result<
393        GetNeedsReviewEntriesResult,
394        crate::client::DecthingsRpcError<GetNeedsReviewEntriesError>,
395    > {
396        let (tx, rx) = tokio::sync::oneshot::channel();
397        self.rpc
398            .raw_method_call::<_, _, &[u8]>(
399                "Dataset",
400                "getNeedsReviewEntries",
401                params,
402                &[],
403                crate::client::RpcProtocol::Http,
404                |x| {
405                    tx.send(x).ok();
406                    StateModification::empty()
407                },
408            )
409            .await;
410        rx.await
411            .unwrap()
412            .map_err(crate::client::DecthingsRpcError::Request)
413            .and_then(|mut x| {
414                let res: super::Response<
415                    InnerGetNeedsReviewEntriesResult,
416                    GetNeedsReviewEntriesError,
417                > = serde_json::from_slice(&x.0)?;
418                match res {
419                    super::Response::Result(val) => {
420                        if x.1.len() != val.indexes.len() * val.keys.len() {
421                            return Err(crate::client::DecthingsClientError::InvalidMessage.into());
422                        }
423                        let mut res = GetNeedsReviewEntriesResult {
424                            keys: val
425                                .keys
426                                .into_iter()
427                                .map(|x| KeyData {
428                                    name: x,
429                                    data: Vec::with_capacity(val.indexes.len()),
430                                })
431                                .collect(),
432                        };
433                        for index in val.indexes {
434                            for key in res.keys.iter_mut() {
435                                key.data.push(FetchedEntry {
436                                    index,
437                                    data: OwnedDecthingsTensor::from_bytes(x.1.remove(0)).map_err(
438                                        |_| crate::client::DecthingsClientError::InvalidMessage,
439                                    )?,
440                                });
441                            }
442                        }
443                        Ok(res)
444                    }
445                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
446                }
447            })
448    }
449
450    pub async fn remove_entries(
451        &self,
452        params: RemoveEntriesParams<'_>,
453    ) -> Result<RemoveEntriesResult, crate::client::DecthingsRpcError<RemoveEntriesError>> {
454        let (tx, rx) = tokio::sync::oneshot::channel();
455        self.rpc
456            .raw_method_call::<_, _, &[u8]>(
457                "Dataset",
458                "removeEntries",
459                params,
460                &[],
461                crate::client::RpcProtocol::Http,
462                |x| {
463                    tx.send(x).ok();
464                    StateModification::empty()
465                },
466            )
467            .await;
468        rx.await
469            .unwrap()
470            .map_err(crate::client::DecthingsRpcError::Request)
471            .and_then(|x| {
472                let res: super::Response<RemoveEntriesResult, RemoveEntriesError> =
473                    serde_json::from_slice(&x.0)?;
474                match res {
475                    super::Response::Result(val) => Ok(val),
476                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
477                }
478            })
479    }
480
481    pub async fn remove_needs_review_entries(
482        &self,
483        params: RemoveNeedsReviewEntriesParams<'_>,
484    ) -> Result<
485        RemoveNeedsReviewEntriesResult,
486        crate::client::DecthingsRpcError<RemoveNeedsReviewEntriesError>,
487    > {
488        let (tx, rx) = tokio::sync::oneshot::channel();
489        self.rpc
490            .raw_method_call::<_, _, &[u8]>(
491                "Dataset",
492                "removeNeedsReviewEntries",
493                params,
494                &[],
495                crate::client::RpcProtocol::Http,
496                |x| {
497                    tx.send(x).ok();
498                    StateModification::empty()
499                },
500            )
501            .await;
502        rx.await
503            .unwrap()
504            .map_err(crate::client::DecthingsRpcError::Request)
505            .and_then(|x| {
506                let res: super::Response<
507                    RemoveNeedsReviewEntriesResult,
508                    RemoveNeedsReviewEntriesError,
509                > = serde_json::from_slice(&x.0)?;
510                match res {
511                    super::Response::Result(val) => Ok(val),
512                    super::Response::Error(val) => Err(crate::client::DecthingsRpcError::Rpc(val)),
513                }
514            })
515    }
516}