ambi 0.1.4

A flexible, multi-backend, customizable AI agent framework, entirely based on Rust.
Documentation
use crate::llm::{LLMEngineTrait, LLMRequest};
use crate::types::message::{ContentPart, Message};
use anyhow::{anyhow, Result};
use async_openai::config::OpenAIConfig;
use async_openai::types::chat::{
    ChatCompletionMessageToolCalls, ChatCompletionRequestAssistantMessageArgs,
    ChatCompletionRequestMessage, ChatCompletionRequestSystemMessageArgs,
    ChatCompletionRequestToolMessageArgs, ChatCompletionRequestUserMessageArgs,
    CreateChatCompletionRequest, CreateChatCompletionRequestArgs,
};
use async_openai::Client;
use async_trait::async_trait;
use futures::StreamExt;
use log::{debug, error};
use serde::Deserialize;
use tokio::sync::mpsc::Sender;

/// Configuration settings for OpenAI-compatible cloud APIs.
///
/// This configuration is used to connect the Agent to remote endpoints that follow
/// the standard OpenAI Chat Completions API format. It can be used for OpenAI,
/// DeepSeek, Groq, or even local proxy servers like Ollama or vLLM.
///
/// # Examples
///
/// ```rust
/// use ambi::llm::providers::openai_api::OpenAIEngineConfig;
///
/// let config = OpenAIEngineConfig {
///     api_key: std::env::var("OPENAI_API_KEY").unwrap_or_default(),
///     base_url: "[https://api.openai.com/v1](https://api.openai.com/v1)".to_string(),
///     model_name: "gpt-4o".to_string(),
///     temp: 0.7,
///     top_p: 0.95,
/// };
/// ```
#[derive(Debug, Deserialize, Clone)]
pub struct OpenAIEngineConfig {
    pub api_key: String,
    pub base_url: String,
    pub model_name: String,
    pub temp: f32,
    pub top_p: f32,
}

impl OpenAIEngineConfig {
    pub fn validate(&self) -> Result<()> {
        if self.api_key.trim().is_empty() {
            return Err(anyhow!("OpenAI API Key cannot be empty"));
        }
        if self.temp < 0.0 || self.temp > 2.0 {
            return Err(anyhow!("Temperature must be between 0.0 and 2.0"));
        }
        Ok(())
    }
}

#[derive(Clone)]
pub struct OpenAIEngine {
    client: Client<OpenAIConfig>,
    cfg: OpenAIEngineConfig,
}

impl OpenAIEngine {
    pub fn load(openai_cfg: OpenAIEngineConfig) -> Result<Self> {
        let api_key = openai_cfg.api_key.clone();
        let mut config = OpenAIConfig::new().with_api_key(api_key);

        config = config.with_api_base(&openai_cfg.base_url);

        let client = Client::with_config(config);

        Ok(Self {
            client,
            cfg: openai_cfg,
        })
    }

    pub async fn generate_response_stream(
        &self,
        request: LLMRequest,
        tx: Sender<Result<String, anyhow::Error>>,
    ) -> Result<()> {
        let mut new_prompt = String::new();

        if let Some(msg) = request.history.last() {
            if let Message::User { content } = &**msg {
                for part in content {
                    if let ContentPart::Text { text } = part {
                        new_prompt.push_str(text);
                    }
                }
            }
        }

        debug!(
            "\n[OpenAI API] Request\n========================================\n{}",
            new_prompt
        );

        let model_name = self.cfg.model_name.clone();

        let api_request = self.get_request(model_name, request, true)?;

        let mut stream = self.client.chat().create_stream(api_request).await?;

        while let Some(result) = stream.next().await {
            match result {
                Ok(response) => {
                    for choice in response.choices {
                        if let Some(content) = choice.delta.content {
                            if tx.send(Ok(content)).await.is_err() {
                                debug!("Output channel closed, terminating OpenAI stream.");
                                return Ok(());
                            }
                        }
                    }
                }
                Err(e) => {
                    error!("OpenAI Stream Error: {}", e);
                    let _ = tx.send(Err(anyhow!("Stream interrupted: {}", e))).await;
                    return Err(e.into());
                }
            }
        }
        Ok(())
    }

