1#![deny(missing_docs)]
4#![warn(clippy::all, clippy::nursery, clippy::pedantic, clippy::cargo)]
5#![allow(clippy::multiple_crate_versions, reason = "Fucking windows.")]
6
7mod error;
8mod json;
9
10use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
11use cyper::{Client, Error as CyperError};
12pub use error::EmbedError;
13use http::{HeaderMap, StatusCode, header::InvalidHeaderValue};
14use instant_distance::Point;
15use json::{RequestBody, ResponseBody};
16use serde::{Deserialize, Serialize};
17use serde_big_array::BigArray;
18use std::ops::Deref;
19
20const API_ENDPOINT: &str = "https://api.siliconflow.cn/v1/embeddings";
22#[derive(Debug, Clone)]
29pub struct ApiClient {
30 client: Client,
32}
33
34#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36#[serde(transparent)]
37pub struct Embedding(#[serde(with = "BigArray")] [f32; 1024]);
38
39impl ApiClient {
40 pub fn new(api_key: &str) -> Result<Self, InvalidHeaderValue> {
46 let mut headers = HeaderMap::new();
47 headers.insert("Authorization", format!("Bearer {api_key}").parse()?);
48 let client = Client::builder().default_headers(headers).build();
49
50 Ok(Self { client })
51 }
52
53 pub async fn embed_text(&self, input: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
63 let body = RequestBody {
64 model: "BAAI/bge-large-zh-v1.5",
65 input,
66 encoding_format: "base64",
67 };
68 let response = self
69 .client
70 .post(API_ENDPOINT)
71 .map_err(EmbedError::RequestPreparation)?
72 .json(&body)
73 .map_err(EmbedError::RequestSerialization)?
74 .send()
75 .await
76 .map_err(EmbedError::RequestSend)?;
77
78 match response.status() {
79 StatusCode::UNAUTHORIZED => {
80 let message = response.text().await.unwrap_or_default();
81 return Err(EmbedError::InvalidApiKey(message));
82 }
83 StatusCode::TOO_MANY_REQUESTS => {
84 let message = response.text().await.unwrap_or_default();
85 return Err(EmbedError::RateLimitExceeded(message));
86 }
87 StatusCode::OK => {}
88 code => {
89 let message = response.text().await.unwrap_or_default();
90 if code == StatusCode::FORBIDDEN && message.starts_with("\"RPM limit exceeded.") {
91 return Err(EmbedError::RateLimitExceeded(message));
93 }
94 return Err(EmbedError::UnknownApiError {
95 code,
96 message,
97 });
98 }
99 }
100
101 let response_body: ResponseBody =
102 response.json().await.map_err(EmbedError::ResponseParse)?;
103
104 let result = response_body
105 .data
106 .into_iter()
107 .map(|data| -> Result<Embedding, EmbedError> {
108 let bytes = DECODER.decode(data.embedding.as_bytes())?;
109 let mut embedding = [0.0; 1024];
110 bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
111 embedding[i] = f32::from_le_bytes(
112 chunk
113 .try_into()
114 .expect("The chunk length should be 4 bytes"),
115 );
116 });
117 Ok(embedding.into())
118 })
119 .collect::<Result<_, _>>()?;
120
121 Ok(result)
122 }
123}
124
125impl From<[f32; 1024]> for Embedding {
126 fn from(value: [f32; 1024]) -> Self {
127 Embedding(value)
128 }
129}
130
131impl Deref for Embedding {
132 type Target = [f32; 1024];
133
134 fn deref(&self) -> &Self::Target {
135 &self.0
136 }
137}
138
139impl Point for Embedding {
140 fn distance(&self, other: &Self) -> f32 {
141 self.0
142 .iter()
143 .zip(other.0.iter())
144 .map(|(a, b)| (a - b).powi(2))
145 .sum::<f32>()
146 .sqrt()
147 }
148}