Skip to main content

async_graphql/dynamic/
subscription.rs

1use std::{borrow::Cow, fmt, fmt::Debug, sync::Arc};
2
3use futures_util::{
4    Future, FutureExt, Stream, StreamExt, TryStreamExt, future::BoxFuture, stream::BoxStream,
5};
6use indexmap::IndexMap;
7
8use crate::{
9    ContextSelectionSet, Data, QueryPathNode, QueryPathSegment, Response, Result, ServerResult,
10    Value,
11    dynamic::{
12        FieldValue, InputValue, ObjectAccessor, ResolverContext, Schema, SchemaError, TypeRef,
13        resolve::resolve,
14    },
15    extensions::ResolveInfo,
16    parser::types::Selection,
17    registry::{Deprecation, MetaField, MetaType, Registry},
18    subscription::BoxFieldStream,
19};
20
21type BoxResolveFut<'a> = BoxFuture<'a, Result<BoxStream<'a, Result<FieldValue<'a>>>>>;
22
23/// A future that returned from field resolver
24pub struct SubscriptionFieldFuture<'a>(pub(crate) BoxResolveFut<'a>);
25
26impl<'a> SubscriptionFieldFuture<'a> {
27    /// Create a ResolverFuture
28    pub fn new<Fut, S, T>(future: Fut) -> Self
29    where
30        Fut: Future<Output = Result<S>> + Send + 'a,
31        S: Stream<Item = Result<T>> + Send + 'a,
32        T: Into<FieldValue<'a>> + Send + 'a,
33    {
34        Self(
35            async move {
36                let res = future.await?.map_ok(Into::into);
37                Ok(res.boxed())
38            }
39            .boxed(),
40        )
41    }
42}
43
44type BoxResolverFn =
45    Arc<dyn for<'a> Fn(ResolverContext<'a>) -> SubscriptionFieldFuture<'a> + Send + Sync>;
46
47/// A GraphQL subscription field
48pub struct SubscriptionField {
49    pub(crate) name: String,
50    pub(crate) description: Option<String>,
51    pub(crate) arguments: IndexMap<String, InputValue>,
52    pub(crate) ty: TypeRef,
53    pub(crate) resolver_fn: BoxResolverFn,
54    pub(crate) deprecation: Deprecation,
55}
56
57impl SubscriptionField {
58    /// Create a GraphQL subscription field
59    pub fn new<N, T, F>(name: N, ty: T, resolver_fn: F) -> Self
60    where
61        N: Into<String>,
62        T: Into<TypeRef>,
63        F: for<'a> Fn(ResolverContext<'a>) -> SubscriptionFieldFuture<'a> + Send + Sync + 'static,
64    {
65        Self {
66            name: name.into(),
67            description: None,
68            arguments: Default::default(),
69            ty: ty.into(),
70            resolver_fn: Arc::new(resolver_fn),
71            deprecation: Deprecation::NoDeprecated,
72        }
73    }
74
75    impl_set_description!();
76    impl_set_deprecation!();
77
78    /// Add an argument to the subscription field
79    #[inline]
80    pub fn argument(mut self, input_value: InputValue) -> Self {
81        self.arguments.insert(input_value.name.clone(), input_value);
82        self
83    }
84}
85
86impl Debug for SubscriptionField {
87    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
88        f.debug_struct("Field")
89            .field("name", &self.name)
90            .field("description", &self.description)
91            .field("arguments", &self.arguments)
92            .field("ty", &self.ty)
93            .field("deprecation", &self.deprecation)
94            .finish()
95    }
96}
97
98/// A GraphQL subscription type
99#[derive(Debug)]
100pub struct Subscription {
101    pub(crate) name: String,
102    pub(crate) description: Option<String>,
103    pub(crate) fields: IndexMap<String, SubscriptionField>,
104}
105
106impl Subscription {
107    /// Create a GraphQL object type
108    #[inline]
109    pub fn new(name: impl Into<String>) -> Self {
110        Self {
111            name: name.into(),
112            description: None,
113            fields: Default::default(),
114        }
115    }
116
117    impl_set_description!();
118
119    /// Add an field to the object
120    #[inline]
121    pub fn field(mut self, field: SubscriptionField) -> Self {
122        assert!(
123            !self.fields.contains_key(&field.name),
124            "Field `{}` already exists",
125            field.name
126        );
127        self.fields.insert(field.name.clone(), field);
128        self
129    }
130
131    /// Returns the type name
132    #[inline]
133    pub fn type_name(&self) -> &str {
134        &self.name
135    }
136
137    pub(crate) fn register(&self, registry: &mut Registry) -> Result<(), SchemaError> {
138        let mut fields = IndexMap::new();
139
140        for field in self.fields.values() {
141            let mut args = IndexMap::new();
142
143            for argument in field.arguments.values() {
144                args.insert(argument.name.clone(), argument.to_meta_input_value());
145            }
146
147            fields.insert(
148                field.name.clone(),
149                MetaField {
150                    name: field.name.clone(),
151                    description: field.description.clone(),
152                    args,
153                    ty: field.ty.to_string(),
154                    deprecation: field.deprecation.clone(),
155                    cache_control: Default::default(),
156                    external: false,
157                    requires: None,
158                    provides: None,
159                    visible: None,
160                    shareable: false,
161                    inaccessible: false,
162                    tags: vec![],
163                    override_from: None,
164                    compute_complexity: None,
165                    directive_invocations: vec![],
166                    requires_scopes: vec![],
167                },
168            );
169        }
170
171        registry.types.insert(
172            self.name.clone(),
173            MetaType::Object {
174                name: self.name.clone(),
175                description: self.description.clone(),
176                fields,
177                cache_control: Default::default(),
178                extends: false,
179                shareable: false,
180                resolvable: true,
181                keys: None,
182                visible: None,
183                inaccessible: false,
184                interface_object: false,
185                tags: vec![],
186                is_subscription: true,
187                rust_typename: None,
188                directive_invocations: vec![],
189                requires_scopes: vec![],
190            },
191        );
192
193        Ok(())
194    }
195
196    pub(crate) fn collect_streams<'a>(
197        &self,
198        schema: &Schema,
199        ctx: &ContextSelectionSet<'a>,
200        streams: &mut Vec<BoxFieldStream<'a>>,
201        root_value: &'a FieldValue<'static>,
202    ) {
203        for selection in &ctx.item.node.items {
204            if let Selection::Field(field) = &selection.node
205                && let Some(field_def) = self.fields.get(field.node.name.node.as_str())
206            {
207                let schema = schema.clone();
208                let field_type = field_def.ty.clone();
209                let resolver_fn = field_def.resolver_fn.clone();
210                let ctx = ctx.clone();
211
212                streams.push(
213                    asynk_strim::try_stream_fn(move |mut yielder| async move {
214                        let ctx_field = ctx.with_field(field);
215                        let field_name = ctx_field.item.node.response_key().node.clone();
216                        let arguments = ObjectAccessor(Cow::Owned(
217                            field
218                                .node
219                                .arguments
220                                .iter()
221                                .map(|(name, value)| {
222                                    // Drop omitted variable-backed arguments instead of
223                                    // materializing them as `null`.
224                                    ctx_field
225                                        .resolve_input_value(value.clone())
226                                        .map(|value| value.map(|value| (name.node.clone(), value)))
227                                })
228                                .collect::<ServerResult<Vec<_>>>()?
229                                .into_iter()
230                                .flatten()
231                                .collect::<IndexMap<_, _>>(),
232                        ));
233
234                        let mut stream = resolver_fn(ResolverContext {
235                            ctx: &ctx_field,
236                            args: arguments,
237                            parent_value: root_value,
238                        })
239                        .0
240                        .await
241                        .map_err(|err| {
242                            ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos))
243                        })?;
244
245                        while let Some(value) = stream.next().await.transpose().map_err(|err| {
246                            ctx_field.set_error_path(err.into_server_error(ctx_field.item.pos))
247                        })? {
248                            let f = |execute_data: Option<Data>| {
249                                let schema = schema.clone();
250                                let field_name = field_name.clone();
251                                let field_type = field_type.clone();
252                                let ctx_field = ctx_field.clone();
253
254                                async move {
255                                    let mut ctx_field = ctx_field.clone();
256                                    ctx_field.execute_data = execute_data.as_ref();
257                                    let ri = ResolveInfo {
258                                        path_node: &QueryPathNode {
259                                            parent: None,
260                                            segment: QueryPathSegment::Name(&field_name),
261                                        },
262                                        parent_type: schema
263                                            .0
264                                            .env
265                                            .registry
266                                            .subscription_type
267                                            .as_ref()
268                                            .unwrap(),
269                                        return_type: &field_type.to_string(),
270                                        name: field.node.name.node.as_str(),
271                                        alias: field
272                                            .node
273                                            .alias
274                                            .as_ref()
275                                            .map(|alias| alias.node.as_str()),
276                                        is_for_introspection: false,
277                                        field: &field.node,
278                                    };
279                                    let resolve_fut =
280                                        resolve(&schema, &ctx_field, &field_type, Some(&value));
281                                    futures_util::pin_mut!(resolve_fut);
282                                    let value = ctx_field
283                                        .query_env
284                                        .extensions
285                                        .resolve(ri, &mut resolve_fut)
286                                        .await;
287
288                                    match value {
289                                        Ok(value) => {
290                                            let mut map = IndexMap::new();
291                                            map.insert(
292                                                field_name.clone(),
293                                                value.unwrap_or_default(),
294                                            );
295                                            Response::new(Value::Object(map))
296                                        }
297                                        Err(err) => Response::from_errors(vec![err]),
298                                    }
299                                }
300                            };
301                            let resp = ctx_field
302                                .query_env
303                                .extensions
304                                .execute(ctx_field.query_env.operation_name.as_deref(), f)
305                                .await;
306                            let is_err = !resp.errors.is_empty();
307                            yielder.yield_ok(resp).await;
308                            if is_err {
309                                break;
310                            }
311                        }
312
313                        Ok(())
314                    })
315                    .map(|res| res.unwrap_or_else(|err| Response::from_errors(vec![err])))
316                    .boxed(),
317                );
318            }
319        }
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use std::time::Duration;
326
327    use futures_util::StreamExt;
328
329    use crate::{Value, dynamic::*, value};
330
331    #[tokio::test]
332    async fn subscription() {
333        struct MyObjData {
334            value: i32,
335        }
336
337        let my_obj = Object::new("MyObject").field(Field::new(
338            "value",
339            TypeRef::named_nn(TypeRef::INT),
340            |ctx| {
341                FieldFuture::new(async {
342                    Ok(Some(Value::from(
343                        ctx.parent_value.try_downcast_ref::<MyObjData>()?.value,
344                    )))
345                })
346            },
347        ));
348
349        let query = Object::new("Query").field(Field::new(
350            "value",
351            TypeRef::named_nn(TypeRef::INT),
352            |_| FieldFuture::new(async { Ok(FieldValue::none()) }),
353        ));
354
355        let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
356            "obj",
357            TypeRef::named_nn(my_obj.type_name()),
358            |_| {
359                SubscriptionFieldFuture::new(async {
360                    Ok(asynk_strim::try_stream_fn(|mut yielder| async move {
361                        for i in 0..10 {
362                            tokio::time::sleep(Duration::from_millis(100)).await;
363                            yielder
364                                .yield_ok(FieldValue::owned_any(MyObjData { value: i }))
365                                .await;
366                        }
367
368                        Ok(())
369                    }))
370                })
371            },
372        ));
373
374        let schema = Schema::build(query.type_name(), None, Some(subscription.type_name()))
375            .register(my_obj)
376            .register(query)
377            .register(subscription)
378            .finish()
379            .unwrap();
380
381        let mut stream = schema.execute_stream("subscription { obj { value } }");
382        for i in 0..10 {
383            assert_eq!(
384                stream.next().await.unwrap().into_result().unwrap().data,
385                value!({
386                    "obj": { "value": i }
387                })
388            );
389        }
390    }
391
392    #[tokio::test]
393    async fn borrow_context() {
394        struct State {
395            value: i32,
396        }
397
398        let query =
399            Object::new("Query").field(Field::new("value", TypeRef::named(TypeRef::INT), |_| {
400                FieldFuture::new(async { Ok(FieldValue::NONE) })
401            }));
402
403        let subscription = Subscription::new("Subscription").field(SubscriptionField::new(
404            "values",
405            TypeRef::named_nn(TypeRef::INT),
406            |ctx| {
407                SubscriptionFieldFuture::new(async move {
408                    Ok(asynk_strim::try_stream_fn(|mut yielder| async move {
409                        for i in 0..10 {
410                            tokio::time::sleep(Duration::from_millis(100)).await;
411                            yielder
412                                .yield_ok(FieldValue::value(
413                                    ctx.data_unchecked::<State>().value + i,
414                                ))
415                                .await;
416                        }
417
418                        Ok(())
419                    }))
420                })
421            },
422        ));
423
424        let schema = Schema::build("Query", None, Some(subscription.type_name()))
425            .register(query)
426            .register(subscription)
427            .data(State { value: 123 })
428            .finish()
429            .unwrap();
430
431        let mut stream = schema.execute_stream("subscription { values }");
432        for i in 0..10 {
433            assert_eq!(
434                stream.next().await.unwrap().into_result().unwrap().data,
435                value!({ "values": i + 123 })
436            );
437        }
438    }
439}