use crate::{
client::{FalkorClientProvider, ProvidesSyncConnections},
connection::{
asynchronous::{BorrowedAsyncConnection, FalkorAsyncConnection},
blocking::FalkorSyncConnection,
},
parser::{parse_config_hashmap, redis_value_as_untyped_string_vec},
AsyncGraph, ConfigValue, FalkorConnectionInfo, FalkorDBError, FalkorResult,
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
runtime::{Handle, RuntimeFlavor},
sync::{mpsc, Mutex},
task,
};
pub struct FalkorAsyncClientInner {
_inner: Mutex<FalkorClientProvider>,
connection_pool_size: u8,
connection_pool_tx: mpsc::Sender<FalkorAsyncConnection>,
connection_pool_rx: Mutex<mpsc::Receiver<FalkorAsyncConnection>>,
}
impl FalkorAsyncClientInner {
#[cfg_attr(
feature = "tracing",
tracing::instrument(
name = "Borrow Connection From Connection Pool",
skip_all,
level = "debug"
)
)]
pub(crate) async fn borrow_connection(
&self,
pool_owner: Arc<Self>,
) -> FalkorResult<BorrowedAsyncConnection> {
Ok(BorrowedAsyncConnection::new(
self.connection_pool_rx
.lock()
.await
.recv()
.await
.ok_or(FalkorDBError::EmptyConnection)?,
self.connection_pool_tx.clone(),
pool_owner,
))
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(
name = "Get New Async Connection From Client",
skip_all,
level = "info"
)
)]
pub(crate) async fn get_async_connection(&self) -> FalkorResult<FalkorAsyncConnection> {
self._inner.lock().await.get_async_connection().await
}
}
impl ProvidesSyncConnections for FalkorAsyncClientInner {
#[cfg_attr(
feature = "tracing",
tracing::instrument(
name = "Get New Sync Connection From Client",
skip_all,
level = "info"
)
)]
fn get_connection(&self) -> FalkorResult<FalkorSyncConnection> {
let handle = Handle::try_current().map_err(|_| FalkorDBError::NoRuntime)?;
match handle.runtime_flavor() {
RuntimeFlavor::CurrentThread => Err(FalkorDBError::SingleThreadedRuntime),
_ => task::block_in_place(|| handle.block_on(self._inner.lock())).get_connection(),
}
}
}
pub struct FalkorAsyncClient {
inner: Arc<FalkorAsyncClientInner>,
_connection_info: FalkorConnectionInfo,
}
impl FalkorAsyncClient {
pub(crate) async fn create(
mut client: FalkorClientProvider,
connection_info: FalkorConnectionInfo,
num_connections: u8,
) -> FalkorResult<Self> {
let (connection_pool_tx, connection_pool_rx) = mpsc::channel(num_connections as usize);
for _ in 0..num_connections {
let new_conn = client
.get_async_connection()
.await
.map_err(|err| FalkorDBError::RedisError(err.to_string()))?;
connection_pool_tx
.send(new_conn)
.await
.map_err(|_| FalkorDBError::EmptyConnection)?;
}
Ok(Self {
inner: Arc::new(FalkorAsyncClientInner {
_inner: client.into(),
connection_pool_size: num_connections,
connection_pool_tx,
connection_pool_rx: Mutex::new(connection_pool_rx),
}),
_connection_info: connection_info,
})
}
pub fn connection_pool_size(&self) -> u8 {
self.inner.connection_pool_size
}
pub(crate) async fn borrow_connection(&self) -> FalkorResult<BorrowedAsyncConnection> {
self.inner.borrow_connection(self.inner.clone()).await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "List Graphs", skip_all, level = "info")
)]
pub async fn list_graphs(&self) -> FalkorResult<Vec<String>> {
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.LIST", None, None)
.await
.and_then(redis_value_as_untyped_string_vec)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Get Config Value", skip_all, level = "info")
)]
pub async fn config_get(
&self,
config_key: &str,
) -> FalkorResult<HashMap<String, ConfigValue>> {
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.CONFIG", Some("GET"), Some(&[config_key]))
.await
.and_then(parse_config_hashmap)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Set Config Value", skip_all, level = "info")
)]
pub async fn config_set<C: Into<ConfigValue>>(
&self,
config_key: &str,
value: C,
) -> FalkorResult<redis::Value> {
self.borrow_connection()
.await?
.execute_command(
None,
"GRAPH.CONFIG",
Some("SET"),
Some(&[config_key, value.into().to_string().as_str()]),
)
.await
}
pub fn select_graph<T: ToString>(
&self,
graph_name: T,
) -> AsyncGraph {
AsyncGraph::new(self.inner.clone(), graph_name)
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Copy Graph", skip_all, level = "info")
)]
pub async fn copy_graph(
&self,
graph_to_clone: &str,
new_graph_name: &str,
) -> FalkorResult<AsyncGraph> {
self.borrow_connection()
.await?
.execute_command(
Some(graph_to_clone),
"GRAPH.COPY",
None,
Some(&[new_graph_name]),
)
.await?;
Ok(self.select_graph(new_graph_name))
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Client Get Redis Info", skip_all, level = "info")
)]
pub async fn redis_info(
&self,
section: Option<&str>,
) -> FalkorResult<HashMap<String, String>> {
let mut conn = self.borrow_connection().await?;
let redis_info = conn.as_inner()?.get_redis_info(section).await;
conn.return_to_pool().await;
redis_info
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Load UDF Library", skip_all, level = "info")
)]
pub async fn udf_load(
&self,
name: &str,
script: &str,
replace: bool,
) -> FalkorResult<redis::Value> {
let params = if replace {
vec!["REPLACE", name, script]
} else {
vec![name, script]
};
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.UDF", Some("LOAD"), Some(¶ms))
.await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "List UDF Libraries", skip_all, level = "info")
)]
pub async fn udf_list(
&self,
lib: Option<&str>,
with_code: bool,
) -> FalkorResult<redis::Value> {
let mut params = Vec::new();
if let Some(library) = lib {
params.push(library);
}
if with_code {
params.push("WITHCODE");
}
let params_slice = if params.is_empty() {
None
} else {
Some(params.as_slice())
};
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.UDF", Some("LIST"), params_slice)
.await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Flush UDF Libraries", skip_all, level = "info")
)]
pub async fn udf_flush(&self) -> FalkorResult<redis::Value> {
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.UDF", Some("FLUSH"), None)
.await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(name = "Delete UDF Library", skip_all, level = "info")
)]
pub async fn udf_delete(
&self,
lib: &str,
) -> FalkorResult<redis::Value> {
self.borrow_connection()
.await?
.execute_command(None, "GRAPH.UDF", Some("DELETE"), Some(&[lib]))
.await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
test_utils::{create_async_test_client, TestAsyncGraphHandle},
FalkorClientBuilder,
};
use std::{mem, num::NonZeroU8, thread};
use tokio::sync::mpsc::error::TryRecvError;
#[tokio::test(flavor = "multi_thread")]
async fn test_borrow_connection() {
let client = FalkorClientBuilder::new_async()
.with_num_connections(NonZeroU8::new(6).expect("Could not create a perfectly valid u8"))
.build()
.await
.expect("Could not create client for this test");
let mut conn_vec = Vec::with_capacity(6);
for _ in 0..6 {
let conn = client.borrow_connection().await;
assert!(conn.is_ok());
conn_vec.push(conn);
}
let non_existing_conn = client.inner.connection_pool_rx.lock().await.try_recv();
assert!(non_existing_conn.is_err());
let Err(TryRecvError::Empty) = non_existing_conn else {
panic!("Got error, but not a TryRecvError::Empty, as expected");
};
}
#[tokio::test(flavor = "multi_thread")]
async fn test_list_graphs() {
let client = create_async_test_client().await;
let res = client.list_graphs().await;
assert!(res.is_ok());
let graphs = res.unwrap();
assert!(graphs.contains(&"imdb".to_string()));
}
#[tokio::test(flavor = "multi_thread")]
async fn test_select_graph_and_query() {
let client = create_async_test_client().await;
let mut graph = client.select_graph("imdb");
assert_eq!(graph.graph_name(), "imdb".to_string());
let res = graph
.query("MATCH (a:actor) return a")
.execute()
.await
.expect("Could not get actors from unmodified graph");
assert_eq!(res.data.collect::<Vec<_>>().len(), 1317);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_copy_graph() {
let client = create_async_test_client().await;
client
.select_graph("imdb_ro_copy_async")
.delete()
.await
.ok();
let graph = client.copy_graph("imdb", "imdb_ro_copy_async").await;
assert!(graph.is_ok());
let mut graph = TestAsyncGraphHandle {
inner: graph.unwrap(),
};
let mut original_graph = client.select_graph("imdb");
assert_eq!(
graph
.inner
.query("MATCH (a:actor) RETURN a")
.execute()
.await
.expect("Could not get actors from unmodified graph")
.data
.collect::<Vec<_>>(),
original_graph
.query("MATCH (a:actor) RETURN a")
.execute()
.await
.expect("Could not get actors from unmodified graph")
.data
.collect::<Vec<_>>()
)
}
#[tokio::test(flavor = "multi_thread")]
async fn test_get_config() {
let client = create_async_test_client().await;
let config = client
.config_get("QUERY_MEM_CAPACITY")
.await
.expect("Could not get configuration");
assert_eq!(config.len(), 1);
assert!(config.contains_key("QUERY_MEM_CAPACITY"));
assert_eq!(
mem::discriminant(config.get("QUERY_MEM_CAPACITY").unwrap()),
mem::discriminant(&ConfigValue::Int64(0))
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_get_config_all() {
let client = create_async_test_client().await;
let configuration = client
.config_get("*")
.await
.expect("Could not get configuration");
assert_eq!(
configuration.get("THREAD_COUNT").cloned().unwrap(),
ConfigValue::Int64(thread::available_parallelism().unwrap().get() as i64)
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_set_config() {
let client = create_async_test_client().await;
let config = client
.config_get("MAX_QUEUED_QUERIES")
.await
.expect("Could not get configuration");
let current_val = config
.get("MAX_QUEUED_QUERIES")
.cloned()
.unwrap()
.as_i64()
.unwrap();
let desired_val = if current_val == 4294967295 {
4294967295 / 2
} else {
4294967295
};
client
.config_set("MAX_QUEUED_QUERIES", desired_val)
.await
.expect("Could not set config value");
let new_config = client
.config_get("MAX_QUEUED_QUERIES")
.await
.expect("Could not get configuration");
assert_eq!(
new_config
.get("MAX_QUEUED_QUERIES")
.cloned()
.unwrap()
.as_i64()
.unwrap(),
desired_val
);
client
.config_set("MAX_QUEUED_QUERIES", current_val)
.await
.ok();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_udf_operations() {
let client = create_async_test_client().await;
let script = r#"
#!js api_version=1.0 name=mylib_async
redis.registerFunction('my_func', function(a, b) {
return a + b;
});
"#;
let result = client.udf_load("mylib_async", script, false).await;
assert!(result.is_ok(), "Failed to load UDF library: {:?}", result);
let list_result = client.udf_list(None, false).await;
assert!(list_result.is_ok(), "Failed to list UDF libraries");
let list_with_code = client.udf_list(Some("mylib_async"), true).await;
assert!(
list_with_code.is_ok(),
"Failed to list UDF library with code"
);
let delete_result = client.udf_delete("mylib_async").await;
assert!(delete_result.is_ok(), "Failed to delete UDF library");
let list_after_delete = client.udf_list(None, false).await;
assert!(
list_after_delete.is_ok(),
"Failed to list UDF libraries after delete"
);
}
#[tokio::test(flavor = "multi_thread")]
async fn test_udf_load_replace() {
let client = create_async_test_client().await;
let script = r#"
#!js api_version=1.0 name=replacelib_async
redis.registerFunction('func1', function(x) {
return x * 2;
});
"#;
let result = client.udf_load("replacelib_async", script, false).await;
assert!(result.is_ok(), "Failed to load UDF library");
let updated_script = r#"
#!js api_version=1.0 name=replacelib_async
redis.registerFunction('func1', function(x) {
return x * 3;
});
"#;
let replace_result = client
.udf_load("replacelib_async", updated_script, true)
.await;
assert!(replace_result.is_ok(), "Failed to replace UDF library");
client.udf_delete("replacelib_async").await.ok();
}
#[tokio::test(flavor = "multi_thread")]
async fn test_udf_flush() {
let client = create_async_test_client().await;
let script = r#"
#!js api_version=1.0 name=flushlib_async
redis.registerFunction('test_func', function() {
return 42;
});
"#;
client.udf_load("flushlib_async", script, false).await.ok();
let flush_result = client.udf_flush().await;
assert!(flush_result.is_ok(), "Failed to flush UDF libraries");
let list_after_flush = client.udf_list(None, false).await;
assert!(
list_after_flush.is_ok(),
"Failed to list UDF libraries after flush"
);
}
}