flow_lib/
context.rs

1//! Providing services and information for nodes to use.
2//!
3//! Services are abstracted with [`tower::Service`] trait, using our
4//! [`TowerClient`][crate::utils::TowerClient] utility to make it easier to use.
5//!
6//! Each service is defined is a separated module:
7//! - [`get_jwt`]
8//! - [`execute`]
9//! - [`signer`]
10
11use crate::{
12    ContextConfig, FlowRunId, HttpClientConfig, NodeId, SolanaClientConfig, UserId, ValueSet,
13    config::{Endpoints, client::FlowRunOrigin},
14    flow_run_events::{self, Event, NodeLogContent, NodeLogSender},
15    solana::Instructions,
16    utils::{Extensions, tower_client::unimplemented_svc},
17};
18use bytes::Bytes;
19use chrono::Utc;
20use futures::channel::mpsc;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use solana_pubkey::Pubkey;
24use solana_rpc_client::nonblocking::rpc_client::RpcClient as SolanaClient;
25use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
26use tower::{Service, ServiceExt};
27
28pub mod env {
29    pub const RUST_LOG: &str = "RUST_LOG";
30    pub const OVERWRITE_FEEPAYER: &str = "OVERWRITE_FEEPAYER";
31    pub const COMPUTE_BUDGET: &str = "COMPUTE_BUDGET";
32    pub const FALLBACK_COMPUTE_BUDGET: &str = "FALLBACK_COMPUTE_BUDGET";
33    pub const PRIORITY_FEE: &str = "PRIORITY_FEE";
34    pub const TX_COMMITMENT_LEVEL: &str = "TX_COMMITMENT_LEVEL";
35    pub const WAIT_COMMITMENT_LEVEL: &str = "WAIT_COMMITMENT_LEVEL";
36    pub const EXECUTE_ON: &str = "EXECUTE_ON";
37    pub const DEVNET_LOOKUP_TABLE: &str = "DEVNET_LOOKUP_TABLE";
38    pub const MAINNET_LOOKUP_TABLE: &str = "MAINNET_LOOKUP_TABLE";
39}
40
41pub mod api_input {
42    use std::time::Duration;
43
44    use crate::{
45        FlowRunId, NodeId,
46        utils::{TowerClient, tower_client::CommonError},
47    };
48    use thiserror::Error as ThisError;
49    use value::Value;
50
51    pub struct Request {
52        pub flow_run_id: FlowRunId,
53        pub node_id: NodeId,
54        pub times: u32,
55        pub timeout: Duration,
56    }
57
58    pub struct Response {
59        pub value: Value,
60    }
61
62    #[derive(ThisError, Debug, Clone)]
63    pub enum Error {
64        #[error("canceled by user")]
65        Canceled,
66        #[error("timeout")]
67        Timeout,
68        #[error(transparent)]
69        Common(#[from] CommonError),
70    }
71
72    pub type Svc = TowerClient<Request, Response, Error>;
73}
74
75/// Get user's JWT, require
76/// [`user_token`][crate::config::node::Permissions::user_tokens] permission.
77pub mod get_jwt {
78    use crate::{UserId, utils::TowerClient, utils::tower_client::CommonError};
79    use std::future::Ready;
80    use thiserror::Error as ThisError;
81
82    #[derive(Clone, Copy)]
83    pub struct Request {
84        pub user_id: UserId,
85    }
86
87    #[derive(Clone, Debug)]
88    pub struct Response {
89        pub access_token: String,
90    }
91
92    #[derive(ThisError, Debug, Clone)]
93    pub enum Error {
94        #[error("not allowed")]
95        NotAllowed,
96        #[error("user not found")]
97        UserNotFound,
98        #[error("wrong recipient")]
99        WrongRecipient,
100        #[error("{}: {}", error, error_description)]
101        Supabase {
102            error: String,
103            error_description: String,
104        },
105        #[error(transparent)]
106        Common(#[from] CommonError),
107    }
108
109    impl From<actix::MailboxError> for Error {
110        fn from(value: actix::MailboxError) -> Self {
111            CommonError::from(value).into()
112        }
113    }
114
115    impl actix::Message for Request {
116        type Result = Result<Response, Error>;
117    }
118
119    pub type Svc = TowerClient<Request, Response, Error>;
120
121    pub fn not_allowed() -> Svc {
122        Svc::new(tower::service_fn(|_| {
123            std::future::ready(Result::<Response, _>::Err(Error::NotAllowed))
124        }))
125    }
126
127    #[derive(Clone, Copy, Debug)]
128    pub struct RetryPolicy(pub usize);
129
130    impl Default for RetryPolicy {
131        fn default() -> Self {
132            Self(1)
133        }
134    }
135
136    impl tower::retry::Policy<Request, Response, Error> for RetryPolicy {
137        type Future = Ready<()>;
138
139        fn retry(
140            &mut self,
141            _: &mut Request,
142            result: &mut Result<Response, Error>,
143        ) -> Option<Self::Future> {
144            match result {
145                Err(Error::Supabase {
146                    error_description, ..
147                }) if error_description.contains("Refresh Token") && self.0 > 0 => {
148                    tracing::error!("get_jwt error: {}, retrying", error_description);
149                    self.0 -= 1;
150                    Some(std::future::ready(()))
151                }
152                _ => None,
153            }
154        }
155
156        fn clone_request(&mut self, req: &Request) -> Option<Request> {
157            Some(*req)
158        }
159    }
160}
161
162/// Request Solana signature from external wallets.
163pub mod signer {
164    use crate::{
165        FlowRunId,
166        utils::{TowerClient, tower_client::CommonError},
167    };
168    use actix::MailboxError;
169    use chrono::{DateTime, Utc};
170    use serde::{Deserialize, Serialize};
171    use serde_with::{DisplayFromStr, DurationSecondsWithFrac, base64::Base64, serde_as};
172    use solana_presigner::Presigner as SdkPresigner;
173    use solana_pubkey::Pubkey;
174    use solana_signature::Signature;
175    use std::time::Duration;
176    use thiserror::Error as ThisError;
177
178    #[derive(ThisError, Debug)]
179    pub enum Error {
180        #[error("can't sign for pubkey: {}", .0)]
181        Pubkey(String),
182        #[error("can't sign for this user")]
183        User,
184        #[error("timeout")]
185        Timeout,
186        #[error(transparent)]
187        Common(#[from] CommonError),
188    }
189
190    impl From<MailboxError> for Error {
191        fn from(value: MailboxError) -> Self {
192            CommonError::from(value).into()
193        }
194    }
195
196    pub type Svc = TowerClient<SignatureRequest, SignatureResponse, Error>;
197
198    #[serde_as]
199    #[derive(Debug, Clone, Serialize, Deserialize)]
200    pub struct Presigner {
201        #[serde_as(as = "DisplayFromStr")]
202        pub pubkey: Pubkey,
203        #[serde_as(as = "DisplayFromStr")]
204        pub signature: Signature,
205    }
206
207    impl From<Presigner> for SdkPresigner {
208        fn from(value: Presigner) -> Self {
209            SdkPresigner::new(&value.pubkey, &value.signature)
210        }
211    }
212
213    #[serde_as]
214    #[derive(Debug, Clone, Serialize, Deserialize)]
215    pub struct SignatureRequest {
216        pub id: Option<i64>,
217        #[serde(with = "chrono::serde::ts_milliseconds")]
218        pub time: DateTime<Utc>,
219        #[serde_as(as = "DisplayFromStr")]
220        pub pubkey: Pubkey,
221        #[serde_as(as = "Base64")]
222        pub message: bytes::Bytes,
223        #[serde_as(as = "DurationSecondsWithFrac<f64>")]
224        pub timeout: Duration,
225        pub flow_run_id: Option<FlowRunId>,
226        pub signatures: Option<Vec<Presigner>>,
227    }
228
229    impl actix::Message for SignatureRequest {
230        type Result = Result<SignatureResponse, Error>;
231    }
232
233    #[serde_as]
234    #[derive(Debug, Clone, Serialize, Deserialize)]
235    pub struct SignatureResponse {
236        #[serde_as(as = "DisplayFromStr")]
237        pub signature: Signature,
238        #[serde_as(as = "Option<Base64>")]
239        pub new_message: Option<bytes::Bytes>,
240    }
241
242    impl bincode::Encode for SignatureResponse {
243        fn encode<E: bincode::enc::Encoder>(
244            &self,
245            encoder: &mut E,
246        ) -> Result<(), bincode::error::EncodeError> {
247            self.signature.as_array().encode(encoder)?;
248            self.new_message
249                .as_ref()
250                .map(|m| m.as_ref())
251                .encode(encoder)?;
252            Ok(())
253        }
254    }
255
256    impl<C> bincode::Decode<C> for SignatureResponse {
257        fn decode<D: bincode::de::Decoder<Context = C>>(
258            decoder: &mut D,
259        ) -> Result<Self, bincode::error::DecodeError> {
260            let signature = Signature::from(<[u8; 64]>::decode(decoder)?);
261            let new_message = Option::<Vec<u8>>::decode(decoder)?.map(Into::into);
262            Ok(Self {
263                signature,
264                new_message,
265            })
266        }
267    }
268
269    impl<'de, C> bincode::BorrowDecode<'de, C> for SignatureResponse {
270        fn borrow_decode<D: bincode::de::BorrowDecoder<'de, Context = C>>(
271            decoder: &mut D,
272        ) -> Result<Self, bincode::error::DecodeError> {
273            bincode::Decode::decode(decoder)
274        }
275    }
276}
277
278/// Output values and Solana instructions to be executed.
279pub mod execute {
280    use crate::{
281        solana::Instructions,
282        utils::{
283            TowerClient,
284            tower_client::{CommonError, CommonErrorExt},
285        },
286    };
287    use futures::channel::oneshot::Canceled;
288    use serde::{Deserialize, Serialize};
289    use serde_with::{DisplayFromStr, base64::Base64, serde_as};
290    use solana_instruction_error::InstructionError;
291    use solana_message::CompileError;
292    use solana_sanitize::SanitizeError;
293
294    use solana_rpc_client_api::client_error::Error as ClientError;
295    use solana_signature::Signature;
296    use solana_signer::SignerError;
297    use std::sync::Arc;
298    use thiserror::Error as ThisError;
299
300    use super::signer;
301
302    pub type Svc = TowerClient<Request, Response, Error>;
303
304    #[derive(Deserialize)]
305    #[serde(try_from = "RequestRepr")]
306    pub struct Request {
307        pub instructions: Instructions,
308        pub output: value::Map,
309    }
310
311    impl bincode::Encode for Request {
312        fn encode<E: bincode::enc::Encoder>(
313            &self,
314            encoder: &mut E,
315        ) -> Result<(), bincode::error::EncodeError> {
316            self.instructions.encode(encoder)?;
317            value::bincode_impl::MapBincode::from(&self.output).encode(encoder)?;
318            Ok(())
319        }
320    }
321
322    impl<C> bincode::Decode<C> for Request {
323        fn decode<D: bincode::de::Decoder<Context = C>>(
324            decoder: &mut D,
325        ) -> Result<Self, bincode::error::DecodeError> {
326            Ok(Self {
327                instructions: Instructions::decode(decoder)?,
328                output: value::bincode_impl::MapBincode::decode(decoder)?
329                    .0
330                    .into_owned(),
331            })
332        }
333    }
334
335    #[serde_as]
336    #[derive(Deserialize)]
337    struct RequestRepr {
338        #[serde_as(as = "Base64")]
339        instructions: Vec<u8>,
340        output: value::Map,
341    }
342
343    impl TryFrom<RequestRepr> for Request {
344        type Error = rmp_serde::decode::Error;
345        fn try_from(value: RequestRepr) -> Result<Self, Self::Error> {
346            Ok(Self {
347                instructions: rmp_serde::from_slice(&value.instructions)?,
348                output: value.output,
349            })
350        }
351    }
352
353    #[serde_as]
354    #[derive(Serialize, Clone, Copy)]
355    pub struct Response {
356        #[serde_as(as = "Option<DisplayFromStr>")]
357        pub signature: Option<Signature>,
358    }
359
360    impl bincode::Encode for Response {
361        fn encode<E: bincode::enc::Encoder>(
362            &self,
363            encoder: &mut E,
364        ) -> Result<(), bincode::error::EncodeError> {
365            self.signature.map(|s| *s.as_array()).encode(encoder)?;
366            Ok(())
367        }
368    }
369
370    impl<C> bincode::Decode<C> for Response {
371        fn decode<D: bincode::de::Decoder<Context = C>>(
372            decoder: &mut D,
373        ) -> Result<Self, bincode::error::DecodeError> {
374            let value = Option::<[u8; 64]>::decode(decoder)?;
375            Ok(Self {
376                signature: value.map(Signature::from),
377            })
378        }
379    }
380
381    fn unwrap(s: &Option<String>) -> &str {
382        s.as_ref().map(|v| v.as_str()).unwrap_or_default()
383    }
384
385    #[serde_as]
386    #[derive(ThisError, Debug, Clone, Serialize, Deserialize)]
387    pub enum Error {
388        #[error("canceled {}", unwrap(.0))]
389        Canceled(Option<String>),
390        #[error("collected")]
391        Collected,
392        #[error("some node failed to provide instructions")]
393        TxIncomplete,
394        #[error("time out")]
395        Timeout,
396        #[error("insufficient solana balance, needed={needed}; have={balance};")]
397        InsufficientSolanaBalance { needed: u64, balance: u64 },
398        #[error("transaction simulation failed")]
399        TxSimFailed,
400        #[error("{}", crate::utils::verbose_solana_error(.error))]
401        Solana {
402            #[source]
403            #[serde_as(as = "Arc<crate::errors::AsClientError>")]
404            error: Arc<ClientError>,
405            inserted: usize,
406        },
407        #[error(transparent)]
408        Signer(
409            #[from]
410            #[serde_as(as = "Arc<crate::errors::AsSignerError>")]
411            Arc<SignerError>,
412        ),
413        #[error(transparent)]
414        CompileError(
415            #[from]
416            #[serde_as(as = "Arc<crate::errors::AsCompileError>")]
417            Arc<CompileError>,
418        ),
419        #[error(transparent)]
420        InstructionError(#[from] Arc<InstructionError>),
421        #[error(transparent)]
422        SanitizeError(
423            #[from]
424            #[serde_as(as = "Arc<crate::errors::AsSanitizeError>")]
425            Arc<SanitizeError>,
426        ),
427        #[error(transparent)]
428        ChannelClosed(
429            #[from]
430            #[serde_as(as = "crate::errors::AsCancelled")]
431            Canceled,
432        ),
433        #[error(transparent)]
434        Common(#[from] CommonError),
435    }
436
437    impl From<actix::MailboxError> for Error {
438        fn from(value: actix::MailboxError) -> Self {
439            CommonError::from(value).into()
440        }
441    }
442
443    impl Error {
444        pub fn solana(error: ClientError, inserted: usize) -> Self {
445            Self::Solana {
446                error: Arc::new(error),
447                inserted,
448            }
449        }
450    }
451
452    impl From<signer::Error> for Error {
453        fn from(value: signer::Error) -> Self {
454            match value {
455                e @ signer::Error::Pubkey(_) => Self::other(e),
456                e @ signer::Error::User => Self::other(e),
457                signer::Error::Timeout => Self::Timeout,
458                signer::Error::Common(error) => Self::Common(error),
459            }
460        }
461    }
462
463    impl From<SignerError> for Error {
464        fn from(value: SignerError) -> Self {
465            Error::Signer(Arc::new(value))
466        }
467    }
468
469    impl From<CompileError> for Error {
470        fn from(value: CompileError) -> Self {
471            Error::CompileError(Arc::new(value))
472        }
473    }
474
475    impl From<InstructionError> for Error {
476        fn from(value: InstructionError) -> Self {
477            Error::InstructionError(Arc::new(value))
478        }
479    }
480
481    impl From<SanitizeError> for Error {
482        fn from(value: SanitizeError) -> Self {
483            Error::SanitizeError(Arc::new(value))
484        }
485    }
486}
487
488#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
489pub struct FlowSetContextData {
490    pub flow_owner: User,
491    pub started_by: User,
492    pub endpoints: Endpoints,
493    pub solana: SolanaClientConfig,
494    pub http: HttpClientConfig,
495}
496
497#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
498pub struct FlowContextData {
499    pub flow_run_id: FlowRunId,
500    pub environment: HashMap<String, String>,
501    pub inputs: ValueSet,
502    pub set: FlowSetContextData,
503}
504
505#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
506pub struct CommandContextData {
507    pub node_id: NodeId,
508    pub times: u32,
509    pub flow: FlowContextData,
510}
511
512#[derive(Clone, bon::Builder)]
513pub struct FlowSetServices {
514    pub http: reqwest::Client,
515    pub solana_client: Arc<SolanaClient>,
516    pub extensions: Arc<Extensions>,
517    pub api_input: api_input::Svc,
518}
519
520#[derive(Clone, bon::Builder)]
521pub struct FlowServices {
522    pub signer: signer::Svc,
523    pub set: FlowSetServices,
524}
525
526#[derive(Clone, bon::Builder)]
527pub struct CommandContext {
528    data: CommandContextData,
529    execute: execute::Svc,
530    get_jwt: get_jwt::Svc,
531    node_log: NodeLogSender,
532    flow: FlowServices,
533}
534
535impl CommandContext {
536    pub fn test_context() -> Self {
537        let config = ContextConfig::default();
538        let solana_client = Arc::new(config.solana_client.build_client(None));
539        let node_id = NodeId::nil();
540        let times = 0;
541        let (tx, _) = flow_run_events::channel();
542        Self {
543            data: CommandContextData {
544                node_id,
545                times,
546                flow: FlowContextData {
547                    flow_run_id: FlowRunId::nil(),
548                    environment: HashMap::new(),
549                    inputs: ValueSet::default(),
550                    set: FlowSetContextData {
551                        flow_owner: User::default(),
552                        started_by: User::default(),
553                        endpoints: Endpoints::default(),
554                        solana: config.solana_client,
555                        http: config.http_client,
556                    },
557                },
558            },
559            execute: unimplemented_svc(),
560            get_jwt: unimplemented_svc(),
561            node_log: NodeLogSender::new(tx, node_id, times),
562            flow: FlowServices {
563                signer: unimplemented_svc(),
564                set: FlowSetServices {
565                    http: reqwest::Client::new(),
566                    solana_client,
567                    extensions: <_>::default(),
568                    api_input: unimplemented_svc(),
569                },
570            },
571        }
572    }
573
574    pub fn log(&self, log: NodeLogContent) -> Result<(), mpsc::TrySendError<Event>> {
575        self.node_log.send(log)
576    }
577
578    pub fn flow_inputs(&self) -> &value::Map {
579        &self.data.flow.inputs
580    }
581
582    pub fn new_interflow_origin(&self) -> FlowRunOrigin {
583        FlowRunOrigin::Interflow {
584            flow_run_id: *self.flow_run_id(),
585            node_id: *self.node_id(),
586            times: *self.times(),
587        }
588    }
589
590    pub fn flow_run_id(&self) -> &FlowRunId {
591        &self.data.flow.flow_run_id
592    }
593
594    pub fn node_id(&self) -> &NodeId {
595        &self.data.node_id
596    }
597
598    pub fn times(&self) -> &u32 {
599        &self.data.times
600    }
601
602    pub fn environment(&self) -> &HashMap<String, String> {
603        &self.data.flow.environment
604    }
605
606    pub fn endpoints(&self) -> &Endpoints {
607        &self.data.flow.set.endpoints
608    }
609
610    pub fn flow_owner(&self) -> &User {
611        &self.data.flow.set.flow_owner
612    }
613
614    pub fn started_by(&self) -> &User {
615        &self.data.flow.set.started_by
616    }
617
618    pub fn solana_config(&self) -> &SolanaClientConfig {
619        &self.data.flow.set.solana
620    }
621
622    pub fn solana_client(&self) -> &Arc<SolanaClient> {
623        &self.flow.set.solana_client
624    }
625
626    pub fn http(&self) -> &reqwest::Client {
627        &self.flow.set.http
628    }
629
630    pub async fn api_input(
631        &mut self,
632        timeout: Option<Duration>,
633    ) -> Result<api_input::Response, api_input::Error> {
634        let req = api_input::Request {
635            flow_run_id: *self.flow_run_id(),
636            node_id: *self.node_id(),
637            times: *self.times(),
638            timeout: timeout.unwrap_or(Duration::MAX),
639        };
640        self.flow.set.api_input.ready().await?.call(req).await
641    }
642
643    /// Call [`get_jwt`] service, the result will have `Bearer ` prefix.
644    pub async fn get_jwt_header(&mut self) -> Result<String, get_jwt::Error> {
645        let user_id = self.flow_owner().id;
646        let access_token = self
647            .get_jwt
648            .ready()
649            .await?
650            .call(get_jwt::Request { user_id })
651            .await?
652            .access_token;
653        Ok(["Bearer ", &access_token].concat())
654    }
655
656    /// Call [`execute`] service.
657    pub async fn execute(
658        &mut self,
659        instructions: Instructions,
660        output: value::Map,
661    ) -> Result<execute::Response, execute::Error> {
662        self.execute
663            .ready()
664            .await?
665            .call(execute::Request {
666                instructions,
667                output,
668            })
669            .await
670    }
671
672    /// Call [`signer`] service.
673    pub async fn request_signature(
674        &mut self,
675        pubkey: Pubkey,
676        message: Bytes,
677        timeout: Duration,
678    ) -> Result<signer::SignatureResponse, signer::Error> {
679        self.flow
680            .signer
681            .ready()
682            .await?
683            .call(signer::SignatureRequest {
684                id: None,
685                time: Utc::now(),
686                pubkey,
687                message,
688                timeout,
689                flow_run_id: Some(self.data.flow.flow_run_id),
690                signatures: None,
691            })
692            .await
693    }
694
695    /// Get an extension by type.
696    pub fn get<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
697        self.flow.set.extensions.get::<T>()
698    }
699
700    pub fn extensions_mut(&mut self) -> Option<&mut Extensions> {
701        Arc::get_mut(&mut self.flow.set.extensions)
702    }
703
704    pub fn raw(&self) -> RawContext<'_> {
705        RawContext {
706            data: &self.data,
707            services: RawServices {
708                signer: &self.flow.signer,
709                execute: &self.execute,
710                get_jwt: &self.get_jwt,
711            },
712        }
713    }
714}
715
716pub struct RawServices<'a> {
717    pub signer: &'a signer::Svc,
718    pub execute: &'a execute::Svc,
719    pub get_jwt: &'a get_jwt::Svc,
720}
721
722pub struct RawContext<'a> {
723    pub data: &'a CommandContextData,
724    pub services: RawServices<'a>,
725}
726
727impl Default for CommandContext {
728    fn default() -> Self {
729        Self::test_context()
730    }
731}
732
733#[derive(Clone, Copy, Serialize, Deserialize, Debug, JsonSchema)]
734pub struct User {
735    pub id: UserId,
736}
737
738impl User {
739    pub fn new(id: UserId) -> Self {
740        Self { id }
741    }
742}
743
744impl Default for User {
745    /// For testing
746    fn default() -> Self {
747        User {
748            id: uuid::Uuid::nil(),
749        }
750    }
751}