use std::net::IpAddr;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;
use crate::connection::info::{ConnectionInfo, IntoConnectionInfo};
use crate::connection::runtime;
use crate::connection::{DisconnectNotifier, MultiplexedConnection, ValkeyRuntime, connect_simple};
use crate::pubsub::push_manager::PushInfo;
use crate::pubsub::synchronizer_trait::PubSubSynchronizer;
use crate::retry_strategies::RetryStrategy;
use crate::value::{ProtocolVersion, Result};
#[derive(Debug, Clone)]
pub struct Client {
pub(crate) connection_info: ConnectionInfo,
}
impl Client {
pub fn open<T: IntoConnectionInfo>(params: T) -> Result<Client> {
Ok(Client {
connection_info: params.into_connection_info()?,
})
}
pub fn get_connection_info(&self) -> &ConnectionInfo {
&self.connection_info
}
}
#[derive(Clone, Default)]
pub struct FerrisKeyConnectionOptions {
pub push_sender: Option<mpsc::UnboundedSender<PushInfo>>,
pub disconnect_notifier: Option<Box<dyn DisconnectNotifier>>,
pub discover_az: bool,
pub connection_timeout: Option<Duration>,
pub connection_retry_strategy: Option<RetryStrategy>,
pub tcp_nodelay: bool,
pub pubsub_synchronizer: Option<Arc<dyn PubSubSynchronizer>>,
pub iam_token_provider: Option<Arc<dyn IAMTokenProvider>>,
}
#[async_trait::async_trait]
pub trait IAMTokenProvider: Send + Sync {
async fn get_valid_token(&self) -> Option<String>;
}
pub(crate) const NO_TIMEOUT: std::time::Duration = std::time::Duration::MAX;
impl Client {
pub async fn get_multiplexed_async_connection(
&self,
ferriskey_connection_options: FerrisKeyConnectionOptions,
) -> Result<MultiplexedConnection> {
self.get_multiplexed_async_connection_with_timeouts(
NO_TIMEOUT,
NO_TIMEOUT,
ferriskey_connection_options,
)
.await
}
pub(crate) async fn get_multiplexed_async_connection_with_timeouts(
&self,
response_timeout: std::time::Duration,
connection_timeout: std::time::Duration,
ferriskey_connection_options: FerrisKeyConnectionOptions,
) -> Result<MultiplexedConnection> {
let result = runtime::timeout(
connection_timeout,
self.get_multiplexed_async_connection_inner::<crate::connection::tokio::Tokio>(
response_timeout,
None,
ferriskey_connection_options,
),
)
.await;
match result {
Ok(Ok(connection)) => Ok(connection),
Ok(Err(e)) => Err(e),
Err(elapsed) => Err(elapsed.into()),
}
.map(|(conn, _ip)| conn)
}
pub(crate) async fn get_multiplexed_async_connection_inner<T>(
&self,
response_timeout: std::time::Duration,
socket_addr: Option<SocketAddr>,
ferriskey_connection_options: FerrisKeyConnectionOptions,
) -> Result<(MultiplexedConnection, Option<IpAddr>)>
where
T: ValkeyRuntime,
{
let conn_info = self.connection_info.clone();
let (con, ip) = connect_simple::<T>(
&conn_info,
socket_addr,
ferriskey_connection_options.tcp_nodelay,
)
.await?;
let (connection, driver) = MultiplexedConnection::new_with_response_timeout(
conn_info,
con.boxed(),
response_timeout,
ferriskey_connection_options,
)
.await?;
T::spawn(driver);
Ok((connection, ip))
}
pub fn update_password(&mut self, password: Option<String>) {
self.connection_info.valkey.password = password;
}
pub fn update_database(&mut self, database_id: i64) {
self.connection_info.valkey.db = database_id;
}
pub fn update_client_name(&mut self, client_name: Option<String>) {
self.connection_info.valkey.client_name = client_name;
}
pub fn update_username(&mut self, username: Option<String>) {
self.connection_info.valkey.username = username;
}
pub fn update_protocol(&mut self, protocol: ProtocolVersion) {
self.connection_info.valkey.protocol = protocol;
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn regression_293_parse_ipv6_with_interface() {
assert!(Client::open(("fe80::cafe:beef%eno1", 6379)).is_ok());
}
#[test]
fn test_update_database() {
let mut client = Client::open("redis://127.0.0.1/0").unwrap();
assert_eq!(client.connection_info.valkey.db, 0);
client.update_database(1);
assert_eq!(client.connection_info.valkey.db, 1);
}
}