Skip to main content

forge_dioxus/
client.rs

1
2use std::cell::{Cell, RefCell};
3use std::rc::Rc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use dioxus::prelude::{Signal, WritableExt, dioxus_core::Task};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9
10use crate::types::{
11    ConnectionState, ForgeClientError, ForgeError, RpcEnvelopeRaw, SseEnvelopeRaw, StreamEvent,
12};
13
14type TokenProvider = Rc<dyn Fn() -> Option<String>>;
15type AuthErrorHandler = Rc<dyn Fn(ForgeError)>;
16
17static NEXT_SUBSCRIPTION_ID: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Clone)]
20pub struct ForgeClientConfig {
21    pub url: String,
22    pub get_token: Option<TokenProvider>,
23    pub on_auth_error: Option<AuthErrorHandler>,
24    pub(crate) connection_state: Option<Signal<ConnectionState>>,
25}
26
27impl ForgeClientConfig {
28    pub fn new(url: impl Into<String>) -> Self {
29        Self {
30            url: url.into(),
31            get_token: None,
32            on_auth_error: None,
33            connection_state: None,
34        }
35    }
36
37    pub fn with_token_provider(mut self, provider: impl Fn() -> Option<String> + 'static) -> Self {
38        self.get_token = Some(Rc::new(provider));
39        self
40    }
41
42    pub fn with_auth_error_handler(
43        mut self,
44        handler: impl Fn(ForgeError) + 'static,
45    ) -> Self {
46        self.on_auth_error = Some(Rc::new(handler));
47        self
48    }
49
50    pub(crate) fn with_connection_state(mut self, state: Signal<ConnectionState>) -> Self {
51        self.connection_state = Some(state);
52        self
53    }
54}
55
56#[derive(Clone)]
57pub struct ForgeClient {
58    inner: Rc<ForgeClientInner>,
59}
60
61struct ForgeClientInner {
62    url: String,
63    get_token: Option<TokenProvider>,
64    on_auth_error: Option<AuthErrorHandler>,
65    connection_state: Option<Signal<ConnectionState>>,
66}
67
68impl ForgeClient {
69    pub fn new(config: ForgeClientConfig) -> Self {
70        Self {
71            inner: Rc::new(ForgeClientInner {
72                url: config.url.trim_end_matches('/').to_string(),
73                get_token: config.get_token,
74                on_auth_error: config.on_auth_error,
75                connection_state: config.connection_state,
76            }),
77        }
78    }
79
80    pub async fn call<TArgs, TResult>(
81        &self,
82        function_name: &str,
83        args: TArgs,
84    ) -> Result<TResult, ForgeClientError>
85    where
86        TArgs: Serialize,
87        TResult: DeserializeOwned,
88    {
89        let body = serde_json::json!({ "args": args });
90        let envelope = platform::request_json(
91            self,
92            &format!("{}/_api/rpc/{}", self.inner.url, function_name),
93            body,
94        )
95        .await?;
96        self.decode_envelope(envelope)
97    }
98
99    #[cfg(target_arch = "wasm32")]
100    pub async fn call_multipart<TResult>(
101        &self,
102        function_name: &str,
103        form: web_sys::FormData,
104    ) -> Result<TResult, ForgeClientError>
105    where
106        TResult: DeserializeOwned,
107    {
108        let envelope = platform::request_multipart(
109            self,
110            &format!("{}/_api/rpc/{}/upload", self.inner.url, function_name),
111            form,
112        )
113        .await?;
114        self.decode_envelope(envelope)
115    }
116
117    #[cfg(not(target_arch = "wasm32"))]
118    pub async fn call_multipart<TResult>(
119        &self,
120        function_name: &str,
121        form: reqwest::multipart::Form,
122    ) -> Result<TResult, ForgeClientError>
123    where
124        TResult: DeserializeOwned,
125    {
126        let envelope = platform::request_multipart(
127            self,
128            &format!("{}/_api/rpc/{}/upload", self.inner.url, function_name),
129            form,
130        )
131        .await?;
132        self.decode_envelope(envelope)
133    }
134
135    pub fn subscribe_query<TArgs, TResult, F>(
136        &self,
137        function_name: &str,
138        args: TArgs,
139        callback: F,
140    ) -> SubscriptionHandle
141    where
142        TArgs: Serialize + Clone + 'static,
143        TResult: DeserializeOwned + Clone + 'static,
144        F: FnMut(StreamEvent<TResult>) + 'static,
145    {
146        platform::subscribe_query(self.clone(), function_name.to_string(), args, callback)
147    }
148
149    pub fn subscribe_job<TResult, F>(&self, job_id: String, callback: F) -> SubscriptionHandle
150    where
151        TResult: DeserializeOwned + Clone + 'static,
152        F: FnMut(StreamEvent<TResult>) + 'static,
153    {
154        self.subscribe_tracker(
155            "job",
156            serde_json::json!({ "job_id": job_id }),
157            "/_api/subscribe-job",
158            callback,
159        )
160    }
161
162    pub fn subscribe_workflow<TResult, F>(
163        &self,
164        workflow_id: String,
165        callback: F,
166    ) -> SubscriptionHandle
167    where
168        TResult: DeserializeOwned + Clone + 'static,
169        F: FnMut(StreamEvent<TResult>) + 'static,
170    {
171        self.subscribe_tracker(
172            "wf",
173            serde_json::json!({ "workflow_id": workflow_id }),
174            "/_api/subscribe-workflow",
175            callback,
176        )
177    }
178
179    fn subscribe_tracker<TResult, F>(
180        &self,
181        prefix: &str,
182        payload: serde_json::Value,
183        endpoint: &str,
184        callback: F,
185    ) -> SubscriptionHandle
186    where
187        TResult: DeserializeOwned + Clone + 'static,
188        F: FnMut(StreamEvent<TResult>) + 'static,
189    {
190        platform::subscribe_tracker(
191            self.clone(),
192            prefix.to_string(),
193            payload,
194            endpoint.to_string(),
195            callback,
196        )
197    }
198
199    fn get_token(&self) -> Option<String> {
200        self.inner
201            .get_token
202            .as_ref()
203            .and_then(|provider| provider())
204            .filter(|t| !t.is_empty())
205    }
206
207    fn emit_connection<TValue, T>(&self, callback: &Rc<RefCell<T>>, state: ConnectionState)
208    where
209        T: FnMut(StreamEvent<TValue>),
210    {
211        if let Some(mut signal) = self.inner.connection_state {
212            signal.set(state);
213        }
214        (callback.borrow_mut())(StreamEvent::Connection(state));
215    }
216
217    fn emit_error<TValue, T>(&self, callback: &Rc<RefCell<T>>, error: ForgeClientError)
218    where
219        T: FnMut(StreamEvent<TValue>),
220    {
221        if error.code == "UNAUTHORIZED" {
222            if let Some(handler) = &self.inner.on_auth_error {
223                handler(error.as_forge_error());
224            }
225        }
226        (callback.borrow_mut())(StreamEvent::Error(error));
227    }
228
229    fn decode_envelope<TResult>(
230        &self,
231        envelope: RpcEnvelopeRaw,
232    ) -> Result<TResult, ForgeClientError>
233    where
234        TResult: DeserializeOwned,
235    {
236        if !envelope.success {
237            let error = envelope.error.unwrap_or(ForgeError {
238                code: "UNKNOWN".to_string(),
239                message: "Unknown error".to_string(),
240                details: None,
241            });
242            return Err(ForgeClientError::new(error.code, error.message, error.details));
243        }
244
245        let data = envelope.data.ok_or_else(|| {
246            ForgeClientError::new("EMPTY_RESPONSE", "Server returned no data", None)
247        })?;
248        serde_json::from_value(data)
249            .map_err(|err| ForgeClientError::new("DESERIALIZATION_ERROR", err.to_string(), None))
250    }
251
252    fn random_id(&self, prefix: &str) -> String {
253        let id = NEXT_SUBSCRIPTION_ID.fetch_add(1, Ordering::Relaxed);
254        format!("{prefix}-{id}")
255    }
256}
257
258#[derive(Clone)]
259pub struct SubscriptionHandle {
260    closed: Rc<Cell<bool>>,
261    task: Rc<RefCell<Option<Task>>>,
262}
263
264impl SubscriptionHandle {
265    fn new() -> Self {
266        Self {
267            closed: Rc::new(Cell::new(false)),
268            task: Rc::new(RefCell::new(None)),
269        }
270    }
271
272    fn set_task(&self, task: Task) {
273        *self.task.borrow_mut() = Some(task);
274    }
275
276    fn finish(&self) {
277        self.closed.set(true);
278        self.task.borrow_mut().take();
279    }
280
281    pub fn close(&self) {
282        self.closed.set(true);
283        if let Some(task) = self.task.borrow_mut().take() {
284            task.cancel();
285        }
286    }
287
288    pub fn is_closed(&self) -> bool {
289        self.closed.get()
290    }
291}
292
293impl Drop for SubscriptionHandle {
294    fn drop(&mut self) {
295        self.close();
296    }
297}
298
299fn parse_json_str<T>(raw: &str) -> Result<T, ForgeClientError>
300where
301    T: DeserializeOwned,
302{
303    serde_json::from_str(raw)
304        .map_err(|err| ForgeClientError::new("INVALID_SSE_PAYLOAD", err.to_string(), None))
305}
306
307fn emit_sse_error<TValue, T>(
308    client: &ForgeClient,
309    callback: &Rc<RefCell<T>>,
310    envelope: SseEnvelopeRaw,
311) where
312    T: FnMut(StreamEvent<TValue>),
313{
314    client.emit_error(
315        callback,
316        ForgeClientError::new(
317            envelope.code.unwrap_or_else(|| "SSE_ERROR".to_string()),
318            envelope
319                .message
320                .unwrap_or_else(|| "Subscription error".to_string()),
321            None,
322        ),
323    );
324}
325
326#[cfg(target_arch = "wasm32")]
327mod platform {
328    use std::cell::RefCell;
329    use std::rc::Rc;
330
331    use dioxus::prelude::spawn;
332    use futures_util::{StreamExt, stream};
333    use gloo_net::eventsource::futures::{EventSource, EventSourceSubscription};
334    use gloo_net::http::Request;
335    use js_sys::encode_uri_component;
336    use serde::Serialize;
337    use serde::de::DeserializeOwned;
338
339    use super::{ForgeClient, SubscriptionHandle, emit_sse_error, parse_json_str};
340    use crate::types::{
341        ConnectedEvent, ConnectionState, ForgeClientError, RpcEnvelopeRaw, SseEnvelopeRaw,
342        StreamEvent,
343    };
344
345    pub(super) async fn request_json(
346        client: &ForgeClient,
347        url: &str,
348        body: serde_json::Value,
349    ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
350        let mut request = Request::post(url).header("Content-Type", "application/json");
351        if let Some(token) = client.get_token() {
352            request = request.header("Authorization", &format!("Bearer {token}"));
353        }
354
355        let request = request.body(body.to_string()).map_err(request_error)?;
356        request
357            .send()
358            .await
359            .map_err(request_error)?
360            .json()
361            .await
362            .map_err(request_error)
363    }
364
365    pub(super) async fn request_multipart(
366        client: &ForgeClient,
367        url: &str,
368        form: web_sys::FormData,
369    ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
370        let mut request = Request::post(url);
371        if let Some(token) = client.get_token() {
372            request = request.header("Authorization", &format!("Bearer {token}"));
373        }
374
375        let response = request.body(form).map_err(request_error)?;
376        response
377            .send()
378            .await
379            .map_err(request_error)?
380            .json()
381            .await
382            .map_err(request_error)
383    }
384
385    struct SseConnection {
386        event_source: EventSource,
387        update_stream: EventSourceSubscription,
388        error_stream: EventSourceSubscription,
389    }
390
391    async fn open_sse_connection<TValue, F>(
392        client: &ForgeClient,
393        callback: &Rc<RefCell<F>>,
394        handle_task: &SubscriptionHandle,
395    ) -> Option<(SseConnection, ConnectedEvent)>
396    where
397        F: FnMut(StreamEvent<TValue>),
398    {
399        let mut event_source = match EventSource::new(&events_url(client)) {
400            Ok(source) => source,
401            Err(err) => {
402                client.emit_error(
403                    callback,
404                    ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
405                );
406                client.emit_connection(callback, ConnectionState::Disconnected);
407                handle_task.finish();
408                return None;
409            }
410        };
411
412        macro_rules! subscribe_or_bail {
413            ($event_type:expr) => {
414                match event_source.subscribe($event_type) {
415                    Ok(stream) => stream,
416                    Err(err) => {
417                        client.emit_error(
418                            callback,
419                            ForgeClientError::new(
420                                "SSE_SUBSCRIBE_FAILED",
421                                err.to_string(),
422                                None,
423                            ),
424                        );
425                        client.emit_connection(callback, ConnectionState::Disconnected);
426                        handle_task.finish();
427                        return None;
428                    }
429                }
430            };
431        }
432
433        let mut connected_stream = subscribe_or_bail!("connected");
434        let update_stream = subscribe_or_bail!("update");
435        let error_stream = subscribe_or_bail!("error");
436
437        let connected_event = match connected_stream.next().await {
438            Some(Ok((_kind, message))) => {
439                let Some(raw) = message.data().as_string() else {
440                    client.emit_error(
441                        callback,
442                        ForgeClientError::new(
443                            "INVALID_SSE_PAYLOAD",
444                            "SSE payload was not a string",
445                            None,
446                        ),
447                    );
448                    client.emit_connection(callback, ConnectionState::Disconnected);
449                    handle_task.finish();
450                    return None;
451                };
452                match parse_json_str::<ConnectedEvent>(&raw) {
453                    Ok(event) => event,
454                    Err(err) => {
455                        client.emit_error(callback, err);
456                        client.emit_connection(callback, ConnectionState::Disconnected);
457                        handle_task.finish();
458                        return None;
459                    }
460                }
461            }
462            Some(Err(err)) => {
463                client.emit_error(
464                    callback,
465                    ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
466                );
467                client.emit_connection(callback, ConnectionState::Disconnected);
468                handle_task.finish();
469                return None;
470            }
471            None => {
472                client.emit_connection(callback, ConnectionState::Disconnected);
473                handle_task.finish();
474                return None;
475            }
476        };
477
478        if handle_task.is_closed() {
479            event_source.close();
480            handle_task.finish();
481            return None;
482        }
483
484        Some((SseConnection { event_source, update_stream, error_stream }, connected_event))
485    }
486
487    async fn process_sse_events<TResult, F>(
488        update_stream: EventSourceSubscription,
489        error_stream: EventSourceSubscription,
490        client: &ForgeClient,
491        callback: &Rc<RefCell<F>>,
492        handle_task: &SubscriptionHandle,
493    ) where
494        TResult: DeserializeOwned + 'static,
495        F: FnMut(StreamEvent<TResult>),
496    {
497        let mut events = stream::select(update_stream, error_stream);
498        while !handle_task.is_closed() {
499            let Some(event) = events.next().await else {
500                break;
501            };
502
503            match event {
504                Ok((kind, message)) if kind == "update" => {
505                    let Some(raw) = message.data().as_string() else {
506                        client.emit_error(
507                            callback,
508                            ForgeClientError::new(
509                                "INVALID_SSE_PAYLOAD",
510                                "SSE payload was not a string",
511                                None,
512                            ),
513                        );
514                        continue;
515                    };
516                    let envelope = match parse_json_str::<SseEnvelopeRaw>(&raw) {
517                        Ok(value) => value,
518                        Err(err) => {
519                            client.emit_error(callback, err);
520                            continue;
521                        }
522                    };
523                    if let Some(data) = envelope.payload {
524                        let parsed = match serde_json::from_value::<TResult>(data) {
525                            Ok(value) => value,
526                            Err(err) => {
527                                client.emit_error(
528                                    callback,
529                                    ForgeClientError::new(
530                                        "INVALID_SSE_PAYLOAD",
531                                        err.to_string(),
532                                        None,
533                                    ),
534                                );
535                                continue;
536                            }
537                        };
538                        (callback.borrow_mut())(StreamEvent::Data(parsed));
539                    }
540                }
541                Ok((_kind, message)) => {
542                    let Some(raw) = message.data().as_string() else {
543                        client.emit_error(
544                            callback,
545                            ForgeClientError::new(
546                                "INVALID_SSE_PAYLOAD",
547                                "SSE payload was not a string",
548                                None,
549                            ),
550                        );
551                        continue;
552                    };
553                    let envelope = match parse_json_str::<SseEnvelopeRaw>(&raw) {
554                        Ok(value) => value,
555                        Err(err) => {
556                            client.emit_error(callback, err);
557                            continue;
558                        }
559                    };
560                    emit_sse_error(client, callback, envelope);
561                }
562                Err(err) => {
563                    client.emit_error(
564                        callback,
565                        ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
566                    );
567                    break;
568                }
569            }
570        }
571    }
572
573    pub(super) fn subscribe_query<TArgs, TResult, F>(
574        client: ForgeClient,
575        function_name: String,
576        args: TArgs,
577        callback: F,
578    ) -> SubscriptionHandle
579    where
580        TArgs: Serialize + Clone + 'static,
581        TResult: DeserializeOwned + Clone + 'static,
582        F: FnMut(StreamEvent<TResult>) + 'static,
583    {
584        let handle = SubscriptionHandle::new();
585        let handle_task = handle.clone();
586        let callback = Rc::new(RefCell::new(callback));
587
588        let task = spawn(async move {
589            client.emit_connection(&callback, ConnectionState::Connecting);
590
591            let args_value = match serde_json::to_value(args) {
592                Ok(value) => value,
593                Err(err) => {
594                    client.emit_error(
595                        &callback,
596                        ForgeClientError::new("SERIALIZATION_ERROR", err.to_string(), None),
597                    );
598                    client.emit_connection(&callback, ConnectionState::Disconnected);
599                    handle_task.finish();
600                    return;
601                }
602            };
603
604            let Some((sse, connected)) =
605                open_sse_connection(&client, &callback, &handle_task).await
606            else {
607                return;
608            };
609
610            let register_payload = serde_json::json!({
611                "session_id": connected.session_id,
612                "session_secret": connected.session_secret,
613                "id": client.random_id("sub"),
614                "function": function_name,
615                "args": args_value,
616            });
617
618            match request_json(
619                &client,
620                &format!("{}/_api/subscribe", client.inner.url),
621                register_payload,
622            )
623            .await
624            {
625                Ok(envelope) => match client.decode_envelope::<TResult>(envelope) {
626                    Ok(data) => {
627                        client.emit_connection(&callback, ConnectionState::Connected);
628                        (callback.borrow_mut())(StreamEvent::Data(data));
629                    }
630                    Err(err) => {
631                        client.emit_error(&callback, err);
632                        client.emit_connection(&callback, ConnectionState::Disconnected);
633                        handle_task.finish();
634                        return;
635                    }
636                },
637                Err(err) => {
638                    client.emit_error(&callback, err);
639                    client.emit_connection(&callback, ConnectionState::Disconnected);
640                    handle_task.finish();
641                    return;
642                }
643            }
644
645            process_sse_events::<TResult, _>(
646                sse.update_stream,
647                sse.error_stream,
648                &client,
649                &callback,
650                &handle_task,
651            )
652            .await;
653
654            sse.event_source.close();
655            client.emit_connection(&callback, ConnectionState::Disconnected);
656            handle_task.finish();
657        });
658
659        handle.set_task(task);
660        handle
661    }
662
663    pub(super) fn subscribe_tracker<TResult, F>(
664        client: ForgeClient,
665        prefix: String,
666        payload: serde_json::Value,
667        endpoint: String,
668        callback: F,
669    ) -> SubscriptionHandle
670    where
671        TResult: DeserializeOwned + Clone + 'static,
672        F: FnMut(StreamEvent<TResult>) + 'static,
673    {
674        let handle = SubscriptionHandle::new();
675        let handle_task = handle.clone();
676        let callback = Rc::new(RefCell::new(callback));
677
678        let task = spawn(async move {
679            client.emit_connection(&callback, ConnectionState::Connecting);
680
681            let Some((sse, connected)) =
682                open_sse_connection(&client, &callback, &handle_task).await
683            else {
684                return;
685            };
686
687            let mut register_payload = payload;
688            let register_object = register_payload
689                .as_object_mut()
690                .expect("tracker payload must be an object");
691            register_object.insert(
692                "session_id".to_string(),
693                serde_json::Value::String(connected.session_id.unwrap_or_default()),
694            );
695            register_object.insert(
696                "session_secret".to_string(),
697                serde_json::Value::String(connected.session_secret.unwrap_or_default()),
698            );
699            register_object.insert(
700                "id".to_string(),
701                serde_json::Value::String(client.random_id(&prefix)),
702            );
703
704            match request_json(
705                &client,
706                &format!("{}{}", client.inner.url, endpoint),
707                register_payload,
708            )
709            .await
710            {
711                Ok(envelope) => {
712                    client.emit_connection(&callback, ConnectionState::Connected);
713                    if envelope.success {
714                        if let Some(data) = envelope.data {
715                            if let Ok(parsed) = serde_json::from_value::<TResult>(data) {
716                                (callback.borrow_mut())(StreamEvent::Data(parsed));
717                            }
718                        }
719                    }
720                }
721                Err(err) => {
722                    client.emit_error(&callback, err);
723                    client.emit_connection(&callback, ConnectionState::Disconnected);
724                    handle_task.finish();
725                    return;
726                }
727            }
728
729            process_sse_events::<TResult, _>(
730                sse.update_stream,
731                sse.error_stream,
732                &client,
733                &callback,
734                &handle_task,
735            )
736            .await;
737
738            sse.event_source.close();
739            client.emit_connection(&callback, ConnectionState::Disconnected);
740            handle_task.finish();
741        });
742
743        handle.set_task(task);
744        handle
745    }
746
747    fn events_url(client: &ForgeClient) -> String {
748        match client.get_token() {
749            Some(token) => format!(
750                "{}/_api/events?token={}",
751                client.inner.url,
752                encode_uri_component(&token)
753            ),
754            None => format!("{}/_api/events", client.inner.url),
755        }
756    }
757
758    fn request_error(err: gloo_net::Error) -> ForgeClientError {
759        ForgeClientError::new("REQUEST_FAILED", err.to_string(), None)
760    }
761}
762
763#[cfg(not(target_arch = "wasm32"))]
764mod platform {
765    use std::cell::RefCell;
766    use std::rc::Rc;
767
768    use dioxus::prelude::spawn;
769    use futures_util::StreamExt;
770    use reqwest::Client;
771    use reqwest_eventsource::{Event, EventSource};
772    use serde::Serialize;
773    use serde::de::DeserializeOwned;
774
775    use super::{ForgeClient, SubscriptionHandle, emit_sse_error, parse_json_str};
776    use crate::types::{
777        ConnectedEvent, ConnectionState, ForgeClientError, RpcEnvelopeRaw, SseEnvelopeRaw,
778        StreamEvent,
779    };
780
781    pub(super) async fn request_json(
782        client: &ForgeClient,
783        url: &str,
784        body: serde_json::Value,
785    ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
786        let mut request = Client::new().post(url).json(&body);
787        if let Some(token) = client.get_token() {
788            request = request.bearer_auth(token);
789        }
790
791        request
792            .send()
793            .await
794            .map_err(request_error)?
795            .json()
796            .await
797            .map_err(request_error)
798    }
799
800    pub(super) async fn request_multipart(
801        client: &ForgeClient,
802        url: &str,
803        form: reqwest::multipart::Form,
804    ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
805        let mut request = Client::new().post(url).multipart(form);
806        if let Some(token) = client.get_token() {
807            request = request.bearer_auth(token);
808        }
809
810        request
811            .send()
812            .await
813            .map_err(request_error)?
814            .json()
815            .await
816            .map_err(request_error)
817    }
818
819    async fn process_sse_events<TResult, F>(
820        event_source: &mut EventSource,
821        client: &ForgeClient,
822        callback: &Rc<RefCell<F>>,
823        handle_task: &SubscriptionHandle,
824    ) where
825        TResult: DeserializeOwned + 'static,
826        F: FnMut(StreamEvent<TResult>),
827    {
828        while !handle_task.is_closed() {
829            let Some(event) = event_source.next().await else {
830                break;
831            };
832
833            match event {
834                Ok(Event::Open) => {}
835                Ok(Event::Message(message)) if message.event == "update" => {
836                    let envelope = match parse_json_str::<SseEnvelopeRaw>(&message.data) {
837                        Ok(value) => value,
838                        Err(err) => {
839                            client.emit_error(callback, err);
840                            continue;
841                        }
842                    };
843                    if let Some(data) = envelope.payload {
844                        let parsed = match serde_json::from_value::<TResult>(data) {
845                            Ok(value) => value,
846                            Err(err) => {
847                                client.emit_error(
848                                    callback,
849                                    ForgeClientError::new(
850                                        "INVALID_SSE_PAYLOAD",
851                                        err.to_string(),
852                                        None,
853                                    ),
854                                );
855                                continue;
856                            }
857                        };
858                        (callback.borrow_mut())(StreamEvent::Data(parsed));
859                    }
860                }
861                Ok(Event::Message(message)) if message.event == "error" => {
862                    let envelope = match parse_json_str::<SseEnvelopeRaw>(&message.data) {
863                        Ok(value) => value,
864                        Err(err) => {
865                            client.emit_error(callback, err);
866                            continue;
867                        }
868                    };
869                    emit_sse_error(client, callback, envelope);
870                }
871                Ok(Event::Message(_)) => {}
872                Err(err) => {
873                    client.emit_error(
874                        callback,
875                        ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
876                    );
877                    break;
878                }
879            }
880        }
881    }
882
883    async fn open_and_connect<TValue, F>(
884        client: &ForgeClient,
885        callback: &Rc<RefCell<F>>,
886        handle_task: &SubscriptionHandle,
887    ) -> Option<(EventSource, ConnectedEvent)>
888    where
889        F: FnMut(StreamEvent<TValue>),
890    {
891        let mut event_source = match open_event_source(client) {
892            Ok(source) => source,
893            Err(err) => {
894                client.emit_error(callback, err);
895                client.emit_connection(callback, ConnectionState::Disconnected);
896                handle_task.finish();
897                return None;
898            }
899        };
900
901        let connected_event =
902            match next_connected_event(&mut event_source, client, callback).await {
903                Ok(Some(event)) => event,
904                Ok(None) => {
905                    client.emit_connection(callback, ConnectionState::Disconnected);
906                    handle_task.finish();
907                    return None;
908                }
909                Err(err) => {
910                    client.emit_error(callback, err);
911                    client.emit_connection(callback, ConnectionState::Disconnected);
912                    handle_task.finish();
913                    return None;
914                }
915            };
916
917        if handle_task.is_closed() {
918            event_source.close();
919            handle_task.finish();
920            return None;
921        }
922
923        Some((event_source, connected_event))
924    }
925
926    pub(super) fn subscribe_query<TArgs, TResult, F>(
927        client: ForgeClient,
928        function_name: String,
929        args: TArgs,
930        callback: F,
931    ) -> SubscriptionHandle
932    where
933        TArgs: Serialize + Clone + 'static,
934        TResult: DeserializeOwned + Clone + 'static,
935        F: FnMut(StreamEvent<TResult>) + 'static,
936    {
937        let handle = SubscriptionHandle::new();
938        let handle_task = handle.clone();
939        let callback = Rc::new(RefCell::new(callback));
940
941        let task = spawn(async move {
942            client.emit_connection(&callback, ConnectionState::Connecting);
943
944            let args_value = match serde_json::to_value(args) {
945                Ok(value) => value,
946                Err(err) => {
947                    client.emit_error(
948                        &callback,
949                        ForgeClientError::new("SERIALIZATION_ERROR", err.to_string(), None),
950                    );
951                    client.emit_connection(&callback, ConnectionState::Disconnected);
952                    handle_task.finish();
953                    return;
954                }
955            };
956
957            let Some((mut event_source, connected)) =
958                open_and_connect(&client, &callback, &handle_task).await
959            else {
960                return;
961            };
962
963            let register_payload = serde_json::json!({
964                "session_id": connected.session_id,
965                "session_secret": connected.session_secret,
966                "id": client.random_id("sub"),
967                "function": function_name,
968                "args": args_value,
969            });
970
971            match request_json(
972                &client,
973                &format!("{}/_api/subscribe", client.inner.url),
974                register_payload,
975            )
976            .await
977            {
978                Ok(envelope) => match client.decode_envelope::<TResult>(envelope) {
979                    Ok(data) => {
980                        client.emit_connection(&callback, ConnectionState::Connected);
981                        (callback.borrow_mut())(StreamEvent::Data(data));
982                    }
983                    Err(err) => {
984                        client.emit_error(&callback, err);
985                        client.emit_connection(&callback, ConnectionState::Disconnected);
986                        handle_task.finish();
987                        return;
988                    }
989                },
990                Err(err) => {
991                    client.emit_error(&callback, err);
992                    client.emit_connection(&callback, ConnectionState::Disconnected);
993                    handle_task.finish();
994                    return;
995                }
996            }
997
998            process_sse_events::<TResult, _>(
999                &mut event_source,
1000                &client,
1001                &callback,
1002                &handle_task,
1003            )
1004            .await;
1005
1006            event_source.close();
1007            client.emit_connection(&callback, ConnectionState::Disconnected);
1008            handle_task.finish();
1009        });
1010
1011        handle.set_task(task);
1012        handle
1013    }
1014
1015    pub(super) fn subscribe_tracker<TResult, F>(
1016        client: ForgeClient,
1017        prefix: String,
1018        payload: serde_json::Value,
1019        endpoint: String,
1020        callback: F,
1021    ) -> SubscriptionHandle
1022    where
1023        TResult: DeserializeOwned + Clone + 'static,
1024        F: FnMut(StreamEvent<TResult>) + 'static,
1025    {
1026        let handle = SubscriptionHandle::new();
1027        let handle_task = handle.clone();
1028        let callback = Rc::new(RefCell::new(callback));
1029
1030        let task = spawn(async move {
1031            client.emit_connection(&callback, ConnectionState::Connecting);
1032
1033            let Some((mut event_source, connected)) =
1034                open_and_connect(&client, &callback, &handle_task).await
1035            else {
1036                return;
1037            };
1038
1039            let mut register_payload = payload;
1040            let register_object = register_payload
1041                .as_object_mut()
1042                .expect("tracker payload must be an object");
1043            register_object.insert(
1044                "session_id".to_string(),
1045                serde_json::Value::String(connected.session_id.unwrap_or_default()),
1046            );
1047            register_object.insert(
1048                "session_secret".to_string(),
1049                serde_json::Value::String(connected.session_secret.unwrap_or_default()),
1050            );
1051            register_object.insert(
1052                "id".to_string(),
1053                serde_json::Value::String(client.random_id(&prefix)),
1054            );
1055
1056            match request_json(
1057                &client,
1058                &format!("{}{}", client.inner.url, endpoint),
1059                register_payload,
1060            )
1061            .await
1062            {
1063                Ok(envelope) => {
1064                    client.emit_connection(&callback, ConnectionState::Connected);
1065                    if envelope.success {
1066                        if let Some(data) = envelope.data {
1067                            if let Ok(parsed) = serde_json::from_value::<TResult>(data) {
1068                                (callback.borrow_mut())(StreamEvent::Data(parsed));
1069                            }
1070                        }
1071                    }
1072                }
1073                Err(err) => {
1074                    client.emit_error(&callback, err);
1075                    client.emit_connection(&callback, ConnectionState::Disconnected);
1076                    handle_task.finish();
1077                    return;
1078                }
1079            }
1080
1081            process_sse_events::<TResult, _>(
1082                &mut event_source,
1083                &client,
1084                &callback,
1085                &handle_task,
1086            )
1087            .await;
1088
1089            event_source.close();
1090            client.emit_connection(&callback, ConnectionState::Disconnected);
1091            handle_task.finish();
1092        });
1093
1094        handle.set_task(task);
1095        handle
1096    }
1097
1098    fn open_event_source(client: &ForgeClient) -> Result<EventSource, ForgeClientError> {
1099        let mut request = Client::new().get(format!("{}/_api/events", client.inner.url));
1100        if let Some(token) = client.get_token() {
1101            request = request.bearer_auth(token);
1102        }
1103
1104        EventSource::new(request)
1105            .map_err(|err| ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None))
1106    }
1107
1108    async fn next_connected_event<TValue, T>(
1109        event_source: &mut EventSource,
1110        client: &ForgeClient,
1111        callback: &Rc<RefCell<T>>,
1112    ) -> Result<Option<ConnectedEvent>, ForgeClientError>
1113    where
1114        T: FnMut(StreamEvent<TValue>),
1115    {
1116        while let Some(event) = event_source.next().await {
1117            match event {
1118                Ok(Event::Open) => continue,
1119                Ok(Event::Message(message)) if message.event == "connected" => {
1120                    return parse_json_str::<ConnectedEvent>(&message.data).map(Some);
1121                }
1122                Ok(Event::Message(message)) if message.event == "error" => {
1123                    let envelope = parse_json_str::<SseEnvelopeRaw>(&message.data)?;
1124                    emit_sse_error(client, callback, envelope);
1125                }
1126                Ok(Event::Message(_)) => {}
1127                Err(err) => {
1128                    return Err(ForgeClientError::new(
1129                        "SSE_CONNECTION_FAILED",
1130                        err.to_string(),
1131                        None,
1132                    ));
1133                }
1134            }
1135        }
1136
1137        Ok(None)
1138    }
1139
1140    fn request_error(err: reqwest::Error) -> ForgeClientError {
1141        ForgeClientError::new("REQUEST_FAILED", err.to_string(), None)
1142    }
1143}