1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use crate::errors::CharybdisError;
5use crate::model::BaseModel;
6use futures::{Stream, StreamExt, TryStreamExt};
7use scylla::client::pager::TypedRowStream;
8use scylla::errors::NextRowError;
9
10pub struct CharybdisModelStream<T: BaseModel + 'static> {
11 inner: TypedRowStream<T>,
12 query_string: &'static str,
13}
14
15impl<T: BaseModel> CharybdisModelStream<T> {
16 pub(crate) fn query_string(&mut self, query_string: &'static str) {
17 self.query_string = query_string;
18 }
19}
20
21impl<T: BaseModel> From<TypedRowStream<T>> for CharybdisModelStream<T> {
22 fn from(iter: TypedRowStream<T>) -> Self {
23 CharybdisModelStream {
24 inner: iter,
25 query_string: "",
26 }
27 }
28}
29
30impl<T: BaseModel> Stream for CharybdisModelStream<T> {
31 type Item = Result<T, CharybdisError>;
32
33 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
34 self.inner
35 .poll_next_unpin(cx)
36 .map_err(|e| CharybdisError::NextRowError(self.query_string, e))
37 }
38}
39
40impl<T: BaseModel> CharybdisModelStream<T> {
41 pub async fn try_collect(self) -> Result<Vec<T>, CharybdisError> {
42 let results: Result<Vec<T>, NextRowError> = self.inner.try_collect().await;
43
44 results.map_err(|e| CharybdisError::NextRowError(self.query_string, e))
45 }
46}