1
2use std::cell::{Cell, RefCell};
3use std::rc::Rc;
4use std::sync::atomic::{AtomicU64, Ordering};
5
6use dioxus::prelude::{Signal, WritableExt, dioxus_core::Task};
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9
10use crate::types::{
11 ConnectionState, ForgeClientError, ForgeError, RpcEnvelopeRaw, SseEnvelopeRaw, StreamEvent,
12};
13
14type TokenProvider = Rc<dyn Fn() -> Option<String>>;
15type AuthErrorHandler = Rc<dyn Fn(ForgeError)>;
16
17static NEXT_SUBSCRIPTION_ID: AtomicU64 = AtomicU64::new(1);
18
19#[derive(Clone)]
20pub struct ForgeClientConfig {
21 pub url: String,
22 pub get_token: Option<TokenProvider>,
23 pub on_auth_error: Option<AuthErrorHandler>,
24 pub(crate) connection_state: Option<Signal<ConnectionState>>,
25}
26
27impl ForgeClientConfig {
28 pub fn new(url: impl Into<String>) -> Self {
29 Self {
30 url: url.into(),
31 get_token: None,
32 on_auth_error: None,
33 connection_state: None,
34 }
35 }
36
37 pub fn with_token_provider(mut self, provider: impl Fn() -> Option<String> + 'static) -> Self {
38 self.get_token = Some(Rc::new(provider));
39 self
40 }
41
42 pub fn with_auth_error_handler(
43 mut self,
44 handler: impl Fn(ForgeError) + 'static,
45 ) -> Self {
46 self.on_auth_error = Some(Rc::new(handler));
47 self
48 }
49
50 pub(crate) fn with_connection_state(mut self, state: Signal<ConnectionState>) -> Self {
51 self.connection_state = Some(state);
52 self
53 }
54}
55
56#[derive(Clone)]
57pub struct ForgeClient {
58 inner: Rc<ForgeClientInner>,
59}
60
61struct ForgeClientInner {
62 url: String,
63 get_token: Option<TokenProvider>,
64 on_auth_error: Option<AuthErrorHandler>,
65 connection_state: Option<Signal<ConnectionState>>,
66}
67
68impl ForgeClient {
69 pub fn new(config: ForgeClientConfig) -> Self {
70 Self {
71 inner: Rc::new(ForgeClientInner {
72 url: config.url.trim_end_matches('/').to_string(),
73 get_token: config.get_token,
74 on_auth_error: config.on_auth_error,
75 connection_state: config.connection_state,
76 }),
77 }
78 }
79
80 pub async fn call<TArgs, TResult>(
81 &self,
82 function_name: &str,
83 args: TArgs,
84 ) -> Result<TResult, ForgeClientError>
85 where
86 TArgs: Serialize,
87 TResult: DeserializeOwned,
88 {
89 let body = serde_json::json!({ "args": args });
90 let envelope = platform::request_json(
91 self,
92 &format!("{}/_api/rpc/{}", self.inner.url, function_name),
93 body,
94 )
95 .await?;
96 self.decode_envelope(envelope)
97 }
98
99 #[cfg(target_arch = "wasm32")]
100 pub async fn call_multipart<TResult>(
101 &self,
102 function_name: &str,
103 form: web_sys::FormData,
104 ) -> Result<TResult, ForgeClientError>
105 where
106 TResult: DeserializeOwned,
107 {
108 let envelope = platform::request_multipart(
109 self,
110 &format!("{}/_api/rpc/{}/upload", self.inner.url, function_name),
111 form,
112 )
113 .await?;
114 self.decode_envelope(envelope)
115 }
116
117 #[cfg(not(target_arch = "wasm32"))]
118 pub async fn call_multipart<TResult>(
119 &self,
120 function_name: &str,
121 form: reqwest::multipart::Form,
122 ) -> Result<TResult, ForgeClientError>
123 where
124 TResult: DeserializeOwned,
125 {
126 let envelope = platform::request_multipart(
127 self,
128 &format!("{}/_api/rpc/{}/upload", self.inner.url, function_name),
129 form,
130 )
131 .await?;
132 self.decode_envelope(envelope)
133 }
134
135 pub fn subscribe_query<TArgs, TResult, F>(
136 &self,
137 function_name: &str,
138 args: TArgs,
139 callback: F,
140 ) -> SubscriptionHandle
141 where
142 TArgs: Serialize + Clone + 'static,
143 TResult: DeserializeOwned + Clone + 'static,
144 F: FnMut(StreamEvent<TResult>) + 'static,
145 {
146 platform::subscribe_query(self.clone(), function_name.to_string(), args, callback)
147 }
148
149 pub fn subscribe_job<TResult, F>(&self, job_id: String, callback: F) -> SubscriptionHandle
150 where
151 TResult: DeserializeOwned + Clone + 'static,
152 F: FnMut(StreamEvent<TResult>) + 'static,
153 {
154 self.subscribe_tracker(
155 "job",
156 serde_json::json!({ "job_id": job_id }),
157 "/_api/subscribe-job",
158 callback,
159 )
160 }
161
162 pub fn subscribe_workflow<TResult, F>(
163 &self,
164 workflow_id: String,
165 callback: F,
166 ) -> SubscriptionHandle
167 where
168 TResult: DeserializeOwned + Clone + 'static,
169 F: FnMut(StreamEvent<TResult>) + 'static,
170 {
171 self.subscribe_tracker(
172 "wf",
173 serde_json::json!({ "workflow_id": workflow_id }),
174 "/_api/subscribe-workflow",
175 callback,
176 )
177 }
178
179 fn subscribe_tracker<TResult, F>(
180 &self,
181 prefix: &str,
182 payload: serde_json::Value,
183 endpoint: &str,
184 callback: F,
185 ) -> SubscriptionHandle
186 where
187 TResult: DeserializeOwned + Clone + 'static,
188 F: FnMut(StreamEvent<TResult>) + 'static,
189 {
190 platform::subscribe_tracker(
191 self.clone(),
192 prefix.to_string(),
193 payload,
194 endpoint.to_string(),
195 callback,
196 )
197 }
198
199 fn get_token(&self) -> Option<String> {
200 self.inner
201 .get_token
202 .as_ref()
203 .and_then(|provider| provider())
204 .filter(|t| !t.is_empty())
205 }
206
207 fn emit_connection<TValue, T>(&self, callback: &Rc<RefCell<T>>, state: ConnectionState)
208 where
209 T: FnMut(StreamEvent<TValue>),
210 {
211 if let Some(mut signal) = self.inner.connection_state {
212 signal.set(state);
213 }
214 (callback.borrow_mut())(StreamEvent::Connection(state));
215 }
216
217 fn emit_error<TValue, T>(&self, callback: &Rc<RefCell<T>>, error: ForgeClientError)
218 where
219 T: FnMut(StreamEvent<TValue>),
220 {
221 if error.code == "UNAUTHORIZED" {
222 if let Some(handler) = &self.inner.on_auth_error {
223 handler(error.as_forge_error());
224 }
225 }
226 (callback.borrow_mut())(StreamEvent::Error(error));
227 }
228
229 fn decode_envelope<TResult>(
230 &self,
231 envelope: RpcEnvelopeRaw,
232 ) -> Result<TResult, ForgeClientError>
233 where
234 TResult: DeserializeOwned,
235 {
236 if !envelope.success {
237 let error = envelope.error.unwrap_or(ForgeError {
238 code: "UNKNOWN".to_string(),
239 message: "Unknown error".to_string(),
240 details: None,
241 });
242 return Err(ForgeClientError::new(error.code, error.message, error.details));
243 }
244
245 let data = envelope.data.ok_or_else(|| {
246 ForgeClientError::new("EMPTY_RESPONSE", "Server returned no data", None)
247 })?;
248 serde_json::from_value(data)
249 .map_err(|err| ForgeClientError::new("DESERIALIZATION_ERROR", err.to_string(), None))
250 }
251
252 fn random_id(&self, prefix: &str) -> String {
253 let id = NEXT_SUBSCRIPTION_ID.fetch_add(1, Ordering::Relaxed);
254 format!("{prefix}-{id}")
255 }
256}
257
258#[derive(Clone)]
259pub struct SubscriptionHandle {
260 closed: Rc<Cell<bool>>,
261 task: Rc<RefCell<Option<Task>>>,
262}
263
264impl SubscriptionHandle {
265 fn new() -> Self {
266 Self {
267 closed: Rc::new(Cell::new(false)),
268 task: Rc::new(RefCell::new(None)),
269 }
270 }
271
272 fn set_task(&self, task: Task) {
273 *self.task.borrow_mut() = Some(task);
274 }
275
276 fn finish(&self) {
277 self.closed.set(true);
278 self.task.borrow_mut().take();
279 }
280
281 pub fn close(&self) {
282 self.closed.set(true);
283 if let Some(task) = self.task.borrow_mut().take() {
284 task.cancel();
285 }
286 }
287
288 pub fn is_closed(&self) -> bool {
289 self.closed.get()
290 }
291}
292
293impl Drop for SubscriptionHandle {
294 fn drop(&mut self) {
295 self.close();
296 }
297}
298
299fn parse_json_str<T>(raw: &str) -> Result<T, ForgeClientError>
300where
301 T: DeserializeOwned,
302{
303 serde_json::from_str(raw)
304 .map_err(|err| ForgeClientError::new("INVALID_SSE_PAYLOAD", err.to_string(), None))
305}
306
307fn emit_sse_error<TValue, T>(
308 client: &ForgeClient,
309 callback: &Rc<RefCell<T>>,
310 envelope: SseEnvelopeRaw,
311) where
312 T: FnMut(StreamEvent<TValue>),
313{
314 client.emit_error(
315 callback,
316 ForgeClientError::new(
317 envelope.code.unwrap_or_else(|| "SSE_ERROR".to_string()),
318 envelope
319 .message
320 .unwrap_or_else(|| "Subscription error".to_string()),
321 None,
322 ),
323 );
324}
325
326#[cfg(target_arch = "wasm32")]
327mod platform {
328 use std::cell::RefCell;
329 use std::rc::Rc;
330
331 use dioxus::prelude::spawn;
332 use futures_util::{StreamExt, stream};
333 use gloo_net::eventsource::futures::{EventSource, EventSourceSubscription};
334 use gloo_net::http::Request;
335 use js_sys::encode_uri_component;
336 use serde::Serialize;
337 use serde::de::DeserializeOwned;
338
339 use super::{ForgeClient, SubscriptionHandle, emit_sse_error, parse_json_str};
340 use crate::types::{
341 ConnectedEvent, ConnectionState, ForgeClientError, RpcEnvelopeRaw, SseEnvelopeRaw,
342 StreamEvent,
343 };
344
345 pub(super) async fn request_json(
346 client: &ForgeClient,
347 url: &str,
348 body: serde_json::Value,
349 ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
350 let mut request = Request::post(url).header("Content-Type", "application/json");
351 if let Some(token) = client.get_token() {
352 request = request.header("Authorization", &format!("Bearer {token}"));
353 }
354
355 let request = request.body(body.to_string()).map_err(request_error)?;
356 request
357 .send()
358 .await
359 .map_err(request_error)?
360 .json()
361 .await
362 .map_err(request_error)
363 }
364
365 pub(super) async fn request_multipart(
366 client: &ForgeClient,
367 url: &str,
368 form: web_sys::FormData,
369 ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
370 let mut request = Request::post(url);
371 if let Some(token) = client.get_token() {
372 request = request.header("Authorization", &format!("Bearer {token}"));
373 }
374
375 let response = request.body(form).map_err(request_error)?;
376 response
377 .send()
378 .await
379 .map_err(request_error)?
380 .json()
381 .await
382 .map_err(request_error)
383 }
384
385 struct SseConnection {
386 event_source: EventSource,
387 update_stream: EventSourceSubscription,
388 error_stream: EventSourceSubscription,
389 }
390
391 async fn open_sse_connection<TValue, F>(
392 client: &ForgeClient,
393 callback: &Rc<RefCell<F>>,
394 handle_task: &SubscriptionHandle,
395 ) -> Option<(SseConnection, ConnectedEvent)>
396 where
397 F: FnMut(StreamEvent<TValue>),
398 {
399 let mut event_source = match EventSource::new(&events_url(client)) {
400 Ok(source) => source,
401 Err(err) => {
402 client.emit_error(
403 callback,
404 ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
405 );
406 client.emit_connection(callback, ConnectionState::Disconnected);
407 handle_task.finish();
408 return None;
409 }
410 };
411
412 macro_rules! subscribe_or_bail {
413 ($event_type:expr) => {
414 match event_source.subscribe($event_type) {
415 Ok(stream) => stream,
416 Err(err) => {
417 client.emit_error(
418 callback,
419 ForgeClientError::new(
420 "SSE_SUBSCRIBE_FAILED",
421 err.to_string(),
422 None,
423 ),
424 );
425 client.emit_connection(callback, ConnectionState::Disconnected);
426 handle_task.finish();
427 return None;
428 }
429 }
430 };
431 }
432
433 let mut connected_stream = subscribe_or_bail!("connected");
434 let update_stream = subscribe_or_bail!("update");
435 let error_stream = subscribe_or_bail!("error");
436
437 let connected_event = match connected_stream.next().await {
438 Some(Ok((_kind, message))) => {
439 let Some(raw) = message.data().as_string() else {
440 client.emit_error(
441 callback,
442 ForgeClientError::new(
443 "INVALID_SSE_PAYLOAD",
444 "SSE payload was not a string",
445 None,
446 ),
447 );
448 client.emit_connection(callback, ConnectionState::Disconnected);
449 handle_task.finish();
450 return None;
451 };
452 match parse_json_str::<ConnectedEvent>(&raw) {
453 Ok(event) => event,
454 Err(err) => {
455 client.emit_error(callback, err);
456 client.emit_connection(callback, ConnectionState::Disconnected);
457 handle_task.finish();
458 return None;
459 }
460 }
461 }
462 Some(Err(err)) => {
463 client.emit_error(
464 callback,
465 ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
466 );
467 client.emit_connection(callback, ConnectionState::Disconnected);
468 handle_task.finish();
469 return None;
470 }
471 None => {
472 client.emit_connection(callback, ConnectionState::Disconnected);
473 handle_task.finish();
474 return None;
475 }
476 };
477
478 if handle_task.is_closed() {
479 event_source.close();
480 handle_task.finish();
481 return None;
482 }
483
484 Some((SseConnection { event_source, update_stream, error_stream }, connected_event))
485 }
486
487 async fn process_sse_events<TResult, F>(
488 update_stream: EventSourceSubscription,
489 error_stream: EventSourceSubscription,
490 client: &ForgeClient,
491 callback: &Rc<RefCell<F>>,
492 handle_task: &SubscriptionHandle,
493 ) where
494 TResult: DeserializeOwned + 'static,
495 F: FnMut(StreamEvent<TResult>),
496 {
497 let mut events = stream::select(update_stream, error_stream);
498 while !handle_task.is_closed() {
499 let Some(event) = events.next().await else {
500 break;
501 };
502
503 match event {
504 Ok((kind, message)) if kind == "update" => {
505 let Some(raw) = message.data().as_string() else {
506 client.emit_error(
507 callback,
508 ForgeClientError::new(
509 "INVALID_SSE_PAYLOAD",
510 "SSE payload was not a string",
511 None,
512 ),
513 );
514 continue;
515 };
516 let envelope = match parse_json_str::<SseEnvelopeRaw>(&raw) {
517 Ok(value) => value,
518 Err(err) => {
519 client.emit_error(callback, err);
520 continue;
521 }
522 };
523 if let Some(data) = envelope.payload {
524 let parsed = match serde_json::from_value::<TResult>(data) {
525 Ok(value) => value,
526 Err(err) => {
527 client.emit_error(
528 callback,
529 ForgeClientError::new(
530 "INVALID_SSE_PAYLOAD",
531 err.to_string(),
532 None,
533 ),
534 );
535 continue;
536 }
537 };
538 (callback.borrow_mut())(StreamEvent::Data(parsed));
539 }
540 }
541 Ok((_kind, message)) => {
542 let Some(raw) = message.data().as_string() else {
543 client.emit_error(
544 callback,
545 ForgeClientError::new(
546 "INVALID_SSE_PAYLOAD",
547 "SSE payload was not a string",
548 None,
549 ),
550 );
551 continue;
552 };
553 let envelope = match parse_json_str::<SseEnvelopeRaw>(&raw) {
554 Ok(value) => value,
555 Err(err) => {
556 client.emit_error(callback, err);
557 continue;
558 }
559 };
560 emit_sse_error(client, callback, envelope);
561 }
562 Err(err) => {
563 client.emit_error(
564 callback,
565 ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
566 );
567 break;
568 }
569 }
570 }
571 }
572
573 pub(super) fn subscribe_query<TArgs, TResult, F>(
574 client: ForgeClient,
575 function_name: String,
576 args: TArgs,
577 callback: F,
578 ) -> SubscriptionHandle
579 where
580 TArgs: Serialize + Clone + 'static,
581 TResult: DeserializeOwned + Clone + 'static,
582 F: FnMut(StreamEvent<TResult>) + 'static,
583 {
584 let handle = SubscriptionHandle::new();
585 let handle_task = handle.clone();
586 let callback = Rc::new(RefCell::new(callback));
587
588 let task = spawn(async move {
589 client.emit_connection(&callback, ConnectionState::Connecting);
590
591 let args_value = match serde_json::to_value(args) {
592 Ok(value) => value,
593 Err(err) => {
594 client.emit_error(
595 &callback,
596 ForgeClientError::new("SERIALIZATION_ERROR", err.to_string(), None),
597 );
598 client.emit_connection(&callback, ConnectionState::Disconnected);
599 handle_task.finish();
600 return;
601 }
602 };
603
604 let Some((sse, connected)) =
605 open_sse_connection(&client, &callback, &handle_task).await
606 else {
607 return;
608 };
609
610 let register_payload = serde_json::json!({
611 "session_id": connected.session_id,
612 "session_secret": connected.session_secret,
613 "id": client.random_id("sub"),
614 "function": function_name,
615 "args": args_value,
616 });
617
618 match request_json(
619 &client,
620 &format!("{}/_api/subscribe", client.inner.url),
621 register_payload,
622 )
623 .await
624 {
625 Ok(envelope) => match client.decode_envelope::<TResult>(envelope) {
626 Ok(data) => {
627 client.emit_connection(&callback, ConnectionState::Connected);
628 (callback.borrow_mut())(StreamEvent::Data(data));
629 }
630 Err(err) => {
631 client.emit_error(&callback, err);
632 client.emit_connection(&callback, ConnectionState::Disconnected);
633 handle_task.finish();
634 return;
635 }
636 },
637 Err(err) => {
638 client.emit_error(&callback, err);
639 client.emit_connection(&callback, ConnectionState::Disconnected);
640 handle_task.finish();
641 return;
642 }
643 }
644
645 process_sse_events::<TResult, _>(
646 sse.update_stream,
647 sse.error_stream,
648 &client,
649 &callback,
650 &handle_task,
651 )
652 .await;
653
654 sse.event_source.close();
655 client.emit_connection(&callback, ConnectionState::Disconnected);
656 handle_task.finish();
657 });
658
659 handle.set_task(task);
660 handle
661 }
662
663 pub(super) fn subscribe_tracker<TResult, F>(
664 client: ForgeClient,
665 prefix: String,
666 payload: serde_json::Value,
667 endpoint: String,
668 callback: F,
669 ) -> SubscriptionHandle
670 where
671 TResult: DeserializeOwned + Clone + 'static,
672 F: FnMut(StreamEvent<TResult>) + 'static,
673 {
674 let handle = SubscriptionHandle::new();
675 let handle_task = handle.clone();
676 let callback = Rc::new(RefCell::new(callback));
677
678 let task = spawn(async move {
679 client.emit_connection(&callback, ConnectionState::Connecting);
680
681 let Some((sse, connected)) =
682 open_sse_connection(&client, &callback, &handle_task).await
683 else {
684 return;
685 };
686
687 let mut register_payload = payload;
688 let register_object = register_payload
689 .as_object_mut()
690 .expect("tracker payload must be an object");
691 register_object.insert(
692 "session_id".to_string(),
693 serde_json::Value::String(connected.session_id.unwrap_or_default()),
694 );
695 register_object.insert(
696 "session_secret".to_string(),
697 serde_json::Value::String(connected.session_secret.unwrap_or_default()),
698 );
699 register_object.insert(
700 "id".to_string(),
701 serde_json::Value::String(client.random_id(&prefix)),
702 );
703
704 match request_json(
705 &client,
706 &format!("{}{}", client.inner.url, endpoint),
707 register_payload,
708 )
709 .await
710 {
711 Ok(envelope) => {
712 client.emit_connection(&callback, ConnectionState::Connected);
713 if envelope.success {
714 if let Some(data) = envelope.data {
715 if let Ok(parsed) = serde_json::from_value::<TResult>(data) {
716 (callback.borrow_mut())(StreamEvent::Data(parsed));
717 }
718 }
719 }
720 }
721 Err(err) => {
722 client.emit_error(&callback, err);
723 client.emit_connection(&callback, ConnectionState::Disconnected);
724 handle_task.finish();
725 return;
726 }
727 }
728
729 process_sse_events::<TResult, _>(
730 sse.update_stream,
731 sse.error_stream,
732 &client,
733 &callback,
734 &handle_task,
735 )
736 .await;
737
738 sse.event_source.close();
739 client.emit_connection(&callback, ConnectionState::Disconnected);
740 handle_task.finish();
741 });
742
743 handle.set_task(task);
744 handle
745 }
746
747 fn events_url(client: &ForgeClient) -> String {
748 match client.get_token() {
749 Some(token) => format!(
750 "{}/_api/events?token={}",
751 client.inner.url,
752 encode_uri_component(&token)
753 ),
754 None => format!("{}/_api/events", client.inner.url),
755 }
756 }
757
758 fn request_error(err: gloo_net::Error) -> ForgeClientError {
759 ForgeClientError::new("REQUEST_FAILED", err.to_string(), None)
760 }
761}
762
763#[cfg(not(target_arch = "wasm32"))]
764mod platform {
765 use std::cell::RefCell;
766 use std::rc::Rc;
767
768 use dioxus::prelude::spawn;
769 use futures_util::StreamExt;
770 use reqwest::Client;
771 use reqwest_eventsource::{Event, EventSource};
772 use serde::Serialize;
773 use serde::de::DeserializeOwned;
774
775 use super::{ForgeClient, SubscriptionHandle, emit_sse_error, parse_json_str};
776 use crate::types::{
777 ConnectedEvent, ConnectionState, ForgeClientError, RpcEnvelopeRaw, SseEnvelopeRaw,
778 StreamEvent,
779 };
780
781 pub(super) async fn request_json(
782 client: &ForgeClient,
783 url: &str,
784 body: serde_json::Value,
785 ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
786 let mut request = Client::new().post(url).json(&body);
787 if let Some(token) = client.get_token() {
788 request = request.bearer_auth(token);
789 }
790
791 request
792 .send()
793 .await
794 .map_err(request_error)?
795 .json()
796 .await
797 .map_err(request_error)
798 }
799
800 pub(super) async fn request_multipart(
801 client: &ForgeClient,
802 url: &str,
803 form: reqwest::multipart::Form,
804 ) -> Result<RpcEnvelopeRaw, ForgeClientError> {
805 let mut request = Client::new().post(url).multipart(form);
806 if let Some(token) = client.get_token() {
807 request = request.bearer_auth(token);
808 }
809
810 request
811 .send()
812 .await
813 .map_err(request_error)?
814 .json()
815 .await
816 .map_err(request_error)
817 }
818
819 async fn process_sse_events<TResult, F>(
820 event_source: &mut EventSource,
821 client: &ForgeClient,
822 callback: &Rc<RefCell<F>>,
823 handle_task: &SubscriptionHandle,
824 ) where
825 TResult: DeserializeOwned + 'static,
826 F: FnMut(StreamEvent<TResult>),
827 {
828 while !handle_task.is_closed() {
829 let Some(event) = event_source.next().await else {
830 break;
831 };
832
833 match event {
834 Ok(Event::Open) => {}
835 Ok(Event::Message(message)) if message.event == "update" => {
836 let envelope = match parse_json_str::<SseEnvelopeRaw>(&message.data) {
837 Ok(value) => value,
838 Err(err) => {
839 client.emit_error(callback, err);
840 continue;
841 }
842 };
843 if let Some(data) = envelope.payload {
844 let parsed = match serde_json::from_value::<TResult>(data) {
845 Ok(value) => value,
846 Err(err) => {
847 client.emit_error(
848 callback,
849 ForgeClientError::new(
850 "INVALID_SSE_PAYLOAD",
851 err.to_string(),
852 None,
853 ),
854 );
855 continue;
856 }
857 };
858 (callback.borrow_mut())(StreamEvent::Data(parsed));
859 }
860 }
861 Ok(Event::Message(message)) if message.event == "error" => {
862 let envelope = match parse_json_str::<SseEnvelopeRaw>(&message.data) {
863 Ok(value) => value,
864 Err(err) => {
865 client.emit_error(callback, err);
866 continue;
867 }
868 };
869 emit_sse_error(client, callback, envelope);
870 }
871 Ok(Event::Message(_)) => {}
872 Err(err) => {
873 client.emit_error(
874 callback,
875 ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None),
876 );
877 break;
878 }
879 }
880 }
881 }
882
883 async fn open_and_connect<TValue, F>(
884 client: &ForgeClient,
885 callback: &Rc<RefCell<F>>,
886 handle_task: &SubscriptionHandle,
887 ) -> Option<(EventSource, ConnectedEvent)>
888 where
889 F: FnMut(StreamEvent<TValue>),
890 {
891 let mut event_source = match open_event_source(client) {
892 Ok(source) => source,
893 Err(err) => {
894 client.emit_error(callback, err);
895 client.emit_connection(callback, ConnectionState::Disconnected);
896 handle_task.finish();
897 return None;
898 }
899 };
900
901 let connected_event =
902 match next_connected_event(&mut event_source, client, callback).await {
903 Ok(Some(event)) => event,
904 Ok(None) => {
905 client.emit_connection(callback, ConnectionState::Disconnected);
906 handle_task.finish();
907 return None;
908 }
909 Err(err) => {
910 client.emit_error(callback, err);
911 client.emit_connection(callback, ConnectionState::Disconnected);
912 handle_task.finish();
913 return None;
914 }
915 };
916
917 if handle_task.is_closed() {
918 event_source.close();
919 handle_task.finish();
920 return None;
921 }
922
923 Some((event_source, connected_event))
924 }
925
926 pub(super) fn subscribe_query<TArgs, TResult, F>(
927 client: ForgeClient,
928 function_name: String,
929 args: TArgs,
930 callback: F,
931 ) -> SubscriptionHandle
932 where
933 TArgs: Serialize + Clone + 'static,
934 TResult: DeserializeOwned + Clone + 'static,
935 F: FnMut(StreamEvent<TResult>) + 'static,
936 {
937 let handle = SubscriptionHandle::new();
938 let handle_task = handle.clone();
939 let callback = Rc::new(RefCell::new(callback));
940
941 let task = spawn(async move {
942 client.emit_connection(&callback, ConnectionState::Connecting);
943
944 let args_value = match serde_json::to_value(args) {
945 Ok(value) => value,
946 Err(err) => {
947 client.emit_error(
948 &callback,
949 ForgeClientError::new("SERIALIZATION_ERROR", err.to_string(), None),
950 );
951 client.emit_connection(&callback, ConnectionState::Disconnected);
952 handle_task.finish();
953 return;
954 }
955 };
956
957 let Some((mut event_source, connected)) =
958 open_and_connect(&client, &callback, &handle_task).await
959 else {
960 return;
961 };
962
963 let register_payload = serde_json::json!({
964 "session_id": connected.session_id,
965 "session_secret": connected.session_secret,
966 "id": client.random_id("sub"),
967 "function": function_name,
968 "args": args_value,
969 });
970
971 match request_json(
972 &client,
973 &format!("{}/_api/subscribe", client.inner.url),
974 register_payload,
975 )
976 .await
977 {
978 Ok(envelope) => match client.decode_envelope::<TResult>(envelope) {
979 Ok(data) => {
980 client.emit_connection(&callback, ConnectionState::Connected);
981 (callback.borrow_mut())(StreamEvent::Data(data));
982 }
983 Err(err) => {
984 client.emit_error(&callback, err);
985 client.emit_connection(&callback, ConnectionState::Disconnected);
986 handle_task.finish();
987 return;
988 }
989 },
990 Err(err) => {
991 client.emit_error(&callback, err);
992 client.emit_connection(&callback, ConnectionState::Disconnected);
993 handle_task.finish();
994 return;
995 }
996 }
997
998 process_sse_events::<TResult, _>(
999 &mut event_source,
1000 &client,
1001 &callback,
1002 &handle_task,
1003 )
1004 .await;
1005
1006 event_source.close();
1007 client.emit_connection(&callback, ConnectionState::Disconnected);
1008 handle_task.finish();
1009 });
1010
1011 handle.set_task(task);
1012 handle
1013 }
1014
1015 pub(super) fn subscribe_tracker<TResult, F>(
1016 client: ForgeClient,
1017 prefix: String,
1018 payload: serde_json::Value,
1019 endpoint: String,
1020 callback: F,
1021 ) -> SubscriptionHandle
1022 where
1023 TResult: DeserializeOwned + Clone + 'static,
1024 F: FnMut(StreamEvent<TResult>) + 'static,
1025 {
1026 let handle = SubscriptionHandle::new();
1027 let handle_task = handle.clone();
1028 let callback = Rc::new(RefCell::new(callback));
1029
1030 let task = spawn(async move {
1031 client.emit_connection(&callback, ConnectionState::Connecting);
1032
1033 let Some((mut event_source, connected)) =
1034 open_and_connect(&client, &callback, &handle_task).await
1035 else {
1036 return;
1037 };
1038
1039 let mut register_payload = payload;
1040 let register_object = register_payload
1041 .as_object_mut()
1042 .expect("tracker payload must be an object");
1043 register_object.insert(
1044 "session_id".to_string(),
1045 serde_json::Value::String(connected.session_id.unwrap_or_default()),
1046 );
1047 register_object.insert(
1048 "session_secret".to_string(),
1049 serde_json::Value::String(connected.session_secret.unwrap_or_default()),
1050 );
1051 register_object.insert(
1052 "id".to_string(),
1053 serde_json::Value::String(client.random_id(&prefix)),
1054 );
1055
1056 match request_json(
1057 &client,
1058 &format!("{}{}", client.inner.url, endpoint),
1059 register_payload,
1060 )
1061 .await
1062 {
1063 Ok(envelope) => {
1064 client.emit_connection(&callback, ConnectionState::Connected);
1065 if envelope.success {
1066 if let Some(data) = envelope.data {
1067 if let Ok(parsed) = serde_json::from_value::<TResult>(data) {
1068 (callback.borrow_mut())(StreamEvent::Data(parsed));
1069 }
1070 }
1071 }
1072 }
1073 Err(err) => {
1074 client.emit_error(&callback, err);
1075 client.emit_connection(&callback, ConnectionState::Disconnected);
1076 handle_task.finish();
1077 return;
1078 }
1079 }
1080
1081 process_sse_events::<TResult, _>(
1082 &mut event_source,
1083 &client,
1084 &callback,
1085 &handle_task,
1086 )
1087 .await;
1088
1089 event_source.close();
1090 client.emit_connection(&callback, ConnectionState::Disconnected);
1091 handle_task.finish();
1092 });
1093
1094 handle.set_task(task);
1095 handle
1096 }
1097
1098 fn open_event_source(client: &ForgeClient) -> Result<EventSource, ForgeClientError> {
1099 let mut request = Client::new().get(format!("{}/_api/events", client.inner.url));
1100 if let Some(token) = client.get_token() {
1101 request = request.bearer_auth(token);
1102 }
1103
1104 EventSource::new(request)
1105 .map_err(|err| ForgeClientError::new("SSE_CONNECTION_FAILED", err.to_string(), None))
1106 }
1107
1108 async fn next_connected_event<TValue, T>(
1109 event_source: &mut EventSource,
1110 client: &ForgeClient,
1111 callback: &Rc<RefCell<T>>,
1112 ) -> Result<Option<ConnectedEvent>, ForgeClientError>
1113 where
1114 T: FnMut(StreamEvent<TValue>),
1115 {
1116 while let Some(event) = event_source.next().await {
1117 match event {
1118 Ok(Event::Open) => continue,
1119 Ok(Event::Message(message)) if message.event == "connected" => {
1120 return parse_json_str::<ConnectedEvent>(&message.data).map(Some);
1121 }
1122 Ok(Event::Message(message)) if message.event == "error" => {
1123 let envelope = parse_json_str::<SseEnvelopeRaw>(&message.data)?;
1124 emit_sse_error(client, callback, envelope);
1125 }
1126 Ok(Event::Message(_)) => {}
1127 Err(err) => {
1128 return Err(ForgeClientError::new(
1129 "SSE_CONNECTION_FAILED",
1130 err.to_string(),
1131 None,
1132 ));
1133 }
1134 }
1135 }
1136
1137 Ok(None)
1138 }
1139
1140 fn request_error(err: reqwest::Error) -> ForgeClientError {
1141 ForgeClientError::new("REQUEST_FAILED", err.to_string(), None)
1142 }
1143}