roder-ext-openai-embeddings 0.1.1

Agentic software development tools and SDKs for Roder.
Documentation
use roder_api::embeddings::{
    EmbeddingModelDescriptor, EmbeddingProvider, EmbeddingProviderDescriptor, EmbeddingRequest,
    EmbeddingResponse, EmbeddingVector,
};
use serde::{Deserialize, Serialize};

pub const DEFAULT_MODEL: &str = "text-embedding-3-large";
pub const DEFAULT_DIMENSIONS: usize = 3072;

#[derive(Clone)]
pub struct OpenAiEmbeddingProvider {
    api_key: Option<String>,
    base_url: String,
}

impl OpenAiEmbeddingProvider {
    pub fn new(api_key: Option<String>) -> Self {
        Self {
            api_key,
            base_url: "https://api.openai.com/v1".to_string(),
        }
    }

    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
        self.base_url = base_url.into();
        self
    }
}

#[async_trait::async_trait]
impl EmbeddingProvider for OpenAiEmbeddingProvider {
    fn descriptor(&self) -> EmbeddingProviderDescriptor {
        EmbeddingProviderDescriptor {
            id: "openai".to_string(),
            name: "OpenAI".to_string(),
            default_model: DEFAULT_MODEL.to_string(),
            models: vec![EmbeddingModelDescriptor {
                id: DEFAULT_MODEL.to_string(),
                dimensions: DEFAULT_DIMENSIONS,
                default: true,
            }],
        }
    }

    async fn embed(&self, request: EmbeddingRequest) -> anyhow::Result<EmbeddingResponse> {
        let api_key = self
            .api_key
            .clone()
            .or_else(|| std::env::var("OPENAI_API_KEY").ok())
            .filter(|key| !key.trim().is_empty())
            .ok_or_else(|| anyhow::anyhow!("OPENAI_API_KEY is required for OpenAI embeddings"))?;
        if request.inputs.is_empty() {
            anyhow::bail!("embedding request must include at least one input");
        }

        let body = OpenAiEmbeddingRequest {
            model: request.model.clone(),
            input: request.inputs,
            dimensions: request.dimensions,
        };
        let response = reqwest::Client::new()
            .post(format!(
                "{}/embeddings",
                self.base_url.trim_end_matches('/')
            ))
            .bearer_auth(api_key)
            .json(&body)
            .send()
            .await?;
        let status = response.status();
        if !status.is_success() {
            let message = response.text().await.unwrap_or_default();
            anyhow::bail!("OpenAI embeddings request failed with {status}: {message}");
        }
        let payload: OpenAiEmbeddingResponse = response.json().await?;
        let embeddings = payload
            .data
            .into_iter()
            .map(|item| EmbeddingVector {
                index: item.index,
                values: item.embedding,
            })
            .collect::<Vec<_>>();
        if embeddings
            .iter()
            .any(|embedding| embedding.values.is_empty())
        {
            anyhow::bail!("OpenAI returned an empty embedding vector");
        }
        Ok(EmbeddingResponse {
            provider_id: "openai".to_string(),
            model: payload.model.unwrap_or(request.model),
            embeddings,
        })
    }
}

#[derive(Debug, Serialize)]
struct OpenAiEmbeddingRequest {
    model: String,
    input: Vec<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    dimensions: Option<usize>,
}

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingResponse {
    #[serde(default)]
    model: Option<String>,
    data: Vec<OpenAiEmbeddingItem>,
}

#[derive(Debug, Deserialize)]
struct OpenAiEmbeddingItem {
    index: usize,
    embedding: Vec<f32>,
}

#[cfg(test)]
mod tests {
    use roder_api::embeddings::EmbeddingInputType;

    use super::*;

    #[test]
    fn registers_openai_embedding_descriptor() {
        let descriptor = OpenAiEmbeddingProvider::new(None).descriptor();
        assert_eq!(descriptor.id, "openai");
        assert_eq!(descriptor.default_model, DEFAULT_MODEL);
        assert_eq!(descriptor.models[0].dimensions, DEFAULT_DIMENSIONS);
    }

    #[tokio::test]
    async fn missing_key_returns_concise_error() {
        let err = OpenAiEmbeddingProvider::new(Some(String::new()))
            .with_base_url("http://127.0.0.1:1")
            .embed(EmbeddingRequest {
                model: DEFAULT_MODEL.to_string(),
                inputs: vec!["hello".to_string()],
                input_type: EmbeddingInputType::Document,
                dimensions: None,
            })
            .await
            .unwrap_err();
        assert!(err.to_string().contains("OPENAI_API_KEY"));
    }
}