hyw_embed/
lib.rs

1//! Helper for creating embeddings using the Silicon Flow API.
2
3#![deny(missing_docs)]
4#![warn(clippy::all, clippy::nursery, clippy::pedantic, clippy::cargo)]
5#![allow(clippy::multiple_crate_versions, reason = "Fucking windows.")]
6
7mod error;
8mod json;
9
10use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
11use cyper::{Client, Error as CyperError};
12pub use error::EmbedError;
13use http::{HeaderMap, StatusCode, header::InvalidHeaderValue};
14use instant_distance::Point;
15use json::{RequestBody, ResponseBody};
16use serde::{Deserialize, Serialize};
17use serde_big_array::BigArray;
18use std::ops::Deref;
19
20// const API_ENDPOINT: &str = "https://api.siliconflow.com/v1/embeddings";
21const API_ENDPOINT: &str = "https://api.siliconflow.cn/v1/embeddings";
22// Useful documentations:
23// - API for embedding: https://docs.siliconflow.cn/cn/api-reference/embeddings/create-embeddings
24// - Model description & rate limits: https://cloud.siliconflow.cn/open/models?target=BAAI%2Fbge-large-zh-v1.5&types=embedding
25// - Rate limit metrics: https://docs.siliconflow.cn/cn/userguide/rate-limits/rate-limit-and-upgradation#1-3-rate-limits-%E6%8C%87%E6%A0%87
26
27/// A client for the Silicon Flow API.
28#[derive(Debug, Clone)]
29pub struct ApiClient {
30    /// HTTP client.
31    client: Client,
32}
33
34/// The embedding type. Just a wrapper around a fixed-size array of 1024 f32 values.
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36#[serde(transparent)]
37pub struct Embedding(#[serde(with = "BigArray")] [f32; 1024]);
38
39impl ApiClient {
40    /// Create a new API client.
41    ///
42    /// # Errors
43    ///
44    /// Returns a [`InvalidHeaderValue`] if the API key does not make a valid header value.
45    pub fn new(api_key: &str) -> Result<Self, InvalidHeaderValue> {
46        let mut headers = HeaderMap::new();
47        headers.insert("Authorization", format!("Bearer {api_key}").parse()?);
48        let client = Client::builder().default_headers(headers).build();
49
50        Ok(Self { client })
51    }
52
53    /// Embed text using the Silicon Flow API.
54    ///
55    /// # Errors
56    ///
57    /// Returns an [`EmbedError`] if the request fails or the response cannot be parsed.
58    ///
59    /// # Panics
60    ///
61    /// This function should not panic under normal circumstances. If it does, then `chunks_exact` does not properly return a chunk of 4 bytes.
62    pub async fn embed_text(&self, input: &[&str]) -> Result<Vec<Embedding>, EmbedError> {
63        let body = RequestBody {
64            model: "BAAI/bge-large-zh-v1.5",
65            input,
66            encoding_format: "base64",
67        };
68        let response = self
69            .client
70            .post(API_ENDPOINT)
71            .map_err(EmbedError::RequestPreparation)?
72            .json(&body)
73            .map_err(EmbedError::RequestSerialization)?
74            .send()
75            .await
76            .map_err(EmbedError::RequestSend)?;
77
78        match response.status() {
79            StatusCode::UNAUTHORIZED => {
80                let message = response.text().await.unwrap_or_default();
81                return Err(EmbedError::InvalidApiKey(message));
82            }
83            StatusCode::TOO_MANY_REQUESTS => {
84                let message = response.text().await.unwrap_or_default();
85                return Err(EmbedError::RateLimitExceeded(message));
86            }
87            StatusCode::OK => {}
88            code => {
89                let message = response.text().await.unwrap_or_default();
90                if code == StatusCode::FORBIDDEN && message.starts_with("\"RPM limit exceeded.") {
91                    // They actually return 403 Forbidden for rate limit exceeded if you've not completed identity verification, instead of 429 Too Many Requests.
92                    return Err(EmbedError::RateLimitExceeded(message));
93                }
94                return Err(EmbedError::UnknownApiError {
95                    code,
96                    message,
97                });
98            }
99        }
100
101        let response_body: ResponseBody =
102            response.json().await.map_err(EmbedError::ResponseParse)?;
103
104        let result = response_body
105            .data
106            .into_iter()
107            .map(|data| -> Result<Embedding, EmbedError> {
108                let bytes = DECODER.decode(data.embedding.as_bytes())?;
109                let mut embedding = [0.0; 1024];
110                bytes.chunks_exact(4).enumerate().for_each(|(i, chunk)| {
111                    embedding[i] = f32::from_le_bytes(
112                        chunk
113                            .try_into()
114                            .expect("The chunk length should be 4 bytes"),
115                    );
116                });
117                Ok(embedding.into())
118            })
119            .collect::<Result<_, _>>()?;
120
121        Ok(result)
122    }
123}
124
125impl From<[f32; 1024]> for Embedding {
126    fn from(value: [f32; 1024]) -> Self {
127        Embedding(value)
128    }
129}
130
131impl Deref for Embedding {
132    type Target = [f32; 1024];
133
134    fn deref(&self) -> &Self::Target {
135        &self.0
136    }
137}
138
139impl Point for Embedding {
140    fn distance(&self, other: &Self) -> f32 {
141        self.0
142            .iter()
143            .zip(other.0.iter())
144            .map(|(a, b)| (a - b).powi(2))
145            .sum::<f32>()
146            .sqrt()
147    }
148}