use crate::error::{Result, SdkError};
pub use crate::types::{BackoffStrategy, PriorityFeeConfig, Protocol, WorkerEndpoint};
use std::time::Duration;
const DEFAULT_CONNECTION_TIMEOUT_MS: u64 = 10_000;
const DEFAULT_MAX_RETRIES: u32 = 3;
const QUIC_TIMEOUT_MS: u64 = 2_000;
const GRPC_TIMEOUT_MS: u64 = 3_000;
const WEBSOCKET_TIMEOUT_MS: u64 = 3_000;
const HTTP_TIMEOUT_MS: u64 = 5_000;
#[derive(Debug, Clone)]
pub struct Config {
pub api_key: String,
pub region: Option<String>,
pub endpoint: Option<String>,
pub discovery_url: String,
pub connection_timeout: Duration,
pub max_retries: u32,
pub leader_hints: bool,
pub stream_tip_instructions: bool,
pub stream_priority_fees: bool,
pub stream_latest_blockhash: bool,
pub stream_latest_slot: bool,
pub protocol_timeouts: ProtocolTimeouts,
pub preferred_protocol: Option<Protocol>,
pub selected_worker: Option<WorkerEndpoint>,
pub priority_fee: PriorityFeeConfig,
pub retry_backoff: BackoffStrategy,
pub min_confidence: u32,
pub idle_timeout: Option<Duration>,
pub tier: String,
pub keepalive: bool,
pub keepalive_interval: Duration,
pub webhook_url: Option<String>,
pub webhook_events: Vec<String>,
pub webhook_notification_level: String,
}
#[derive(Debug, Clone)]
pub struct ProtocolTimeouts {
pub quic: Duration,
pub grpc: Duration,
pub websocket: Duration,
pub http: Duration,
}
impl Default for ProtocolTimeouts {
fn default() -> Self {
Self {
quic: Duration::from_millis(QUIC_TIMEOUT_MS),
grpc: Duration::from_millis(GRPC_TIMEOUT_MS),
websocket: Duration::from_millis(WEBSOCKET_TIMEOUT_MS),
http: Duration::from_millis(HTTP_TIMEOUT_MS),
}
}
}
impl Config {
pub fn builder() -> ConfigBuilder {
ConfigBuilder::default()
}
pub fn validate(&self) -> Result<()> {
if self.api_key.is_empty() {
return Err(SdkError::config("api_key is required"));
}
if !self.api_key.starts_with("sk_") {
return Err(SdkError::config("api_key must start with 'sk_'"));
}
Ok(())
}
pub fn get_endpoint(&self, protocol: Protocol) -> String {
if let Some(ref worker) = self.selected_worker {
if let Some(endpoint) = worker.get_endpoint(protocol) {
return endpoint.to_string();
}
}
if let Some(ref endpoint) = self.endpoint {
return endpoint.clone();
}
match protocol {
Protocol::Quic => "quic://localhost:4433".to_string(),
Protocol::Grpc => "http://localhost:10000".to_string(),
Protocol::WebSocket => "ws://localhost:9000/ws".to_string(),
Protocol::Http => "http://localhost:9091".to_string(),
}
}
}
#[derive(Debug, Default)]
pub struct ConfigBuilder {
api_key: Option<String>,
region: Option<String>,
endpoint: Option<String>,
discovery_url: Option<String>,
connection_timeout: Option<Duration>,
max_retries: Option<u32>,
leader_hints: Option<bool>,
stream_tip_instructions: Option<bool>,
stream_priority_fees: Option<bool>,
stream_latest_blockhash: Option<bool>,
stream_latest_slot: Option<bool>,
protocol_timeouts: Option<ProtocolTimeouts>,
preferred_protocol: Option<Protocol>,
selected_worker: Option<WorkerEndpoint>,
priority_fee: Option<PriorityFeeConfig>,
retry_backoff: Option<BackoffStrategy>,
min_confidence: Option<u32>,
idle_timeout: Option<Duration>,
tier: Option<String>,
keepalive: Option<bool>,
keepalive_interval: Option<Duration>,
webhook_url: Option<String>,
webhook_events: Option<Vec<String>>,
webhook_notification_level: Option<String>,
}
impl ConfigBuilder {
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
pub fn discovery_url(mut self, url: impl Into<String>) -> Self {
self.discovery_url = Some(url.into());
self
}
pub fn connection_timeout(mut self, timeout: Duration) -> Self {
self.connection_timeout = Some(timeout);
self
}
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = Some(retries);
self
}
pub fn leader_hints(mut self, enabled: bool) -> Self {
self.leader_hints = Some(enabled);
self
}
pub fn stream_tip_instructions(mut self, enabled: bool) -> Self {
self.stream_tip_instructions = Some(enabled);
self
}
pub fn stream_priority_fees(mut self, enabled: bool) -> Self {
self.stream_priority_fees = Some(enabled);
self
}
pub fn stream_latest_blockhash(mut self, enabled: bool) -> Self {
self.stream_latest_blockhash = Some(enabled);
self
}
pub fn stream_latest_slot(mut self, enabled: bool) -> Self {
self.stream_latest_slot = Some(enabled);
self
}
pub fn protocol_timeouts(mut self, timeouts: ProtocolTimeouts) -> Self {
self.protocol_timeouts = Some(timeouts);
self
}
pub fn preferred_protocol(mut self, protocol: Protocol) -> Self {
self.preferred_protocol = Some(protocol);
self
}
pub fn priority_fee(mut self, config: PriorityFeeConfig) -> Self {
self.priority_fee = Some(config);
self
}
pub fn retry_backoff(mut self, strategy: BackoffStrategy) -> Self {
self.retry_backoff = Some(strategy);
self
}
pub fn min_confidence(mut self, confidence: u32) -> Self {
self.min_confidence = Some(confidence);
self
}
pub fn idle_timeout(mut self, timeout: Duration) -> Self {
self.idle_timeout = Some(timeout);
self
}
pub fn tier(mut self, tier: impl Into<String>) -> Self {
self.tier = Some(tier.into());
self
}
pub fn keepalive(mut self, enabled: bool) -> Self {
self.keepalive = Some(enabled);
self
}
pub fn keepalive_interval(mut self, secs: u64) -> Self {
self.keepalive_interval = Some(Duration::from_secs(secs));
self
}
pub fn webhook_url(mut self, url: impl Into<String>) -> Self {
self.webhook_url = Some(url.into());
self
}
pub fn webhook_events(mut self, events: Vec<String>) -> Self {
self.webhook_events = Some(events);
self
}
pub fn webhook_notification_level(mut self, level: impl Into<String>) -> Self {
self.webhook_notification_level = Some(level.into());
self
}
pub fn build(self) -> Result<Config> {
let config = Config {
api_key: self.api_key.ok_or_else(|| SdkError::config("api_key is required"))?,
region: self.region,
endpoint: self.endpoint,
discovery_url: self.discovery_url.unwrap_or_else(|| {
crate::discovery::DEFAULT_DISCOVERY_URL.to_string()
}),
connection_timeout: self
.connection_timeout
.unwrap_or_else(|| Duration::from_millis(DEFAULT_CONNECTION_TIMEOUT_MS)),
max_retries: self.max_retries.unwrap_or(DEFAULT_MAX_RETRIES),
leader_hints: self.leader_hints.unwrap_or(true),
stream_tip_instructions: self.stream_tip_instructions.unwrap_or(false),
stream_priority_fees: self.stream_priority_fees.unwrap_or(false),
stream_latest_blockhash: self.stream_latest_blockhash.unwrap_or(false),
stream_latest_slot: self.stream_latest_slot.unwrap_or(false),
protocol_timeouts: self.protocol_timeouts.unwrap_or_default(),
preferred_protocol: self.preferred_protocol,
selected_worker: self.selected_worker,
priority_fee: self.priority_fee.unwrap_or_default(),
retry_backoff: self.retry_backoff.unwrap_or_default(),
min_confidence: self.min_confidence.unwrap_or(70),
idle_timeout: self.idle_timeout,
tier: self.tier.unwrap_or_else(|| "pro".to_string()),
keepalive: self.keepalive.unwrap_or(true),
keepalive_interval: self.keepalive_interval.unwrap_or(Duration::from_secs(5)),
webhook_url: self.webhook_url,
webhook_events: self.webhook_events.unwrap_or_else(|| {
vec!["transaction.confirmed".to_string()]
}),
webhook_notification_level: self
.webhook_notification_level
.unwrap_or_else(|| "final".to_string()),
};
config.validate()?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = Config::builder()
.api_key("sk_test_12345678")
.region("us-west")
.build()
.unwrap();
assert_eq!(config.api_key, "sk_test_12345678");
assert_eq!(config.region, Some("us-west".to_string()));
assert!(config.leader_hints);
assert!(!config.stream_tip_instructions);
}
#[test]
fn test_config_builder_missing_api_key() {
let result = Config::builder().build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("api_key"));
}
#[test]
fn test_config_builder_invalid_api_key() {
let result = Config::builder().api_key("invalid_key").build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("sk_"));
}
#[test]
fn test_protocol_fallback_order() {
let order = Protocol::fallback_order();
assert_eq!(order[0], Protocol::Quic);
assert_eq!(order[1], Protocol::Grpc);
assert_eq!(order[2], Protocol::WebSocket);
assert_eq!(order[3], Protocol::Http);
}
#[test]
fn test_config_get_endpoint() {
let config = Config::builder()
.api_key("sk_test_12345678")
.build()
.unwrap();
assert!(config.get_endpoint(Protocol::Quic).contains("4433"));
assert!(config.get_endpoint(Protocol::Grpc).contains("10000"));
assert!(config.get_endpoint(Protocol::WebSocket).contains("ws://"));
assert!(config.get_endpoint(Protocol::Http).contains("http://"));
}
#[test]
fn test_config_custom_endpoint() {
let config = Config::builder()
.api_key("sk_test_12345678")
.endpoint("https://custom.example.com")
.build()
.unwrap();
assert_eq!(
config.get_endpoint(Protocol::Http),
"https://custom.example.com"
);
}
}