#[cfg(feature = "cohere")]
mod inner {
use crate::embedding::Embedder;
use crate::error::{EngramError, Result};
#[derive(Debug, Clone)]
pub struct CohereConfig {
pub api_key: String,
pub model: String,
pub base_url: String,
pub dimensions: usize,
}
impl Default for CohereConfig {
fn default() -> Self {
Self {
api_key: String::new(),
model: "embed-english-v3.0".to_string(),
base_url: "https://api.cohere.ai/v1".to_string(),
dimensions: 1024,
}
}
}
pub struct CohereEmbedder {
config: CohereConfig,
client: reqwest::Client,
}
impl CohereEmbedder {
pub fn new(config: CohereConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
}
}
pub async fn embed_async(&self, text: &str) -> Result<Vec<f32>> {
let url = format!("{}/embed", self.config.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.json(&serde_json::json!({
"texts": [text],
"model": self.config.model,
"input_type": "search_document",
}))
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(EngramError::Embedding(format!(
"Cohere API error {status}: {body}"
)));
}
let data: serde_json::Value = response.json().await?;
let embeddings = data["embeddings"].as_array().ok_or_else(|| {
EngramError::Embedding("Cohere response missing 'embeddings' field".to_string())
})?;
let embedding: Vec<f32> = embeddings
.first()
.and_then(|e| e.as_array())
.ok_or_else(|| {
EngramError::Embedding(
"Cohere response 'embeddings[0]' is missing or not an array".to_string(),
)
})?
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
if embedding.is_empty() {
return Err(EngramError::Embedding(
"Cohere returned an empty embedding vector".to_string(),
));
}
Ok(embedding)
}
pub async fn embed_batch_async(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(vec![]);
}
let url = format!("{}/embed", self.config.base_url);
let response = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.config.api_key))
.json(&serde_json::json!({
"texts": texts,
"model": self.config.model,
"input_type": "search_document",
}))
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(EngramError::Embedding(format!(
"Cohere API error {status}: {body}"
)));
}
let data: serde_json::Value = response.json().await?;
let raw = data["embeddings"].as_array().ok_or_else(|| {
EngramError::Embedding("Cohere response missing 'embeddings' field".to_string())
})?;
let embeddings: Vec<Vec<f32>> = raw
.iter()
.map(|e| {
e.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default()
})
.collect();
Ok(embeddings)
}
}
impl Embedder for CohereEmbedder {
fn embed(&self, text: &str) -> crate::error::Result<Vec<f32>> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(self.embed_async(text))
})
}
fn embed_batch(&self, texts: &[&str]) -> crate::error::Result<Vec<Vec<f32>>> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(self.embed_batch_async(texts))
})
}
fn dimensions(&self) -> usize {
self.config.dimensions
}
fn model_name(&self) -> &str {
&self.config.model
}
}
}
#[cfg(feature = "cohere")]
pub use inner::{CohereConfig, CohereEmbedder};
#[cfg(test)]
mod tests {
struct StubCohereEmbedder {
dimensions: usize,
model: String,
}
impl StubCohereEmbedder {
fn new(dimensions: usize) -> Self {
Self {
dimensions,
model: "embed-english-v3.0".to_string(),
}
}
fn embed_stub(&self, text: &str) -> Vec<f32> {
let mut embedding = vec![0.0_f32; self.dimensions];
for (i, byte) in text.bytes().enumerate() {
embedding[i % self.dimensions] += byte as f32;
embedding[(i * 7 + 13) % self.dimensions] -= byte as f32 * 0.1;
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for x in &mut embedding {
*x /= norm;
}
}
embedding
}
fn embed_batch_stub(&self, texts: &[&str]) -> Vec<Vec<f32>> {
texts.iter().map(|t| self.embed_stub(t)).collect()
}
}
#[test]
fn test_stub_embed_returns_correct_dimensions() {
let embedder = StubCohereEmbedder::new(1024);
let result = embedder.embed_stub("hello world");
assert_eq!(result.len(), 1024, "embedding must have 1024 dimensions");
}
#[test]
fn test_stub_embed_is_deterministic() {
let embedder = StubCohereEmbedder::new(1024);
let e1 = embedder.embed_stub("same text input");
let e2 = embedder.embed_stub("same text input");
assert_eq!(e1, e2, "same input must produce identical vectors");
}
#[test]
fn test_stub_embed_different_inputs_differ() {
let embedder = StubCohereEmbedder::new(1024);
let e1 = embedder.embed_stub("first sentence about AI");
let e2 = embedder.embed_stub("totally unrelated content xyz");
assert_ne!(e1, e2, "different inputs should produce different vectors");
}
#[test]
fn test_stub_embed_batch_length_matches_input() {
let embedder = StubCohereEmbedder::new(1024);
let texts = ["alpha", "beta", "gamma"];
let results = embedder.embed_batch_stub(&texts);
assert_eq!(
results.len(),
3,
"batch result count must match input count"
);
for r in &results {
assert_eq!(r.len(), 1024);
}
}
#[test]
fn test_stub_embed_empty_input_returns_zero_vector() {
let embedder = StubCohereEmbedder::new(1024);
let result = embedder.embed_stub("");
assert_eq!(result.len(), 1024);
assert!(
result.iter().all(|&x| x == 0.0),
"empty input should yield zero vector"
);
}
#[cfg(feature = "cohere")]
#[test]
fn test_cohere_config_defaults() {
use super::inner::CohereConfig;
let cfg = CohereConfig::default();
assert_eq!(cfg.model, "embed-english-v3.0");
assert_eq!(cfg.base_url, "https://api.cohere.ai/v1");
assert_eq!(cfg.dimensions, 1024);
assert!(cfg.api_key.is_empty(), "default api_key must be empty");
}
#[cfg(feature = "cohere")]
#[test]
fn test_cohere_config_custom() {
use super::inner::CohereConfig;
let cfg = CohereConfig {
api_key: "co-test-key".to_string(),
model: "embed-multilingual-v3.0".to_string(),
base_url: "https://api.cohere.ai/v1".to_string(),
dimensions: 1024,
};
assert_eq!(cfg.api_key, "co-test-key");
assert_eq!(cfg.model, "embed-multilingual-v3.0");
}
}