use async_trait::async_trait;
use pinecone_sdk::{
models::{Kind, Metadata, Namespace, QueryResponse, Value, Vector},
pinecone::{data::Index, PineconeClientConfig},
utils::errors::PineconeError,
};
use serde::{de::Error, Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use seedframe::embeddings::embedding::Embedding;
use seedframe::vector_store::{VectorStore, VectorStoreError};
use tokio::sync::Mutex;
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct Config {
api_key_var: Option<String>,
index_host: String,
source_tag: Option<String>,
namespace: Option<String>,
}
pub struct PineconeVectorStore {
index: Mutex<Index>,
namespace: Namespace,
}
const PINECONE_API_VERSION: &str = "2025-01";
impl PineconeVectorStore {
pub async fn new(config_json: Option<&str>) -> Result<Self, VectorStoreError> {
assert!(
config_json.is_some(),
"{:?}",
serde_json::Error::custom("A config json with the required `index_host` expected!")
);
let json_config: Config = serde_json::from_str(config_json.unwrap()).unwrap();
let mut api_key = None;
if let Some(var_name) = json_config.api_key_var {
api_key = std::env::var(var_name).ok();
};
let config = PineconeClientConfig {
api_key,
control_plane_host: None,
additional_headers: Some(HashMap::from([(
"X-Pinecone-API-Version".to_string(),
PINECONE_API_VERSION.to_string(),
)])),
source_tag: json_config.source_tag,
};
let client = config.client().expect("Failed to create pinecone instance");
let index = Mutex::new(
client
.index(&json_config.index_host)
.await
.map_err(into_vec_store_error)?,
);
let name = json_config.namespace.unwrap_or_default();
let namespace = Namespace { name };
Ok(Self { index, namespace })
}
}
#[async_trait]
impl VectorStore for PineconeVectorStore {
async fn get_by_id(&self, id: String) -> Result<Embedding, VectorStoreError> {
let mut index_guard = self.index.lock().await;
let resp = index_guard
.query_by_id(&id, 1, &self.namespace, None, Some(true), Some(true))
.await
.map_err(into_vec_store_error)?;
Ok(Embeddings::try_from(resp)?
.0
.first()
.ok_or(VectorStoreError::EmbeddingNotFound)?
.clone())
}
async fn store(&self, embedding: Embedding) -> Result<(), VectorStoreError> {
let mut index_guard = self.index.lock().await;
if embedding.raw_data.is_empty() {
() = index_guard
.delete_by_id(&[&embedding.id], &self.namespace)
.await
.map_err(into_vec_store_error)?;
} else {
_ = index_guard
.upsert(&[vector_from_embedding(embedding)], &self.namespace)
.await
.map_err(into_vec_store_error)?;
}
Ok(())
}
#[allow(clippy::cast_possible_truncation)]
async fn top_n(&self, query: &[f64], n: usize) -> Result<Vec<Embedding>, VectorStoreError> {
let mut index_guard = self.index.lock().await;
let resp = index_guard
.query_by_value(
query.iter().map(|&v| v as f32).collect::<Vec<f32>>(),
None,
n as u32,
&self.namespace,
None,
Some(true),
Some(true),
)
.await
.map_err(into_vec_store_error)?;
Ok(Embeddings::try_from(resp)?.0)
}
}
fn value_from_str(value: String) -> Value {
let kind = Some(Kind::StringValue(value));
Value { kind }
}
#[allow(clippy::cast_possible_truncation)]
fn vector_from_embedding(embedding: Embedding) -> Vector {
let id = embedding.id;
let values = embedding.embedded_data.iter().map(|&v| v as f32).collect();
let mut fields = BTreeMap::new();
fields.insert("text".to_string(), value_from_str(embedding.raw_data));
let metadata = Some(Metadata { fields });
Vector {
id,
values,
sparse_values: None,
metadata,
}
}
struct Embeddings(Vec<Embedding>);
impl TryFrom<QueryResponse> for Embeddings {
type Error = VectorStoreError;
fn try_from(value: QueryResponse) -> Result<Self, Self::Error> {
let mut embeddings: Vec<Embedding> = vec![];
for m in value.matches {
let mut raw_data = String::new();
let metadata = m.metadata.ok_or(VectorStoreError::Provider(
"Query response without raw data".to_string(),
))?;
metadata
.fields
.iter()
.for_each(|(k, v)| raw_data.push_str(&format!("\"{k}\":\"{v:?}\"")));
let embedded_data = m.values.iter().map(|&v| f64::from(v)).collect();
embeddings.push(Embedding {
id: m.id,
embedded_data,
raw_data,
});
}
Ok(Embeddings(embeddings))
}
}
#[allow(clippy::needless_pass_by_value)]
fn into_vec_store_error(e: PineconeError) -> VectorStoreError {
VectorStoreError::Provider(e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
#[ignore]
async fn test_new_pinecone_vec_store() {
let host = std::env::var("PINECONE_IDX_HOST").unwrap();
let config = format!(r#"{{"index_host": "{}"}}"#, host);
let pcvs = PineconeVectorStore::new(Some(&config)).await;
assert!(pcvs.is_ok());
let resp = pcvs
.unwrap()
.index
.lock()
.await
.describe_index_stats(None)
.await;
assert!(resp.is_ok());
}
#[tokio::test]
#[ignore]
async fn test_pinecone_get_by_id() {
let host = std::env::var("PINECONE_IDX_HOST").unwrap();
let config = format!(r#"{{"index_host": "{}"}}"#, host);
let pcvs = PineconeVectorStore::new(Some(&config)).await;
assert!(pcvs.is_ok());
let resp = pcvs.unwrap().get_by_id("1".to_string()).await;
assert!(resp.is_ok());
}
}