1use async_trait::async_trait;
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6use bridge_embex_core::db::VectorDatabase;
7use bridge_embex_core::error::{EmbexError, Result};
8use bridge_embex_core::types::{
9 CollectionSchema, DistanceMetric, MetadataUpdate, Point, SearchResponse, SearchResult,
10 VectorQuery,
11};
12
13const PINECONE_CONTROL_URL: &str = "https://api.pinecone.io";
14const PINECONE_API_VERSION: &str = "2024-10";
15
16pub struct PineconeAdapter {
17 http: Client,
18 api_key: String,
19 namespace: String,
20 cloud: String,
21 region: String,
22}
23
24impl PineconeAdapter {
25 pub fn new(
26 api_key: &str,
27 cloud: Option<&str>,
28 region: Option<&str>,
29 namespace: Option<&str>,
30 ) -> Result<Self> {
31 Self::new_with_pool_size(api_key, cloud, region, namespace, None)
32 }
33
34 pub fn new_with_pool_size(
35 api_key: &str,
36 cloud: Option<&str>,
37 region: Option<&str>,
38 namespace: Option<&str>,
39 pool_size: Option<u32>,
40 ) -> Result<Self> {
41 let builder = Client::builder()
42 .timeout(std::time::Duration::from_secs(30))
43 .pool_max_idle_per_host(pool_size.unwrap_or(10) as usize)
44 .pool_idle_timeout(std::time::Duration::from_secs(90));
45
46 let http = builder
47 .build()
48 .map_err(|e| EmbexError::Connection(format!("Failed to create HTTP client: {}", e)))?;
49
50 Ok(Self {
51 http,
52 api_key: api_key.to_string(),
53 namespace: namespace.unwrap_or("").to_string(),
54 cloud: cloud.unwrap_or("aws").to_string(),
55 region: region.unwrap_or("us-east-1").to_string(),
56 })
57 }
58
59 fn control_headers(&self) -> reqwest::header::HeaderMap {
60 let mut headers = reqwest::header::HeaderMap::new();
61 headers.insert("Api-Key", self.api_key.parse().unwrap());
62 headers.insert(
63 "X-Pinecone-API-Version",
64 PINECONE_API_VERSION.parse().unwrap(),
65 );
66 headers.insert("Content-Type", "application/json".parse().unwrap());
67 headers
68 }
69
70 fn data_headers(&self) -> reqwest::header::HeaderMap {
71 self.control_headers()
72 }
73
74 async fn get_index_host(&self, index_name: &str) -> Result<String> {
75 let url = format!("{}/indexes/{}", PINECONE_CONTROL_URL, index_name);
76
77 let response = self
78 .http
79 .get(&url)
80 .headers(self.control_headers())
81 .send()
82 .await
83 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
84
85 if !response.status().is_success() {
86 let status = response.status();
87 let body = response.text().await.unwrap_or_default();
88 return Err(EmbexError::Database(format!(
89 "Describe index failed ({}): {}",
90 status, body
91 )));
92 }
93
94 let info: DescribeIndexResponse = response
95 .json()
96 .await
97 .map_err(|e| EmbexError::Database(format!("Parse error: {}", e)))?;
98
99 Ok(info.host)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
108 fn test_pinecone_adapter_new() {
109 let adapter = PineconeAdapter::new("test-key", None, None, None);
110 assert!(adapter.is_ok());
111
112 let adapter = adapter.unwrap();
113 assert_eq!(adapter.api_key, "test-key");
114 assert_eq!(adapter.namespace, "");
115 assert_eq!(adapter.cloud, "aws");
116 assert_eq!(adapter.region, "us-east-1");
117 }
118
119 #[test]
120 fn test_pinecone_adapter_new_with_options() {
121 let adapter = PineconeAdapter::new(
122 "test-key",
123 Some("gcp"),
124 Some("us-west1"),
125 Some("my-namespace")
126 );
127 assert!(adapter.is_ok());
128
129 let adapter = adapter.unwrap();
130 assert_eq!(adapter.cloud, "gcp");
131 assert_eq!(adapter.region, "us-west1");
132 assert_eq!(adapter.namespace, "my-namespace");
133 }
134
135 #[test]
136 fn test_control_headers() {
137 let adapter = PineconeAdapter::new("test-key", None, None, None).unwrap();
138 let headers = adapter.control_headers();
139
140 assert!(headers.contains_key("Api-Key"));
141 assert!(headers.contains_key("X-Pinecone-API-Version"));
142 assert!(headers.contains_key("Content-Type"));
143 }
144
145 #[test]
146 fn test_data_headers() {
147 let adapter = PineconeAdapter::new("test-key", None, None, None).unwrap();
148 let headers = adapter.data_headers();
149
150 assert!(headers.contains_key("Api-Key"));
151 assert!(headers.contains_key("X-Pinecone-API-Version"));
152 }
153}
154
155#[derive(Serialize)]
156struct CreateIndexRequest {
157 name: String,
158 dimension: usize,
159 metric: String,
160 spec: IndexSpec,
161}
162
163#[derive(Serialize)]
164struct IndexSpec {
165 serverless: ServerlessSpec,
166}
167
168#[derive(Serialize)]
169struct ServerlessSpec {
170 cloud: String,
171 region: String,
172}
173
174#[derive(Deserialize)]
175struct DescribeIndexResponse {
176 host: String,
177}
178
179#[derive(Serialize)]
180struct UpsertRequest {
181 vectors: Vec<PineconeVector>,
182 namespace: String,
183}
184
185#[derive(Serialize)]
186struct PineconeVector {
187 id: String,
188 values: Vec<f32>,
189 #[serde(skip_serializing_if = "Option::is_none")]
190 metadata: Option<serde_json::Value>,
191}
192
193#[derive(Serialize)]
194struct QueryRequest {
195 namespace: String,
196 vector: Vec<f32>,
197 #[serde(rename = "topK")]
198 top_k: usize,
199 #[serde(rename = "includeValues")]
200 include_values: bool,
201 #[serde(rename = "includeMetadata")]
202 include_metadata: bool,
203 #[serde(skip_serializing_if = "Option::is_none")]
204 filter: Option<serde_json::Value>,
205}
206
207#[derive(Deserialize)]
208struct QueryResponse {
209 matches: Vec<PineconeMatch>,
210}
211
212#[derive(Deserialize)]
213struct PineconeMatch {
214 id: String,
215 score: f32,
216 values: Option<Vec<f32>>,
217 metadata: Option<serde_json::Value>,
218}
219
220#[derive(Serialize)]
221struct UpdateRequest {
222 id: String,
223 #[serde(rename = "setMetadata")]
224 #[serde(skip_serializing_if = "Option::is_none")]
225 set_metadata: Option<serde_json::Value>,
226 namespace: String,
227}
228
229#[derive(Serialize)]
230struct DeleteRequest {
231 ids: Vec<String>,
232 namespace: String,
233}
234
235#[async_trait]
236impl VectorDatabase for PineconeAdapter {
237 #[tracing::instrument(skip(self, schema), fields(collection = %schema.name, dimension = schema.dimension, provider = "pinecone"))]
238 async fn create_collection(&self, schema: &CollectionSchema) -> Result<()> {
239 let metric = match schema.metric {
240 DistanceMetric::Cosine => "cosine",
241 DistanceMetric::Euclidean => "euclidean",
242 DistanceMetric::Dot => "dotproduct",
243 };
244
245 let request = CreateIndexRequest {
246 name: schema.name.clone(),
247 dimension: schema.dimension,
248 metric: metric.to_string(),
249 spec: IndexSpec {
250 serverless: ServerlessSpec {
251 cloud: self.cloud.clone(),
252 region: self.region.clone(),
253 },
254 },
255 };
256
257 let url = format!("{}/indexes", PINECONE_CONTROL_URL);
258
259 let response = self
260 .http
261 .post(&url)
262 .headers(self.control_headers())
263 .json(&request)
264 .send()
265 .await
266 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
267
268 if !response.status().is_success() {
269 let status = response.status();
270 let body = response.text().await.unwrap_or_default();
271 return Err(EmbexError::Database(format!(
272 "Create index failed ({}): {}",
273 status, body
274 )));
275 }
276
277 Ok(())
278 }
279
280 #[tracing::instrument(skip(self), fields(collection = %name, provider = "pinecone"))]
281 async fn delete_collection(&self, name: &str) -> Result<()> {
282 let url = format!("{}/indexes/{}", PINECONE_CONTROL_URL, name);
283
284 let response = self
285 .http
286 .delete(&url)
287 .headers(self.control_headers())
288 .send()
289 .await
290 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
291
292 if !response.status().is_success() {
293 let status = response.status();
294 let body = response.text().await.unwrap_or_default();
295 return Err(EmbexError::Database(format!(
296 "Delete index failed ({}): {}",
297 status, body
298 )));
299 }
300
301 Ok(())
302 }
303
304 #[tracing::instrument(skip(self, points), fields(collection = %collection, count = points.len(), provider = "pinecone"))]
305 async fn insert(&self, collection: &str, points: Vec<Point>) -> Result<()> {
306 let host = self.get_index_host(collection).await?;
307
308 let vectors: Vec<PineconeVector> = points
309 .into_iter()
310 .map(|p| PineconeVector {
311 id: p.id,
312 values: p.vector,
313 metadata: p
314 .metadata
315 .map(|m| serde_json::to_value(m).unwrap_or_default()),
316 })
317 .collect();
318
319 let request = UpsertRequest {
320 vectors,
321 namespace: self.namespace.clone(),
322 };
323
324 let url = format!("https://{}/vectors/upsert", host);
325
326 let response = self
327 .http
328 .post(&url)
329 .headers(self.data_headers())
330 .json(&request)
331 .send()
332 .await
333 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
334
335 if !response.status().is_success() {
336 let status = response.status();
337 let body = response.text().await.unwrap_or_default();
338 return Err(EmbexError::Database(format!(
339 "Upsert failed ({}): {}",
340 status, body
341 )));
342 }
343
344 Ok(())
345 }
346
347 #[tracing::instrument(skip(self, query), fields(collection = %query.collection, top_k = query.top_k, provider = "pinecone"))]
348 async fn search(&self, query: &VectorQuery) -> Result<SearchResponse> {
349 let host = self.get_index_host(&query.collection).await?;
350
351 let vector = query.vector.clone().ok_or_else(|| {
352 EmbexError::Unsupported("Pinecone adapter requires a vector for search queries.".into())
353 })?;
354
355 let request = QueryRequest {
357 namespace: self.namespace.clone(),
358 vector,
359 top_k: query.top_k,
360 include_values: query.include_vector,
361 include_metadata: query.include_metadata,
362 filter: query.filter.as_ref().map(convert_filter),
363 };
364
365 let url = format!("https://{}/query", host);
366
367 let response = self
368 .http
369 .post(&url)
370 .headers(self.data_headers())
371 .json(&request)
372 .send()
373 .await
374 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
375
376 if !response.status().is_success() {
377 let status = response.status();
378 let body = response.text().await.unwrap_or_default();
379 return Err(EmbexError::Database(format!(
380 "Query failed ({}): {}",
381 status, body
382 )));
383 }
384
385 let result: QueryResponse = response
386 .json()
387 .await
388 .map_err(|e| EmbexError::Database(format!("Parse error: {}", e)))?;
389
390 let mut aggregations = HashMap::new();
391 for agg in &query.aggregations {
392 match agg {
393 bridge_embex_core::types::Aggregation::Count => {
394 aggregations.insert(
399 "count".to_string(),
400 serde_json::Value::Number(result.matches.len().into()),
401 );
402 }
403 }
404 }
405
406 Ok(SearchResponse {
407 results: result
408 .matches
409 .into_iter()
410 .map(|m| SearchResult {
411 id: m.id,
412 score: m.score,
413 vector: m.values,
414 metadata: m.metadata.and_then(|v| {
415 serde_json::from_value::<HashMap<String, serde_json::Value>>(v).ok()
416 }),
417 })
418 .collect(),
419 aggregations,
420 })
421 }
422
423 #[tracing::instrument(skip(self), fields(collection = %collection, count = ids.len(), provider = "pinecone"))]
424 async fn delete(&self, collection: &str, ids: Vec<String>) -> Result<()> {
425 let host = self.get_index_host(collection).await?;
426
427 let request = DeleteRequest {
428 ids,
429 namespace: self.namespace.clone(),
430 };
431
432 let url = format!("https://{}/vectors/delete", host);
433
434 let response = self
435 .http
436 .post(&url)
437 .headers(self.data_headers())
438 .json(&request)
439 .send()
440 .await
441 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
442
443 if !response.status().is_success() {
444 let status = response.status();
445 let body = response.text().await.unwrap_or_default();
446 return Err(EmbexError::Database(format!(
447 "Delete failed ({}): {}",
448 status, body
449 )));
450 }
451
452 Ok(())
453 }
454
455 #[tracing::instrument(skip(self, updates), fields(collection = %collection, count = updates.len(), provider = "pinecone"))]
456 async fn update_metadata(&self, collection: &str, updates: Vec<MetadataUpdate>) -> Result<()> {
457 let host = self.get_index_host(collection).await?;
458 let url = format!("https://{}/vectors/update", host);
459
460 for update in updates {
461 let request = UpdateRequest {
462 id: update.id,
463 set_metadata: Some(serde_json::to_value(update.updates).unwrap_or_default()),
464 namespace: self.namespace.clone(),
465 };
466
467 let response = self
468 .http
469 .post(&url)
470 .headers(self.data_headers())
471 .json(&request)
472 .send()
473 .await
474 .map_err(|e| EmbexError::Database(format!("HTTP error: {}", e)))?;
475
476 if !response.status().is_success() {
477 let status = response.status();
478 let body = response.text().await.unwrap_or_default();
479 return Err(EmbexError::Database(format!(
480 "Update metadata failed ({}): {}",
481 status, body
482 )));
483 }
484 }
485
486 Ok(())
487 }
488}
489
490fn convert_filter(filter: &bridge_embex_core::types::Filter) -> serde_json::Value {
491 use bridge_embex_core::types::Filter;
492 use serde_json::json;
493
494 match filter {
495 Filter::Must(filters) => {
496 json!({ "$and": filters.iter().map(convert_filter).collect::<Vec<_>>() })
497 }
498 Filter::MustNot(filters) => {
499 json!({ "$and": filters.iter().map(convert_filter).collect::<Vec<_>>() })
502 }
507 Filter::Should(filters) => {
508 json!({ "$or": filters.iter().map(convert_filter).collect::<Vec<_>>() })
509 }
510 Filter::Key(key, condition) => {
511 json!({ key: convert_condition(condition) })
512 }
513 }
514}
515
516fn convert_condition(condition: &bridge_embex_core::types::Condition) -> serde_json::Value {
517 use bridge_embex_core::types::Condition;
518 use serde_json::json;
519
520 match condition {
521 Condition::Eq(v) => json!({ "$eq": v }),
522 Condition::Ne(v) => json!({ "$ne": v }),
523 Condition::Gt(v) => json!({ "$gt": v }),
524 Condition::Gte(v) => json!({ "$gte": v }),
525 Condition::Lt(v) => json!({ "$lt": v }),
526 Condition::Lte(v) => json!({ "$lte": v }),
527 Condition::In(v) => json!({ "$in": v }),
528 Condition::NotIn(v) => json!({ "$nin": v }),
529 }
530}