seedframe 0.1.2

A clean, macro driven Rust library for building LLM apps
Documentation
use crate::embeddings::{model::EmbeddingModel, EmbedderError};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tracing::{debug, error, info, instrument};

const DEFAULT_API_KEY_VAR_NAME: &str = "OPENAI_EMBEDDING_API_KEY";
const DEFAULT_URL: &str = "https://api.openai.com/v1/embeddings";
const DEFAULT_MODEL: &str = "text-embedding-3-small";

#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
struct ModelConfig {
    api_key: Option<String>,
    api_url: Option<String>,
    model: Option<String>,
}

pub struct OpenAIEmbedding {
    api_key: String,
    api_url: String,
    model: String,
    client: Client,
}

impl OpenAIEmbedding {
    #[instrument]
    #[must_use]
    pub fn new(json_config: Option<&str>) -> Self {
        let (api_key_var, api_url, model) = if let Some(json) = json_config {
            let config = match serde_json::from_str::<ModelConfig>(json) {
                Ok(config) => config,
                Err(e) => {
                    let e = format!("Failed to deserialize json config: {e}");
                    error!(e);
                    panic!("{e}");
                }
            };
            (
                config
                    .api_key
                    .unwrap_or(DEFAULT_API_KEY_VAR_NAME.to_string()),
                config.api_url.unwrap_or(DEFAULT_URL.to_string()),
                config.model.unwrap_or(DEFAULT_MODEL.to_string()),
            )
        } else {
            (
                DEFAULT_API_KEY_VAR_NAME.to_string(),
                DEFAULT_URL.to_string(),
                DEFAULT_MODEL.to_string(),
            )
        };
        let api_key = match std::env::var(&api_key_var) {
            Ok(key) => key,
            Err(e) => {
                let e = format!("Failed to fetch env var `{api_key_var}`!, {e}");
                error!(e);
                panic!("{e}");
            }
        };
        Self {
            api_key,
            api_url,
            client: reqwest::Client::new(),
            model,
        }
    }
}

#[derive(Deserialize)]
struct OpenAIEmbeddingResponse {
    pub data: Vec<OpenAIEmbeddingData>,
}

#[derive(Deserialize)]
struct OpenAIEmbeddingData {
    pub embedding: Vec<f64>,
}

#[async_trait]
impl EmbeddingModel for OpenAIEmbedding {
    #[instrument(
        skip(self, data),
        fields(
            model = self.model,
            api_url = self.api_url,
            input_length = data.len()
        )
    )]
    async fn embed(&self, data: &str) -> Result<Vec<f64>, EmbedderError> {
        info!("Preparing embedding request");
        let request_body = json!({
                "input": data,
                "model": self.model,
        });
        debug!(
            request_body.input_length = data.len(),
            "Sending embedding request"
        );
        let response = self
            .client
            .post(&self.api_url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await
            .map_err(|e| {
                error!(error = ?e, "Embedding request failed");
                EmbedderError::RequestError(e.to_string())
            })?;

        let status = response.status();
        debug!(%status, "Received embedding response");

        if status.is_success() {
            let response = response
                .json::<OpenAIEmbeddingResponse>()
                .await
                .map_err(|e| {
                    error!(error = ?e, "Failed to parse embedding response");
                    EmbedderError::ParseError(e.to_string())
                })?;

            let embeddings: Vec<f64> = response
                .data
                .into_iter()
                .flat_map(|d| d.embedding)
                .collect();
            info!(
                embedding_length = embeddings.len(),
                "Successfully generated embeddings"
            );
            Ok(embeddings)
        } else {
            let error_message = response
                .text()
                .await
                .unwrap_or_else(|_| "Unknown error".to_string());
            error!(
                status = %status,
                error = %error_message,
                "Embedding API returned error"
            );

            Err(EmbedderError::ProviderError(error_message))
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    #[ignore]
    async fn simple_openai_embed_request() {
        let openai_embedding_model = OpenAIEmbedding::new(None);

        let response = openai_embedding_model.embed("test").await;
        assert!(response.is_ok());
    }
}