1use crate::error::{CasperError, Result};
2use crate::models::*;
3use crate::grpc::service::matrix_service::{
4 matrix_service_client::MatrixServiceClient,
5 upload_matrix_request, MatrixData, MatrixHeader, UploadMatrixRequest,
6};
7use reqwest::Client;
8use std::time::Duration;
9use tokio_stream::wrappers::ReceiverStream;
10use tonic::Request;
11use url::Url;
12
13#[derive(Debug, Clone)]
15pub struct CasperClient {
16 client: Client,
17 base_url: Url,
18 grpc_addr: String,
19}
20
21impl CasperClient {
22 pub fn new(host: &str, http_port: u16, grpc_port: u16) -> Result<Self> {
28 let base_url_str = format!("{}:{}", host, http_port);
29 let base_url = Url::parse(&base_url_str)?;
30 let client = Client::builder()
31 .timeout(Duration::from_secs(30))
32 .build()?;
33
34 let grpc_addr = format!("{}:{}", host, grpc_port);
35
36 Ok(Self { client, base_url, grpc_addr })
37 }
38
39 pub fn with_timeout(host: &str, http_port: u16, grpc_port: u16, timeout: Duration) -> Result<Self> {
45 let base_url_str = format!("{}:{}", host, http_port);
46 let base_url = Url::parse(&base_url_str)?;
47 let client = Client::builder()
48 .timeout(timeout)
49 .build()?;
50
51 let grpc_addr = format!("{}:{}", host, grpc_port);
52
53 Ok(Self { client, base_url, grpc_addr })
54 }
55
56 pub fn base_url(&self) -> &str {
58 self.base_url.as_str()
59 }
60
61 pub fn grpc_addr(&self) -> &str {
63 &self.grpc_addr
64 }
65
66 pub async fn list_collections(&self) -> Result<CollectionsListResponse> {
68 let url = self.base_url.join("collections")?;
69 let response = self.client.get(url).send().await?;
70
71 self.handle_response(response).await
72 }
73
74 pub async fn get_collection(&self, collection_name: &str) -> Result<CollectionInfo> {
76 let url = self.base_url.join(&format!("collection/{}", collection_name))?;
77 let response = self.client.get(url).send().await?;
78
79 self.handle_response(response).await
80 }
81
82 pub async fn create_collection(
84 &self,
85 collection_name: &str,
86 request: CreateCollectionRequest,
87 ) -> Result<()> {
88 let url = self.base_url.join(&format!("collection/{}", collection_name))?;
89 let response = self
90 .client
91 .post(url)
92 .query(&request)
93 .header("Content-Type", "application/json")
94 .send()
95 .await?;
96
97 self.handle_empty_response(response).await
98 }
99
100 pub async fn delete_collection(&self, collection_name: &str) -> Result<()> {
102 let url = self.base_url.join(&format!("collection/{}", collection_name))?;
103 let response = self.client.delete(url).send().await?;
104
105 self.handle_empty_response(response).await
106 }
107
108 pub async fn insert_vector(
110 &self,
111 collection_name: &str,
112 request: InsertRequest,
113 ) -> Result<()> {
114 let url = self.base_url.join(&format!("collection/{}/insert", collection_name))?;
115 let response = self
116 .client
117 .post(url)
118 .query(&[("id", request.id.to_string())])
119 .header("Content-Type", "application/json")
120 .json(&InsertVectorBody { vector: request.vector })
121 .send()
122 .await?;
123
124 self.handle_empty_response(response).await
125 }
126
127 pub async fn delete_vector(
129 &self,
130 collection_name: &str,
131 request: DeleteRequest,
132 ) -> Result<()> {
133 let url = self.base_url.join(&format!("collection/{}/delete", collection_name))?;
134 let response = self
135 .client
136 .delete(url)
137 .query(&[("id", request.id.to_string())])
138 .header("Content-Type", "application/json")
139 .send()
140 .await?;
141
142 self.handle_empty_response(response).await
143 }
144
145 pub async fn search(
147 &self,
148 collection_name: &str,
149 limit: usize,
150 request: SearchRequest,
151 ) -> Result<SearchResponse> {
152 let url = self.base_url.join(&format!("collection/{}/search", collection_name))?;
153 let response = self
154 .client
155 .post(url)
156 .query(&[
157 ("limit", limit.to_string()),
158 ("output", "bin".to_string()),
159 ])
160 .header("Content-Type", "application/json")
161 .json(&SearchVectorBody { vector: request.vector })
162 .send()
163 .await?;
164
165 let status = response.status();
166 if !status.is_success() {
167 let text = response.text().await?;
168 return Err(self.parse_error_response(status.as_u16(), &text));
169 }
170
171 let bytes = response.bytes().await?;
172 let buf = bytes.as_ref();
173
174 if buf.len() < 4 {
177 return Err(CasperError::InvalidResponse(
178 "binary search response too short (missing count)".to_string(),
179 ));
180 }
181
182 let mut offset = 0;
183 let mut count_bytes = [0u8; 4];
184 count_bytes.copy_from_slice(&buf[offset..offset + 4]);
185 let count = u32::from_le_bytes(count_bytes) as usize;
186 offset += 4;
187
188 let expected_len = 4 + count * (4 + 4);
189 if buf.len() < expected_len {
190 return Err(CasperError::InvalidResponse(format!(
191 "binary search response truncated: expected at least {} bytes, got {}",
192 expected_len,
193 buf.len()
194 )));
195 }
196
197 let mut results = Vec::with_capacity(count);
198 for _ in 0..count {
199 let mut id_bytes = [0u8; 4];
200 id_bytes.copy_from_slice(&buf[offset..offset + 4]);
201 let id = u32::from_le_bytes(id_bytes);
202 offset += 4;
203
204 let mut score_bytes = [0u8; 4];
205 score_bytes.copy_from_slice(&buf[offset..offset + 4]);
206 let score = f32::from_le_bytes(score_bytes);
207 offset += 4;
208
209 results.push(SearchResult { id, score });
210 }
211
212 Ok(results)
213 }
214
215 pub async fn get_vector(&self, collection_name: &str, id: u32) -> Result<Option<Vec<f32>>> {
217 let url = self.base_url.join(&format!("collection/{}/vector/{}", collection_name, id))?;
218 let response = self.client.get(url).send().await?;
219
220 if response.status() == 404 {
221 return Ok(None);
222 }
223
224 let vector_response: GetVectorResponse = self.handle_response(response).await?;
225 Ok(Some(vector_response.vector))
226 }
227
228 pub async fn batch_update(
230 &self,
231 collection_name: &str,
232 request: BatchUpdateRequest,
233 ) -> Result<()> {
234 let url = self.base_url.join(&format!("collection/{}/update", collection_name))?;
235 let response = self
236 .client
237 .post(url)
238 .header("Content-Type", "application/json")
239 .json(&request)
240 .send()
241 .await?;
242
243 self.handle_empty_response(response).await
244 }
245
246 pub async fn create_hnsw_index(
247 &self,
248 collection_name: &str,
249 request: CreateHNSWIndexRequest,
250 ) -> Result<()> {
251 let url = self.base_url.join(&format!("collection/{}/index", collection_name))?;
252 let response = self
253 .client
254 .post(url)
255 .header("Content-Type", "application/json")
256 .json(&request)
257 .send()
258 .await?;
259
260 self.handle_empty_response(response).await
261 }
262
263 pub async fn delete_index(&self, collection_name: &str) -> Result<()> {
265 let url = self.base_url.join(&format!("collection/{}/index", collection_name))?;
266 let response = self.client.delete(url).send().await?;
267
268 self.handle_empty_response(response).await
269 }
270
271 pub async fn upload_matrix(
278 &self,
279 matrix_name: &str,
280 dimension: usize,
281 vectors: Vec<f32>,
282 chunk_floats: usize,
283 ) -> Result<UploadMatrixResult> {
284 use crate::error::CasperError;
285
286 if dimension == 0 {
287 return Err(CasperError::InvalidResponse(
288 "dimension must be greater than 0".to_string(),
289 ));
290 }
291
292 if vectors.len() % dimension != 0 {
293 return Err(CasperError::InvalidResponse(format!(
294 "vector buffer length {} is not divisible by dimension {}",
295 vectors.len(),
296 dimension
297 )));
298 }
299
300 let chunk_floats = if chunk_floats < dimension {
301 dimension
302 } else {
303 chunk_floats
304 };
305
306 let total_floats = vectors.len();
307 let total_chunks = (total_floats + chunk_floats - 1) / chunk_floats;
308
309 let mut client = MatrixServiceClient::connect(self.grpc_addr.clone())
310 .await
311 .map_err(|e| CasperError::Grpc(e.to_string()))?;
312
313 let (tx, rx) = tokio::sync::mpsc::channel::<UploadMatrixRequest>(4);
314
315 let name = matrix_name.to_string();
317 let vectors_clone = vectors.clone();
318 tokio::spawn(async move {
319 let max_vectors_per_chunk = (chunk_floats / dimension).max(1) as u32;
321 let header = MatrixHeader {
322 name: name.clone(),
323 dimension: dimension as u32,
324 total_chunks: total_chunks as u32,
325 max_vectors_per_chunk,
326 };
327 let header_msg = UploadMatrixRequest {
328 payload: Some(upload_matrix_request::Payload::Header(header)),
329 };
330 if tx.send(header_msg).await.is_err() {
331 return;
332 }
333
334 for chunk_idx in 0..total_chunks {
336 let start = chunk_idx * chunk_floats;
337 let end = (start + chunk_floats).min(total_floats);
338 let slice = &vectors_clone[start..end];
339
340 let data = MatrixData {
341 chunk_index: chunk_idx as u32,
342 vector: slice.to_vec(),
343 };
344 let msg = UploadMatrixRequest {
345 payload: Some(upload_matrix_request::Payload::Data(data)),
346 };
347
348 if tx.send(msg).await.is_err() {
349 break;
350 }
351 }
352 });
353
354 let request = Request::new(ReceiverStream::new(rx));
355 let response = client
356 .upload_matrix(request)
357 .await
358 .map_err(|e| CasperError::Grpc(e.to_string()))?
359 .into_inner();
360
361 Ok(UploadMatrixResult {
362 success: true,
363 message: format!(
364 "Successfully uploaded {} vectors in {} chunks",
365 response.total_vectors, response.total_chunks
366 ),
367 total_vectors: response.total_vectors,
368 total_chunks: response.total_chunks,
369 })
370 }
371
372 pub async fn delete_matrix(&self, name: &str) -> Result<()> {
374 let url = self.base_url.join(&format!("matrix/{}", name))?;
375 let response = self
376 .client
377 .delete(url)
378 .header("Content-Type", "application/json")
379 .send()
380 .await?;
381
382 self.handle_empty_response(response).await
383 }
384
385 pub async fn list_matrices(&self) -> Result<Vec<MatrixInfo>> {
387 let url = self.base_url.join("matrix/list")?;
388 let response = self
389 .client
390 .get(url)
391 .header("Content-Type", "application/json")
392 .send()
393 .await?;
394
395 self.handle_response(response).await
396 }
397
398 pub async fn get_matrix_info(&self, name: &str) -> Result<MatrixInfo> {
400 let url = self.base_url.join(&format!("matrix/{}", name))?;
401 let response = self
402 .client
403 .get(url)
404 .header("Content-Type", "application/json")
405 .send()
406 .await?;
407
408 self.handle_response(response).await
409 }
410
411 pub async fn create_pq(
413 &self,
414 name: &str,
415 request: CreatePqRequest,
416 ) -> Result<()> {
417 let url = self.base_url.join(&format!("pq/{}", name))?;
418 let response = self
419 .client
420 .post(url)
421 .header("Content-Type", "application/json")
422 .json(&request)
423 .send()
424 .await?;
425
426 self.handle_empty_response(response).await
427 }
428
429 pub async fn delete_pq(&self, name: &str) -> Result<()> {
431 let url = self.base_url.join(&format!("pq/{}", name))?;
432 let response = self
433 .client
434 .delete(url)
435 .header("Content-Type", "application/json")
436 .send()
437 .await?;
438
439 self.handle_empty_response(response).await
440 }
441
442 pub async fn list_pqs(&self) -> Result<Vec<PqInfo>> {
444 let url = self.base_url.join("pq/list")?;
445 let response = self
446 .client
447 .get(url)
448 .header("Content-Type", "application/json")
449 .send()
450 .await?;
451
452 self.handle_response(response).await
453 }
454
455 pub async fn get_pq(&self, name: &str) -> Result<PqInfo> {
457 let url = self.base_url.join(&format!("pq/{}", name))?;
458 let response = self
459 .client
460 .get(url)
461 .header("Content-Type", "application/json")
462 .send()
463 .await?;
464
465 self.handle_response(response).await
466 }
467
468 async fn handle_response<T>(&self, response: reqwest::Response) -> Result<T>
470 where
471 T: serde::de::DeserializeOwned,
472 {
473 let status = response.status();
474 let text = response.text().await?;
475
476 if status.is_success() {
477 serde_json::from_str(&text).map_err(|e| CasperError::InvalidResponse(format!(
478 "Failed to parse response: {} - {}", e, text
479 )))
480 } else {
481 Err(self.parse_error_response(status.as_u16(), &text))
482 }
483 }
484
485 async fn handle_empty_response(&self, response: reqwest::Response) -> Result<()> {
487 let status = response.status();
488
489 if status.is_success() {
490 Ok(())
491 } else {
492 let text = response.text().await?;
493 Err(self.parse_error_response(status.as_u16(), &text))
494 }
495 }
496
497
498 fn parse_error_response(&self, status: u16, text: &str) -> CasperError {
500 if let Ok(error_json) = serde_json::from_str::<serde_json::Value>(text) {
502 if let Some(message) = error_json.get("error").and_then(|v| v.as_str()) {
503 return CasperError::from_status(status, message.to_string());
504 }
505 }
506
507 CasperError::from_status(status, text.to_string())
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515
516 #[test]
517 fn test_client_creation() {
518 let client = CasperClient::new("http://localhost", 8080, 50051).unwrap();
519 assert_eq!(client.base_url(), "http://localhost:8080/");
520 }
521}