semantic_search/
api.rs

1//! # Silicon Flow module
2//!
3//! This module contains logic for the Silicon Flow API.
4
5use std::fmt::Display;
6
7use super::{SenseError, embedding::EmbeddingBytes};
8use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
9use doc_for::{DocDyn, doc_impl};
10use reqwest::{Client, ClientBuilder, Url, header::HeaderMap};
11use serde::{Deserialize, Serialize};
12
13// == API key validation and model definitions ==
14
15/// Available models.
16#[doc_impl(
17    strip = 1,
18    doc_for = false,
19    doc_dyn = true,
20    gen_attr = "serde(rename = {doc})"
21)]
22#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
23pub enum Model {
24    /// BAAI/bge-large-zh-v1.5
25    BgeLargeZhV1_5,
26    /// BAAI/bge-large-en-v1.5
27    BgeLargeEnV1_5,
28    /// netease-youdao/bce-embedding-base_v1
29    BceEmbeddingBaseV1,
30    /// BAAI/bge-m3
31    BgeM3,
32    /// Pro/BAAI/bge-m3
33    ProBgeM3,
34}
35
36impl Default for Model {
37    fn default() -> Self {
38        Self::BgeLargeZhV1_5
39    }
40}
41
42impl Display for Model {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        write!(f, "{}", self.doc_dyn().unwrap())
45    }
46}
47
48/// Validate that the API key is well-formed.
49fn validate_api_key(key: &str) -> Result<(), SenseError> {
50    if key.len() != 51 {
51        return Err(SenseError::MalformedApiKey);
52    }
53    for c in key.chars().skip(3) {
54        if !c.is_ascii_alphanumeric() {
55            return Err(SenseError::MalformedApiKey);
56        }
57    }
58    Ok(())
59}
60
61// == Request and response definitions ==
62
63/// The request body for the Silicon Flow API.
64#[derive(Serialize)]
65struct RequestBody<'a> {
66    /// The model to use.
67    model: &'a str,
68    /// The input text.
69    input: &'a str,
70    /// The encoding format, either "float" or "base64".
71    encoding_format: &'a str,
72}
73
74/// ResponseBody.data: The list of embeddings generated by the model.
75#[derive(Deserialize)]
76struct Data {
77    /// Fixed string "embedding".
78    #[serde(rename = "object")]
79    _object: String,
80    /// Base64-encoded embedding.
81    embedding: String,
82    /// Unused.
83    #[serde(rename = "index")]
84    _index: i32,
85}
86
87/// ResponseBody.usage: The usage information for the request.
88#[derive(Deserialize)]
89#[allow(dead_code, reason = "For deserialization only")]
90#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
91struct Usage {
92    /// The number of tokens used by the prompt.
93    prompt_tokens: u32,
94    /// The number of tokens used by the completion.
95    completion_tokens: u32,
96    /// The total number of tokens used by the request.
97    total_tokens: u32,
98}
99
100/// The response body for the Silicon Flow API.
101#[derive(Deserialize)]
102struct ResponseBody {
103    /// The name of the model used to generate the embedding.
104    model: String,
105    /// The list of embeddings generated by the model.
106    data: Vec<Data>,
107    /// The usage information for the request.
108    #[serde(rename = "usage")]
109    _usage: Usage,
110}
111
112// == API client ==
113
114/// A client for the Silicon Flow API.
115#[derive(Clone)]
116pub struct ApiClient {
117    /// The model to use.
118    model: String,
119    /// API endpoint.
120    endpoint: Url,
121    /// HTTP client.
122    client: Client,
123}
124
125impl ApiClient {
126    /// Create a new API client.
127    ///
128    /// # Errors
129    ///
130    /// Returns an error if the API key is malformed or the HTTP client cannot be created.
131    #[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
132    pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
133        validate_api_key(key)?;
134        let mut headers = HeaderMap::new();
135        headers.insert("Authorization", format!("Bearer {key}").parse()?);
136        let client = ClientBuilder::new().default_headers(headers).build()?;
137
138        Ok(Self {
139            model: model.to_string(),
140            endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
141            client,
142        })
143    }
144
145    /// Embed a text.
146    ///
147    /// # Errors
148    ///
149    /// Returns:
150    ///
151    /// - [`SenseError::RequestFailed`] if the request fails
152    /// - [`SenseError::Base64DecodingFailed`] if base64 decoding fails
153    /// - [`SenseError::DimensionMismatch`] if the embedding is not 1024-dimensional.
154    pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
155        let request_body = RequestBody {
156            model: &self.model,
157            input: text,
158            encoding_format: "base64",
159        };
160        let request = self.client.post(self.endpoint.clone()).json(&request_body);
161
162        let response: ResponseBody = request.send().await?.json().await?;
163        debug_assert_eq!(response.model, self.model);
164
165        let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
166        Ok(embedding.try_into()?)
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";
175
176    #[test]
177    fn test_api_key_ok() {
178        validate_api_key(KEY).unwrap();
179    }
180
181    #[test]
182    fn test_api_key_malformed() {
183        let malformed = &KEY[..KEY.len() - 1];
184        let err = validate_api_key(malformed).unwrap_err();
185        assert!(matches!(err, SenseError::MalformedApiKey));
186    }
187
188    #[test]
189    fn test_model_string() {
190        let model = Model::BgeLargeZhV1_5;
191        assert_eq!(model.to_string(), "BAAI/bge-large-zh-v1.5");
192    }
193
194    #[tokio::test]
195    #[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
196    async fn test_embed() {
197        // Read the API key from the environment
198        let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
199        let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
200        let embedding = client.embed("Hello, world!").await;
201        let _ = embedding.unwrap();
202    }
203}