semantic-search 0.1.7

🔎 Semantic search library.
Documentation
//! # Silicon Flow module
//!
//! This module contains logic for the Silicon Flow API.

use std::fmt::Display;

use super::{SenseError, embedding::EmbeddingBytes};
use base64::{Engine as _, engine::general_purpose::STANDARD as DECODER};
use doc_for::{DocDyn, doc_impl};
use reqwest::{Client, ClientBuilder, Url, header::HeaderMap};
use serde::{Deserialize, Serialize};

// == API key validation and model definitions ==

/// Available models.
#[doc_impl(
    strip = 1,
    doc_for = false,
    doc_dyn = true,
    gen_attr = "serde(rename = {doc})"
)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
pub enum Model {
    /// BAAI/bge-large-zh-v1.5
    BgeLargeZhV1_5,
    /// BAAI/bge-large-en-v1.5
    BgeLargeEnV1_5,
    /// netease-youdao/bce-embedding-base_v1
    BceEmbeddingBaseV1,
    /// BAAI/bge-m3
    BgeM3,
    /// Pro/BAAI/bge-m3
    ProBgeM3,
}

impl Default for Model {
    fn default() -> Self {
        Self::BgeLargeZhV1_5
    }
}

impl Display for Model {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.doc_dyn().unwrap())
    }
}

/// Validate that the API key is well-formed.
fn validate_api_key(key: &str) -> Result<(), SenseError> {
    if key.len() != 51 {
        return Err(SenseError::MalformedApiKey);
    }
    for c in key.chars().skip(3) {
        if !c.is_ascii_alphanumeric() {
            return Err(SenseError::MalformedApiKey);
        }
    }
    Ok(())
}

// == Request and response definitions ==

/// The request body for the Silicon Flow API.
#[derive(Serialize)]
struct RequestBody<'a> {
    /// The model to use.
    model: &'a str,
    /// The input text.
    input: &'a str,
    /// The encoding format, either "float" or "base64".
    encoding_format: &'a str,
}

/// ResponseBody.data: The list of embeddings generated by the model.
#[derive(Deserialize)]
struct Data {
    /// Fixed string "embedding".
    #[serde(rename = "object")]
    _object: String,
    /// Base64-encoded embedding.
    embedding: String,
    /// Unused.
    #[serde(rename = "index")]
    _index: i32,
}

/// ResponseBody.usage: The usage information for the request.
#[derive(Deserialize)]
#[allow(dead_code, reason = "For deserialization only")]
#[allow(clippy::struct_field_names, reason = "Consistency with API response")]
struct Usage {
    /// The number of tokens used by the prompt.
    prompt_tokens: u32,
    /// The number of tokens used by the completion.
    completion_tokens: u32,
    /// The total number of tokens used by the request.
    total_tokens: u32,
}

/// The response body for the Silicon Flow API.
#[derive(Deserialize)]
struct ResponseBody {
    /// The name of the model used to generate the embedding.
    model: String,
    /// The list of embeddings generated by the model.
    data: Vec<Data>,
    /// The usage information for the request.
    #[serde(rename = "usage")]
    _usage: Usage,
}

// == API client ==

/// A client for the Silicon Flow API.
#[derive(Clone)]
pub struct ApiClient {
    /// The model to use.
    model: String,
    /// API endpoint.
    endpoint: Url,
    /// HTTP client.
    client: Client,
}

impl ApiClient {
    /// Create a new API client.
    ///
    /// # Errors
    ///
    /// Returns an error if the API key is malformed or the HTTP client cannot be created.
    #[allow(clippy::missing_panics_doc, reason = "URL is hardcoded")]
    pub fn new(key: &str, model: Model) -> Result<Self, SenseError> {
        validate_api_key(key)?;
        let mut headers = HeaderMap::new();
        headers.insert("Authorization", format!("Bearer {key}").parse()?);
        let client = ClientBuilder::new().default_headers(headers).build()?;

        Ok(Self {
            model: model.to_string(),
            endpoint: Url::parse("https://api.siliconflow.cn/v1/embeddings").unwrap(),
            client,
        })
    }

    /// Embed a text.
    ///
    /// # Errors
    ///
    /// Returns:
    ///
    /// - [`SenseError::RequestFailed`] if the request fails
    /// - [`SenseError::Base64DecodingFailed`] if base64 decoding fails
    /// - [`SenseError::DimensionMismatch`] if the embedding is not 1024-dimensional.
    pub async fn embed(&self, text: &str) -> Result<EmbeddingBytes, SenseError> {
        let request_body = RequestBody {
            model: &self.model,
            input: text,
            encoding_format: "base64",
        };
        let request = self.client.post(self.endpoint.clone()).json(&request_body);

        let response: ResponseBody = request.send().await?.json().await?;
        debug_assert_eq!(response.model, self.model);

        let embedding = DECODER.decode(response.data[0].embedding.as_bytes())?;
        Ok(embedding.try_into()?)
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    const KEY: &str = "sk-1234567890abcdef1234567890abcdef1234567890abcdef";

    #[test]
    fn test_api_key_ok() {
        validate_api_key(KEY).unwrap();
    }

    #[test]
    fn test_api_key_malformed() {
        let malformed = &KEY[..KEY.len() - 1];
        let err = validate_api_key(malformed).unwrap_err();
        assert!(matches!(err, SenseError::MalformedApiKey));
    }

    #[test]
    fn test_model_string() {
        let model = Model::BgeLargeZhV1_5;
        assert_eq!(model.to_string(), "BAAI/bge-large-zh-v1.5");
    }

    #[tokio::test]
    #[ignore = "requires API key in `SILICONFLOW_API_KEY` env var"]
    async fn test_embed() {
        // Read the API key from the environment
        let key = std::env::var("SILICONFLOW_API_KEY").unwrap();
        let client = ApiClient::new(&key, Model::BgeLargeZhV1_5).unwrap();
        let embedding = client.embed("Hello, world!").await;
        let _ = embedding.unwrap();
    }
}