1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
use std::ops::Deref;

use futures::{Future, Stream, StreamExt};
use jsonrpsee::{
    types::{Params, ResponsePayload},
    RpcModule, SubscriptionCloseResponse, SubscriptionMessage,
};
use serde::Serialize;
use serde_json::json;

use crate::FromContext;

/// Builder to construct the RPC module. Handlers can be registered using the [`RpcBuilder::query`]
/// and [`RpcBuilder::subscription`] methods. It tracks an internally mutable [`RpcModule`] and
/// it's namespace, ensuring that handlers names are correctly created.
///
/// For the most part, this should not be used manually, but rather with the [`qubit_macros::handler`]
/// macro.
pub struct RpcBuilder<Ctx> {
    /// The namespace for this module, which will be prepended onto handler names (if present).
    namespace: Option<&'static str>,

    /// The actual [`RpcModule`] that is being constructed.
    module: RpcModule<Ctx>,
}

impl<Ctx> RpcBuilder<Ctx>
where
    Ctx: Clone + Send + Sync + 'static,
{
    /// Create a builder with the provided namespace.
    pub(crate) fn with_namespace(ctx: Ctx, namespace: Option<&'static str>) -> Self {
        Self {
            namespace,
            module: RpcModule::new(ctx),
        }
    }

    /// Consume the builder to produce the internal [`RpcModule`], ready to be used.
    pub(crate) fn build(self) -> RpcModule<Ctx> {
        self.module
    }

    /// Register a new query handler with the provided name.
    ///
    /// The `handler` can take its own `Ctx`, so long as it implements [`FromContext`]. It must
    /// return a future which outputs a serializable value.
    pub fn query<T, C, F, Fut>(mut self, name: &'static str, handler: F) -> Self
    where
        T: Serialize + Clone + 'static,
        C: FromContext<Ctx>,
        F: Fn(C, Params<'static>) -> Fut + Send + Sync + Clone + 'static,
        Fut: Future<Output = T> + Send + 'static,
    {
        self.module
            .register_async_method(self.namespace_str(name), move |params, ctx, _extensions| {
                // NOTE: Handler has to be cloned in since `register_async_method` takes `Fn`, not
                // `FnOnce`. Not sure if it's better to be an `Rc`/leaked/???
                let handler = handler.clone();

                async move {
                    // Build the context
                    let ctx = match C::from_app_ctx(ctx.deref().clone()).await {
                        Ok(ctx) => ctx,
                        Err(e) => {
                            // Handle any error building the context by turning it into a response
                            // payload.
                            return ResponsePayload::Error(e.into());
                        }
                    };

                    // Run the actual handler
                    ResponsePayload::success(handler(ctx, params).await)
                }
            })
            .unwrap();

        self
    }

    /// Register a new subscription handler with the provided name.
    ///
    /// The `handler` can take its own `Ctx`, so long as it implements [`FromContext`]. It must
    /// return a future that outputs a stream of serializable values.
    pub fn subscription<T, C, F, Fut, S>(
        mut self,
        name: &'static str,
        notification_name: &'static str,
        unsubscribe_name: &'static str,
        handler: F,
    ) -> Self
    where
        T: Serialize + Send + Clone + 'static,
        C: FromContext<Ctx>,
        F: Fn(C, Params<'static>) -> Fut + Send + Sync + Clone + 'static,
        Fut: Future<Output = S> + Send + 'static,
        S: Stream<Item = T> + Send + 'static,
    {
        self.module
            .register_subscription(
                self.namespace_str(name),
                self.namespace_str(notification_name),
                self.namespace_str(unsubscribe_name),
                move |params, subscription, ctx, _extensions| {
                    // NOTE: Same deal here with cloning the handler as in the query registration.
                    let handler = handler.clone();

                    async move {
                        // Accept the subscription
                        let subscription = subscription.accept().await.unwrap();

                        // Set up a channel to avoid cloning the subscription
                        let (tx, mut rx) = tokio::sync::mpsc::channel(10);

                        // Track the number of items emitted through the subscription
                        let mut count = 0;
                        let subscription_id = subscription.subscription_id();

                        // Recieve values on a new thread, sending them onwards to the subscription
                        tokio::spawn(async move {
                            while let Some(value) = rx.recv().await {
                                if subscription.is_closed() {
                                    // Don't continue processing items once the web socket is
                                    // closed
                                    break;
                                }

                                subscription
                                    .send(SubscriptionMessage::from_json(&value).unwrap())
                                    .await
                                    .unwrap();
                            }
                        });

                        // Build the context
                        // NOTE: It won't be held across await so that `C` doesn't have to be
                        // `Send`
                        let ctx = match C::from_app_ctx(ctx.deref().clone()).await {
                            Ok(ctx) => ctx,
                            Err(e) => {
                                // Handle any error building the context by turning it into a
                                // subscriptions close response
                                return SubscriptionCloseResponse::NotifErr(
                                    SubscriptionMessage::from_json(&e).unwrap(),
                                );
                            }
                        };

                        // Run the handler, capturing each of the values sand forwarding it onwards
                        // to the channel
                        let mut stream = Box::pin(handler(ctx, params).await);

                        while let Some(value) = stream.next().await {
                            if tx.send(value).await.is_ok() {
                                count += 1;
                            } else {
                                break;
                            }
                        }

                        // Notify that stream is closing
                        SubscriptionCloseResponse::Notif(
                            SubscriptionMessage::from_json(
                                &json!({ "close_stream": subscription_id, "count": count }),
                            )
                            .unwrap(),
                        )
                    }
                },
            )
            .unwrap();

        self
    }

    /// Helper to 'resolve' some string with the namespace of this module (if it's present)
    fn namespace_str(&self, s: &'static str) -> &'static str {
        if let Some(namespace) = self.namespace {
            Box::leak(format!("{namespace}.{s}").into_boxed_str())
        } else {
            s
        }
    }
}