use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::time::Duration;
use url::Url;
use crate::database::Database;
use crate::enums::ReadConsistency;
use crate::error::{Result, VectorDBError};
#[derive(Debug, Clone)]
pub struct VectorDBClient {
client: Client,
base_url: Url,
username: String,
key: String,
read_consistency: ReadConsistency,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct AuthRequest {
username: String,
key: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct AuthResponse {
code: i32,
msg: String,
#[serde(default)]
token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiResponse<T> {
code: i32,
msg: String,
#[serde(default)]
data: Option<T>,
}
impl VectorDBClient {
pub fn new(
url: impl AsRef<str>,
username: impl Into<String>,
key: impl Into<String>,
read_consistency: ReadConsistency,
timeout_secs: u64,
) -> Result<Self> {
let base_url = Url::parse(url.as_ref())?;
let client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.connect_timeout(Duration::from_secs(10))
.tcp_keepalive(Duration::from_secs(60))
.tcp_nodelay(true)
.http2_prior_knowledge() .pool_idle_timeout(Duration::from_secs(30))
.pool_max_idle_per_host(10)
.user_agent("tcvectordb-rust/0.1.0")
.build()?;
Ok(Self {
client,
base_url,
username: username.into(),
key: key.into(),
read_consistency,
})
}
pub fn with_client(
client: Client,
url: impl AsRef<str>,
username: impl Into<String>,
key: impl Into<String>,
read_consistency: ReadConsistency,
) -> Result<Self> {
let base_url = Url::parse(url.as_ref())?;
Ok(Self {
client,
base_url,
username: username.into(),
key: key.into(),
read_consistency,
})
}
pub fn http_client(&self) -> &Client {
&self.client
}
pub fn base_url(&self) -> &Url {
&self.base_url
}
pub fn username(&self) -> &str {
&self.username
}
pub fn key(&self) -> &str {
&self.key
}
pub fn read_consistency(&self) -> &ReadConsistency {
&self.read_consistency
}
pub async fn exists_database(&self, database_name: &str) -> Result<bool> {
let databases = self.list_databases().await?;
Ok(databases.iter().any(|db| db.name() == database_name))
}
pub async fn create_database(&self, database_name: impl Into<String>) -> Result<Database> {
let database_name = database_name.into();
let db = Database::new(self.clone(), database_name.clone());
db.create().await?;
Ok(db)
}
pub async fn create_database_if_not_exists(&self, database_name: impl Into<String>) -> Result<Database> {
let database_name = database_name.into();
if self.exists_database(&database_name).await? {
Ok(Database::new(self.clone(), database_name))
} else {
self.create_database(database_name).await
}
}
pub async fn drop_database(&self, database_name: impl Into<String>) -> Result<Value> {
let database_name = database_name.into();
let db = Database::new(self.clone(), database_name);
db.drop().await
}
pub async fn list_databases(&self) -> Result<Vec<Database>> {
let response = self.get("/database/list")
.await?
.send()
.await?;
let response_text = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&response_text)?;
if let Some(code) = response_json.get("code").and_then(|v| v.as_i64()) {
if code != 0 {
let msg = response_json.get("msg")
.and_then(|v| v.as_str())
.unwrap_or("Unknown error");
return Err(VectorDBError::server_error(code as i32, msg));
}
}
let database_names: Vec<String> = if let Some(databases) = response_json.get("databases") {
serde_json::from_value(databases.clone())?
} else if let Some(data) = response_json.get("data") {
serde_json::from_value(data.clone())?
} else {
vec![]
};
Ok(database_names
.into_iter()
.map(|name| Database::new(self.clone(), name))
.collect())
}
pub async fn database(&self, database_name: impl Into<String>) -> Result<Database> {
let database_name = database_name.into();
if self.exists_database(&database_name).await? {
Ok(Database::new(self.clone(), database_name))
} else {
Err(VectorDBError::param_error(
14100,
format!("Database not exist: {}", database_name),
))
}
}
pub(crate) async fn handle_response<T>(&self, response: Response) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
let status = response.status();
let text = response.text().await?;
if !status.is_success() {
return Err(VectorDBError::connect_error(
status.as_u16() as i32,
format!("HTTP error {}: {}", status, text),
));
}
serde_json::from_str(&text).map_err(|e| {
VectorDBError::unexpected_error(format!("Failed to parse response: {}", e))
})
}
pub(crate) async fn request(&self, method: reqwest::Method, path: &str) -> Result<reqwest::RequestBuilder> {
let url = self.base_url.join(path)?;
Ok(self.client
.request(method, url)
.header("Authorization", format!("Bearer account={}&api_key={}", self.username, self.key))
.header("Content-Type", "application/json"))
}
pub(crate) async fn get(&self, path: &str) -> Result<reqwest::RequestBuilder> {
self.request(reqwest::Method::GET, path).await
}
pub(crate) async fn post(&self, path: &str) -> Result<reqwest::RequestBuilder> {
self.request(reqwest::Method::POST, path).await
}
#[allow(dead_code)]
pub(crate) async fn put(&self, path: &str) -> Result<reqwest::RequestBuilder> {
self.request(reqwest::Method::PUT, path).await
}
pub(crate) async fn delete(&self, path: &str) -> Result<reqwest::RequestBuilder> {
self.request(reqwest::Method::DELETE, path).await
}
}