casper_client/
client.rs

1use crate::error::{CasperError, Result};
2use crate::models::*;
3use crate::grpc::service::matrix_service::{
4    matrix_service_client::MatrixServiceClient,
5    upload_matrix_request, MatrixData, MatrixHeader, UploadMatrixRequest,
6};
7use reqwest::Client;
8use std::time::Duration;
9use tokio_stream::wrappers::ReceiverStream;
10use tonic::Request;
11use url::Url;
12
13/// Casper vector database client
14#[derive(Debug, Clone)]
15pub struct CasperClient {
16    client: Client,
17    base_url: Url,
18    grpc_addr: String,
19}
20
21impl CasperClient {
22    /// Create a new Casper client
23    ///
24    /// - `host`: hostname or IP of the Casper server (e.g. "127.0.0.1")
25    /// - `http_port`: HTTP API port (e.g. 8080)
26    /// - `grpc_port`: gRPC API port (e.g. 50051)
27    pub fn new(host: &str, http_port: u16, grpc_port: u16) -> Result<Self> {
28        let base_url_str = format!("{}:{}", host, http_port);
29        let base_url = Url::parse(&base_url_str)?;
30        let client = Client::builder()
31            .timeout(Duration::from_secs(30))
32            .build()?;
33        
34        let grpc_addr = format!("{}:{}", host, grpc_port);
35
36        Ok(Self { client, base_url, grpc_addr })
37    }
38
39    /// Create a new Casper client with custom timeout
40    ///
41    /// - `host`: hostname or IP of the Casper server (e.g. "127.0.0.1")
42    /// - `http_port`: HTTP API port (e.g. 8080)
43    /// - `grpc_port`: gRPC API port (e.g. 50051)
44    pub fn with_timeout(host: &str, http_port: u16, grpc_port: u16, timeout: Duration) -> Result<Self> {
45        let base_url_str = format!("{}:{}", host, http_port);
46        let base_url = Url::parse(&base_url_str)?;
47        let client = Client::builder()
48            .timeout(timeout)
49            .build()?;
50        
51        let grpc_addr = format!("{}:{}", host, grpc_port);
52
53        Ok(Self { client, base_url, grpc_addr })
54    }
55
56    /// Get the base URL
57    pub fn base_url(&self) -> &str {
58        self.base_url.as_str()
59    }
60
61    /// Get the gRPC address
62    pub fn grpc_addr(&self) -> &str {
63        &self.grpc_addr
64    }
65
66    /// List all collections
67    pub async fn list_collections(&self) -> Result<CollectionsListResponse> {
68        let url = self.base_url.join("collections")?;
69        let response = self.client.get(url).send().await?;
70        
71        self.handle_response(response).await
72    }
73
74    /// Get collection information
75    pub async fn get_collection(&self, collection_name: &str) -> Result<CollectionInfo> {
76        let url = self.base_url.join(&format!("collection/{}", collection_name))?;
77        let response = self.client.get(url).send().await?;
78        
79        self.handle_response(response).await
80    }
81
82    /// Create a new collection
83    pub async fn create_collection(
84        &self,
85        collection_name: &str,
86        request: CreateCollectionRequest,
87    ) -> Result<()> {
88        let url = self.base_url.join(&format!("collection/{}", collection_name))?;
89        let response = self
90            .client
91            .post(url)
92            .query(&request)
93            .header("Content-Type", "application/json")
94            .send()
95            .await?;
96        
97        self.handle_empty_response(response).await
98    }
99
100    /// Delete a collection
101    pub async fn delete_collection(&self, collection_name: &str) -> Result<()> {
102        let url = self.base_url.join(&format!("collection/{}", collection_name))?;
103        let response = self.client.delete(url).send().await?;
104        
105        self.handle_empty_response(response).await
106    }
107
108    /// Insert a vector into a collection
109    pub async fn insert_vector(
110        &self,
111        collection_name: &str,
112        request: InsertRequest,
113    ) -> Result<()> {
114        let url = self.base_url.join(&format!("collection/{}/insert", collection_name))?;
115        let response = self
116            .client
117            .post(url)
118            .query(&[("id", request.id.to_string())])
119            .header("Content-Type", "application/json")
120            .json(&InsertVectorBody { vector: request.vector })
121            .send()
122            .await?;
123        
124        self.handle_empty_response(response).await
125    }
126
127    /// Delete a vector from a collection
128    pub async fn delete_vector(
129        &self,
130        collection_name: &str,
131        request: DeleteRequest,
132    ) -> Result<()> {
133        let url = self.base_url.join(&format!("collection/{}/delete", collection_name))?;
134        let response = self
135            .client
136            .delete(url)
137            .query(&[("id", request.id.to_string())])
138            .header("Content-Type", "application/json")
139            .send()
140            .await?;
141        
142        self.handle_empty_response(response).await
143    }
144
145    /// Search for similar vectors
146    pub async fn search(
147        &self,
148        collection_name: &str,
149        limit: usize,
150        request: SearchRequest,
151    ) -> Result<SearchResponse> {
152        let url = self.base_url.join(&format!("collection/{}/search", collection_name))?;
153        let response = self
154            .client
155            .post(url)
156            .query(&[
157                ("limit", limit.to_string()),
158                ("output", "bin".to_string()),
159            ])
160            .header("Content-Type", "application/json")
161            .json(&SearchVectorBody { vector: request.vector })
162            .send()
163            .await?;
164
165        let status = response.status();
166        if !status.is_success() {
167            let text = response.text().await?;
168            return Err(self.parse_error_response(status.as_u16(), &text));
169        }
170
171        let bytes = response.bytes().await?;
172        let buf = bytes.as_ref();
173
174        // Binary format:
175        // [u32 LE: count] followed by `count` * (u32 LE id, f32 LE score)
176        if buf.len() < 4 {
177            return Err(CasperError::InvalidResponse(
178                "binary search response too short (missing count)".to_string(),
179            ));
180        }
181
182        let mut offset = 0;
183        let mut count_bytes = [0u8; 4];
184        count_bytes.copy_from_slice(&buf[offset..offset + 4]);
185        let count = u32::from_le_bytes(count_bytes) as usize;
186        offset += 4;
187
188        let expected_len = 4 + count * (4 + 4);
189        if buf.len() < expected_len {
190            return Err(CasperError::InvalidResponse(format!(
191                "binary search response truncated: expected at least {} bytes, got {}",
192                expected_len,
193                buf.len()
194            )));
195        }
196
197        let mut results = Vec::with_capacity(count);
198        for _ in 0..count {
199            let mut id_bytes = [0u8; 4];
200            id_bytes.copy_from_slice(&buf[offset..offset + 4]);
201            let id = u32::from_le_bytes(id_bytes);
202            offset += 4;
203
204            let mut score_bytes = [0u8; 4];
205            score_bytes.copy_from_slice(&buf[offset..offset + 4]);
206            let score = f32::from_le_bytes(score_bytes);
207            offset += 4;
208
209            results.push(SearchResult { id, score });
210        }
211
212        Ok(results)
213    }
214
215    /// Get vector by ID
216    pub async fn get_vector(&self, collection_name: &str, id: u32) -> Result<Option<Vec<f32>>> {
217        let url = self.base_url.join(&format!("collection/{}/vector/{}", collection_name, id))?;
218        let response = self.client.get(url).send().await?;
219        
220        if response.status() == 404 {
221            return Ok(None);
222        }
223        
224        let vector_response: GetVectorResponse = self.handle_response(response).await?;
225        Ok(Some(vector_response.vector))
226    }
227
228    /// Batch update operations
229    pub async fn batch_update(
230        &self,
231        collection_name: &str,
232        request: BatchUpdateRequest,
233    ) -> Result<()> {
234        let url = self.base_url.join(&format!("collection/{}/update", collection_name))?;
235        let response = self
236            .client
237            .post(url)
238            .header("Content-Type", "application/json")
239            .json(&request)
240            .send()
241            .await?;
242        
243        self.handle_empty_response(response).await
244    }
245
246    pub async fn create_hnsw_index(
247        &self,
248        collection_name: &str,
249        request: CreateHNSWIndexRequest,
250    ) -> Result<()> {
251        let url = self.base_url.join(&format!("collection/{}/index", collection_name))?;
252        let response = self
253            .client
254            .post(url)
255            .header("Content-Type", "application/json")
256            .json(&request)
257            .send()
258            .await?;
259        
260        self.handle_empty_response(response).await
261    }
262
263    /// Delete index from collection
264    pub async fn delete_index(&self, collection_name: &str) -> Result<()> {
265        let url = self.base_url.join(&format!("collection/{}/index", collection_name))?;
266        let response = self.client.delete(url).send().await?;
267        
268        self.handle_empty_response(response).await
269    }
270
271    /// Upload a matrix via gRPC streaming using the configured gRPC address.
272    ///
273    /// - `matrix_name`: name of the matrix to create/overwrite
274    /// - `dimension`: vector dimensionality
275    /// - `vectors`: flat list of all vectors, concatenated row-wise
276    /// - `chunk_floats`: number of f32 values per chunk (must be >= dimension)
277    pub async fn upload_matrix(
278        &self,
279        matrix_name: &str,
280        dimension: usize,
281        vectors: Vec<f32>,
282        chunk_floats: usize,
283    ) -> Result<UploadMatrixResult> {
284        use crate::error::CasperError;
285
286        if dimension == 0 {
287            return Err(CasperError::InvalidResponse(
288                "dimension must be greater than 0".to_string(),
289            ));
290        }
291
292        if vectors.len() % dimension != 0 {
293            return Err(CasperError::InvalidResponse(format!(
294                "vector buffer length {} is not divisible by dimension {}",
295                vectors.len(),
296                dimension
297            )));
298        }
299
300        let chunk_floats = if chunk_floats < dimension {
301            dimension
302        } else {
303            chunk_floats
304        };
305
306        let total_floats = vectors.len();
307        let total_chunks = (total_floats + chunk_floats - 1) / chunk_floats;
308
309        let mut client = MatrixServiceClient::connect(self.grpc_addr.clone())
310            .await
311            .map_err(|e| CasperError::Grpc(e.to_string()))?;
312
313        let (tx, rx) = tokio::sync::mpsc::channel::<UploadMatrixRequest>(4);
314
315        // Spawn producer task to send header + chunks
316        let name = matrix_name.to_string();
317        let vectors_clone = vectors.clone();
318        tokio::spawn(async move {
319            // Header first
320            let max_vectors_per_chunk = (chunk_floats / dimension).max(1) as u32;
321            let header = MatrixHeader {
322                name: name.clone(),
323                dimension: dimension as u32,
324                total_chunks: total_chunks as u32,
325                max_vectors_per_chunk,
326            };
327            let header_msg = UploadMatrixRequest {
328                payload: Some(upload_matrix_request::Payload::Header(header)),
329            };
330            if tx.send(header_msg).await.is_err() {
331                return;
332            }
333
334            // Then data chunks
335            for chunk_idx in 0..total_chunks {
336                let start = chunk_idx * chunk_floats;
337                let end = (start + chunk_floats).min(total_floats);
338                let slice = &vectors_clone[start..end];
339
340                let data = MatrixData {
341                    chunk_index: chunk_idx as u32,
342                    vector: slice.to_vec(),
343                };
344                let msg = UploadMatrixRequest {
345                    payload: Some(upload_matrix_request::Payload::Data(data)),
346                };
347
348                if tx.send(msg).await.is_err() {
349                    break;
350                }
351            }
352        });
353
354        let request = Request::new(ReceiverStream::new(rx));
355        let response = client
356            .upload_matrix(request)
357            .await
358            .map_err(|e| CasperError::Grpc(e.to_string()))?
359            .into_inner();
360
361        Ok(UploadMatrixResult {
362            success: true,
363            message: format!(
364                "Successfully uploaded {} vectors in {} chunks",
365                response.total_vectors, response.total_chunks
366            ),
367            total_vectors: response.total_vectors,
368            total_chunks: response.total_chunks,
369        })
370    }
371
372    /// Delete a matrix by name (HTTP)
373    pub async fn delete_matrix(&self, name: &str) -> Result<()> {
374        let url = self.base_url.join(&format!("matrix/{}", name))?;
375        let response = self
376            .client
377            .delete(url)
378            .header("Content-Type", "application/json")
379            .send()
380            .await?;
381
382        self.handle_empty_response(response).await
383    }
384
385    /// List all matrices (HTTP)
386    pub async fn list_matrices(&self) -> Result<Vec<MatrixInfo>> {
387        let url = self.base_url.join("matrix/list")?;
388        let response = self
389            .client
390            .get(url)
391            .header("Content-Type", "application/json")
392            .send()
393            .await?;
394
395        self.handle_response(response).await
396    }
397
398    /// Get matrix info by name (HTTP)
399    pub async fn get_matrix_info(&self, name: &str) -> Result<MatrixInfo> {
400        let url = self.base_url.join(&format!("matrix/{}", name))?;
401        let response = self
402            .client
403            .get(url)
404            .header("Content-Type", "application/json")
405            .send()
406            .await?;
407
408        self.handle_response(response).await
409    }
410
411    /// Create a PQ entry
412    pub async fn create_pq(
413        &self,
414        name: &str,
415        request: CreatePqRequest,
416    ) -> Result<()> {
417        let url = self.base_url.join(&format!("pq/{}", name))?;
418        let response = self
419            .client
420            .post(url)
421            .header("Content-Type", "application/json")
422            .json(&request)
423            .send()
424            .await?;
425
426        self.handle_empty_response(response).await
427    }
428
429    /// Delete a PQ entry
430    pub async fn delete_pq(&self, name: &str) -> Result<()> {
431        let url = self.base_url.join(&format!("pq/{}", name))?;
432        let response = self
433            .client
434            .delete(url)
435            .header("Content-Type", "application/json")
436            .send()
437            .await?;
438
439        self.handle_empty_response(response).await
440    }
441
442    /// List all PQs
443    pub async fn list_pqs(&self) -> Result<Vec<PqInfo>> {
444        let url = self.base_url.join("pq/list")?;
445        let response = self
446            .client
447            .get(url)
448            .header("Content-Type", "application/json")
449            .send()
450            .await?;
451
452        self.handle_response(response).await
453    }
454
455    /// Get PQ info by name
456    pub async fn get_pq(&self, name: &str) -> Result<PqInfo> {
457        let url = self.base_url.join(&format!("pq/{}", name))?;
458        let response = self
459            .client
460            .get(url)
461            .header("Content-Type", "application/json")
462            .send()
463            .await?;
464
465        self.handle_response(response).await
466    }
467
468    /// Handle JSON response
469    async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
470    where
471        T: serde::de::DeserializeOwned,
472    {
473        let status = response.status();
474        let text = response.text().await?;
475        
476        if status.is_success() {
477            serde_json::from_str(&text).map_err(|e| CasperError::InvalidResponse(format!(
478                "Failed to parse response: {} - {}", e, text
479            )))
480        } else {
481            Err(self.parse_error_response(status.as_u16(), &text))
482        }
483    }
484
485    /// Handle empty response (204 No Content)
486    async fn handle_empty_response(&self, response: reqwest::Response) -> Result<()> {
487        let status = response.status();
488        
489        if status.is_success() {
490            Ok(())
491        } else {
492            let text = response.text().await?;
493            Err(self.parse_error_response(status.as_u16(), &text))
494        }
495    }
496
497
498    /// Parse error response
499    fn parse_error_response(&self, status: u16, text: &str) -> CasperError {
500        // Try to parse as JSON error response
501        if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(text) {
502            if let Some(message) = error_json.get("error").and_then(|v| v.as_str()) {
503                return CasperError::from_status(status, message.to_string());
504            }
505        }
506        
507        // Fallback to status-based error
508        CasperError::from_status(status, text.to_string())
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn test_client_creation() {
518        let client = CasperClient::new("http://localhost", 8080, 50051).unwrap();
519        assert_eq!(client.base_url(), "http://localhost:8080/");
520    }
521}