use tonic::transport::{Channel, ClientTlsConfig};
use tonic::metadata::MetadataValue;
use std::convert::TryFrom;
use tonic::Request;
use crate::config::ClientConfig;
use crate::error::{ClientError, Result};
use vectordb_proto::vector_db_service_client::VectorDbServiceClient;
use vectordb_proto::vectordb::{
Empty, ServerInfo, StatusResponse, CreateCollectionRequest,
CollectionRequest, ListCollectionsResponse, QueryRequest,
QueryResponse, IndexFilesRequest, IndexResponse,
AddRepositoryRequest, RepositoryRequest, RemoveRepositoryRequest,
SyncRepositoryRequest, UseBranchRequest, ListRepositoriesResponse,
};
pub struct VectorDBClient {
client: VectorDbServiceClient<Channel>,
config: ClientConfig,
}
impl VectorDBClient {
pub async fn new(config: ClientConfig) -> Result<Self> {
let client = Self::create_client(&config).await?;
Ok(Self { client, config })
}
pub async fn default() -> Result<Self> {
Self::new(ClientConfig::default()).await
}
pub async fn connect<S: Into<String>>(address: S) -> Result<Self> {
let config = ClientConfig::new(address);
Self::new(config).await
}
async fn create_client(config: &ClientConfig) -> Result<VectorDbServiceClient<Channel>> {
let channel = if config.use_tls {
let tls_config = if let Some(ca_cert_path) = &config.ca_cert_path {
let ca_cert = tokio::fs::read(ca_cert_path).await
.map_err(|e| ClientError::Configuration(format!("Failed to read CA certificate: {}", e)))?;
ClientTlsConfig::new()
.ca_certificate(tonic::transport::Certificate::from_pem(ca_cert))
.domain_name(Self::extract_domain(&config.server_address)?)
} else {
ClientTlsConfig::new()
.domain_name(Self::extract_domain(&config.server_address)?)
};
Channel::from_shared(config.server_address.clone())
.map_err(|e| ClientError::Configuration(format!("Invalid server address: {}", e)))?
.tls_config(tls_config)
.map_err(|e| ClientError::Configuration(format!("TLS configuration error: {}", e)))?
.connect()
.await?
} else {
Channel::from_shared(config.server_address.clone())
.map_err(|e| ClientError::Configuration(format!("Invalid server address: {}", e)))?
.connect()
.await?
};
let client = VectorDbServiceClient::new(channel);
Ok(client)
}
fn extract_domain(address: &str) -> Result<String> {
let parts: Vec<&str> = address.split("://").collect();
let host_part = if parts.len() > 1 {
parts[1]
} else {
parts[0]
};
let host = host_part.split(':').next().unwrap_or(host_part);
Ok(host.to_string())
}
fn prepare_request<T>(&self, request: Request<T>) -> Request<T> {
if let Some(api_key) = &self.config.api_key {
if let Ok(value) = MetadataValue::try_from(api_key.as_str()) {
let mut req = request;
req.metadata_mut().insert("x-api-key", value);
return req;
}
}
request
}
pub async fn get_server_info(&mut self) -> Result<ServerInfo> {
let request = self.prepare_request(Request::new(Empty {}));
let response = self.client.get_server_info(request).await?;
Ok(response.into_inner())
}
pub async fn create_collection(
&mut self,
name: String,
vector_size: i32,
distance: String
) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(CreateCollectionRequest {
name,
vector_size,
distance,
}));
let response = self.client.create_collection(request).await?;
Ok(response.into_inner())
}
pub async fn list_collections(&mut self) -> Result<ListCollectionsResponse> {
let request = self.prepare_request(Request::new(Empty {}));
let response = self.client.list_collections(request).await?;
Ok(response.into_inner())
}
pub async fn delete_collection(&mut self, name: String) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(CollectionRequest {
name,
}));
let response = self.client.delete_collection(request).await?;
Ok(response.into_inner())
}
pub async fn clear_collection(&mut self, name: String) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(CollectionRequest {
name,
}));
let response = self.client.clear_collection(request).await?;
Ok(response.into_inner())
}
pub async fn index_files(
&mut self,
collection_name: String,
paths: Vec<String>,
extensions: Vec<String>,
) -> Result<IndexResponse> {
let request = self.prepare_request(Request::new(IndexFilesRequest {
collection_name,
paths,
extensions,
}));
let response = self.client.index_files(request).await?;
Ok(response.into_inner())
}
pub async fn query_collection(
&mut self,
collection_name: String,
query_text: String,
limit: i32,
language: Option<String>,
element_type: Option<String>,
) -> Result<QueryResponse> {
let request = self.prepare_request(Request::new(QueryRequest {
collection_name,
query_text,
limit,
language,
element_type,
}));
let response = self.client.query_collection(request).await?;
Ok(response.into_inner())
}
pub async fn add_repository(
&mut self,
url: String,
local_path: Option<String>,
name: Option<String>,
branch: Option<String>,
remote: Option<String>,
ssh_key_path: Option<String>,
ssh_passphrase: Option<String>,
) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(AddRepositoryRequest {
url,
local_path,
name,
branch,
remote,
ssh_key_path,
ssh_passphrase,
}));
let response = self.client.add_repository(request).await?;
Ok(response.into_inner())
}
pub async fn list_repositories(&mut self) -> Result<ListRepositoriesResponse> {
let request = self.prepare_request(Request::new(Empty {}));
let response = self.client.list_repositories(request).await?;
Ok(response.into_inner())
}
pub async fn use_repository(&mut self, name: String) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(RepositoryRequest {
name,
}));
let response = self.client.use_repository(request).await?;
Ok(response.into_inner())
}
pub async fn remove_repository(
&mut self,
name: String,
skip_confirmation: bool,
) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(RemoveRepositoryRequest {
name,
skip_confirmation,
}));
let response = self.client.remove_repository(request).await?;
Ok(response.into_inner())
}
pub async fn sync_repository(
&mut self,
name: Option<String>,
extensions: Vec<String>,
force: bool,
) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(SyncRepositoryRequest {
name,
extensions,
force,
}));
let response = self.client.sync_repository(request).await?;
Ok(response.into_inner())
}
pub async fn use_branch(
&mut self,
branch_name: String,
repository_name: Option<String>,
) -> Result<StatusResponse> {
let request = self.prepare_request(Request::new(UseBranchRequest {
branch_name,
repository_name,
}));
let response = self.client.use_branch(request).await?;
Ok(response.into_inner())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tonic::metadata::MetadataMap;
#[test]
fn test_extract_domain_http() {
let result = VectorDBClient::extract_domain("http://localhost:50051").unwrap();
assert_eq!(result, "localhost");
}
#[test]
fn test_extract_domain_https() {
let result = VectorDBClient::extract_domain("https://example.com:8080").unwrap();
assert_eq!(result, "example.com");
}
#[test]
fn test_extract_domain_no_protocol() {
let result = VectorDBClient::extract_domain("127.0.0.1:50051").unwrap();
assert_eq!(result, "127.0.0.1");
}
#[test]
fn test_extract_domain_no_port() {
let result = VectorDBClient::extract_domain("https://api.example.com").unwrap();
assert_eq!(result, "api.example.com");
}
#[test]
fn test_prepare_request_with_api_key() {
let client = VectorDBClient {
client: VectorDbServiceClient::new(Channel::from_static("http://[::1]:50051")),
config: ClientConfig {
server_address: "http://localhost:50051".to_string(),
use_tls: false,
api_key: Some("test-api-key".to_string()),
ca_cert_path: None,
},
};
let request = Request::new(Empty {});
let prepared = client.prepare_request(request);
let metadata: &MetadataMap = prepared.metadata();
assert!(metadata.contains_key("x-api-key"));
assert_eq!(
metadata.get("x-api-key").unwrap().to_str().unwrap(),
"test-api-key"
);
}
#[test]
fn test_prepare_request_without_api_key() {
let client = VectorDBClient {
client: VectorDbServiceClient::new(Channel::from_static("http://[::1]:50051")),
config: ClientConfig {
server_address: "http://localhost:50051".to_string(),
use_tls: false,
api_key: None,
ca_cert_path: None,
},
};
let request = Request::new(Empty {});
let prepared = client.prepare_request(request);
let metadata: &MetadataMap = prepared.metadata();
assert!(!metadata.contains_key("x-api-key"));
}
}