use std::{collections::HashMap, fs};
use base64::Engine;
use reqwest::Client;
use serde::Deserialize;
use serde_json::json;
use crate::embeddings::embed::{EmbedData, EmbeddingResult};
use crate::embeddings::local::colpali::get_images_from_pdf;
#[derive(Deserialize, Debug, Default)]
pub struct FloatResponse {
pub float: Vec<Vec<f32>>,
}
#[derive(Deserialize, Debug, Default)]
pub struct CohereEmbedResponse {
pub id: String,
pub embeddings: FloatResponse,
}
#[derive(Debug)]
pub struct CohereEmbedder {
url: String,
model: String,
api_key: String,
client: Client,
}
impl Default for CohereEmbedder {
fn default() -> Self {
Self::new("embed-english-v3.0".to_string(), None)
}
}
impl CohereEmbedder {
pub fn new(model: String, api_key: Option<String>) -> Self {
let api_key =
api_key.unwrap_or_else(|| std::env::var("CO_API_KEY").expect("API key not set"));
Self {
model,
url: "https://api.cohere.com/v2/embed".to_string(),
api_key,
client: Client::new(),
}
}
fn load_image<T: AsRef<std::path::Path>>(&self, path: T) -> Result<String, anyhow::Error> {
let img = image::ImageReader::open(path)?.decode()?;
let img = img.to_rgb8();
let mut buffer = Vec::new();
img.write_to(
&mut std::io::Cursor::new(&mut buffer),
image::ImageFormat::Png,
)?;
let engine = base64::engine::general_purpose::STANDARD;
let img = engine.encode(buffer);
Ok(format!("data:image/png;base64,{}", img))
}
fn load_image_batch<T: AsRef<std::path::Path>>(
&self,
paths: &[T],
) -> Result<Vec<String>, anyhow::Error> {
paths.iter().map(|path| self.load_image(path)).collect()
}
pub async fn embed(&self, text_batch: &[&str]) -> Result<Vec<EmbeddingResult>, anyhow::Error> {
let response = self
.client
.post(&self.url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"texts": text_batch,
"model": self.model,
"input_type": "search_document",
"embedding_types": ["float"]
}))
.send()
.await?;
let data = match response.error_for_status() {
Ok(resp) => resp.json::<CohereEmbedResponse>().await?,
Err(e) => {
println!("❌ API Error: {}", e);
return Err(anyhow::anyhow!("API request failed: {}", e));
}
};
let encodings = data.embeddings;
let encodings = encodings
.float
.iter()
.map(|embedding| EmbeddingResult::DenseVector(embedding.clone()))
.collect::<Vec<_>>();
Ok(encodings)
}
pub async fn embed_image(
&self,
image_path: impl AsRef<std::path::Path>,
metadata: Option<HashMap<String, String>>,
) -> Result<EmbedData, anyhow::Error> {
let img = self.load_image(image_path)?;
let response = self
.client
.post(&self.url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"images": vec![img],
"model": self.model,
"input_type": "image"
}))
.send()
.await?;
let data = match response.error_for_status() {
Ok(resp) => resp.json::<CohereEmbedResponse>().await?,
Err(e) => {
println!("❌ API Error: {}", e);
return Err(anyhow::anyhow!("API request failed: {}", e));
}
};
let encodings = data.embeddings;
let embedding = encodings
.float
.iter()
.map(|embedding| EmbeddingResult::DenseVector(embedding.clone()))
.collect::<Vec<_>>();
Ok(EmbedData::new(embedding[0].clone(), None, metadata))
}
pub async fn embed_image_batch(
&self,
image_paths: &[impl AsRef<std::path::Path>],
batch_size: Option<usize>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
let mut embeddings = Vec::new();
for image_path in image_paths.chunks(batch_size.unwrap_or(32)) {
let imgs = self.load_image_batch(image_path)?;
let response = self
.client
.post(&self.url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"images": imgs,
"model": self.model,
"input_type": "image"
}))
.send()
.await?;
let data = match response.error_for_status() {
Ok(resp) => resp.json::<CohereEmbedResponse>().await?,
Err(e) => {
println!("❌ API Error: {}", e);
return Err(anyhow::anyhow!("API request failed: {}", e));
}
};
let encodings = data.embeddings;
let embedding = encodings.float.iter().cloned();
embeddings.extend(embedding);
}
let embeddings = embeddings
.iter()
.zip(image_paths)
.map(|(data, path)| {
let mut metadata = HashMap::new();
metadata.insert(
"file_name".to_string(),
fs::canonicalize(path)
.unwrap()
.to_str()
.unwrap()
.to_string(),
);
EmbedData::new(
EmbeddingResult::DenseVector(data.clone()),
Some(path.as_ref().to_str().unwrap().to_string()),
Some(metadata),
)
})
.collect::<Vec<_>>();
Ok(embeddings)
}
pub async fn embed_pdf(
&self,
file_path: impl AsRef<std::path::Path>,
batch_size: Option<usize>,
) -> Result<Vec<EmbedData>, anyhow::Error> {
let pages = get_images_from_pdf(&file_path)?;
let mut embed_data = Vec::new();
let batch_size = batch_size.unwrap_or(8);
let pages_base64 = pages
.iter()
.map(|page| -> Result<String, anyhow::Error> {
let img = page.to_rgb8();
let mut buffer = Vec::new();
img.write_to(
&mut std::io::Cursor::new(&mut buffer),
image::ImageFormat::Png,
)?;
let engine = base64::engine::general_purpose::STANDARD;
let img = engine.encode(buffer);
Ok(format!("data:image/png;base64,{}", img))
})
.collect::<Result<Vec<_>, anyhow::Error>>()?;
for (index, batch) in pages_base64.chunks(batch_size).enumerate() {
let start_page = index * batch_size + 1;
let end_page = start_page + batch.len();
let page_numbers = (start_page..=end_page).collect::<Vec<_>>();
let response = self
.client
.post(&self.url)
.header("Accept", "application/json")
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", self.api_key))
.json(&json!({
"images": batch,
"model": self.model,
"input_type": "image",
"embedding_types": ["float"]
}))
.send()
.await?;
let data = match response.error_for_status() {
Ok(resp) => resp.json::<CohereEmbedResponse>().await?,
Err(e) => {
println!("❌ API Error: {}", e);
return Err(anyhow::anyhow!("API request failed: {}", e));
}
};
let encodings = data.embeddings;
let image_embeddings = encodings
.float
.iter()
.map(|embedding| EmbeddingResult::DenseVector(embedding.clone()));
let embed_data_batch = image_embeddings
.zip(page_numbers.into_iter())
.zip(batch.iter())
.map(|((embedding, page_number), page_image)| {
let mut metadata = HashMap::new();
metadata.insert("page_number".to_string(), page_number.to_string());
metadata.insert(
"file_path".to_string(),
file_path.as_ref().to_str().unwrap_or("").to_string(),
);
metadata.insert("image".to_string(), page_image.clone());
EmbedData::new(embedding, None, Some(metadata))
});
embed_data.extend(embed_data_batch);
}
Ok(embed_data)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cohere_embed() {
let cohere = CohereEmbedder::default();
let text_batch = vec![
"Once upon a time",
"The quick brown fox jumps over the lazy dog",
];
let embeddings = cohere.embed(&text_batch).await.unwrap();
assert_eq!(embeddings.len(), 2);
}
#[tokio::test]
async fn test_cohere_embed_pdf() {
let cohere = CohereEmbedder::new("embed-v4.0".to_string(), None);
let file_path = "../test_files/colpali.pdf";
let embeddings = cohere.embed_pdf(file_path, None).await.unwrap();
assert_eq!(embeddings.len(), 26);
}
}