Skip to main content

flow_lib/
context.rs

1//! Providing services and information for nodes to use.
2//!
3//! Services are abstracted with [`tower::Service`] trait, using our
4//! [`TowerClient`][crate::utils::TowerClient] utility to make it easier to use.
5//!
6//! Each service is defined is a separated module:
7//! - [`get_jwt`]
8//! - [`execute`]
9//! - [`signer`]
10
11use crate::{
12    ContextConfig, FlowRunId, HttpClientConfig, NodeId, SolanaClientConfig, UserId, ValueSet,
13    config::{Endpoints, client::FlowRunOrigin},
14    flow_run_events::{self, NodeLogContent, NodeLogSender},
15    solana::Instructions,
16    utils::{Extensions, tower_client::unimplemented_svc},
17};
18use bytes::Bytes;
19use chrono::Utc;
20use futures::channel::mpsc;
21use reqwest::header::HeaderMap;
22use schemars::JsonSchema;
23use serde::{Deserialize, Serialize};
24use solana_pubkey::Pubkey;
25use solana_rpc_client::nonblocking::rpc_client::RpcClient as SolanaClient;
26use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
27use tower::{Service, ServiceExt};
28
29pub use spo_helius::Helius;
30
31pub mod env {
32    pub const RUST_LOG: &str = "RUST_LOG";
33    pub const OVERWRITE_FEEPAYER: &str = "OVERWRITE_FEEPAYER";
34    pub const COMPUTE_BUDGET: &str = "COMPUTE_BUDGET";
35    pub const FALLBACK_COMPUTE_BUDGET: &str = "FALLBACK_COMPUTE_BUDGET";
36    pub const PRIORITY_FEE: &str = "PRIORITY_FEE";
37    pub const TX_COMMITMENT_LEVEL: &str = "TX_COMMITMENT_LEVEL";
38    pub const WAIT_COMMITMENT_LEVEL: &str = "WAIT_COMMITMENT_LEVEL";
39    pub const EXECUTE_ON: &str = "EXECUTE_ON";
40    pub const DEVNET_LOOKUP_TABLE: &str = "DEVNET_LOOKUP_TABLE";
41    pub const MAINNET_LOOKUP_TABLE: &str = "MAINNET_LOOKUP_TABLE";
42}
43
44pub mod api_input {
45    use std::time::Duration;
46
47    use crate::{
48        FlowRunId, NodeId,
49        utils::{TowerClient, tower_client::CommonError},
50    };
51    use reqwest::header::HeaderMap;
52    use thiserror::Error as ThisError;
53    use value::Value;
54
55    #[derive(Debug)]
56    pub struct Request {
57        pub flow_run_id: FlowRunId,
58        pub node_id: NodeId,
59        pub times: u32,
60        pub timeout: Duration,
61        pub webhook_url: Option<String>,
62        pub webhook_headers: Option<HeaderMap>,
63        pub extra: Option<serde_json::Map<String, serde_json::Value>>,
64    }
65
66    pub struct Response {
67        pub value: Value,
68    }
69
70    #[derive(ThisError, Debug, Clone)]
71    pub enum Error {
72        #[error("canceled by user")]
73        Canceled,
74        #[error("timeout")]
75        Timeout,
76        #[error(transparent)]
77        Common(#[from] CommonError),
78    }
79
80    pub type Svc = TowerClient<Request, Response, Error>;
81}
82
83/// Get user's JWT, require
84/// [`user_token`][crate::config::node::Permissions::user_tokens] permission.
85pub mod get_jwt {
86    use crate::{UserId, utils::TowerClient, utils::tower_client::CommonError};
87    use std::future::Ready;
88    use thiserror::Error as ThisError;
89
90    #[derive(Clone, Copy)]
91    pub struct Request {
92        pub user_id: UserId,
93    }
94
95    #[derive(Clone, Debug)]
96    pub struct Response {
97        pub access_token: String,
98    }
99
100    #[derive(ThisError, Debug, Clone)]
101    pub enum Error {
102        #[error("not allowed")]
103        NotAllowed,
104        #[error("user not found")]
105        UserNotFound,
106        #[error("wrong recipient")]
107        WrongRecipient,
108        #[error("{}: {}", error, error_description)]
109        Supabase {
110            error: String,
111            error_description: String,
112        },
113        #[error(transparent)]
114        Common(#[from] CommonError),
115    }
116
117    impl From<actix::MailboxError> for Error {
118        fn from(value: actix::MailboxError) -> Self {
119            CommonError::from(value).into()
120        }
121    }
122
123    impl actix::Message for Request {
124        type Result = Result<Response, Error>;
125    }
126
127    pub type Svc = TowerClient<Request, Response, Error>;
128
129    pub fn not_allowed() -> Svc {
130        Svc::new(tower::service_fn(|_| {
131            std::future::ready(Result::<Response, _>::Err(Error::NotAllowed))
132        }))
133    }
134
135    #[derive(Clone, Copy, Debug)]
136    pub struct RetryPolicy(pub usize);
137
138    impl Default for RetryPolicy {
139        fn default() -> Self {
140            Self(1)
141        }
142    }
143
144    impl tower::retry::Policy<Request, Response, Error> for RetryPolicy {
145        type Future = Ready<()>;
146
147        fn retry(
148            &mut self,
149            _: &mut Request,
150            result: &mut Result<Response, Error>,
151        ) -> Option<Self::Future> {
152            match result {
153                Err(Error::Supabase {
154                    error_description, ..
155                }) if error_description.contains("Refresh Token") && self.0 > 0 => {
156                    tracing::error!("get_jwt error: {}, retrying", error_description);
157                    self.0 -= 1;
158                    Some(std::future::ready(()))
159                }
160                _ => None,
161            }
162        }
163
164        fn clone_request(&mut self, req: &Request) -> Option<Request> {
165            Some(*req)
166        }
167    }
168}
169
170/// Request Solana signature from external wallets.
171pub mod signer {
172    use crate::{
173        FlowRunId,
174        utils::{TowerClient, tower_client::CommonError},
175    };
176    use actix::MailboxError;
177    use chrono::{DateTime, Utc};
178    use serde::{Deserialize, Serialize};
179    use serde_with::{DisplayFromStr, DurationSecondsWithFrac, base64::Base64, serde_as};
180    use solana_presigner::Presigner as SdkPresigner;
181    use solana_pubkey::Pubkey;
182    use solana_signature::Signature;
183    use std::time::Duration;
184    use thiserror::Error as ThisError;
185
186    #[derive(ThisError, Debug)]
187    pub enum Error {
188        #[error("can't sign for pubkey: {}", .0)]
189        Pubkey(String),
190        #[error("can't sign for this user")]
191        User,
192        #[error("timeout")]
193        Timeout,
194        #[error(transparent)]
195        Common(#[from] CommonError),
196    }
197
198    impl From<MailboxError> for Error {
199        fn from(value: MailboxError) -> Self {
200            CommonError::from(value).into()
201        }
202    }
203
204    pub type Svc = TowerClient<SignatureRequest, SignatureResponse, Error>;
205
206    #[serde_as]
207    #[derive(Debug, Clone, Serialize, Deserialize)]
208    pub struct Presigner {
209        #[serde_as(as = "DisplayFromStr")]
210        pub pubkey: Pubkey,
211        #[serde_as(as = "DisplayFromStr")]
212        pub signature: Signature,
213    }
214
215    impl From<Presigner> for SdkPresigner {
216        fn from(value: Presigner) -> Self {
217            SdkPresigner::new(&value.pubkey, &value.signature)
218        }
219    }
220
221    #[serde_as]
222    #[derive(Debug, Clone, Serialize, Deserialize)]
223    pub struct SignatureRequest {
224        pub id: Option<i64>,
225        #[serde(with = "chrono::serde::ts_milliseconds")]
226        pub time: DateTime<Utc>,
227        pub token: Option<String>,
228        #[serde_as(as = "DisplayFromStr")]
229        pub pubkey: Pubkey,
230        #[serde_as(as = "Base64")]
231        pub message: bytes::Bytes,
232        #[serde_as(as = "DurationSecondsWithFrac<f64>")]
233        pub timeout: Duration,
234        pub flow_run_id: Option<FlowRunId>,
235        pub signatures: Option<Vec<Presigner>>,
236    }
237
238    impl actix::Message for SignatureRequest {
239        type Result = Result<SignatureResponse, Error>;
240    }
241
242    #[serde_as]
243    #[derive(Debug, Clone, Serialize, Deserialize)]
244    pub struct SignatureResponse {
245        #[serde_as(as = "DisplayFromStr")]
246        pub signature: Signature,
247        #[serde_as(as = "Option<Base64>")]
248        pub new_message: Option<bytes::Bytes>,
249    }
250
251    impl bincode::Encode for SignatureResponse {
252        fn encode<E: bincode::enc::Encoder>(
253            &self,
254            encoder: &mut E,
255        ) -> Result<(), bincode::error::EncodeError> {
256            self.signature.as_array().encode(encoder)?;
257            self.new_message
258                .as_ref()
259                .map(|m| m.as_ref())
260                .encode(encoder)?;
261            Ok(())
262        }
263    }
264
265    impl<C> bincode::Decode<C> for SignatureResponse {
266        fn decode<D: bincode::de::Decoder<Context = C>>(
267            decoder: &mut D,
268        ) -> Result<Self, bincode::error::DecodeError> {
269            let signature = Signature::from(<[u8; 64]>::decode(decoder)?);
270            let new_message = Option::<Vec<u8>>::decode(decoder)?.map(Into::into);
271            Ok(Self {
272                signature,
273                new_message,
274            })
275        }
276    }
277
278    impl<'de, C> bincode::BorrowDecode<'de, C> for SignatureResponse {
279        fn borrow_decode<D: bincode::de::BorrowDecoder<'de, Context = C>>(
280            decoder: &mut D,
281        ) -> Result<Self, bincode::error::DecodeError> {
282            bincode::Decode::decode(decoder)
283        }
284    }
285}
286
287/// Output values and Solana instructions to be executed.
288pub mod execute {
289    use crate::{
290        solana::Instructions,
291        utils::{
292            TowerClient,
293            tower_client::{CommonError, CommonErrorExt},
294        },
295    };
296    use futures::channel::oneshot::Canceled;
297    use serde::{Deserialize, Serialize};
298    use serde_with::{DisplayFromStr, base64::Base64, serde_as};
299    use solana_instruction_error::InstructionError;
300    use solana_message::CompileError;
301    use solana_sanitize::SanitizeError;
302
303    use solana_rpc_client_api::client_error::Error as ClientError;
304    use solana_signature::Signature;
305    use solana_signer::SignerError;
306    use std::sync::Arc;
307    use thiserror::Error as ThisError;
308
309    use super::signer;
310
311    pub type Svc = TowerClient<Request, Response, Error>;
312
313    #[derive(Deserialize)]
314    #[serde(try_from = "RequestRepr")]
315    pub struct Request {
316        pub instructions: Instructions,
317        pub output: value::Map,
318    }
319
320    impl bincode::Encode for Request {
321        fn encode<E: bincode::enc::Encoder>(
322            &self,
323            encoder: &mut E,
324        ) -> Result<(), bincode::error::EncodeError> {
325            self.instructions.encode(encoder)?;
326            value::bincode_impl::MapBincode::from(&self.output).encode(encoder)?;
327            Ok(())
328        }
329    }
330
331    impl<C> bincode::Decode<C> for Request {
332        fn decode<D: bincode::de::Decoder<Context = C>>(
333            decoder: &mut D,
334        ) -> Result<Self, bincode::error::DecodeError> {
335            Ok(Self {
336                instructions: Instructions::decode(decoder)?,
337                output: value::bincode_impl::MapBincode::decode(decoder)?
338                    .0
339                    .into_owned(),
340            })
341        }
342    }
343
344    #[serde_as]
345    #[derive(Deserialize)]
346    struct RequestRepr {
347        #[serde_as(as = "Base64")]
348        instructions: Vec<u8>,
349        output: value::Map,
350    }
351
352    impl TryFrom<RequestRepr> for Request {
353        type Error = rmp_serde::decode::Error;
354        fn try_from(value: RequestRepr) -> Result<Self, Self::Error> {
355            Ok(Self {
356                instructions: rmp_serde::from_slice(&value.instructions)?,
357                output: value.output,
358            })
359        }
360    }
361
362    #[serde_as]
363    #[derive(Serialize, Clone, Copy)]
364    pub struct Response {
365        #[serde_as(as = "Option<DisplayFromStr>")]
366        pub signature: Option<Signature>,
367    }
368
369    impl bincode::Encode for Response {
370        fn encode<E: bincode::enc::Encoder>(
371            &self,
372            encoder: &mut E,
373        ) -> Result<(), bincode::error::EncodeError> {
374            self.signature.map(|s| *s.as_array()).encode(encoder)?;
375            Ok(())
376        }
377    }
378
379    impl<C> bincode::Decode<C> for Response {
380        fn decode<D: bincode::de::Decoder<Context = C>>(
381            decoder: &mut D,
382        ) -> Result<Self, bincode::error::DecodeError> {
383            let value = Option::<[u8; 64]>::decode(decoder)?;
384            Ok(Self {
385                signature: value.map(Signature::from),
386            })
387        }
388    }
389
390    fn unwrap(s: &Option<String>) -> &str {
391        s.as_ref().map(|v| v.as_str()).unwrap_or_default()
392    }
393
394    #[serde_as]
395    #[derive(ThisError, Debug, Clone, Serialize, Deserialize)]
396    pub enum Error {
397        #[error("canceled {}", unwrap(.0))]
398        Canceled(Option<String>),
399        #[error("collected")]
400        Collected,
401        #[error("some node failed to provide instructions")]
402        TxIncomplete,
403        #[error("time out")]
404        Timeout,
405        #[error("insufficient solana balance, needed={needed}; have={balance};")]
406        InsufficientSolanaBalance { needed: u64, balance: u64 },
407        #[error("transaction simulation failed")]
408        TxSimFailed,
409        #[error("{}", crate::utils::verbose_solana_error(.error))]
410        Solana {
411            #[source]
412            #[serde_as(as = "Arc<crate::errors::AsClientError>")]
413            error: Arc<ClientError>,
414            inserted: usize,
415        },
416        #[error(transparent)]
417        Signer(
418            #[from]
419            #[serde_as(as = "Arc<crate::errors::AsSignerError>")]
420            Arc<SignerError>,
421        ),
422        #[error(transparent)]
423        CompileError(
424            #[from]
425            #[serde_as(as = "Arc<crate::errors::AsCompileError>")]
426            Arc<CompileError>,
427        ),
428        #[error(transparent)]
429        InstructionError(#[from] Arc<InstructionError>),
430        #[error(transparent)]
431        SanitizeError(
432            #[from]
433            #[serde_as(as = "Arc<crate::errors::AsSanitizeError>")]
434            Arc<SanitizeError>,
435        ),
436        #[error(transparent)]
437        ChannelClosed(
438            #[from]
439            #[serde_as(as = "crate::errors::AsCancelled")]
440            Canceled,
441        ),
442        #[error(transparent)]
443        Common(#[from] CommonError),
444    }
445
446    impl From<actix::MailboxError> for Error {
447        fn from(value: actix::MailboxError) -> Self {
448            CommonError::from(value).into()
449        }
450    }
451
452    impl Error {
453        pub fn solana(error: ClientError, inserted: usize) -> Self {
454            Self::Solana {
455                error: Arc::new(error),
456                inserted,
457            }
458        }
459    }
460
461    impl From<signer::Error> for Error {
462        fn from(value: signer::Error) -> Self {
463            match value {
464                e @ signer::Error::Pubkey(_) => Self::other(e),
465                e @ signer::Error::User => Self::other(e),
466                signer::Error::Timeout => Self::Timeout,
467                signer::Error::Common(error) => Self::Common(error),
468            }
469        }
470    }
471
472    impl From<SignerError> for Error {
473        fn from(value: SignerError) -> Self {
474            Error::Signer(Arc::new(value))
475        }
476    }
477
478    impl From<CompileError> for Error {
479        fn from(value: CompileError) -> Self {
480            Error::CompileError(Arc::new(value))
481        }
482    }
483
484    impl From<InstructionError> for Error {
485        fn from(value: InstructionError) -> Self {
486            Error::InstructionError(Arc::new(value))
487        }
488    }
489
490    impl From<SanitizeError> for Error {
491        fn from(value: SanitizeError) -> Self {
492            Error::SanitizeError(Arc::new(value))
493        }
494    }
495}
496
497#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
498pub struct FlowSetContextData {
499    pub flow_owner: User,
500    pub started_by: User,
501    pub endpoints: Endpoints,
502    pub solana: SolanaClientConfig,
503    pub http: HttpClientConfig,
504}
505
506#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
507pub struct FlowContextData {
508    pub flow_run_id: FlowRunId,
509    pub environment: HashMap<String, String>,
510    pub inputs: ValueSet,
511    pub set: FlowSetContextData,
512}
513
514#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
515pub struct CommandContextData {
516    pub node_id: NodeId,
517    pub times: u32,
518    pub flow: FlowContextData,
519}
520
521#[derive(Clone, bon::Builder)]
522pub struct FlowSetServices {
523    pub http: reqwest::Client,
524    pub solana_client: Arc<SolanaClient>,
525    pub helius: Option<Arc<Helius>>,
526    pub extensions: Arc<Extensions>,
527    pub api_input: api_input::Svc,
528}
529
530#[derive(Clone, bon::Builder)]
531pub struct FlowServices {
532    pub signer: signer::Svc,
533    pub set: FlowSetServices,
534}
535
536#[derive(Clone, bon::Builder)]
537pub struct CommandContext {
538    data: CommandContextData,
539    execute: execute::Svc,
540    get_jwt: get_jwt::Svc,
541    node_log: NodeLogSender,
542    flow: FlowServices,
543}
544
545impl CommandContext {
546    pub fn test_context() -> Self {
547        let config = ContextConfig::default();
548        let solana_client = Arc::new(config.solana_client.build_client(None));
549        let node_id = NodeId::nil();
550        let times = 0;
551        let (tx, _) = flow_run_events::channel();
552        Self {
553            data: CommandContextData {
554                node_id,
555                times,
556                flow: FlowContextData {
557                    flow_run_id: FlowRunId::nil(),
558                    environment: HashMap::new(),
559                    inputs: ValueSet::default(),
560                    set: FlowSetContextData {
561                        flow_owner: User::default(),
562                        started_by: User::default(),
563                        endpoints: Endpoints::default(),
564                        solana: config.solana_client,
565                        http: config.http_client,
566                    },
567                },
568            },
569            execute: unimplemented_svc(),
570            get_jwt: unimplemented_svc(),
571            node_log: NodeLogSender::new(tx, node_id, times),
572            flow: FlowServices {
573                signer: unimplemented_svc(),
574                set: FlowSetServices {
575                    http: reqwest::Client::new(),
576                    solana_client,
577                    helius: None,
578                    extensions: <_>::default(),
579                    api_input: unimplemented_svc(),
580                },
581            },
582        }
583    }
584
585    pub fn log(&self, log: NodeLogContent) -> Result<(), mpsc::SendError> {
586        self.node_log.send(log)
587    }
588
589    pub fn flow_inputs(&self) -> &value::Map {
590        &self.data.flow.inputs
591    }
592
593    pub fn new_interflow_origin(&self) -> FlowRunOrigin {
594        FlowRunOrigin::Interflow {
595            flow_run_id: *self.flow_run_id(),
596            node_id: *self.node_id(),
597            times: *self.times(),
598        }
599    }
600
601    pub fn flow_run_id(&self) -> &FlowRunId {
602        &self.data.flow.flow_run_id
603    }
604
605    pub fn node_id(&self) -> &NodeId {
606        &self.data.node_id
607    }
608
609    pub fn times(&self) -> &u32 {
610        &self.data.times
611    }
612
613    pub fn environment(&self) -> &HashMap<String, String> {
614        &self.data.flow.environment
615    }
616
617    pub fn endpoints(&self) -> &Endpoints {
618        &self.data.flow.set.endpoints
619    }
620
621    pub fn flow_owner(&self) -> &User {
622        &self.data.flow.set.flow_owner
623    }
624
625    pub fn started_by(&self) -> &User {
626        &self.data.flow.set.started_by
627    }
628
629    pub fn solana_config(&self) -> &SolanaClientConfig {
630        &self.data.flow.set.solana
631    }
632
633    pub fn solana_client(&self) -> &Arc<SolanaClient> {
634        &self.flow.set.solana_client
635    }
636
637    pub fn http(&self) -> &reqwest::Client {
638        &self.flow.set.http
639    }
640
641    pub async fn api_input(
642        &mut self,
643        timeout: Option<Duration>,
644        webhook_url: Option<String>,
645        webhook_headers: Option<HeaderMap>,
646        extra: Option<serde_json::Map<String, serde_json::Value>>,
647    ) -> Result<api_input::Response, api_input::Error> {
648        let req = api_input::Request {
649            flow_run_id: *self.flow_run_id(),
650            node_id: *self.node_id(),
651            times: *self.times(),
652            timeout: timeout.unwrap_or(Duration::MAX),
653            webhook_url,
654            webhook_headers,
655            extra,
656        };
657        self.flow.set.api_input.ready().await?.call(req).await
658    }
659
660    /// Call [`get_jwt`] service, the result will have `Bearer ` prefix.
661    pub async fn get_jwt_header(&mut self) -> Result<String, get_jwt::Error> {
662        let user_id = self.flow_owner().id;
663        let access_token = self
664            .get_jwt
665            .ready()
666            .await?
667            .call(get_jwt::Request { user_id })
668            .await?
669            .access_token;
670        Ok(["Bearer ", &access_token].concat())
671    }
672
673    /// Call [`execute`] service.
674    pub async fn execute(
675        &mut self,
676        instructions: Instructions,
677        output: value::Map,
678    ) -> Result<execute::Response, execute::Error> {
679        self.execute
680            .ready()
681            .await?
682            .call(execute::Request {
683                instructions,
684                output,
685            })
686            .await
687    }
688
689    /// Call [`signer`] service.
690    pub async fn request_signature(
691        &mut self,
692        pubkey: Pubkey,
693        token: Option<String>,
694        message: Bytes,
695        timeout: Duration,
696    ) -> Result<signer::SignatureResponse, signer::Error> {
697        self.flow
698            .signer
699            .ready()
700            .await?
701            .call(signer::SignatureRequest {
702                id: None,
703                time: Utc::now(),
704                pubkey,
705                token,
706                message,
707                timeout,
708                flow_run_id: Some(self.data.flow.flow_run_id),
709                signatures: None,
710            })
711            .await
712    }
713
714    /// Get an extension by type.
715    pub fn get<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
716        self.flow.set.extensions.get::<T>()
717    }
718
719    pub fn extensions_mut(&mut self) -> Option<&mut Extensions> {
720        Arc::get_mut(&mut self.flow.set.extensions)
721    }
722
723    pub fn raw(&self) -> RawContext<'_> {
724        RawContext {
725            data: &self.data,
726            services: RawServices {
727                signer: &self.flow.signer,
728                execute: &self.execute,
729                get_jwt: &self.get_jwt,
730            },
731        }
732    }
733}
734
735pub struct RawServices<'a> {
736    pub signer: &'a signer::Svc,
737    pub execute: &'a execute::Svc,
738    pub get_jwt: &'a get_jwt::Svc,
739}
740
741pub struct RawContext<'a> {
742    pub data: &'a CommandContextData,
743    pub services: RawServices<'a>,
744}
745
746impl Default for CommandContext {
747    fn default() -> Self {
748        Self::test_context()
749    }
750}
751
752#[derive(Clone, Copy, Serialize, Deserialize, Debug, JsonSchema)]
753pub struct User {
754    pub id: UserId,
755}
756
757impl User {
758    pub fn new(id: UserId) -> Self {
759        Self { id }
760    }
761}
762
763impl Default for User {
764    /// For testing
765    fn default() -> Self {
766        User {
767            id: uuid::Uuid::nil(),
768        }
769    }
770}