agent-base 0.1.0

A lightweight Agent Runtime Kernel for building AI agents in Rust
Documentation
use async_trait::async_trait;
use eventsource_stream::Eventsource;
use futures_core::Stream;
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::{json, Value};
use std::pin::Pin;

use crate::types::{AgentResult, AgentError, ChatMessage, ImageAttachment, ImageDetail, ResponseFormat, ToolCallMessage};
use super::{LlmCapabilities, LlmClient, StreamChunk, UsageInfo};


pub struct OpenAiClient {
    api_key: String,
    model: String,
    base_url: String,
    client: Client,
}

impl OpenAiClient {
    pub fn new(api_key: String, model: String, base_url: Option<String>) -> Self {
        Self {
            api_key,
            model,
            base_url: base_url
                .unwrap_or_else(|| "https://api.openai.com/v1".to_string()),
            client: Client::new(),
        }
    }

    fn chat_message_to_json(msg: &ChatMessage) -> Value {
        match msg {
            ChatMessage::System { content } => json!({
                "role": "system",
                "content": content,
            }),
            ChatMessage::User { content, images } => {
                if images.is_empty() {
                    json!({
                        "role": "user",
                        "content": content,
                    })
                } else {
                    let mut content_parts: Vec<Value> = Vec::new();
                    content_parts.push(json!({"type": "text", "text": content}));
                    for img in images {
                        content_parts.push(Self::image_to_json(img));
                    }
                    json!({
                        "role": "user",
                        "content": content_parts,
                    })
                }
            }
            ChatMessage::Assistant { content, reasoning_content, tool_calls } => {
                let mut obj = serde_json::Map::new();
                obj.insert("role".to_string(), json!("assistant"));
                obj.insert("content".to_string(), json!(content));
                if let Some(reasoning) = reasoning_content {
                    obj.insert("reasoning_content".to_string(), json!(reasoning));
                }
                if let Some(tc) = tool_calls {
                    let tool_calls_json: Vec<Value> = tc
                        .iter()
                        .map(|t| Self::tool_call_to_json(t))
                        .collect();
                    obj.insert("tool_calls".to_string(), json!(tool_calls_json));
                }
                Value::Object(obj)
            }
            ChatMessage::Tool { tool_call_id, content } => json!({
                "role": "tool",
                "tool_call_id": tool_call_id,
                "content": content,
            }),
        }
    }

    fn tool_call_to_json(tc: &ToolCallMessage) -> Value {
        json!({
            "id": tc.id,
            "type": "function",
            "function": {
                "name": tc.name,
                "arguments": tc.arguments,
            }
        })
    }

    fn image_to_json(img: &ImageAttachment) -> Value {
        match img {
            ImageAttachment::Url { url, detail } => {
                let mut obj = serde_json::Map::new();
                obj.insert("url".to_string(), json!(url));
                if let Some(d) = detail {
                    let detail_str = match d {
                        ImageDetail::Low => "low",
                        ImageDetail::High => "high",
                        ImageDetail::Auto => "auto",
                    };
                    obj.insert("detail".to_string(), json!(detail_str));
                }
                json!({
                    "type": "image_url",
                    "image_url": Value::Object(obj),
                })
            }
            ImageAttachment::Base64 { data, media_type, detail } => {
                let mime = media_type.as_deref().unwrap_or("image/jpeg");
                let data_url = format!("data:{mime};base64,{data}");
                let mut obj = serde_json::Map::new();
                obj.insert("url".to_string(), json!(data_url));
                if let Some(d) = detail {
                    let detail_str = match d {
                        ImageDetail::Low => "low",
                        ImageDetail::High => "high",
                        ImageDetail::Auto => "auto",
                    };
                    obj.insert("detail".to_string(), json!(detail_str));
                }
                json!({
                    "type": "image_url",
                    "image_url": Value::Object(obj),
                })
            }
        }
    }

    fn messages_to_json(messages: &[ChatMessage]) -> Vec<Value> {
        messages.iter().map(Self::chat_message_to_json).collect()
    }
}

