Skip to main content

shilp_sdk/
collections.rs

1use crate::client::Client;
2use crate::error::Result;
3use crate::models::{
4    AddCollectionRequest, EnableMetadataStoreRequest, EnableMetadataStoreResponse, GenericResponse,
5    GetCollectionDataResponse, GetCollectionModelResponse, GetCollectionSchemaResponse,
6    InsertRecordRequest, InsertRecordResponse, ListCollectionsModelsResponse,
7    ListCollectionsResponse, UpdateModelsEvent,
8};
9use futures_util::StreamExt;
10use reqwest::Response;
11use std::collections::HashMap;
12
13impl Client {
14    /// Lists all collections
15    pub async fn list_collections(&self) -> Result<ListCollectionsResponse> {
16        self.do_request::<ListCollectionsResponse, ()>(
17            reqwest::Method::GET,
18            "/api/collections/v1/",
19            None,
20            None,
21        )
22        .await
23    }
24
25    /// Adds a new collection
26    pub async fn add_collection(&self, req: &AddCollectionRequest) -> Result<GenericResponse> {
27        self.do_request(
28            reqwest::Method::POST,
29            "/api/collections/v1/",
30            Some(req),
31            None,
32        )
33        .await
34    }
35
36    /// Deletes a record from a collection
37    pub async fn delete_record(&self, collection_name: &str, id: &str) -> Result<GenericResponse> {
38        let path = format!("/api/collections/v1/{}/{}", collection_name, id);
39        self.do_request::<GenericResponse, ()>(reqwest::Method::DELETE, &path, None, None)
40            .await
41    }
42
43    /// Performs expiry cleanup on a collection
44    pub async fn expiry_cleanup(&self, collection_name: &str) -> Result<GenericResponse> {
45        let path = format!("/api/collections/v1/{}/expiry-cleanup", collection_name);
46        self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
47            .await
48    }
49
50    /// Drops an existing collection
51    pub async fn drop_collection(&self, name: &str) -> Result<GenericResponse> {
52        let path = format!("/api/collections/v1/{}", name);
53        self.do_request::<GenericResponse, ()>(reqwest::Method::DELETE, &path, None, None)
54            .await
55    }
56
57    /// Flushes a collection to disk
58    pub async fn flush_collection(&self, name: &str) -> Result<GenericResponse> {
59        let path = format!("/api/collections/v1/{}/flush", name);
60        self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
61            .await
62    }
63
64    /// Loads a collection into memory
65    pub async fn load_collection(&self, name: &str) -> Result<GenericResponse> {
66        let path = format!("/api/collections/v1/{}/load", name);
67        self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
68            .await
69    }
70
71    /// Unloads a collection from memory
72    pub async fn unload_collection(&self, name: &str) -> Result<GenericResponse> {
73        let path = format!("/api/collections/v1/{}/unload", name);
74        self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
75            .await
76    }
77
78    /// Exports a collection and returns a Response for downloading the file
79    /// The caller is responsible for processing the response (e.g., saving to a file)
80    pub async fn export_collection(&self, name: &str) -> Result<Response> {
81        let path = format!("/api/collections/v1/{}/export", name);
82        self.do_request_with_file_response(reqwest::Method::POST, &path, None)
83            .await
84    }
85
86    /// Imports a collection from a file
87    pub async fn import_collection(&self, file_path: &std::path::Path) -> Result<()> {
88        self.do_file_request(
89            reqwest::Method::POST,
90            "/api/collections/v1/import",
91            file_path,
92        )
93        .await
94    }
95
96    /// Renames an existing collection
97    pub async fn rename_collection(
98        &self,
99        old_name: &str,
100        new_name: &str,
101    ) -> Result<GenericResponse> {
102        let path = format!("/api/collections/v1/{}/rename/{}", old_name, new_name);
103        self.do_request::<GenericResponse, ()>(reqwest::Method::PUT, &path, None, None)
104            .await
105    }
106
107    /// Re-indexes a collection for debug purposes
108    pub async fn reindex_collection(&self, collection_name: &str) -> Result<GenericResponse> {
109        let path = format!("/api/collections/v1/{}/reindex", collection_name);
110        self.do_request::<GenericResponse, ()>(reqwest::Method::PUT, &path, None, None)
111            .await
112    }
113
114    /// Performs Product Quantization training for an existing collection
115    pub async fn pq_train(&self, collection_name: &str) -> Result<GenericResponse> {
116        let path = format!("/api/collections/v1/{}/pq-train", collection_name);
117        self.do_request::<GenericResponse, ()>(reqwest::Method::POST, &path, None, None)
118            .await
119    }
120
121    /// Inserts a new record into a collection
122    pub async fn insert_record(&self, req: &InsertRecordRequest) -> Result<InsertRecordResponse> {
123        self.do_request(
124            reqwest::Method::POST,
125            "/api/collections/v1/record",
126            Some(req),
127            None,
128        )
129        .await
130    }
131
132    /// Gets paginated data records from a collection
133    pub async fn get_collection_data(
134        &self,
135        collection_name: &str,
136        offset: i32,
137        limit: i32,
138    ) -> Result<GetCollectionDataResponse> {
139        let path = format!(
140            "/api/collections/v1/{}/data?offset={}&limit={}",
141            collection_name, offset, limit
142        );
143        self.do_request::<GetCollectionDataResponse, ()>(reqwest::Method::GET, &path, None, None)
144            .await
145    }
146
147    /// Enables Natural Language Inference for a collection and vertical.
148    /// This is an SSE endpoint that streams the progress of enabling NLI.
149    /// The `vertical` parameter specifies the NLI provider vertical; pass an empty string
150    /// for a custom vertical.
151    /// Returns the raw streaming response; the caller is responsible for reading the SSE events.
152    pub async fn enable_nli(&self, collection: &str, vertical: &str) -> Result<Response> {
153        let mut params = HashMap::new();
154        params.insert("vertical".to_string(), vertical.to_string());
155        let path = format!("/api/collections/v1/{}/nli/enable", collection);
156        self.do_request_with_file_response(reqwest::Method::GET, &path, Some(&params))
157            .await
158    }
159
160    /// Gets the schema for a collection
161    pub async fn get_collection_schema(
162        &self,
163        collection_name: &str,
164    ) -> Result<GetCollectionSchemaResponse> {
165        let path = format!("/api/collections/v1/{}/schema", collection_name);
166        self.do_request::<GetCollectionSchemaResponse, ()>(reqwest::Method::GET, &path, None, None)
167            .await
168    }
169
170    /// Enables metadata store for an existing collection
171    pub async fn enable_metadata_store(
172        &self,
173        collection_name: &str,
174        req: &EnableMetadataStoreRequest,
175    ) -> Result<EnableMetadataStoreResponse> {
176        let path = format!("/api/collections/v1/{}/metadata/enable", collection_name);
177        self.do_request(reqwest::Method::POST, &path, Some(req), None)
178            .await
179    }
180
181    /// Lists all collection models
182    pub async fn list_collection_models(&self) -> Result<ListCollectionsModelsResponse> {
183        self.do_request::<ListCollectionsModelsResponse, ()>(
184            reqwest::Method::GET,
185            "/api/collections/v1/models",
186            None,
187            None,
188        )
189        .await
190    }
191
192    /// Gets information about a specific collection model
193    pub async fn get_collection_model_info(
194        &self,
195        collection_name: &str,
196        model_id: &str,
197    ) -> Result<GetCollectionModelResponse> {
198        let path = format!(
199            "/api/collections/v1/{}/models/{}",
200            collection_name, model_id
201        );
202        self.do_request::<GetCollectionModelResponse, ()>(reqwest::Method::GET, &path, None, None)
203            .await
204    }
205
206    /// Updates collection models with streaming progress updates.
207    /// Returns a stream of UpdateModelsEvent.
208    /// The stream sends events line by line until completion or error.
209    /// Each event is a newline-delimited JSON object.
210    pub async fn update_collection_model(
211        &self,
212        collection_name: &str,
213    ) -> Result<impl futures_util::Stream<Item = Result<UpdateModelsEvent>>> {
214        let url = format!(
215            "{}/api/collections/v1/{}/models/update",
216            self.base_url, collection_name
217        );
218        let mut request = self.http_client.request(reqwest::Method::POST, &url);
219
220        if let Some(token) = &self.auth_token {
221            request = request.bearer_auth(token);
222        }
223
224        let response = request.send().await?;
225
226        if response.status().is_client_error() || response.status().is_server_error() {
227            let status = response.status().as_u16();
228            let message = response.text().await.unwrap_or_default();
229            return Err(crate::error::ShilpError::ApiError { message, status });
230        }
231
232        // Use a codec to split the stream by newlines and parse each line as JSON
233        use tokio_util::codec::{FramedRead, LinesCodec};
234
235        let stream_reader =
236            tokio_util::io::StreamReader::new(response.bytes_stream().map(|result| {
237                result.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
238            }));
239
240        let stream = FramedRead::new(stream_reader, LinesCodec::new()).map(|result| {
241            result
242                .map_err(|e| {
243                    crate::error::ShilpError::IoError(std::io::Error::new(
244                        std::io::ErrorKind::Other,
245                        e,
246                    ))
247                })
248                .and_then(|line| {
249                    serde_json::from_str::<UpdateModelsEvent>(&line)
250                        .map_err(|e| crate::error::ShilpError::from(e))
251                })
252        });
253
254        Ok(stream)
255    }
256}