1use std::borrow::Cow;
2use std::collections::HashMap;
3use std::future::IntoFuture;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use anyhow::bail;
8use futures::StreamExt;
9use futures::future::Either;
10use futures::stream::SelectAll;
11use indexmap::IndexMap;
12use serde::Serialize;
13use serde::de::DeserializeOwned;
14use uuid::Uuid;
15
16use super::transaction::WithTransaction;
17use super::{Stream, live};
18use crate::api::conn::Command;
19use crate::api::err::Error;
20use crate::api::method::BoxFuture;
21use crate::api::{self, Connection, ExtraFeatures, Result, opt};
22use crate::core::expr::{LogicalPlan, TopLevelExpr};
23use crate::core::val;
24use crate::method::{OnceLockExt, Stats, WithStats};
25use crate::value::Notification;
26use crate::{Surreal, Value};
27
28#[derive(Debug)]
30#[must_use = "futures do nothing unless you `.await` or poll them"]
31pub struct Query<'r, C: Connection> {
32 pub(crate) txn: Option<Uuid>,
33 pub(crate) client: Cow<'r, Surreal<C>>,
34 pub(crate) inner: Result<ValidQuery>,
35}
36
37impl<C> WithTransaction for Query<'_, C>
38where
39 C: Connection,
40{
41 fn with_transaction(mut self, id: Uuid) -> Self {
42 self.txn = Some(id);
43 self
44 }
45}
46
47#[derive(Debug)]
48pub(crate) enum ValidQuery {
49 Raw {
50 query: Cow<'static, str>,
51 bindings: val::Object,
52 },
53 Normal {
54 query: Vec<TopLevelExpr>,
55 register_live_queries: bool,
56 bindings: val::Object,
57 },
58}
59
60impl<'r, C> Query<'r, C>
61where
62 C: Connection,
63{
64 pub(crate) fn normal(
65 client: Cow<'r, Surreal<C>>,
66 query: Vec<TopLevelExpr>,
67 bindings: val::Object,
68 register_live_queries: bool,
69 ) -> Self {
70 Query {
71 txn: None,
72 client,
73 inner: Ok(ValidQuery::Normal {
74 query,
75 bindings,
76 register_live_queries,
77 }),
78 }
79 }
80
81 pub(crate) fn map_valid<F>(self, f: F) -> Self
82 where
83 F: FnOnce(ValidQuery) -> Result<ValidQuery>,
84 {
85 match self.inner {
86 Ok(x) => Query {
87 txn: self.txn,
88 client: self.client,
89 inner: f(x),
90 },
91 x => Query {
92 txn: self.txn,
93 client: self.client,
94 inner: x,
95 },
96 }
97 }
98
99 pub fn into_owned(self) -> Query<'static, C> {
102 Query {
103 txn: self.txn,
104 client: Cow::Owned(self.client.into_owned()),
105 inner: self.inner,
106 }
107 }
108}
109
110impl<'r, Client> IntoFuture for Query<'r, Client>
111where
112 Client: Connection,
113{
114 type Output = Result<Response>;
115 type IntoFuture = BoxFuture<'r, Self::Output>;
116
117 fn into_future(self) -> Self::IntoFuture {
118 Box::pin(async move {
119 let router = self.client.inner.router.extract()?;
121
122 match self.inner? {
123 ValidQuery::Raw {
124 query,
125 bindings,
126 } => {
127 router
128 .execute_query(Command::RawQuery {
129 query,
130 txn: self.txn,
131 variables: bindings,
132 })
133 .await
134 }
135 ValidQuery::Normal {
136 query,
137 register_live_queries,
138 bindings,
139 } => {
140 let query_indicies = if register_live_queries {
142 query
143 .iter()
144 .filter(|x| {
146 !matches!(
147 x,
148 TopLevelExpr::Begin
149 | TopLevelExpr::Commit | TopLevelExpr::Cancel
150 )
151 })
152 .enumerate()
153 .filter(|(_, x)| matches!(x, TopLevelExpr::Live(_)))
154 .map(|(i, _)| i)
155 .collect()
156 } else {
157 Vec::new()
158 };
159
160 if !query_indicies.is_empty()
162 && !router.features.contains(&ExtraFeatures::LiveQueries)
163 {
164 return Err(Error::LiveQueriesNotSupported.into());
165 }
166
167 let query = LogicalPlan {
168 expressions: query,
169 };
170
171 let mut response = router
172 .execute_query(Command::Query {
173 txn: self.txn,
174 query,
175 variables: bindings,
176 })
177 .await?;
178
179 for idx in query_indicies {
180 let Some((_, result)) = response.results.get(&idx) else {
181 continue;
182 };
183
184 let res = match result {
187 Ok(id) => {
188 let val::Value::Uuid(uuid) = id else {
189 bail!(Error::InternalError(
190 "successfull live query did not return a uuid".to_string(),
191 ));
192 };
193 live::register(router, uuid.0).await.map(|rx| {
194 Stream::new(self.client.inner.clone().into(), uuid.0, Some(rx))
195 })
196 }
197 Err(_) => Err(anyhow::Error::new(Error::NotLiveQuery(idx))),
198 };
199 response.live_queries.insert(idx, res);
200 }
201
202 Ok(response)
203 }
204 }
205 })
206 }
207}
208
209impl<'r, Client> IntoFuture for WithStats<Query<'r, Client>>
210where
211 Client: Connection,
212{
213 type Output = Result<WithStats<Response>>;
214 type IntoFuture = BoxFuture<'r, Self::Output>;
215
216 fn into_future(self) -> Self::IntoFuture {
217 Box::pin(async move {
218 let response = self.0.await?;
219 Ok(WithStats(response))
220 })
221 }
222}
223
224impl<C> Query<'_, C>
225where
226 C: Connection,
227{
228 pub fn query(self, surql: impl opt::IntoQuery) -> Self {
230 let client = self.client.clone();
231 self.map_valid(move |valid| match valid {
232 ValidQuery::Raw {
233 ..
234 } => {
235 Err(Error::InvalidParams("Appending to raw queries is not supported".to_owned())
236 .into())
237 }
238 ValidQuery::Normal {
239 mut query,
240 register_live_queries,
241 bindings,
242 } => match client.query(surql).inner {
243 Ok(ValidQuery::Normal {
244 query: stmts,
245 ..
246 }) => {
247 query.extend(stmts);
248 Ok(ValidQuery::Normal {
249 query,
250 register_live_queries,
251 bindings,
252 })
253 }
254 Ok(ValidQuery::Raw {
255 ..
256 }) => Err(Error::InvalidParams("Appending raw queries is not supported".to_owned())
257 .into()),
258 Err(error) => Err(error),
259 },
260 })
261 }
262
263 pub const fn with_stats(self) -> WithStats<Self> {
265 WithStats(self)
266 }
267
268 pub fn bind(self, bindings: impl Serialize + 'static) -> Self {
307 self.map_valid(move |mut valid| {
308 let current_bindings = match &mut valid {
309 ValidQuery::Raw {
310 bindings,
311 ..
312 } => bindings,
313 ValidQuery::Normal {
314 bindings,
315 ..
316 } => bindings,
317 };
318 let bindings = api::value::to_core_value(bindings)?;
319 match bindings {
320 val::Value::Object(mut map) => current_bindings.append(&mut map.0),
321 val::Value::Array(array) => {
322 if array.len() != 2 || !matches!(array[0], val::Value::Strand(_)) {
323 let bindings = val::Value::Array(array);
324 let bindings = Value::from_inner(bindings);
325 return Err(Error::InvalidBindings(bindings).into());
326 }
327
328 let mut iter = array.into_iter();
329 let Some(val::Value::Strand(key)) = iter.next() else {
330 unreachable!()
331 };
332 let Some(value) = iter.next() else {
333 unreachable!()
334 };
335
336 current_bindings.insert(key.into_string(), value);
337 }
338 _ => {
339 let bindings = Value::from_inner(bindings);
340 return Err(Error::InvalidBindings(bindings).into());
341 }
342 }
343
344 Ok(valid)
345 })
346 }
347}
348
349pub(crate) type QueryResult = Result<val::Value>;
350
351#[derive(Debug)]
353pub struct Response {
354 pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
355 pub(crate) live_queries: IndexMap<usize, Result<Stream<Value>>>,
356}
357
358#[derive(Debug)]
360#[must_use = "streams do nothing unless you poll them"]
361pub struct QueryStream<R>(pub(crate) Either<Stream<R>, SelectAll<Stream<R>>>);
362
363impl futures::Stream for QueryStream<Value> {
364 type Item = Notification<Value>;
365
366 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
367 self.as_mut().0.poll_next_unpin(cx)
368 }
369}
370
371impl<R> futures::Stream for QueryStream<Notification<R>>
372where
373 R: DeserializeOwned + Unpin,
374{
375 type Item = Result<Notification<R>>;
376
377 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
378 self.as_mut().0.poll_next_unpin(cx)
379 }
380}
381
382impl Response {
383 pub(crate) fn new() -> Self {
384 Self {
385 results: Default::default(),
386 live_queries: Default::default(),
387 }
388 }
389
390 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Result<R>
452 where
453 R: DeserializeOwned,
454 {
455 index.query_result(self)
456 }
457
458 pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
503 index.query_stream(self)
504 }
505
506 pub fn take_errors(&mut self) -> HashMap<usize, anyhow::Error> {
525 let mut keys = Vec::new();
526 for (key, result) in &self.results {
527 if result.1.is_err() {
528 keys.push(*key);
529 }
530 }
531 let mut errors = HashMap::with_capacity(keys.len());
532 for key in keys {
533 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
534 errors.insert(key, error);
535 }
536 }
537 errors
538 }
539
540 pub fn check(mut self) -> Result<Self> {
556 let mut first_error = None;
557 for (key, result) in &self.results {
558 if result.1.is_err() {
559 first_error = Some(*key);
560 break;
561 }
562 }
563 if let Some(key) = first_error {
564 if let Some((_, Err(error))) = self.results.swap_remove(&key) {
565 return Err(error);
566 }
567 }
568 Ok(self)
569 }
570
571 pub fn num_statements(&self) -> usize {
587 self.results.len()
588 }
589}
590
591impl WithStats<Response> {
592 pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Option<(Stats, Result<R>)>
655 where
656 R: DeserializeOwned,
657 {
658 let stats = index.stats(&self.0)?;
659 let result = index.query_result(&mut self.0);
660 Some((stats, result))
661 }
662
663 pub fn take_errors(&mut self) -> HashMap<usize, (Stats, anyhow::Error)> {
682 let mut keys = Vec::new();
683 for (key, result) in &self.0.results {
684 if result.1.is_err() {
685 keys.push(*key);
686 }
687 }
688 let mut errors = HashMap::with_capacity(keys.len());
689 for key in keys {
690 if let Some((stats, Err(error))) = self.0.results.swap_remove(&key) {
691 errors.insert(key, (stats, error));
692 }
693 }
694 errors
695 }
696
697 pub fn check(self) -> Result<Self> {
713 let response = self.0.check()?;
714 Ok(Self(response))
715 }
716
717 pub fn num_statements(&self) -> usize {
733 self.0.num_statements()
734 }
735
736 pub fn into_inner(self) -> Response {
738 self.0
739 }
740}
741
742#[cfg(test)]
743mod tests {
744 use serde::Deserialize;
745
746 use super::*;
747 use crate::value::to_value;
748
749 #[derive(Debug, Clone, Serialize, Deserialize)]
750 struct Summary {
751 title: String,
752 }
753
754 #[derive(Debug, Clone, Serialize, Deserialize)]
755 struct Article {
756 title: String,
757 body: String,
758 }
759
760 fn to_map(vec: Vec<QueryResult>) -> IndexMap<usize, (Stats, QueryResult)> {
761 vec.into_iter()
762 .map(|result| {
763 let stats = Stats {
764 execution_time: Default::default(),
765 };
766 (stats, result)
767 })
768 .enumerate()
769 .collect()
770 }
771
772 #[test]
773 fn take_from_an_empty_response() {
774 let mut response = Response::new();
775 let value: Value = response.take(0).unwrap();
776 assert!(value.into_inner().is_none());
777
778 let mut response = Response::new();
779 let option: Option<String> = response.take(0).unwrap();
780 assert!(option.is_none());
781
782 let mut response = Response::new();
783 let vec: Vec<String> = response.take(0).unwrap();
784 assert!(vec.is_empty());
785 }
786
787 #[test]
788 fn take_from_an_errored_query() {
789 let mut response = Response {
790 results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
791 ..Response::new()
792 };
793 response.take::<Option<()>>(0).unwrap_err();
794 }
795
796 #[test]
797 fn take_from_empty_records() {
798 let mut response = Response {
799 results: to_map(vec![]),
800 ..Response::new()
801 };
802 let value: Value = response.take(0).unwrap();
803 assert_eq!(value, Default::default());
804
805 let mut response = Response {
806 results: to_map(vec![]),
807 ..Response::new()
808 };
809 let option: Option<String> = response.take(0).unwrap();
810 assert!(option.is_none());
811
812 let mut response = Response {
813 results: to_map(vec![]),
814 ..Response::new()
815 };
816 let vec: Vec<String> = response.take(0).unwrap();
817 assert!(vec.is_empty());
818 }
819
820 #[test]
821 fn take_from_a_scalar_response() {
822 let scalar = 265;
823
824 let mut response = Response {
825 results: to_map(vec![Ok(scalar.into())]),
826 ..Response::new()
827 };
828 let value: Value = response.take(0).unwrap();
829 assert_eq!(value.into_inner(), val::Value::from(scalar));
830
831 let mut response = Response {
832 results: to_map(vec![Ok(scalar.into())]),
833 ..Response::new()
834 };
835 let option: Option<_> = response.take(0).unwrap();
836 assert_eq!(option, Some(scalar));
837
838 let mut response = Response {
839 results: to_map(vec![Ok(scalar.into())]),
840 ..Response::new()
841 };
842 let vec: Vec<i64> = response.take(0).unwrap();
843 assert_eq!(vec, vec![scalar]);
844
845 let scalar = true;
846
847 let mut response = Response {
848 results: to_map(vec![Ok(scalar.into())]),
849 ..Response::new()
850 };
851 let value: Value = response.take(0).unwrap();
852 assert_eq!(value.into_inner(), val::Value::from(scalar));
853
854 let mut response = Response {
855 results: to_map(vec![Ok(scalar.into())]),
856 ..Response::new()
857 };
858 let option: Option<_> = response.take(0).unwrap();
859 assert_eq!(option, Some(scalar));
860
861 let mut response = Response {
862 results: to_map(vec![Ok(scalar.into())]),
863 ..Response::new()
864 };
865 let vec: Vec<bool> = response.take(0).unwrap();
866 assert_eq!(vec, vec![scalar]);
867 }
868
869 #[test]
870 fn take_preserves_order() {
871 let mut response = Response {
872 results: to_map(vec![
873 Ok(0.into()),
874 Ok(1.into()),
875 Ok(2.into()),
876 Ok(3.into()),
877 Ok(4.into()),
878 Ok(5.into()),
879 Ok(6.into()),
880 Ok(7.into()),
881 ]),
882 ..Response::new()
883 };
884 let Some(four): Option<i32> = response.take(4).unwrap() else {
885 panic!("query not found");
886 };
887 assert_eq!(four, 4);
888 let Some(six): Option<i32> = response.take(6).unwrap() else {
889 panic!("query not found");
890 };
891 assert_eq!(six, 6);
892 let Some(zero): Option<i32> = response.take(0).unwrap() else {
893 panic!("query not found");
894 };
895 assert_eq!(zero, 0);
896 let one: Value = response.take(1).unwrap();
897 assert_eq!(one.into_inner(), val::Value::from(1));
898 }
899
900 #[test]
901 fn take_key() {
902 let summary = Summary {
903 title: "Lorem Ipsum".to_owned(),
904 };
905 let value = to_value(summary.clone()).unwrap();
906
907 let mut response = Response {
908 results: to_map(vec![Ok(value.clone().into_inner())]),
909 ..Response::new()
910 };
911 let title: Value = response.take("title").unwrap();
912 assert_eq!(title.into_inner(), val::Value::from(summary.title.as_str()));
913
914 let mut response = Response {
915 results: to_map(vec![Ok(value.clone().into_inner())]),
916 ..Response::new()
917 };
918 let Some(title): Option<String> = response.take("title").unwrap() else {
919 panic!("title not found");
920 };
921 assert_eq!(title, summary.title);
922
923 let mut response = Response {
924 results: to_map(vec![Ok(value.into_inner())]),
925 ..Response::new()
926 };
927 let vec: Vec<String> = response.take("title").unwrap();
928 assert_eq!(vec, vec![summary.title]);
929
930 let article = Article {
931 title: "Lorem Ipsum".to_owned(),
932 body: "Lorem Ipsum Lorem Ipsum".to_owned(),
933 };
934 let value = to_value(article.clone()).unwrap();
935
936 let mut response = Response {
937 results: to_map(vec![Ok(value.clone().into_inner())]),
938 ..Response::new()
939 };
940 let Some(title): Option<String> = response.take("title").unwrap() else {
941 panic!("title not found");
942 };
943 assert_eq!(title, article.title);
944 let Some(body): Option<String> = response.take("body").unwrap() else {
945 panic!("body not found");
946 };
947 assert_eq!(body, article.body);
948
949 let mut response = Response {
950 results: to_map(vec![Ok(value.clone().into_inner())]),
951 ..Response::new()
952 };
953 let vec: Vec<String> = response.take("title").unwrap();
954 assert_eq!(vec, vec![article.title.clone()]);
955
956 let mut response = Response {
957 results: to_map(vec![Ok(value.into_inner())]),
958 ..Response::new()
959 };
960 let value: Value = response.take("title").unwrap();
961 assert_eq!(value.into_inner(), val::Value::from(article.title));
962 }
963
964 #[test]
965 fn take_key_multi() {
966 let article = Article {
967 title: "Lorem Ipsum".to_owned(),
968 body: "Lorem Ipsum Lorem Ipsum".to_owned(),
969 };
970 let value = to_value(article.clone()).unwrap();
971
972 let mut response = Response {
973 results: to_map(vec![Ok(value.clone().into_inner())]),
974 ..Response::new()
975 };
976 let title: Vec<String> = response.take("title").unwrap();
977 assert_eq!(title, vec![article.title.clone()]);
978 let body: Vec<String> = response.take("body").unwrap();
979 assert_eq!(body, vec![article.body]);
980
981 let mut response = Response {
982 results: to_map(vec![Ok(value.clone().into_inner())]),
983 ..Response::new()
984 };
985 let vec: Vec<String> = response.take("title").unwrap();
986 assert_eq!(vec, vec![article.title]);
987 }
988
989 #[test]
990 fn take_partial_records() {
991 let mut response = Response {
992 results: to_map(vec![Ok(vec![val::Value::from(true), val::Value::from(false)].into())]),
993 ..Response::new()
994 };
995 let value: Value = response.take(0).unwrap();
996 assert_eq!(
997 value.into_inner(),
998 val::Value::from(vec![val::Value::from(true), val::Value::from(false)])
999 );
1000
1001 let mut response = Response {
1002 results: to_map(vec![Ok(vec![val::Value::from(true), val::Value::from(false)].into())]),
1003 ..Response::new()
1004 };
1005 let vec: Vec<bool> = response.take(0).unwrap();
1006 assert_eq!(vec, vec![true, false]);
1007
1008 let mut response = Response {
1009 results: to_map(vec![Ok(vec![val::Value::from(true), val::Value::from(false)].into())]),
1010 ..Response::new()
1011 };
1012
1013 let Err(e) = response.take::<Option<bool>>(0) else {
1014 panic!("silently dropping records not allowed");
1015 };
1016 let Ok(Error::LossyTake(Response {
1017 results: mut map,
1018 ..
1019 })) = e.downcast()
1020 else {
1021 panic!("silently dropping records not allowed");
1022 };
1023
1024 let records = map.swap_remove(&0).unwrap().1.unwrap();
1025 assert_eq!(
1026 records,
1027 val::Value::from(vec![val::Value::from(true), val::Value::from(false)])
1028 );
1029 }
1030
1031 #[test]
1032 fn check_returns_the_first_error() {
1033 let response = vec![
1034 Ok(0.into()),
1035 Ok(1.into()),
1036 Ok(2.into()),
1037 Err(Error::ConnectionUninitialised.into()),
1038 Ok(3.into()),
1039 Ok(4.into()),
1040 Ok(5.into()),
1041 Err(Error::BackupsNotSupported.into()),
1042 Ok(6.into()),
1043 Ok(7.into()),
1044 Err(Error::DuplicateRequestId(0).into()),
1045 ];
1046 let response = Response {
1047 results: to_map(response),
1048 ..Response::new()
1049 };
1050 let Some(Error::ConnectionUninitialised) = response.check().unwrap_err().downcast_ref()
1051 else {
1052 panic!("check did not return the first error");
1053 };
1054 }
1055
1056 #[test]
1057 fn take_errors() {
1058 let response = vec![
1059 Ok(0.into()),
1060 Ok(1.into()),
1061 Ok(2.into()),
1062 Err(Error::ConnectionUninitialised.into()),
1063 Ok(3.into()),
1064 Ok(4.into()),
1065 Ok(5.into()),
1066 Err(Error::BackupsNotSupported.into()),
1067 Ok(6.into()),
1068 Ok(7.into()),
1069 Err(Error::DuplicateRequestId(0).into()),
1070 ];
1071 let mut response = Response {
1072 results: to_map(response),
1073 ..Response::new()
1074 };
1075 let errors = response.take_errors();
1076 assert_eq!(response.num_statements(), 8);
1077 assert_eq!(errors.len(), 3);
1078 let Some(Error::DuplicateRequestId(0)) = errors[&10].downcast_ref() else {
1079 panic!("index `10` is not `DuplicateRequestId`");
1080 };
1081 let Some(Error::BackupsNotSupported) = errors[&7].downcast_ref() else {
1082 panic!("index `7` is not `BackupsNotSupported`");
1083 };
1084 let Some(Error::ConnectionUninitialised) = errors[&3].downcast_ref() else {
1085 panic!("index `3` is not `ConnectionUninitialised`");
1086 };
1087 let Some(value): Option<i32> = response.take(2).unwrap() else {
1088 panic!("statement not found");
1089 };
1090 assert_eq!(value, 2);
1091 let value: Value = response.take(4).unwrap();
1092 assert_eq!(value.into_inner(), val::Value::from(3));
1093 }
1094}