flow_lib/
context.rs

1//! Providing services and information for nodes to use.
2//!
3//! Services are abstracted with [`tower::Service`] trait, using our
4//! [`TowerClient`][crate::utils::TowerClient] utility to make it easier to use.
5//!
6//! Each service is defined is a separated module:
7//! - [`get_jwt`]
8//! - [`execute`]
9//! - [`signer`]
10
11use crate::{
12    ContextConfig, FlowRunId, HttpClientConfig, NodeId, SolanaClientConfig, UserId, ValueSet,
13    config::{Endpoints, client::FlowRunOrigin},
14    flow_run_events::{self, Event, NodeLogContent, NodeLogSender},
15    solana::Instructions,
16    utils::{Extensions, tower_client::unimplemented_svc},
17};
18use bytes::Bytes;
19use chrono::Utc;
20use futures::channel::mpsc;
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use solana_pubkey::Pubkey;
24use solana_rpc_client::nonblocking::rpc_client::RpcClient as SolanaClient;
25use std::{any::Any, collections::HashMap, sync::Arc, time::Duration};
26use tower::{Service, ServiceExt};
27
28pub mod env {
29    pub const RUST_LOG: &str = "RUST_LOG";
30    pub const OVERWRITE_FEEPAYER: &str = "OVERWRITE_FEEPAYER";
31    pub const COMPUTE_BUDGET: &str = "COMPUTE_BUDGET";
32    pub const FALLBACK_COMPUTE_BUDGET: &str = "FALLBACK_COMPUTE_BUDGET";
33    pub const PRIORITY_FEE: &str = "PRIORITY_FEE";
34    pub const SIMULATION_COMMITMENT_LEVEL: &str = "SIMULATION_COMMITMENT_LEVEL";
35    pub const TX_COMMITMENT_LEVEL: &str = "TX_COMMITMENT_LEVEL";
36    pub const WAIT_COMMITMENT_LEVEL: &str = "WAIT_COMMITMENT_LEVEL";
37    pub const EXECUTE_ON: &str = "EXECUTE_ON";
38    pub const DEVNET_LOOKUP_TABLE: &str = "DEVNET_LOOKUP_TABLE";
39    pub const MAINNET_LOOKUP_TABLE: &str = "MAINNET_LOOKUP_TABLE";
40}
41
42pub mod api_input {
43    use std::time::Duration;
44
45    use crate::{
46        FlowRunId, NodeId,
47        utils::{TowerClient, tower_client::CommonError},
48    };
49    use thiserror::Error as ThisError;
50    use value::Value;
51
52    pub struct Request {
53        pub flow_run_id: FlowRunId,
54        pub node_id: NodeId,
55        pub times: u32,
56        pub timeout: Duration,
57    }
58
59    pub struct Response {
60        pub value: Value,
61    }
62
63    #[derive(ThisError, Debug, Clone)]
64    pub enum Error {
65        #[error("canceled by user")]
66        Canceled,
67        #[error("timeout")]
68        Timeout,
69        #[error(transparent)]
70        Common(#[from] CommonError),
71    }
72
73    pub type Svc = TowerClient<Request, Response, Error>;
74}
75
76/// Get user's JWT, require
77/// [`user_token`][crate::config::node::Permissions::user_tokens] permission.
78pub mod get_jwt {
79    use crate::{UserId, utils::TowerClient, utils::tower_client::CommonError};
80    use std::future::Ready;
81    use thiserror::Error as ThisError;
82
83    #[derive(Clone, Copy)]
84    pub struct Request {
85        pub user_id: UserId,
86    }
87
88    #[derive(Clone, Debug)]
89    pub struct Response {
90        pub access_token: String,
91    }
92
93    #[derive(ThisError, Debug, Clone)]
94    pub enum Error {
95        #[error("not allowed")]
96        NotAllowed,
97        #[error("user not found")]
98        UserNotFound,
99        #[error("wrong recipient")]
100        WrongRecipient,
101        #[error("{}: {}", error, error_description)]
102        Supabase {
103            error: String,
104            error_description: String,
105        },
106        #[error(transparent)]
107        Common(#[from] CommonError),
108    }
109
110    impl From<actix::MailboxError> for Error {
111        fn from(value: actix::MailboxError) -> Self {
112            CommonError::from(value).into()
113        }
114    }
115
116    impl actix::Message for Request {
117        type Result = Result<Response, Error>;
118    }
119
120    pub type Svc = TowerClient<Request, Response, Error>;
121
122    pub fn not_allowed() -> Svc {
123        Svc::new(tower::service_fn(|_| {
124            std::future::ready(Result::<Response, _>::Err(Error::NotAllowed))
125        }))
126    }
127
128    #[derive(Clone, Copy, Debug)]
129    pub struct RetryPolicy(pub usize);
130
131    impl Default for RetryPolicy {
132        fn default() -> Self {
133            Self(1)
134        }
135    }
136
137    impl tower::retry::Policy<Request, Response, Error> for RetryPolicy {
138        type Future = Ready<()>;
139
140        fn retry(
141            &mut self,
142            _: &mut Request,
143            result: &mut Result<Response, Error>,
144        ) -> Option<Self::Future> {
145            match result {
146                Err(Error::Supabase {
147                    error_description, ..
148                }) if error_description.contains("Refresh Token") && self.0 > 0 => {
149                    tracing::error!("get_jwt error: {}, retrying", error_description);
150                    self.0 -= 1;
151                    Some(std::future::ready(()))
152                }
153                _ => None,
154            }
155        }
156
157        fn clone_request(&mut self, req: &Request) -> Option<Request> {
158            Some(*req)
159        }
160    }
161}
162
163/// Request Solana signature from external wallets.
164pub mod signer {
165    use crate::{
166        FlowRunId,
167        utils::{TowerClient, tower_client::CommonError},
168    };
169    use actix::MailboxError;
170    use chrono::{DateTime, Utc};
171    use serde::{Deserialize, Serialize};
172    use serde_with::{DisplayFromStr, DurationSecondsWithFrac, base64::Base64, serde_as};
173    use solana_presigner::Presigner as SdkPresigner;
174    use solana_pubkey::Pubkey;
175    use solana_signature::Signature;
176    use std::time::Duration;
177    use thiserror::Error as ThisError;
178
179    #[derive(ThisError, Debug)]
180    pub enum Error {
181        #[error("can't sign for pubkey: {}", .0)]
182        Pubkey(String),
183        #[error("can't sign for this user")]
184        User,
185        #[error("timeout")]
186        Timeout,
187        #[error(transparent)]
188        Common(#[from] CommonError),
189    }
190
191    impl From<MailboxError> for Error {
192        fn from(value: MailboxError) -> Self {
193            CommonError::from(value).into()
194        }
195    }
196
197    pub type Svc = TowerClient<SignatureRequest, SignatureResponse, Error>;
198
199    #[serde_as]
200    #[derive(Debug, Clone, Serialize, Deserialize)]
201    pub struct Presigner {
202        #[serde_as(as = "DisplayFromStr")]
203        pub pubkey: Pubkey,
204        #[serde_as(as = "DisplayFromStr")]
205        pub signature: Signature,
206    }
207
208    impl From<Presigner> for SdkPresigner {
209        fn from(value: Presigner) -> Self {
210            SdkPresigner::new(&value.pubkey, &value.signature)
211        }
212    }
213
214    #[serde_as]
215    #[derive(Debug, Clone, Serialize, Deserialize)]
216    pub struct SignatureRequest {
217        pub id: Option<i64>,
218        #[serde(with = "chrono::serde::ts_milliseconds")]
219        pub time: DateTime<Utc>,
220        #[serde_as(as = "DisplayFromStr")]
221        pub pubkey: Pubkey,
222        #[serde_as(as = "Base64")]
223        pub message: bytes::Bytes,
224        #[serde_as(as = "DurationSecondsWithFrac<f64>")]
225        pub timeout: Duration,
226        pub flow_run_id: Option<FlowRunId>,
227        pub signatures: Option<Vec<Presigner>>,
228    }
229
230    impl actix::Message for SignatureRequest {
231        type Result = Result<SignatureResponse, Error>;
232    }
233
234    #[serde_as]
235    #[derive(Debug, Clone, Serialize, Deserialize)]
236    pub struct SignatureResponse {
237        #[serde_as(as = "DisplayFromStr")]
238        pub signature: Signature,
239        #[serde_as(as = "Option<Base64>")]
240        pub new_message: Option<bytes::Bytes>,
241    }
242}
243
244/// Output values and Solana instructions to be executed.
245pub mod execute {
246    use crate::{
247        solana::Instructions,
248        utils::{
249            TowerClient,
250            tower_client::{CommonError, CommonErrorExt},
251        },
252    };
253    use futures::channel::oneshot::Canceled;
254    use serde::{Deserialize, Serialize};
255    use serde_with::{DisplayFromStr, base64::Base64, serde_as};
256    use solana_program::{
257        instruction::InstructionError, message::CompileError, sanitize::SanitizeError,
258    };
259
260    use solana_rpc_client_api::client_error::Error as ClientError;
261    use solana_signature::Signature;
262    use solana_signer::SignerError;
263    use std::sync::Arc;
264    use thiserror::Error as ThisError;
265
266    use super::signer;
267
268    pub type Svc = TowerClient<Request, Response, Error>;
269
270    #[derive(Deserialize)]
271    #[serde(try_from = "RequestRepr")]
272    pub struct Request {
273        pub instructions: Instructions,
274        pub output: value::Map,
275    }
276
277    impl bincode::Encode for Request {
278        fn encode<E: bincode::enc::Encoder>(
279            &self,
280            encoder: &mut E,
281        ) -> Result<(), bincode::error::EncodeError> {
282            self.instructions.encode(encoder)?;
283            value::bincode_impl::MapBincode::from(&self.output).encode(encoder)?;
284            Ok(())
285        }
286    }
287
288    impl<C> bincode::Decode<C> for Request {
289        fn decode<D: bincode::de::Decoder<Context = C>>(
290            decoder: &mut D,
291        ) -> Result<Self, bincode::error::DecodeError> {
292            Ok(Self {
293                instructions: Instructions::decode(decoder)?,
294                output: value::bincode_impl::MapBincode::decode(decoder)?
295                    .0
296                    .into_owned(),
297            })
298        }
299    }
300
301    #[serde_as]
302    #[derive(Deserialize)]
303    struct RequestRepr {
304        #[serde_as(as = "Base64")]
305        instructions: Vec<u8>,
306        output: value::Map,
307    }
308
309    impl TryFrom<RequestRepr> for Request {
310        type Error = rmp_serde::decode::Error;
311        fn try_from(value: RequestRepr) -> Result<Self, Self::Error> {
312            Ok(Self {
313                instructions: rmp_serde::from_slice(&value.instructions)?,
314                output: value.output,
315            })
316        }
317    }
318
319    #[serde_as]
320    #[derive(Serialize, Clone, Copy)]
321    pub struct Response {
322        #[serde_as(as = "Option<DisplayFromStr>")]
323        pub signature: Option<Signature>,
324    }
325
326    impl bincode::Encode for Response {
327        fn encode<E: bincode::enc::Encoder>(
328            &self,
329            encoder: &mut E,
330        ) -> Result<(), bincode::error::EncodeError> {
331            self.signature.map(|s| *s.as_array()).encode(encoder)?;
332            Ok(())
333        }
334    }
335
336    impl<C> bincode::Decode<C> for Response {
337        fn decode<D: bincode::de::Decoder<Context = C>>(
338            decoder: &mut D,
339        ) -> Result<Self, bincode::error::DecodeError> {
340            let value = Option::<[u8; 64]>::decode(decoder)?;
341            Ok(Self {
342                signature: value.map(Signature::from),
343            })
344        }
345    }
346
347    fn unwrap(s: &Option<String>) -> &str {
348        s.as_ref().map(|v| v.as_str()).unwrap_or_default()
349    }
350
351    #[derive(ThisError, Debug, Clone)]
352    pub enum Error {
353        #[error("canceled {}", unwrap(.0))]
354        Canceled(Option<String>),
355        #[error("some node failed to provide instructions")]
356        TxIncomplete,
357        #[error("time out")]
358        Timeout,
359        #[error("insufficient solana balance, needed={needed}; have={balance};")]
360        InsufficientSolanaBalance { needed: u64, balance: u64 },
361        #[error("transaction simulation failed")]
362        TxSimFailed,
363        #[error("{}", crate::utils::verbose_solana_error(.error))]
364        Solana {
365            #[source]
366            error: Arc<ClientError>,
367            inserted: usize,
368        },
369        #[error(transparent)]
370        Signer(#[from] Arc<SignerError>),
371        #[error(transparent)]
372        CompileError(#[from] Arc<CompileError>),
373        #[error(transparent)]
374        InstructionError(#[from] Arc<InstructionError>),
375        #[error(transparent)]
376        SanitizeError(#[from] Arc<SanitizeError>),
377        #[error(transparent)]
378        ChannelClosed(#[from] Canceled),
379        #[error(transparent)]
380        Common(#[from] CommonError),
381    }
382
383    impl From<actix::MailboxError> for Error {
384        fn from(value: actix::MailboxError) -> Self {
385            CommonError::from(value).into()
386        }
387    }
388
389    impl Error {
390        pub fn solana(error: ClientError, inserted: usize) -> Self {
391            Self::Solana {
392                error: Arc::new(error),
393                inserted,
394            }
395        }
396    }
397
398    impl From<signer::Error> for Error {
399        fn from(value: signer::Error) -> Self {
400            match value {
401                e @ signer::Error::Pubkey(_) => Self::other(e),
402                e @ signer::Error::User => Self::other(e),
403                signer::Error::Timeout => Self::Timeout,
404                signer::Error::Common(error) => Self::Common(error),
405            }
406        }
407    }
408
409    impl From<SignerError> for Error {
410        fn from(value: SignerError) -> Self {
411            Error::Signer(Arc::new(value))
412        }
413    }
414
415    impl From<CompileError> for Error {
416        fn from(value: CompileError) -> Self {
417            Error::CompileError(Arc::new(value))
418        }
419    }
420
421    impl From<InstructionError> for Error {
422        fn from(value: InstructionError) -> Self {
423            Error::InstructionError(Arc::new(value))
424        }
425    }
426
427    impl From<SanitizeError> for Error {
428        fn from(value: SanitizeError) -> Self {
429            Error::SanitizeError(Arc::new(value))
430        }
431    }
432}
433
434#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
435pub struct FlowSetContextData {
436    pub flow_owner: User,
437    pub started_by: User,
438    pub endpoints: Endpoints,
439    pub solana: SolanaClientConfig,
440    pub http: HttpClientConfig,
441}
442
443#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
444pub struct FlowContextData {
445    pub flow_run_id: FlowRunId,
446    pub environment: HashMap<String, String>,
447    pub inputs: ValueSet,
448    pub set: FlowSetContextData,
449}
450
451#[derive(Serialize, Deserialize, Debug, Clone, JsonSchema)]
452pub struct CommandContextData {
453    pub node_id: NodeId,
454    pub times: u32,
455    pub flow: FlowContextData,
456}
457
458#[derive(Clone, bon::Builder)]
459pub struct FlowSetServices {
460    pub http: reqwest::Client,
461    pub solana_client: Arc<SolanaClient>,
462    pub extensions: Arc<Extensions>,
463    pub api_input: api_input::Svc,
464}
465
466#[derive(Clone, bon::Builder)]
467pub struct FlowServices {
468    pub signer: signer::Svc,
469    pub set: FlowSetServices,
470}
471
472#[derive(Clone, bon::Builder)]
473pub struct CommandContext {
474    data: CommandContextData,
475    execute: execute::Svc,
476    get_jwt: get_jwt::Svc,
477    node_log: NodeLogSender,
478    flow: FlowServices,
479}
480
481impl CommandContext {
482    pub fn test_context() -> Self {
483        let config = ContextConfig::default();
484        let solana_client = Arc::new(config.solana_client.build_client(None));
485        let node_id = NodeId::nil();
486        let times = 0;
487        let (tx, _) = flow_run_events::channel();
488        Self {
489            data: CommandContextData {
490                node_id,
491                times,
492                flow: FlowContextData {
493                    flow_run_id: FlowRunId::nil(),
494                    environment: HashMap::new(),
495                    inputs: ValueSet::default(),
496                    set: FlowSetContextData {
497                        flow_owner: User::default(),
498                        started_by: User::default(),
499                        endpoints: Endpoints::default(),
500                        solana: config.solana_client,
501                        http: config.http_client,
502                    },
503                },
504            },
505            execute: unimplemented_svc(),
506            get_jwt: unimplemented_svc(),
507            node_log: NodeLogSender::new(tx, node_id, times),
508            flow: FlowServices {
509                signer: unimplemented_svc(),
510                set: FlowSetServices {
511                    http: reqwest::Client::new(),
512                    solana_client,
513                    extensions: <_>::default(),
514                    api_input: unimplemented_svc(),
515                },
516            },
517        }
518    }
519
520    pub fn log(&self, log: NodeLogContent) -> Result<(), mpsc::TrySendError<Event>> {
521        self.node_log.send(log)
522    }
523
524    pub fn flow_inputs(&self) -> &value::Map {
525        &self.data.flow.inputs
526    }
527
528    pub fn new_interflow_origin(&self) -> FlowRunOrigin {
529        FlowRunOrigin::Interflow {
530            flow_run_id: *self.flow_run_id(),
531            node_id: *self.node_id(),
532            times: *self.times(),
533        }
534    }
535
536    pub fn flow_run_id(&self) -> &FlowRunId {
537        &self.data.flow.flow_run_id
538    }
539
540    pub fn node_id(&self) -> &NodeId {
541        &self.data.node_id
542    }
543
544    pub fn times(&self) -> &u32 {
545        &self.data.times
546    }
547
548    pub fn environment(&self) -> &HashMap<String, String> {
549        &self.data.flow.environment
550    }
551
552    pub fn endpoints(&self) -> &Endpoints {
553        &self.data.flow.set.endpoints
554    }
555
556    pub fn flow_owner(&self) -> &User {
557        &self.data.flow.set.flow_owner
558    }
559
560    pub fn started_by(&self) -> &User {
561        &self.data.flow.set.started_by
562    }
563
564    pub fn solana_config(&self) -> &SolanaClientConfig {
565        &self.data.flow.set.solana
566    }
567
568    pub fn solana_client(&self) -> &Arc<SolanaClient> {
569        &self.flow.set.solana_client
570    }
571
572    pub fn http(&self) -> &reqwest::Client {
573        &self.flow.set.http
574    }
575
576    pub async fn api_input(
577        &mut self,
578        timeout: Option<Duration>,
579    ) -> Result<api_input::Response, api_input::Error> {
580        let req = api_input::Request {
581            flow_run_id: *self.flow_run_id(),
582            node_id: *self.node_id(),
583            times: *self.times(),
584            timeout: timeout.unwrap_or(Duration::MAX),
585        };
586        self.flow.set.api_input.ready().await?.call(req).await
587    }
588
589    /// Call [`get_jwt`] service, the result will have `Bearer ` prefix.
590    pub async fn get_jwt_header(&mut self) -> Result<String, get_jwt::Error> {
591        let user_id = self.flow_owner().id;
592        let access_token = self
593            .get_jwt
594            .ready()
595            .await?
596            .call(get_jwt::Request { user_id })
597            .await?
598            .access_token;
599        Ok(["Bearer ", &access_token].concat())
600    }
601
602    /// Call [`execute`] service.
603    pub async fn execute(
604        &mut self,
605        instructions: Instructions,
606        output: value::Map,
607    ) -> Result<execute::Response, execute::Error> {
608        self.execute
609            .ready()
610            .await?
611            .call(execute::Request {
612                instructions,
613                output,
614            })
615            .await
616    }
617
618    /// Call [`signer`] service.
619    pub async fn request_signature(
620        &mut self,
621        pubkey: Pubkey,
622        message: Bytes,
623        timeout: Duration,
624    ) -> Result<signer::SignatureResponse, signer::Error> {
625        self.flow
626            .signer
627            .ready()
628            .await?
629            .call(signer::SignatureRequest {
630                id: None,
631                time: Utc::now(),
632                pubkey,
633                message,
634                timeout,
635                flow_run_id: Some(self.data.flow.flow_run_id),
636                signatures: None,
637            })
638            .await
639    }
640
641    /// Get an extension by type.
642    pub fn get<T: Any + Send + Sync + 'static>(&self) -> Option<&T> {
643        self.flow.set.extensions.get::<T>()
644    }
645
646    pub fn extensions_mut(&mut self) -> Option<&mut Extensions> {
647        Arc::get_mut(&mut self.flow.set.extensions)
648    }
649
650    pub fn raw(&self) -> RawContext<'_> {
651        RawContext {
652            data: &self.data,
653            services: RawServices {
654                signer: &self.flow.signer,
655                execute: &self.execute,
656                get_jwt: &self.get_jwt,
657            },
658        }
659    }
660}
661
662pub struct RawServices<'a> {
663    pub signer: &'a signer::Svc,
664    pub execute: &'a execute::Svc,
665    pub get_jwt: &'a get_jwt::Svc,
666}
667
668pub struct RawContext<'a> {
669    pub data: &'a CommandContextData,
670    pub services: RawServices<'a>,
671}
672
673impl Default for CommandContext {
674    fn default() -> Self {
675        Self::test_context()
676    }
677}
678
679#[derive(Clone, Copy, Serialize, Deserialize, Debug, JsonSchema)]
680pub struct User {
681    pub id: UserId,
682}
683
684impl User {
685    pub fn new(id: UserId) -> Self {
686        Self { id }
687    }
688}
689
690impl Default for User {
691    /// For testing
692    fn default() -> Self {
693        User {
694            id: uuid::Uuid::nil(),
695        }
696    }
697}