1use std::fmt::Display;
6
7use super::{SenseError, embedding::EmbeddingBytes};
8use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
9use doc_for::{DocDyn, doc_impl};
10use reqwest::{Client, ClientBuilder, Url, header::HeaderMap};
11use serde::{Deserialize, Serialize};
12
13#[doc_impl(
17 strip = 1,
18 doc_for = false,
19 doc_dyn = true,
20 gen_attr = "serde(rename = {doc})"
21)]
22#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Model {
24 BgeLargeZhV1_5,
26 BgeLargeEnV1_5,
28 BceEmbeddingBaseV1,
30 BgeM3,
32 ProBgeM3,
34}
35
36impl Default for Model {
37 fn default() -> Self {
38 Self::BgeLargeZhV1_5
39 }
40}
41
42impl Display for Model {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 write!(f, "{}", self.doc_dyn().unwrap())
45 }
46}
47
48fn validate_api_key(key: &str) -> Result<(), SenseError> {
50 if key.len() != 51 {
51 return Err(SenseError::MalformedApiKey);
52 }
53 for c in key.chars().skip(3) {
54 if !c.is_ascii_alphanumeric() {
55 return Err(SenseError::MalformedApiKey);
56 }
57 }
58 Ok(())
59}
60
61#[derive(Serialize)]
65struct RequestBody<'a> {
66 model: &'a str,
68 input: &'a str,
70 encoding_format: &'a str,
72}
73
74#[derive(Deserialize)]
76struct Data {
77 #[serde(rename = "object")]
79 _object: String,
80 embedding: String,
82 #[serde(rename = "index")]
84 _index: i32,
85}
86
87#[derive(Deserialize)]
89#[allow(dead_code, reason = "For deserialization only")]
90#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
91struct Usage {
92 prompt_tokens: u32,
94 completion_tokens: u32,
96 total_tokens: u32,
98}
99
100#[derive(Deserialize)]
102struct ResponseBody {
103 model: String,
105 data: Vec<Data>,
107 #[serde(rename = "usage")]
109 _usage: Usage,
110}
111
112#[derive(Clone)]
116pub struct ApiClient {
117 model: String,
119 endpoint: Url,
121 client: Client,
123}
124
125impl ApiClient {
126 #[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
132 pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
133 validate_api_key(key)?;
134 let mut headers = HeaderMap::new();
135 headers.insert("Authorization", format!("Bearer {key}").parse()?);
136 let client = ClientBuilder::new().default_headers(headers).build()?;
137
138 Ok(Self {
139 model: model.to_string(),
140 endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
141 client,
142 })
143 }
144
145 pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
155 let request_body = RequestBody {
156 model: &self.model,
157 input: text,
158 encoding_format: "base64",
159 };
160 let request = self.client.post(self.endpoint.clone()).json(&request_body);
161
162 let response: ResponseBody = request.send().await?.json().await?;
163 debug_assert_eq!(response.model, self.model);
164
165 let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
166 Ok(embedding.try_into()?)
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173
174 const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";
175
176 #[test]
177 fn test_api_key_ok() {
178 validate_api_key(KEY).unwrap();
179 }
180
181 #[test]
182 fn test_api_key_malformed() {
183 let malformed = &KEY[..KEY.len() - 1];
184 let err = validate_api_key(malformed).unwrap_err();
185 assert!(matches!(err, SenseError::MalformedApiKey));
186 }
187
188 #[test]
189 fn test_model_string() {
190 let model = Model::BgeLargeZhV1_5;
191 assert_eq!(model.to_string(), "BAAI/bge-large-zh-v1.5");
192 }
193
194 #[tokio::test]
195 #[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
196 async fn test_embed() {
197 let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
199 let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
200 let embedding = client.embed("Hello, world!").await;
201 let _ = embedding.unwrap();
202 }
203}