    pub async fn generate_response_sync(&self, request: LLMRequest) -> Result<String> {
        let model_name = self.cfg.model_name.clone();

        let api_request = self.get_request(model_name, request, false)?;

        let response = self.client.chat().create(api_request).await?;

        let content = response
            .choices
            .into_iter()
            .next()
            .and_then(|c| c.message.content)
            .unwrap_or_default();

        Ok(content)
    }

    pub fn reset_context(&mut self) {}

    fn get_request(
        &self,
        model_name: String,
        request: LLMRequest,
        stream: bool,
    ) -> Result<CreateChatCompletionRequest> {
        let mut messages: Vec<ChatCompletionRequestMessage> = Vec::new();

        let mut sys_content = request.system_prompt;
        if !request.tool_prompt.is_empty() {
            if !sys_content.is_empty() {
                sys_content.push_str("\n\n");
            }
            sys_content.push_str(&request.tool_prompt);
        }

        if !sys_content.is_empty() {
            messages.push(
                ChatCompletionRequestSystemMessageArgs::default()
                    .content(sys_content)
                    .build()?
                    .into(),
            );
        }

        let mut i = 0;

        while i < request.history.len() {
            let msg = &request.history[i];

            match &**msg {
                Message::System { .. } => {
                    i += 1;
                    continue;
                }
                Message::User { .. } => {
                    messages.push(
                        ChatCompletionRequestUserMessageArgs::default()
                            .content(msg.get_text_content())
                            .build()?
                            .into(),
                    );
                }
                Message::Assistant { content } => {
                    let mut tool_calls = Vec::new();
                    let mut peek_i = i + 1;

                    while peek_i < request.history.len()
                        && matches!(&*request.history[peek_i], Message::Tool { .. })
                    {
                        let tool_call: ChatCompletionMessageToolCalls =
                            serde_json::from_value(serde_json::json!({
                                "id": format!("call_native_{}", peek_i),
                                "type": "function",
                                "function": {
                                    "name": "agent_tool",
                                    "arguments": "{}"
                                }
                            }))
                            .map_err(|e| anyhow!("Failed to build ToolCall: {}", e))?;

                        tool_calls.push(tool_call);
                        peek_i += 1;
                    }

                    if !tool_calls.is_empty() {
                        messages.push(
                            ChatCompletionRequestAssistantMessageArgs::default()
                                .content(content.clone())
                                .tool_calls(tool_calls)
                                .build()?
                                .into(),
                        );
                    } else {
                        messages.push(
                            ChatCompletionRequestAssistantMessageArgs::default()
                                .content(content.clone())
                                .build()?
                                .into(),
                        );
                    }
                }
                Message::Tool { content } => {
                    let tool_call_id = format!("call_native_{}", i);

                    messages.push(
                        ChatCompletionRequestToolMessageArgs::default()
                            .content(content.clone())
                            .tool_call_id(tool_call_id)
                            .build()?
                            .into(),
                    );
                }
            }
            i += 1;
        }

        let request = CreateChatCompletionRequestArgs::default()
            .model(model_name)
            .messages(messages)
            .temperature(self.cfg.temp)
            .top_p(self.cfg.top_p)
            .stream(stream)
            .build()?;

        Ok(request)
    }
}

#[async_trait]
impl LLMEngineTrait for OpenAIEngine {
    async fn chat(&mut self, request: LLMRequest) -> Result<String> {
        self.generate_response_sync(request).await.map_err(|e| {
            error!("OpenAI model generation error: {}", e);
            anyhow!("OpenAI error: {}", e)
        })
    }

    async fn chat_stream(
        &mut self,
        request: LLMRequest,
        tx: Sender<Result<String, anyhow::Error>>,
    ) {
        if let Err(e) = self.generate_response_stream(request, tx.clone()).await {
            error!("OpenAI stream generation error: {}", e);
            let _ = tx.send(Err(anyhow!("OpenAI API Error: {}", e))).await;
        }
    }

    fn reset_context(&mut self) {
        self.reset_context();
    }
}