use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use crate::embeddings::embed::EmbeddingResult;
#[derive(Deserialize, Debug, Default)]
pub struct CohereEmbedResponse {
pub embeddings: Vec<Vec<f32>>,
}
#[derive(Debug)]
pub struct CohereEmbedder {
url: String,
model: String,
api_key: String,
client: Client,
}
impl Default for CohereEmbedder {
fn default() -> Self {
Self::new("embed-english-v3.0".to_string(), None)
}
}
impl CohereEmbedder {
pub fn new(model: String, api_key: Option<String>) -> Self {
let api_key =
api_key.unwrap_or_else(|| std::env::var("CO_API_KEY").expect("API key not set"));
Self {
model,
url: "https://api.cohere.com/v1/embed".to_string(),
api_key,
client: Client::new(),
}
}
pub async fn embed(
&self,
text_batch: &[String],
) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
let response = self
.client
.post(&self.url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"texts": text_batch,
"model": self.model,
"input_type": "search_document"
}))
.send()
.await?;
let data = response.json::<CohereEmbedResponse>().await?;
let encodings = data.embeddings;
let encodings = encodings
.iter()
.map(|embedding| EmbeddingResult::DenseVector(embedding.clone()))
.collect::<Vec<_>>();
Ok(encodings)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cohere_embed() {
let cohere = CohereEmbedder::default();
let text_batch = vec![
"Once upon a time".to_string(),
"The quick brown fox jumps over the lazy dog".to_string(),
];
let embeddings = cohere.embed(&text_batch).await.unwrap();
assert_eq!(embeddings.len(), 2);
}
}