1use crate::{
9 errors::{AgentDBError, Result},
10 queries::{MetadataQuery, VectorQuery},
11};
12use reqwest::{Client, StatusCode};
13use serde::{de::DeserializeOwned, Deserialize, Serialize};
14use std::time::Duration;
15use tracing::{debug, error, info};
16
17pub struct AgentDBClient {
18 client: Client,
19 base_url: String,
20 api_key: Option<String>,
21}
22
23impl AgentDBClient {
24 pub fn new(base_url: String) -> Self {
25 let client = Client::builder()
26 .timeout(Duration::from_secs(30))
27 .pool_max_idle_per_host(10) .build()
29 .expect("Failed to build HTTP client");
30
31 Self {
32 client,
33 base_url,
34 api_key: None,
35 }
36 }
37
38 pub fn with_api_key(mut self, api_key: String) -> Self {
39 self.api_key = Some(api_key);
40 self
41 }
42
43 pub fn with_timeout(mut self, timeout: Duration) -> Self {
44 self.client = Client::builder()
45 .timeout(timeout)
46 .pool_max_idle_per_host(10)
47 .build()
48 .expect("Failed to build HTTP client");
49 self
50 }
51
52 pub async fn insert<T: Serialize>(
54 &self,
55 collection: &str,
56 id: &[u8],
57 embedding: &[f32],
58 metadata: Option<&T>,
59 ) -> Result<InsertResponse> {
60 let url = format!("{}/collections/{}/insert", self.base_url, collection);
61
62 let body = InsertRequest {
63 id: hex::encode(id),
64 embedding: embedding.to_vec(),
65 metadata: metadata.map(|m| serde_json::to_value(m).unwrap()),
66 };
67
68 let response = self
69 .client
70 .post(&url)
71 .json(&body)
72 .send()
73 .await
74 .map_err(|e| AgentDBError::Network(e.to_string()))?;
75
76 self.handle_response(response).await
77 }
78
79 pub async fn batch_insert<T: Serialize>(
81 &self,
82 collection: &str,
83 documents: Vec<BatchDocument<T>>,
84 ) -> Result<BatchInsertResponse> {
85 let url = format!("{}/collections/{}/batch_insert", self.base_url, collection);
86
87 let body = BatchInsertRequest {
88 documents: documents
89 .into_iter()
90 .map(|doc| InsertRequest {
91 id: hex::encode(&doc.id),
92 embedding: doc.embedding,
93 metadata: doc.metadata.map(|m| serde_json::to_value(m).unwrap()),
94 })
95 .collect(),
96 };
97
98 debug!("Batch inserting {} documents", body.documents.len());
99
100 let response = self
101 .client
102 .post(&url)
103 .json(&body)
104 .send()
105 .await
106 .map_err(|e| AgentDBError::Network(e.to_string()))?;
107
108 self.handle_response(response).await
109 }
110
111 pub async fn vector_search<T: DeserializeOwned>(&self, query: VectorQuery) -> Result<Vec<T>> {
113 let url = format!("{}/collections/{}/search", self.base_url, query.collection);
114
115 let response = self
116 .client
117 .post(&url)
118 .json(&query)
119 .send()
120 .await
121 .map_err(|e| AgentDBError::Network(e.to_string()))?;
122
123 let search_result: SearchResponse<T> = self.handle_response(response).await?;
124 Ok(search_result
125 .results
126 .into_iter()
127 .map(|r| r.document)
128 .collect())
129 }
130
131 pub async fn metadata_search<T: DeserializeOwned>(
133 &self,
134 query: MetadataQuery,
135 ) -> Result<Vec<T>> {
136 let url = format!("{}/collections/{}/query", self.base_url, query.collection);
137
138 let response = self
139 .client
140 .post(&url)
141 .json(&query)
142 .send()
143 .await
144 .map_err(|e| AgentDBError::Network(e.to_string()))?;
145
146 let query_result: QueryResponse<T> = self.handle_response(response).await?;
147 Ok(query_result.documents)
148 }
149
150 pub async fn get<T: DeserializeOwned>(&self, collection: &str, id: &[u8]) -> Result<Option<T>> {
152 let url = format!(
153 "{}/collections/{}/documents/{}",
154 self.base_url,
155 collection,
156 hex::encode(id)
157 );
158
159 let response = self
160 .client
161 .get(&url)
162 .send()
163 .await
164 .map_err(|e| AgentDBError::Network(e.to_string()))?;
165
166 match response.status() {
167 StatusCode::OK => {
168 let doc = self.handle_response(response).await?;
169 Ok(Some(doc))
170 }
171 StatusCode::NOT_FOUND => Ok(None),
172 _ => Err(AgentDBError::Network("Failed to get document".to_string())),
173 }
174 }
175
176 pub async fn delete(&self, collection: &str, id: &[u8]) -> Result<()> {
178 let url = format!(
179 "{}/collections/{}/documents/{}",
180 self.base_url,
181 collection,
182 hex::encode(id)
183 );
184
185 let response = self
186 .client
187 .delete(&url)
188 .send()
189 .await
190 .map_err(|e| AgentDBError::Network(e.to_string()))?;
191
192 match response.status() {
193 StatusCode::OK | StatusCode::NO_CONTENT => Ok(()),
194 StatusCode::NOT_FOUND => Err(AgentDBError::NotFound("Document not found".to_string())),
195 _ => Err(AgentDBError::Network(
196 "Failed to delete document".to_string(),
197 )),
198 }
199 }
200
201 pub async fn create_collection(&self, config: CollectionConfig) -> Result<()> {
203 let url = format!("{}/collections", self.base_url);
204
205 let response = self
206 .client
207 .post(&url)
208 .json(&config)
209 .send()
210 .await
211 .map_err(|e| AgentDBError::Network(e.to_string()))?;
212
213 match response.status() {
214 StatusCode::OK | StatusCode::CREATED => {
215 info!("Collection '{}' created successfully", config.name);
216 Ok(())
217 }
218 StatusCode::CONFLICT => {
219 info!("Collection '{}' already exists", config.name);
220 Ok(())
221 }
222 _ => Err(AgentDBError::Network(
223 "Failed to create collection".to_string(),
224 )),
225 }
226 }
227
228 pub async fn health_check(&self) -> Result<HealthResponse> {
230 let url = format!("{}/health", self.base_url);
231
232 let response = self
233 .client
234 .get(&url)
235 .send()
236 .await
237 .map_err(|e| AgentDBError::Network(e.to_string()))?;
238
239 self.handle_response(response).await
240 }
241
242 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
243 match response.status() {
244 StatusCode::OK | StatusCode::CREATED => response
245 .json()
246 .await
247 .map_err(|e| AgentDBError::Serialization(e.to_string())),
248 StatusCode::BAD_REQUEST => {
249 let error_text = response.text().await.unwrap_or_default();
250 Err(AgentDBError::InvalidQuery(error_text))
251 }
252 StatusCode::NOT_FOUND => {
253 let error_text = response.text().await.unwrap_or_default();
254 Err(AgentDBError::NotFound(error_text))
255 }
256 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
257 Err(AgentDBError::Auth("Authentication failed".to_string()))
258 }
259 _ => {
260 let status = response.status();
261 let error_text = response.text().await.unwrap_or_default();
262 error!("HTTP {}: {}", status, error_text);
263 Err(AgentDBError::Network(format!(
264 "HTTP {}: {}",
265 status, error_text
266 )))
267 }
268 }
269 }
270}
271
272#[derive(Debug, Serialize)]
275struct InsertRequest {
276 id: String,
277 embedding: Vec<f32>,
278
279 #[serde(skip_serializing_if = "Option::is_none")]
280 metadata: Option<serde_json::Value>,
281}
282
283#[derive(Debug, Deserialize)]
284pub struct InsertResponse {
285 pub id: String,
286 pub success: bool,
287}
288
289#[derive(Debug, Serialize)]
290struct BatchInsertRequest {
291 documents: Vec<InsertRequest>,
292}
293
294#[derive(Debug, Deserialize)]
295pub struct BatchInsertResponse {
296 pub inserted: usize,
297 pub failed: usize,
298}
299
300pub struct BatchDocument<T> {
301 pub id: Vec<u8>,
302 pub embedding: Vec<f32>,
303 pub metadata: Option<T>,
304}
305
306#[derive(Debug, Deserialize)]
307struct SearchResponse<T> {
308 results: Vec<SearchResult<T>>,
309}
310
311#[derive(Debug, Deserialize)]
312#[allow(dead_code)]
313struct SearchResult<T> {
314 document: T,
315 score: f32,
316}
317
318#[derive(Debug, Deserialize)]
319struct QueryResponse<T> {
320 documents: Vec<T>,
321}
322
323#[derive(Debug, Serialize)]
324pub struct CollectionConfig {
325 pub name: String,
326 pub dimension: usize,
327 pub index_type: String,
328
329 #[serde(skip_serializing_if = "Option::is_none")]
330 pub metadata_schema: Option<serde_json::Value>,
331}
332
333#[derive(Debug, Deserialize)]
334pub struct HealthResponse {
335 pub status: String,
336 pub version: String,
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342
343 #[test]
344 fn test_client_creation() {
345 let client = AgentDBClient::new("http://localhost:8080".to_string());
346 assert_eq!(client.base_url, "http://localhost:8080");
347 }
348
349 #[test]
350 fn test_client_with_api_key() {
351 let client = AgentDBClient::new("http://localhost:8080".to_string())
352 .with_api_key("test_key".to_string());
353
354 assert_eq!(client.api_key, Some("test_key".to_string()));
355 }
356
357 #[tokio::test]
358 async fn test_hex_encoding() {
359 let id = vec![0x01, 0x02, 0x03, 0x04];
360 let hex = hex::encode(&id);
361 assert_eq!(hex, "01020304");
362 }
363}