poe2-agent 0.2.0

AI agent for Path of Exile 2 build analysis
Documentation
//! OpenAI chat completion API -- blocking and streaming, with tool calling.

use anyhow::{Context, Result};
use futures_core::Stream;
use reqwest::header;
use serde::{Deserialize, Serialize};

const API_URL: &str = "https://api.openai.com/v1/chat/completions";

/// OpenAI chat completion client.
#[derive(Clone)]
pub struct ChatGptClient {
    client: reqwest::Client,
    model: String,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Message {
    pub role: String,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub content: Option<String>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_calls: Option<Vec<ToolCall>>,
    #[serde(skip_serializing_if = "Option::is_none")]
    pub tool_call_id: Option<String>,
}

impl Message {
    pub fn system(content: impl Into<String>) -> Self {
        Self {
            role: "system".to_owned(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: None,
        }
    }

    pub fn user(content: impl Into<String>) -> Self {
        Self {
            role: "user".to_owned(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: None,
        }
    }

    pub fn assistant(content: impl Into<String>) -> Self {
        Self {
            role: "assistant".to_owned(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: None,
        }
    }

    pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
        Self {
            role: "assistant".to_owned(),
            content: None,
            tool_calls: Some(tool_calls),
            tool_call_id: None,
        }
    }

    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
        Self {
            role: "tool".to_owned(),
            content: Some(content.into()),
            tool_calls: None,
            tool_call_id: Some(tool_call_id.into()),
        }
    }
}

// -- Tool-calling types ------------------------------------------------------

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ToolCall {
    pub id: String,
    #[serde(rename = "type")]
    pub call_type: String,
    pub function: FunctionCall,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct FunctionCall {
    pub name: String,
    pub arguments: String,
}

#[derive(Debug, Serialize, Clone)]
pub struct ToolDefinition {
    #[serde(rename = "type")]
    pub tool_type: String,
    pub function: FunctionDefinition,
}

#[derive(Debug, Serialize, Clone)]
pub struct FunctionDefinition {
    pub name: String,
    pub description: String,
    pub parameters: serde_json::Value,
}

// -- Request types -----------------------------------------------------------

#[derive(Debug, Serialize)]
struct ChatRequest {
    model: String,
    messages: Vec<Message>,
    #[serde(skip_serializing_if = "Option::is_none")]
    temperature: Option<f32>,
    #[serde(skip_serializing_if = "std::ops::Not::not")]
    stream: bool,
    #[serde(skip_serializing_if = "Option::is_none")]
    tools: Option<Vec<ToolDefinition>>,
}

// -- Non-streaming response types --------------------------------------------

#[derive(Debug, Deserialize)]
struct ChatResponse {
    choices: Vec<Choice>,
}

#[derive(Debug, Deserialize)]
struct Choice {
    message: Message,
    finish_reason: Option<String>,
}

// -- Streaming response types ------------------------------------------------

#[derive(Debug, Deserialize)]
struct StreamChunk {
    choices: Vec<StreamChoice>,
}

#[derive(Debug, Deserialize)]
struct StreamChoice {
    delta: Delta,
}

#[derive(Debug, Deserialize)]
struct Delta {
    content: Option<String>,
}

// -- Errors ------------------------------------------------------------------

#[derive(Debug, thiserror::Error)]
pub enum LlmError {
    #[error("OpenAI API error (HTTP {status}): {body}")]
    Api { status: u16, body: String },

    #[error(transparent)]
    Transport(#[from] reqwest::Error),

    #[error(transparent)]
    Other(#[from] anyhow::Error),
}

impl ChatGptClient {
    /// Create a new client. The API key is baked into the underlying
    /// `reqwest::Client` as a default header so it doesn't need to be
    /// cloned per-request.
    pub fn new(api_key: &str, model: &str) -> Result<Self> {
        let mut headers = header::HeaderMap::new();
        let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
            .context("invalid API key characters")?;
        auth.set_sensitive(true);
        headers.insert(header::AUTHORIZATION, auth);

        let client = reqwest::Client::builder()
            .default_headers(headers)
            .build()
            .context("failed to build HTTP client")?;

        Ok(Self {
            client,
            model: model.to_owned(),
        })
    }

    /// Send a blocking chat completion request, returning the full response.
    pub async fn chat(&self, messages: Vec<Message>) -> Result<String, LlmError> {
        let request = ChatRequest {
            model: self.model.clone(),
            messages,
            temperature: None,
            stream: false,
            tools: None,
        };

        let response = self.client.post(API_URL).json(&request).send().await?;
        let status = response.status();
        if !status.is_success() {
            let body = response.text().await.unwrap_or_default();
            return Err(LlmError::Api {
                status: status.as_u16(),
                body,
            });
        }

        let parsed: ChatResponse = response.json().await?;
        Ok(parsed
            .choices
            .into_iter()
            .next()
            .and_then(|c| c.message.content)
            .unwrap_or_default())
    }

    /// Non-streaming chat completion with tool support.
    ///
    /// Returns the full assistant `Message` and the `finish_reason`.
    /// The agent loop inspects these to decide whether to execute tools
    /// or return the final answer.
    pub async fn chat_with_tools(
        &self,
        messages: Vec<Message>,
        tools: Option<&[ToolDefinition]>,
    ) -> Result<(Message, Option<String>), LlmError> {
        let request = ChatRequest {
            model: self.model.clone(),
            messages,
            temperature: None,
            stream: false,
            tools: tools.map(|t| t.to_vec()),
        };

        let response = self.client.post(API_URL).json(&request).send().await?;
        let status = response.status();
        if !status.is_success() {
            let body = response.text().await.unwrap_or_default();
            return Err(LlmError::Api {
                status: status.as_u16(),
                body,
            });
        }

        let parsed: ChatResponse = response.json().await?;
        let choice = parsed
            .choices
            .into_iter()
            .next()
            .ok_or_else(|| LlmError::Other(anyhow::anyhow!("no choices in response")))?;

        Ok((choice.message, choice.finish_reason))
    }

    /// Stream a chat completion, yielding content tokens as they arrive.
    ///
    /// The returned stream is `'static` -- it clones the HTTP client and model
    /// name so callers don't need to worry about lifetimes.
    pub fn chat_stream(
        &self,
        messages: Vec<Message>,
    ) -> impl Stream<Item = Result<String, LlmError>> + Send {
        let client = self.client.clone();
        let model = self.model.clone();

        async_stream::try_stream! {
            let request = ChatRequest {
                model,
                messages,
                temperature: None,
                stream: true,
                tools: None,
            };

            let mut response = client.post(API_URL).json(&request).send().await?;
            if !response.status().is_success() {
                let status = response.status().as_u16();
                // Read error body via chunk() to avoid .text() consuming by value.
                let mut body = String::new();
                while let Some(chunk) = response.chunk().await? {
                    body.push_str(&String::from_utf8_lossy(&chunk));
                }
                Err(LlmError::Api { status, body })?;
            }
            let mut buffer = String::new();

            while let Some(chunk) = response.chunk().await? {
                buffer.push_str(&String::from_utf8_lossy(&chunk));

                // Process complete SSE events (delimited by double newline).
                while let Some(pos) = buffer.find("\n\n") {
                    let event = buffer[..pos].to_owned();
                    buffer = buffer[pos + 2..].to_owned();

                    for line in event.lines() {
                        let data = match line.strip_prefix("data: ") {
                            Some(d) => d.trim(),
                            None => continue,
                        };

                        if data == "[DONE]" {
                            return;
                        }

                        if let Ok(parsed) = serde_json::from_str::<StreamChunk>(data) {
                            for choice in parsed.choices {
                                if let Some(content) = choice.delta.content {
                                    yield content;
                                }
                            }
                        }
                    }
                }
            }
        }
    }
}