kora_lib/rpc_server/
server.rs

1use crate::{
2    constant::{X_API_KEY, X_HMAC_SIGNATURE, X_TIMESTAMP},
3    metrics::run_metrics_server_if_required,
4    rpc_server::{
5        auth::{ApiKeyAuthLayer, HmacAuthLayer},
6        middleware_utils::MethodValidationLayer,
7        rpc::KoraRpc,
8    },
9    usage_limit::UsageTracker,
10};
11
12#[cfg(not(test))]
13use crate::state::get_config;
14
15#[cfg(test)]
16use crate::tests::config_mock::mock_state::get_config;
17use http::{header, Method};
18use jsonrpsee::{
19    server::{middleware::proxy_get_request::ProxyGetRequestLayer, ServerBuilder, ServerHandle},
20    RpcModule,
21};
22use std::{net::SocketAddr, time::Duration};
23use tokio::task::JoinHandle;
24use tower::limit::RateLimitLayer;
25use tower_http::cors::CorsLayer;
26
27pub struct ServerHandles {
28    pub rpc_handle: ServerHandle,
29    pub metrics_handle: Option<ServerHandle>,
30    pub balance_tracker_handle: Option<JoinHandle<()>>,
31}
32
33// We'll always prioritize the environment variable over the config value
34fn get_value_by_priority(env_var: &str, config_value: Option<String>) -> Option<String> {
35    std::env::var(env_var).ok().or(config_value)
36}
37
38pub async fn run_rpc_server(rpc: KoraRpc, port: u16) -> Result<ServerHandles, anyhow::Error> {
39    let addr = SocketAddr::from(([0, 0, 0, 0], port));
40    log::info!("RPC server started on {addr}, port {port}");
41
42    // Initialize usage limiter
43    if let Err(e) = UsageTracker::init_usage_limiter().await {
44        log::error!("Failed to initialize usage limiter: {e}");
45        return Err(anyhow::anyhow!("Usage limiter initialization failed: {e}"));
46    }
47
48    // Build middleware stack with tracing and CORS
49    let cors = CorsLayer::new()
50        .allow_origin(tower_http::cors::Any)
51        .allow_methods([Method::POST, Method::GET])
52        .allow_headers([
53            header::CONTENT_TYPE,
54            header::HeaderName::from_static(X_API_KEY),
55            header::HeaderName::from_static(X_HMAC_SIGNATURE),
56            header::HeaderName::from_static(X_TIMESTAMP),
57        ])
58        .max_age(Duration::from_secs(3600));
59
60    let config = get_config()?;
61
62    // Get the RPC client from KoraRpc to pass to metrics initialization
63    let rpc_client = rpc.get_rpc_client().clone();
64
65    let (metrics_handle, metrics_layers, balance_tracker_handle) =
66        run_metrics_server_if_required(port, rpc_client).await?;
67
68    // Build whitelist of allowed methods from enabled_methods config
69    let allowed_methods = config.kora.enabled_methods.get_enabled_method_names();
70
71    let middleware = tower::ServiceBuilder::new()
72        // Add metrics handler first (before other layers) so it can intercept /metrics
73        .layer(ProxyGetRequestLayer::new("/liveness", "liveness")?)
74        .layer(RateLimitLayer::new(config.kora.rate_limit, Duration::from_secs(1)))
75        // Add metrics handler layer for Prometheus metrics
76        .option_layer(
77            metrics_layers.as_ref().and_then(|layers| layers.metrics_handler_layer.clone()),
78        )
79        .layer(cors)
80        // Method validation layer -  to fail fast
81        .layer(MethodValidationLayer::new(allowed_methods.clone()))
82        // Add metrics collection layer
83        .option_layer(metrics_layers.as_ref().and_then(|layers| layers.http_metrics_layer.clone()))
84        // Add authentication layer for API key if configured
85        .option_layer(
86            (get_value_by_priority("KORA_API_KEY", config.kora.auth.api_key.clone()))
87                .map(ApiKeyAuthLayer::new),
88        )
89        // Add authentication layer for HMAC if configured
90        .option_layer(
91            (get_value_by_priority("KORA_HMAC_SECRET", config.kora.auth.hmac_secret.clone()))
92                .map(|secret| HmacAuthLayer::new(secret, config.kora.auth.max_timestamp_age)),
93        );
94
95    // Configure and build the server with HTTP support
96    let server = ServerBuilder::default()
97        .max_request_body_size(config.kora.max_request_body_size as u32)
98        .set_middleware(middleware)
99        .http_only() // Explicitly enable HTTP
100        .build(addr)
101        .await?;
102
103    let rpc_module = build_rpc_module(rpc)?;
104
105    // Start the RPC server
106    let rpc_handle = server
107        .start(rpc_module)
108        .map_err(|e| anyhow::anyhow!("Failed to start RPC server: {}", e))?;
109
110    Ok(ServerHandles { rpc_handle, metrics_handle, balance_tracker_handle })
111}
112
113macro_rules! register_method_if_enabled {
114    // For methods without parameters
115    ($module:expr, $enabled_methods:expr, $field:ident, $method_name:expr, $rpc_method:ident) => {
116        if $enabled_methods.$field {
117            let _ = $module.register_async_method(
118                $method_name,
119                |_rpc_params, rpc_context| async move {
120                    let rpc = rpc_context.as_ref();
121                    rpc.$rpc_method().await.map_err(Into::into)
122                },
123            );
124        }
125    };
126
127    // For methods with parameters
128    ($module:expr, $enabled_methods:expr, $field:ident, $method_name:expr, $rpc_method:ident, with_params) => {
129        if $enabled_methods.$field {
130            #[allow(deprecated)]
131            let _ =
132                $module.register_async_method($method_name, |rpc_params, rpc_context| async move {
133                    let rpc = rpc_context.as_ref();
134                    let params = rpc_params.parse()?;
135                    #[allow(deprecated)]
136                    rpc.$rpc_method(params).await.map_err(Into::into)
137                });
138        }
139    };
140}
141
142fn build_rpc_module(rpc: KoraRpc) -> Result<RpcModule<KoraRpc>, anyhow::Error> {
143    let mut module = RpcModule::new(rpc.clone());
144    let enabled_methods = &get_config()?.kora.enabled_methods;
145
146    register_method_if_enabled!(module, enabled_methods, liveness, "liveness", liveness);
147
148    register_method_if_enabled!(
149        module,
150        enabled_methods,
151        estimate_transaction_fee,
152        "estimateTransactionFee",
153        estimate_transaction_fee,
154        with_params
155    );
156    register_method_if_enabled!(
157        module,
158        enabled_methods,
159        estimate_bundle_fee,
160        "estimateBundleFee",
161        estimate_bundle_fee,
162        with_params
163    );
164    register_method_if_enabled!(
165        module,
166        enabled_methods,
167        get_supported_tokens,
168        "getSupportedTokens",
169        get_supported_tokens
170    );
171    register_method_if_enabled!(
172        module,
173        enabled_methods,
174        get_payer_signer,
175        "getPayerSigner",
176        get_payer_signer
177    );
178    register_method_if_enabled!(
179        module,
180        enabled_methods,
181        sign_transaction,
182        "signTransaction",
183        sign_transaction,
184        with_params
185    );
186    register_method_if_enabled!(
187        module,
188        enabled_methods,
189        sign_and_send_transaction,
190        "signAndSendTransaction",
191        sign_and_send_transaction,
192        with_params
193    );
194    register_method_if_enabled!(
195        module,
196        enabled_methods,
197        transfer_transaction,
198        "transferTransaction",
199        transfer_transaction,
200        with_params
201    );
202    register_method_if_enabled!(
203        module,
204        enabled_methods,
205        get_blockhash,
206        "getBlockhash",
207        get_blockhash
208    );
209    register_method_if_enabled!(module, enabled_methods, get_config, "getConfig", get_config);
210    register_method_if_enabled!(module, enabled_methods, get_version, "getVersion", get_version);
211    register_method_if_enabled!(
212        module,
213        enabled_methods,
214        sign_bundle,
215        "signBundle",
216        sign_bundle,
217        with_params
218    );
219    register_method_if_enabled!(
220        module,
221        enabled_methods,
222        sign_and_send_bundle,
223        "signAndSendBundle",
224        sign_and_send_bundle,
225        with_params
226    );
227
228    Ok(module)
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::{
235        config::EnabledMethods,
236        tests::{
237            common::setup_or_get_test_signer,
238            config_mock::{ConfigMockBuilder, KoraConfigBuilder},
239            rpc_mock::RpcMockBuilder,
240        },
241    };
242    use std::env;
243
244    #[test]
245    fn test_get_value_by_priority_env_var_takes_precedence() {
246        let env_var_name = "TEST_ENV_VAR_PRECEDENCE_UNIQUE";
247        env::set_var(env_var_name, "env_value");
248
249        let result = get_value_by_priority(env_var_name, Some("config_value".to_string()));
250        assert_eq!(result, Some("env_value".to_string()));
251
252        env::remove_var(env_var_name);
253    }
254
255    #[test]
256    fn test_get_value_by_priority_config_fallback() {
257        let env_var_name = "TEST_ENV_VAR_FALLBACK_UNIQUE_XYZ123";
258
259        let result = get_value_by_priority(env_var_name, Some("config_value".to_string()));
260        assert_eq!(result, Some("config_value".to_string()));
261    }
262
263    #[test]
264    fn test_get_value_by_priority_none_when_both_missing() {
265        let env_var_name = "TEST_ENV_VAR_MISSING_UNIQUE_ABC789";
266
267        let result = get_value_by_priority(env_var_name, None);
268        assert_eq!(result, None);
269    }
270
271    #[test]
272    fn test_build_rpc_module_all_methods_enabled() {
273        // Default is all methods enabled
274        let enabled_methods = EnabledMethods::default();
275
276        let kora_config = KoraConfigBuilder::new().with_enabled_methods(enabled_methods).build();
277        let _m = ConfigMockBuilder::new().with_kora(kora_config).build_and_setup();
278        let _ = setup_or_get_test_signer();
279
280        let rpc_client = RpcMockBuilder::new().build();
281        let kora_rpc = KoraRpc::new(rpc_client);
282
283        let result = build_rpc_module(kora_rpc);
284        assert!(result.is_ok(), "Failed to build RPC module with all methods enabled");
285
286        // Verify that the module has the expected methods
287        let module = result.unwrap();
288        let method_names: Vec<&str> = module.method_names().collect();
289        assert_eq!(method_names.len(), 10);
290        assert!(method_names.contains(&"liveness"));
291        assert!(method_names.contains(&"estimateTransactionFee"));
292        assert!(method_names.contains(&"getSupportedTokens"));
293        assert!(method_names.contains(&"getPayerSigner"));
294        assert!(method_names.contains(&"signTransaction"));
295        assert!(method_names.contains(&"signAndSendTransaction"));
296        assert!(method_names.contains(&"transferTransaction"));
297        assert!(method_names.contains(&"getBlockhash"));
298        assert!(method_names.contains(&"getConfig"));
299        assert!(method_names.contains(&"getVersion"));
300        // Note: signBundle is NOT included by default (opt-in via enabled_methods.sign_bundle)
301    }
302
303    #[test]
304    fn test_build_rpc_module_all_methods_disabled() {
305        // Setup config with all methods disabled
306        let enabled_methods = EnabledMethods {
307            estimate_transaction_fee: false,
308            get_supported_tokens: false,
309            get_payer_signer: false,
310            sign_transaction: false,
311            sign_and_send_transaction: false,
312            transfer_transaction: false,
313            get_blockhash: false,
314            get_config: false,
315            get_version: false,
316            liveness: false,
317            estimate_bundle_fee: false,
318            sign_and_send_bundle: false,
319            sign_bundle: false,
320        };
321
322        let kora_config = KoraConfigBuilder::new().with_enabled_methods(enabled_methods).build();
323        let _m = ConfigMockBuilder::new().with_kora(kora_config).build_and_setup();
324        let _ = setup_or_get_test_signer();
325
326        // Create RPC module
327        let rpc_client = RpcMockBuilder::new().build();
328        let kora_rpc = KoraRpc::new(rpc_client);
329
330        // Build the module - should succeed even with no methods
331        let result = build_rpc_module(kora_rpc);
332        assert!(result.is_ok(), "Failed to build RPC module with all methods disabled");
333
334        assert_eq!(result.unwrap().method_names().count(), 0);
335    }
336
337    #[test]
338    fn test_build_rpc_module_selective_methods() {
339        // Setup config with only some methods enabled
340        let enabled_methods = EnabledMethods {
341            liveness: true,
342            get_config: true,
343            get_supported_tokens: true,
344            estimate_transaction_fee: false,
345            get_payer_signer: false,
346            sign_transaction: false,
347            sign_and_send_transaction: false,
348            transfer_transaction: false,
349            get_blockhash: false,
350            get_version: false,
351            estimate_bundle_fee: false,
352            sign_and_send_bundle: false,
353            sign_bundle: false,
354        };
355
356        let kora_config = KoraConfigBuilder::new().with_enabled_methods(enabled_methods).build();
357        let _m = ConfigMockBuilder::new().with_kora(kora_config).build_and_setup();
358        let _ = setup_or_get_test_signer();
359
360        // Create RPC module
361        let rpc_client = RpcMockBuilder::new().build();
362        let kora_rpc = KoraRpc::new(rpc_client);
363
364        // Build the module
365        let result = build_rpc_module(kora_rpc);
366        assert!(result.is_ok(), "Failed to build RPC module with selective methods");
367
368        // Verify that only the expected methods are registered
369        let module = result.unwrap();
370        let method_names: Vec<&str> = module.method_names().collect();
371        assert_eq!(method_names.len(), 3);
372        assert!(method_names.contains(&"liveness"));
373        assert!(method_names.contains(&"getConfig"));
374        assert!(method_names.contains(&"getSupportedTokens"));
375    }
376}