use std::fmt::Display;
use super::{SenseError, embedding::EmbeddingBytes};
use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
use doc_for::{DocDyn, doc_impl};
use reqwest::{Client, ClientBuilder, Url, header::HeaderMap};
use serde::{Deserialize, Serialize};
#[doc_impl(
strip = 1,
doc_for = false,
doc_dyn = true,
gen_attr = "serde(rename = {doc})"
)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Model {
BgeLargeZhV1_5,
BgeLargeEnV1_5,
BceEmbeddingBaseV1,
BgeM3,
ProBgeM3,
}
impl Default for Model {
fn default() -> Self {
Self::BgeLargeZhV1_5
}
}
impl Display for Model {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.doc_dyn().unwrap())
}
}
fn validate_api_key(key: &str) -> Result<(), SenseError> {
if key.len() != 51 {
return Err(SenseError::MalformedApiKey);
}
for c in key.chars().skip(3) {
if !c.is_ascii_alphanumeric() {
return Err(SenseError::MalformedApiKey);
}
}
Ok(())
}
#[derive(Serialize)]
struct RequestBody<'a> {
model: &'a str,
input: &'a str,
encoding_format: &'a str,
}
#[derive(Deserialize)]
struct Data {
#[serde(rename = "object")]
_object: String,
embedding: String,
#[serde(rename = "index")]
_index: i32,
}
#[derive(Deserialize)]
#[allow(dead_code, reason = "For deserialization only")]
#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: u32,
}
#[derive(Deserialize)]
struct ResponseBody {
model: String,
data: Vec<Data>,
#[serde(rename = "usage")]
_usage: Usage,
}
#[derive(Clone)]
pub struct ApiClient {
model: String,
endpoint: Url,
client: Client,
}
impl ApiClient {
#[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
validate_api_key(key)?;
let mut headers = HeaderMap::new();
headers.insert("Authorization", format!("Bearer {key}").parse()?);
let client = ClientBuilder::new().default_headers(headers).build()?;
Ok(Self {
model: model.to_string(),
endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
client,
})
}
pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
let request_body = RequestBody {
model: &self.model,
input: text,
encoding_format: "base64",
};
let request = self.client.post(self.endpoint.clone()).json(&request_body);
let response: ResponseBody = request.send().await?.json().await?;
debug_assert_eq!(response.model, self.model);
let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
Ok(embedding.try_into()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";
#[test]
fn test_api_key_ok() {
validate_api_key(KEY).unwrap();
}
#[test]
fn test_api_key_malformed() {
let malformed = &KEY[..KEY.len() - 1];
let err = validate_api_key(malformed).unwrap_err();
assert!(matches!(err, SenseError::MalformedApiKey));
}
#[test]
fn test_model_string() {
let model = Model::BgeLargeZhV1_5;
assert_eq!(model.to_string(), "BAAI/bge-large-zh-v1.5");
}
#[tokio::test]
#[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
async fn test_embed() {
let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
let embedding = client.embed("Hello, world!").await;
let _ = embedding.unwrap();
}
}