#[async_trait]
impl LlmClient for OpenAiClient {
    async fn chat(
        &self,
        messages: &[ChatMessage],
        tools: &[Value],
        enable_thinking: Option<bool>,
        response_format: Option<&ResponseFormat>,
    ) -> AgentResult<Value> {
        let url = format!("{}/chat/completions", self.base_url);
        let raw_messages = Self::messages_to_json(messages);
        let mut request_body = json!({
            "model": self.model,
            "messages": raw_messages,
            "tools": tools,
        });

        if let Some(thinking) = enable_thinking {
            if let Some(obj) = request_body.as_object_mut() {
                obj.insert("enable_thinking".to_string(), json!(thinking));
            }
        }

        if let Some(rf) = response_format {
            if let Some(obj) = request_body.as_object_mut() {
                obj.insert("response_format".to_string(), rf.to_api_value());
            }
        }

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await
            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;

        let res_json: Value = response.json().await
            .map_err(|e| AgentError::json(format!("Response JSON parse failed: {e}")))?;

        if let Some(error) = res_json.get("error") {
            return Err(AgentError::LlmApi {
                message: format!("{error:#?}"),
            });
        }

        Ok(res_json)
    }

    async fn chat_stream(
        &self,
        messages: &[ChatMessage],
        tools: &[Value],
        enable_thinking: Option<bool>,
        response_format: Option<&ResponseFormat>,
    ) -> AgentResult<Pin<Box<dyn Stream<Item = AgentResult<StreamChunk>> + Send>>> {
        let url = format!("{}/chat/completions", self.base_url);
        let raw_messages = Self::messages_to_json(messages);
        let mut request_body = json!({
            "model": self.model,
            "messages": raw_messages,
            "tools": tools,
            "stream": true,
            "stream_options": { "include_usage": true },
        });

        if let Some(thinking) = enable_thinking {
            if let Some(obj) = request_body.as_object_mut() {
                obj.insert("enable_thinking".to_string(), json!(thinking));
            }
        }

        if let Some(rf) = response_format {
            if let Some(obj) = request_body.as_object_mut() {
                obj.insert("response_format".to_string(), rf.to_api_value());
            }
        }

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.api_key))
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await
            .map_err(|e| AgentError::llm(format!("HTTP request failed: {e}")))?;

        if !response.status().is_success() {
            let err_text = response.text().await
                .map_err(|e| AgentError::llm(format!("Failed to read error response: {e}")))?;
            return Err(AgentError::LlmApi { message: err_text });
        }

        let stream = response.bytes_stream().eventsource().map(|event| match event {
            Ok(event) => {
                if event.data == "[DONE]" {
                    return Ok(StreamChunk::Stop);
                }

                let data: Value = serde_json::from_str(&event.data)
                    .map_err(|e| AgentError::json(format!("JSON Parse error: {e}")))?;

                let choices = data.get("choices").and_then(Value::as_array);

                if choices.is_none() || choices.map_or(true, |c| c.is_empty()) {
                    if let Some(usage) = data.get("usage") {
                        return Ok(StreamChunk::Usage(UsageInfo {
                            prompt_tokens: usage.get("prompt_tokens").and_then(Value::as_u64).map(|v| v as u32),
                            completion_tokens: usage.get("completion_tokens").and_then(Value::as_u64).map(|v| v as u32),
                            total_tokens: usage.get("total_tokens").and_then(Value::as_u64).map(|v| v as u32),
                        }));
                    }
                    return Ok(StreamChunk::Text(String::new()));
                }

                let choice = &choices.unwrap()[0];
                let delta = &choice["delta"];
                let finish_reason = choice["finish_reason"].as_str().unwrap_or("");

                if finish_reason == "tool_calls" || delta.get("tool_calls").is_some() {
                    return Ok(StreamChunk::ToolCall(choice.clone()));
                }

                if let Some(reasoning) = delta.get("reasoning_content") {
                    if let Some(text) = reasoning.as_str() {
                        return Ok(StreamChunk::Thought(text.to_string()));
                    }
                }

                if let Some(content) = delta.get("content") {
                    if let Some(text) = content.as_str() {
                        return Ok(StreamChunk::Text(text.to_string()));
                    }
                }

                if finish_reason == "stop" {
                    return Ok(StreamChunk::Stop);
                }

                Ok(StreamChunk::Text(String::new()))
            }
            Err(e) => Err(AgentError::LlmStream(format!("SSE Stream error: {e}"))),
        });

        Ok(Box::pin(stream))
    }

    fn capabilities(&self) -> LlmCapabilities {
        LlmCapabilities {
            supports_streaming: true,
            supports_tools: true,
            supports_vision: true,
            supports_thinking: false,
            max_context_tokens: Some(128_000),
            max_output_tokens: Some(16_384),
        }
    }
}