use std::collections::HashMap;
use std::time::Duration;
use super::proto::hyper_service::query_param::TransferMode;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
#[derive(Debug, Clone)]
#[must_use = "GrpcConfig uses a consuming builder pattern - each method takes ownership and returns a new instance. You must use the returned value or your configuration changes will be lost"]
pub struct GrpcConfig {
pub(crate) endpoint: String,
pub(crate) database: Option<String>,
pub(crate) connect_timeout: Duration,
pub(crate) request_timeout: Duration,
pub(crate) transfer_mode: TransferMode,
pub(crate) use_tls: bool,
pub(crate) headers: HashMap<String, String>,
pub(crate) settings: HashMap<String, String>,
pub(crate) max_decoding_message_size: usize,
pub(crate) max_encoding_message_size: usize,
}
impl GrpcConfig {
pub fn new(endpoint: impl Into<String>) -> Self {
let endpoint = endpoint.into();
let use_tls = endpoint.starts_with("https://");
GrpcConfig {
endpoint,
database: None,
connect_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(100), transfer_mode: TransferMode::Adaptive, use_tls,
headers: HashMap::new(),
settings: HashMap::new(),
max_decoding_message_size: DEFAULT_MAX_MESSAGE_SIZE,
max_encoding_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn request_timeout(mut self, timeout: Duration) -> Self {
self.request_timeout = timeout;
self
}
pub fn transfer_mode(mut self, mode: TransferMode) -> Self {
self.transfer_mode = mode;
self
}
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_decoding_message_size = size;
self.max_encoding_message_size = size;
self
}
pub fn max_decoding_message_size(mut self, size: usize) -> Self {
self.max_decoding_message_size = size;
self
}
pub fn max_encoding_message_size(mut self, size: usize) -> Self {
self.max_encoding_message_size = size;
self
}
pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.headers.insert(key.into(), value.into());
self
}
pub fn headers(mut self, headers: impl IntoIterator<Item = (String, String)>) -> Self {
self.headers.extend(headers);
self
}
pub fn setting(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.settings.insert(key.into(), value.into());
self
}
#[must_use]
pub fn endpoint(&self) -> &str {
&self.endpoint
}
#[must_use]
pub fn database_path(&self) -> Option<&str> {
self.database.as_deref()
}
#[must_use]
pub fn is_tls(&self) -> bool {
self.use_tls
}
#[must_use]
pub fn get_max_decoding_message_size(&self) -> usize {
self.max_decoding_message_size
}
#[must_use]
pub fn get_max_encoding_message_size(&self) -> usize {
self.max_encoding_message_size
}
#[cfg(feature = "salesforce-auth")]
pub fn with_data_cloud_token(self, token: &hyperdb_api_salesforce::DataCloudToken) -> Self {
self.header("Authorization", token.bearer_token())
.header("audience", token.tenant_url_str())
}
pub fn with_bearer_auth(
self,
bearer_token: impl Into<String>,
audience: impl Into<String>,
) -> Self {
self.header("Authorization", bearer_token)
.header("audience", audience)
}
}
impl Default for GrpcConfig {
fn default() -> Self {
GrpcConfig::new("http://localhost:7484")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = GrpcConfig::new("http://localhost:7484")
.database("test.hyper")
.connect_timeout(Duration::from_secs(10))
.request_timeout(Duration::from_secs(30))
.header("x-custom", "value")
.setting("log_level", "debug");
assert_eq!(config.endpoint, "http://localhost:7484");
assert_eq!(config.database, Some("test.hyper".to_string()));
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.request_timeout, Duration::from_secs(30));
assert!(!config.use_tls);
assert_eq!(config.headers.get("x-custom"), Some(&"value".to_string()));
assert_eq!(config.settings.get("log_level"), Some(&"debug".to_string()));
}
#[expect(
clippy::similar_names,
reason = "paired bindings (request/response, reader/writer, etc.) are more readable with symmetric names than artificially distinct ones"
)]
#[test]
fn test_tls_detection() {
let http_config = GrpcConfig::new("http://localhost:7484");
assert!(!http_config.use_tls);
let https_config = GrpcConfig::new("https://hyper.example.com:443");
assert!(https_config.use_tls);
}
#[test]
fn test_default_values() {
let config = GrpcConfig::default();
assert_eq!(config.endpoint, "http://localhost:7484");
assert_eq!(config.connect_timeout, Duration::from_secs(30));
assert_eq!(config.request_timeout, Duration::from_secs(100));
assert!(matches!(config.transfer_mode, TransferMode::Adaptive));
assert_eq!(config.max_decoding_message_size, DEFAULT_MAX_MESSAGE_SIZE);
assert_eq!(config.max_encoding_message_size, DEFAULT_MAX_MESSAGE_SIZE);
}
#[test]
fn test_message_size_configuration() {
let config = GrpcConfig::new("http://localhost:7484").max_message_size(128 * 1024 * 1024);
assert_eq!(config.max_decoding_message_size, 128 * 1024 * 1024);
assert_eq!(config.max_encoding_message_size, 128 * 1024 * 1024);
let config = GrpcConfig::new("http://localhost:7484")
.max_decoding_message_size(256 * 1024 * 1024)
.max_encoding_message_size(32 * 1024 * 1024);
assert_eq!(config.max_decoding_message_size, 256 * 1024 * 1024);
assert_eq!(config.max_encoding_message_size, 32 * 1024 * 1024);
assert_eq!(config.get_max_decoding_message_size(), 256 * 1024 * 1024);
assert_eq!(config.get_max_encoding_message_size(), 32 * 1024 * 1024);
}
#[test]
fn test_sync_mode_with_large_message_size() {
let config = GrpcConfig::new("http://localhost:7484")
.transfer_mode(TransferMode::Sync)
.max_message_size(256 * 1024 * 1024);
assert!(matches!(config.transfer_mode, TransferMode::Sync));
assert_eq!(config.max_decoding_message_size, 256 * 1024 * 1024);
}
}