abu-provider 0.2.0

API integration supporting multiple vendors.
Documentation
#[allow(unused)]
pub(crate) mod dto;
use dto::*;
use crate::{ProvideError, ProvideResult};
use abu_base::{
    chat::{ChatRequest, ChatResponse}, 
    embed::{EmbedRequest, EmbedResponse},
};
use reqwest::Client;
use super::{ChatProvide, EmbedProvide};

const OPENAI_BASE_URL: &str = "https://api.openai.com/v1";

#[derive(Clone)]
pub struct OpenAi {
    pub(crate) client: Client,
    pub(crate) base_url: String,
    pub(crate) api_key: String,
}

impl OpenAi {
    pub fn new(api_key: impl Into<String>) -> Self {
        Self {
            client: Client::new(),
            base_url: OPENAI_BASE_URL.to_string(),
            api_key: api_key.into()
        }
    }

    pub(crate) fn new_with_base_url(api_key: impl Into<String>, base_url: impl Into<String>) -> Self {
        Self {
            client: Client::new(),
            base_url: base_url.into(),
            api_key: api_key.into()
        }
    }

    pub fn from_env() -> ProvideResult<Self> {
        let base_url = std::env::var("OPENAI_BASE_URL").unwrap_or_else(|_| OPENAI_BASE_URL.to_string());
        let api_key = std::env::var("OPENAI_API_KEY")?;
        Ok(Self { client: Client::new(), base_url, api_key })
    }

    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
        self.base_url = base_url.into();
        self
    }

    async fn send_request<Req, Res>(&self, endpoint: &str, body: &Req) -> ProvideResult<Res>
    where
        Req: serde::Serialize,
        Res: serde::de::DeserializeOwned,
    {
        let url = format!("{}/{}", self.base_url.trim_end_matches('/'), endpoint);

        let resp = self.client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(body)
            .send()
            .await
            .map_err(|e| ProvideError::Network(e.to_string()))?;

        if !resp.status().is_success() {
            let status = resp.status();
            let text = resp.text().await.unwrap_or_default();
            return Err(ProvideError::Api(format!("Status: {}, Body: {}", status, text)));
        }

        let resp = resp.json::<Res>().await?;
        Ok(resp)
    }
}

impl ChatProvide for OpenAi {
    type Error = ProvideError;
    async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, ProvideError> {
        let openai_request = OpenAiChatRequestDTO::from_request(request);
        let openai_response: OpenAiChatResponseDTO = self.send_request("chat/completions", &openai_request).await?;
        openai_response.to_chat_response()
    }
}

impl EmbedProvide for OpenAi {
    type Error = ProvideError;
    async fn embed(&self, request: &EmbedRequest) -> Result<EmbedResponse, ProvideError> {
        let openai_response: OpenAiEmbedResponseDTO = self.send_request("embeddings", &request).await?;
        let embeddings = openai_response.data.into_iter().map(|d| d.embedding).collect();

        Ok(EmbedResponse {
            embeddings,
            usage: openai_response.usage,
        })
    }
}

#[cfg(test)]
mod test {
    use abu_base::{chat::{ChatMessage, ChatRequestBuilder, ToolDefinition}, embed::EmbedRequest};
    use serde_json::json;

    use crate::{openai::OpenAi, ChatProvide, EmbedProvide};

    #[tokio::test]
    async fn test_simple_chat() {
        dotenv::from_filename("./env/openai.env").unwrap();
        let openai = OpenAi::from_env().expect("new client");
        let request = ChatRequestBuilder::default()
            .model(std::env::var("MODEL_ID").expect("No MODEL_ID"))
            .messages([
                ChatMessage::user("hi!"),
            ])
            .build()
            .expect("build request");
                
        let response = openai.chat(&request).await.expect("chat");
        println!("{:#?}", response);
    }

    #[tokio::test]
    async fn test_chat_with_tool() {
        dotenv::from_filename("./env/openai.env").unwrap();
        let openai = OpenAi::from_env().expect("new client");
        let schema = json!({
            "type": "object",
            "properties": {
                "content": {
                    "type": "string",
                    "description": "The text to echo"
                }
            },
            "required": ["content"],
        });
        let echo_tool = ToolDefinition::new("echo", "echo something", schema);
        let request = ChatRequestBuilder::default()
            .model(std::env::var("MODEL_ID").expect("No MODEL_ID"))
            .messages([
                ChatMessage::user("请调用 echo 工具打印一些东西!"),
            ])
            .tools(vec![echo_tool])
            .build()
            .expect("build request");
                
        let response = openai.chat(&request).await.expect("chat");
        println!("{:#?}", response);
    }

    #[tokio::test]
    async fn test_embed() {
        dotenv::from_filename("./env/embed.env").unwrap();
        let openai = OpenAi::from_env().expect("new client");
        let model = std::env::var("EMBED_MODEL_ID").expect("No EMBED_MODEL_ID");
        let request = EmbedRequest::single("I am yc", model);
        let response = openai.embed(&request).await.expect("chat");
        println!("{:#?}", response);
    }
}