use crate::client::Client;
use crate::error::Result;
use crate::models::{
AddCollectionRequest, EnableMetadataStoreRequest, EnableMetadataStoreResponse, GenericResponse,
GetCollectionDataResponse, GetCollectionModelResponse, GetCollectionSchemaResponse,
InsertRecordRequest, InsertRecordResponse, ListCollectionsModelsResponse,
ListCollectionsResponse, UpdateModelsEvent,
};
use futures_util::StreamExt;
use reqwest::Response;
use std::collections::HashMap;
impl Client {
pub async fn list_collections(&self) -> Result<ListCollectionsResponse> {
self.do_request::<ListCollectionsResponse, ()>(
reqwest::Method::GET,
"/api/collections/v1/",
None,
None,
)
.await
}
pub async fn add_collection(&self, req: &AddCollectionRequest) -> Result<GenericResponse> {
self.do_request(
reqwest::Method::POST,
"/api/collections/v1/",
Some(req),
None,
)
.await
}
pub async fn delete_record(&self, collection_name: &str, id: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/{}", collection_name, id);
self.do_request::<GenericResponse, ()>(reqwest::Method::DELETE, &path, None, None)
.await
}
pub async fn expiry_cleanup(&self, collection_name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/expiry-cleanup", collection_name);
self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
.await
}
pub async fn drop_collection(&self, name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}", name);
self.do_request::<GenericResponse, ()>(reqwest::Method::DELETE, &path, None, None)
.await
}
pub async fn flush_collection(&self, name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/flush", name);
self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
.await
}
pub async fn load_collection(&self, name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/load", name);
self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
.await
}
pub async fn unload_collection(&self, name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/unload", name);
self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
.await
}
pub async fn export_collection(&self, name: &str) -> Result<Response> {
let path = format!("/api/collections/v1/{}/export", name);
self.do_request_with_file_response(reqwest::Method::POST, &path, None)
.await
}
pub async fn import_collection(&self, file_path: &std::path::Path) -> Result<()> {
self.do_file_request(
reqwest::Method::POST,
"/api/collections/v1/import",
file_path,
)
.await
}
pub async fn rename_collection(
&self,
old_name: &str,
new_name: &str,
) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/rename/{}", old_name, new_name);
self.do_request::<GenericResponse, ()>(reqwest::Method::PUT, &path, None, None)
.await
}
pub async fn reindex_collection(&self, collection_name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/reindex", collection_name);
self.do_request::<GenericResponse, ()>(reqwest::Method::PUT, &path, None, None)
.await
}
pub async fn pq_train(&self, collection_name: &str) -> Result<GenericResponse> {
let path = format!("/api/collections/v1/{}/pq-train", collection_name);
self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
.await
}
pub async fn insert_record(&self, req: &InsertRecordRequest) -> Result<InsertRecordResponse> {
self.do_request(
reqwest::Method::POST,
"/api/collections/v1/record",
Some(req),
None,
)
.await
}
pub async fn get_collection_data(
&self,
collection_name: &str,
offset: i32,
limit: i32,
) -> Result<GetCollectionDataResponse> {
let path = format!(
"/api/collections/v1/{}/data?offset={}&limit={}",
collection_name, offset, limit
);
self.do_request::<GetCollectionDataResponse, ()>(reqwest::Method::GET, &path, None, None)
.await
}
pub async fn enable_nli(&self, collection: &str, vertical: &str) -> Result<Response> {
let mut params = HashMap::new();
params.insert("vertical".to_string(), vertical.to_string());
let path = format!("/api/collections/v1/{}/nli/enable", collection);
self.do_request_with_file_response(reqwest::Method::GET, &path, Some(¶ms))
.await
}
pub async fn get_collection_schema(
&self,
collection_name: &str,
) -> Result<GetCollectionSchemaResponse> {
let path = format!("/api/collections/v1/{}/schema", collection_name);
self.do_request::<GetCollectionSchemaResponse, ()>(reqwest::Method::GET, &path, None, None)
.await
}
pub async fn enable_metadata_store(
&self,
collection_name: &str,
req: &EnableMetadataStoreRequest,
) -> Result<EnableMetadataStoreResponse> {
let path = format!("/api/collections/v1/{}/metadata/enable", collection_name);
self.do_request(reqwest::Method::POST, &path, Some(req), None)
.await
}
pub async fn list_collection_models(&self) -> Result<ListCollectionsModelsResponse> {
self.do_request::<ListCollectionsModelsResponse, ()>(
reqwest::Method::GET,
"/api/collections/v1/models",
None,
None,
)
.await
}
pub async fn get_collection_model_info(
&self,
collection_name: &str,
model_id: &str,
) -> Result<GetCollectionModelResponse> {
let path = format!(
"/api/collections/v1/{}/models/{}",
collection_name, model_id
);
self.do_request::<GetCollectionModelResponse, ()>(reqwest::Method::GET, &path, None, None)
.await
}
pub async fn update_collection_model(
&self,
collection_name: &str,
) -> Result<impl futures_util::Stream<Item = Result<UpdateModelsEvent>>> {
let url = format!(
"{}/api/collections/v1/{}/models/update",
self.base_url, collection_name
);
let mut request = self.http_client.request(reqwest::Method::POST, &url);
if let Some(token) = &self.auth_token {
request = request.bearer_auth(token);
}
let response = request.send().await?;
if response.status().is_client_error() || response.status().is_server_error() {
let status = response.status().as_u16();
let message = response.text().await.unwrap_or_default();
return Err(crate::error::ShilpError::ApiError { message, status });
}
use tokio_util::codec::{FramedRead, LinesCodec};
let stream_reader =
tokio_util::io::StreamReader::new(response.bytes_stream().map(|result| {
result.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
}));
let stream = FramedRead::new(stream_reader, LinesCodec::new()).map(|result| {
result
.map_err(|e| {
crate::error::ShilpError::IoError(std::io::Error::new(
std::io::ErrorKind::Other,
e,
))
})
.and_then(|line| {
serde_json::from_str::<UpdateModelsEvent>(&line)
.map_err(|e| crate::error::ShilpError::from(e))
})
});
Ok(stream)
}
}