1use std::pin::Pin;
2use std::sync::Arc;
3
4use serde::{Deserialize, Serialize};
5
6use crate::error::{Error, Result};
7
8use super::backend::EmbeddingBackend;
9use super::config::VoyageConfig;
10use super::convert::to_f32_blob;
11
12struct Inner {
13 client: reqwest::Client,
14 api_key: String,
15 model: String,
16 dimensions: usize,
17}
18
19pub struct VoyageEmbedding(Arc<Inner>);
32
33impl Clone for VoyageEmbedding {
34 fn clone(&self) -> Self {
35 Self(Arc::clone(&self.0))
36 }
37}
38
39impl VoyageEmbedding {
40 pub fn new(client: reqwest::Client, config: &VoyageConfig) -> Result<Self> {
46 config.validate()?;
47 Ok(Self(Arc::new(Inner {
48 client,
49 api_key: config.api_key.clone(),
50 model: config.model.clone(),
51 dimensions: config.dimensions,
52 })))
53 }
54}
55
56impl EmbeddingBackend for VoyageEmbedding {
57 fn embed(&self, input: &str) -> Pin<Box<dyn Future<Output = Result<Vec<u8>>> + Send + '_>> {
58 let input = input.to_owned();
59 Box::pin(async move {
60 const URL: &str = concat!("https://api.voyageai.com", "/v1/embeddings");
61 let body = Request {
62 input: &input,
63 model: &self.0.model,
64 output_dimension: self.0.dimensions,
65 };
66
67 let resp = self
68 .0
69 .client
70 .post(URL)
71 .bearer_auth(&self.0.api_key)
72 .json(&body)
73 .send()
74 .await
75 .map_err(|e| Error::internal("voyage embeddings request failed").chain(e))?;
76
77 if !resp.status().is_success() {
78 let status = resp.status();
79 let text = resp.text().await.unwrap_or_default();
80 return Err(Error::internal(format!(
81 "voyage embedding error: {status}: {text}"
82 )));
83 }
84
85 let parsed: Response = resp.json().await.map_err(|e| {
86 Error::internal("failed to parse voyage embedding response").chain(e)
87 })?;
88
89 let values = parsed
90 .data
91 .into_iter()
92 .next()
93 .ok_or_else(|| Error::internal("voyage returned empty embedding data"))?
94 .embedding;
95
96 Ok(to_f32_blob(&values))
97 })
98 }
99
100 fn dimensions(&self) -> usize {
101 self.0.dimensions
102 }
103
104 fn model_name(&self) -> &str {
105 &self.0.model
106 }
107}
108
109#[derive(Serialize)]
110struct Request<'a> {
111 input: &'a str,
112 model: &'a str,
113 output_dimension: usize,
114}
115
116#[derive(Deserialize)]
117struct Response {
118 data: Vec<EmbeddingData>,
119}
120
121#[derive(Deserialize)]
122struct EmbeddingData {
123 embedding: Vec<f32>,
124}