1use std::{borrow::Cow, fmt, fmt::Debug, sync::Arc};
2
3use futures_util::{
4 Future, FutureExt, Stream, StreamExt, TryStreamExt, future::BoxFuture, stream::BoxStream,
5};
6use indexmap::IndexMap;
7
8use crate::{
9 ContextSelectionSet, Data, QueryPathNode, QueryPathSegment, Response, Result, ServerResult,
10 Value,
11 dynamic::{
12 FieldValue, InputValue, ObjectAccessor, ResolverContext, Schema, SchemaError, TypeRef,
13 resolve::resolve,
14 },
15 extensions::ResolveInfo,
16 parser::types::Selection,
17 registry::{Deprecation, MetaField, MetaType, Registry},
18 subscription::BoxFieldStream,
19};
20
21type BoxResolveFut<'a> = BoxFuture<'a, Result<BoxStream<'a, Result<FieldValue<'a>>>>>;
22
23pub struct SubscriptionFieldFuture<'a>(pub(crate) BoxResolveFut<'a>);
25
26impl<'a> SubscriptionFieldFuture<'a> {
27 pub fn new<Fut, S, T>(future: Fut) -> Self
29 where
30 Fut: Future<Output = Result<S>> + Send + 'a,
31 S: Stream<Item = Result<T>> + Send + 'a,
32 T: Into<FieldValue<'a>> + Send + 'a,
33 {
34 Self(
35 async move {
36 let res = future.await?.map_ok(Into::into);
37 Ok(res.boxed())
38 }
39 .boxed(),
40 )
41 }
42}
43
44type BoxResolverFn =
45 Arc<dyn for<'a> Fn(ResolverContext<'a>) -> SubscriptionFieldFuture<'a> + Send + Sync>;
46
47pub struct SubscriptionField {
49 pub(crate) name: String,
50 pub(crate) description: Option<String>,
51 pub(crate) arguments: IndexMap<String, InputValue>,
52 pub(crate) ty: TypeRef,
53 pub(crate) resolver_fn: BoxResolverFn,
54 pub(crate) deprecation: Deprecation,
55}
56
57impl SubscriptionField {
58 pub fn new<N, T, F>(name: N, ty: T, resolver_fn: F) -> Self
60 where
61 N: Into<String>,
62 T: Into<TypeRef>,
63 F: for<'a> Fn(ResolverContext<'a>) -> SubscriptionFieldFuture<'a> + Send + Sync + 'static,
64 {
65 Self {
66 name: name.into(),
67 description: None,
68 arguments: Default::default(),
69 ty: ty.into(),
70 resolver_fn: Arc::new(resolver_fn),
71 deprecation: Deprecation::NoDeprecated,
72 }
73 }
74
75 impl_set_description!();
76 impl_set_deprecation!();
77
78 #[inline]
80 pub fn argument(mut self, input_value: InputValue) -> Self {
81 self.arguments.insert(input_value.name.clone(), input_value);
82 self
83 }
84}
85
86impl Debug for SubscriptionField {
87 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88 f.debug_struct("Field")
89 .field("name", &self.name)
90 .field("description", &self.description)
91 .field("arguments", &self.arguments)
92 .field("ty", &self.ty)
93 .field("deprecation", &self.deprecation)
94 .finish()
95 }
96}
97
98#[derive(Debug)]
100pub struct Subscription {
101 pub(crate) name: String,
102 pub(crate) description: Option<String>,
103 pub(crate) fields: IndexMap<String, SubscriptionField>,
104}
105
106impl Subscription {
107 #[inline]
109 pub fn new(name: impl Into<String>) -> Self {
110 Self {
111 name: name.into(),
112 description: None,
113 fields: Default::default(),
114 }
115 }
116
117 impl_set_description!();
118
119 #[inline]
121 pub fn field(mut self, field: SubscriptionField) -> Self {
122 assert!(
123 !self.fields.contains_key(&field.name),
124 "Field `{}` already exists",
125 field.name
126 );
127 self.fields.insert(field.name.clone(), field);
128 self
129 }
130
131 #[inline]
133 pub fn type_name(&self) -> &str {
134 &self.name
135 }
136
137 pub(crate) fn register(&self, registry: &mut Registry) -> Result<(), SchemaError> {
138 let mut fields = IndexMap::new();
139
140 for field in self.fields.values() {
141 let mut args = IndexMap::new();
142
143 for argument in field.arguments.values() {
144 args.insert(argument.name.clone(), argument.to_meta_input_value());
145 }
146
147 fields.insert(
148 field.name.clone(),
149 MetaField {
150 name: field.name.clone(),
151 description: field.description.clone(),
152 args,
153 ty: field.ty.to_string(),
154 deprecation: field.deprecation.clone(),
155 cache_control: Default::default(),
156 external: false,
157 requires: None,
158 provides: None,
159 visible: None,
160 shareable: false,
161 inaccessible: false,
162 tags: vec![],
163 override_from: None,
164 compute_complexity: None,
165 directive_invocations: vec![],
166 requires_scopes: vec![],
167 },
168 );
169 }
170
171 registry.types.insert(
172 self.name.clone(),
173 MetaType::Object {
174 name: self.name.clone(),
175 description: self.description.clone(),
176 fields,
177 cache_control: Default::default(),
178 extends: false,
179 shareable: false,
180 resolvable: true,
181 keys: None,
182 visible: None,
183 inaccessible: false,
184 interface_object: false,
185 tags: vec![],
186 is_subscription: true,
187 rust_typename: None,
188 directive_invocations: vec![],
189 requires_scopes: vec![],
190 },
191 );
192
193 Ok(())
194 }
195
196 pub(crate) fn collect_streams<'a>(
197 &self,
198 schema: &Schema,
199 ctx: &ContextSelectionSet<'a>,
200 streams: &mut Vec<BoxFieldStream<'a>>,
201 root_value: &'a FieldValue<'static>,
202 ) {
203 for selection in &ctx.item.node.items {
204 if let Selection::Field(field) = &selection.node
205 && let Some(field_def) = self.fields.get(field.node.name.node.as_str())
206 {
207 let schema = schema.clone();
208 let field_type = field_def.ty.clone();
209 let resolver_fn = field_def.resolver_fn.clone();
210 let ctx = ctx.clone();
211
212 streams.push(
213 asynk_strim::try_stream_fn(move |mut yielder| async move {
214 let ctx_field = ctx.with_field(field);
215 let field_name = ctx_field.item.node.response_key().node.clone();
216 let arguments = ObjectAccessor(Cow::Owned(
217 field
218 .node
219 .arguments
220 .iter()
221 .map(|(name, value)| {
222 ctx_field
225 .resolve_input_value(value.clone())
226 .map(|value| value.map(|value| (name.node.clone(), value)))
227 })
228 .collect::<ServerResult<Vec<_>>>()?
229 .into_iter()
230 .flatten()
231 .collect::<IndexMap<_, _>>(),
232 ));
233
234 let mut stream = resolver_fn(ResolverContext {
235 ctx: &ctx_field,
236 args: arguments,
237 parent_value: root_value,
238 })
239 .0
240 .await
241 .map_err(|err| {
242 ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos))
243 })?;
244
245 while let Some(value) = stream.next().await.transpose().map_err(|err| {
246 ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos))
247 })? {
248 let f = |execute_data: Option<Data>| {
249 let schema = schema.clone();
250 let field_name = field_name.clone();
251 let field_type = field_type.clone();
252 let ctx_field = ctx_field.clone();
253
254 async move {
255 let mut ctx_field = ctx_field.clone();
256 ctx_field.execute_data = execute_data.as_ref();
257 let ri = ResolveInfo {
258 path_node: &QueryPathNode {
259 parent: None,
260 segment: QueryPathSegment::Name(&field_name),
261 },
262 parent_type: schema
263 .0
264 .env
265 .registry
266 .subscription_type
267 .as_ref()
268 .unwrap(),
269 return_type: &field_type.to_string(),
270 name: field.node.name.node.as_str(),
271 alias: field
272 .node
273 .alias
274 .as_ref()
275 .map(|alias| alias.node.as_str()),
276 is_for_introspection: false,
277 field: &field.node,
278 };
279 let resolve_fut =
280 resolve(&schema, &ctx_field, &field_type, Some(&value));
281 futures_util::pin_mut!(resolve_fut);
282 let value = ctx_field
283 .query_env
284 .extensions
285 .resolve(ri, &mut resolve_fut)
286 .await;
287
288 match value {
289 Ok(value) => {
290 let mut map = IndexMap::new();
291 map.insert(
292 field_name.clone(),
293 value.unwrap_or_default(),
294 );
295 Response::new(Value::Object(map))
296 }
297 Err(err) => Response::from_errors(vec![err]),
298 }
299 }
300 };
301 let resp = ctx_field
302 .query_env
303 .extensions
304 .execute(ctx_field.query_env.operation_name.as_deref(), f)
305 .await;
306 let is_err = !resp.errors.is_empty();
307 yielder.yield_ok(resp).await;
308 if is_err {
309 break;
310 }
311 }
312
313 Ok(())
314 })
315 .map(|res| res.unwrap_or_else(|err| Response::from_errors(vec![err])))
316 .boxed(),
317 );
318 }
319 }
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use std::time::Duration;
326
327 use futures_util::StreamExt;
328
329 use crate::{Value, dynamic::*, value};
330
331 #[tokio::test]
332 async fn subscription() {
333 struct MyObjData {
334 value: i32,
335 }
336
337 let my_obj = Object::new("MyObject").field(Field::new(
338 "value",
339 TypeRef::named_nn(TypeRef::INT),
340 |ctx| {
341 FieldFuture::new(async {
342 Ok(Some(Value::from(
343 ctx.parent_value.try_downcast_ref::<MyObjData>()?.value,
344 )))
345 })
346 },
347 ));
348
349 let query = Object::new("Query").field(Field::new(
350 "value",
351 TypeRef::named_nn(TypeRef::INT),
352 |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
353 ));
354
355 let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
356 "obj",
357 TypeRef::named_nn(my_obj.type_name()),
358 |_| {
359 SubscriptionFieldFuture::new(async {
360 Ok(asynk_strim::try_stream_fn(|mut yielder| async move {
361 for i in 0..10 {
362 tokio::time::sleep(Duration::from_millis(100)).await;
363 yielder
364 .yield_ok(FieldValue::owned_any(MyObjData { value: i }))
365 .await;
366 }
367
368 Ok(())
369 }))
370 })
371 },
372 ));
373
374 let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
375 .register(my_obj)
376 .register(query)
377 .register(subscription)
378 .finish()
379 .unwrap();
380
381 let mut stream = schema.execute_stream("subscription { obj { value } }");
382 for i in 0..10 {
383 assert_eq!(
384 stream.next().await.unwrap().into_result().unwrap().data,
385 value!({
386 "obj": { "value": i }
387 })
388 );
389 }
390 }
391
392 #[tokio::test]
393 async fn borrow_context() {
394 struct State {
395 value: i32,
396 }
397
398 let query =
399 Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
400 FieldFuture::new(async { Ok(FieldValue::NONE) })
401 }));
402
403 let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
404 "values",
405 TypeRef::named_nn(TypeRef::INT),
406 |ctx| {
407 SubscriptionFieldFuture::new(async move {
408 Ok(asynk_strim::try_stream_fn(|mut yielder| async move {
409 for i in 0..10 {
410 tokio::time::sleep(Duration::from_millis(100)).await;
411 yielder
412 .yield_ok(FieldValue::value(
413 ctx.data_unchecked::<State>().value + i,
414 ))
415 .await;
416 }
417
418 Ok(())
419 }))
420 })
421 },
422 ));
423
424 let schema = Schema::build("Query", None, Some(subscription.type_name()))
425 .register(query)
426 .register(subscription)
427 .data(State { value: 123 })
428 .finish()
429 .unwrap();
430
431 let mut stream = schema.execute_stream("subscription { values }");
432 for i in 0..10 {
433 assert_eq!(
434 stream.next().await.unwrap().into_result().unwrap().data,
435 value!({ "values": i + 123 })
436 );
437 }
438 }
439}