Skip to main content

poe2_agent/
llm.rs

1//! OpenAI chat completion API -- blocking and streaming, with tool calling.
2
3use anyhow::{Context, Result};
4use futures_core::Stream;
5use reqwest::header;
6use serde::{Deserialize, Serialize};
7
8const API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10/// OpenAI chat completion client.
11#[derive(Clone)]
12pub struct ChatGptClient {
13    client: reqwest::Client,
14    model: String,
15    reasoning_effort: Option<String>,
16}
17
18#[derive(Debug, Serialize, Deserialize, Clone)]
19pub struct Message {
20    pub role: String,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub content: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub tool_calls: Option<Vec<ToolCall>>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub tool_call_id: Option<String>,
27}
28
29impl Message {
30    pub fn system(content: impl Into<String>) -> Self {
31        Self {
32            role: "system".to_owned(),
33            content: Some(content.into()),
34            tool_calls: None,
35            tool_call_id: None,
36        }
37    }
38
39    pub fn user(content: impl Into<String>) -> Self {
40        Self {
41            role: "user".to_owned(),
42            content: Some(content.into()),
43            tool_calls: None,
44            tool_call_id: None,
45        }
46    }
47
48    pub fn assistant(content: impl Into<String>) -> Self {
49        Self {
50            role: "assistant".to_owned(),
51            content: Some(content.into()),
52            tool_calls: None,
53            tool_call_id: None,
54        }
55    }
56
57    pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
58        Self {
59            role: "assistant".to_owned(),
60            content: None,
61            tool_calls: Some(tool_calls),
62            tool_call_id: None,
63        }
64    }
65
66    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
67        Self {
68            role: "tool".to_owned(),
69            content: Some(content.into()),
70            tool_calls: None,
71            tool_call_id: Some(tool_call_id.into()),
72        }
73    }
74}
75
76// -- Tool-calling types ------------------------------------------------------
77
78#[derive(Debug, Serialize, Deserialize, Clone)]
79pub struct ToolCall {
80    pub id: String,
81    #[serde(rename = "type")]
82    pub call_type: String,
83    pub function: FunctionCall,
84}
85
86#[derive(Debug, Serialize, Deserialize, Clone)]
87pub struct FunctionCall {
88    pub name: String,
89    pub arguments: String,
90}
91
92#[derive(Debug, Serialize, Clone)]
93pub struct ToolDefinition {
94    #[serde(rename = "type")]
95    pub tool_type: String,
96    pub function: FunctionDefinition,
97}
98
99#[derive(Debug, Serialize, Clone)]
100pub struct FunctionDefinition {
101    pub name: String,
102    pub description: String,
103    pub parameters: serde_json::Value,
104}
105
106// -- Request types -----------------------------------------------------------
107
108#[derive(Debug, Serialize)]
109struct ChatRequest {
110    model: String,
111    messages: Vec<Message>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    temperature: Option<f32>,
114    #[serde(skip_serializing_if = "std::ops::Not::not")]
115    stream: bool,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    tools: Option<Vec<ToolDefinition>>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    reasoning_effort: Option<String>,
120}
121
122// -- Non-streaming response types --------------------------------------------
123
124#[derive(Debug, Deserialize)]
125struct ChatResponse {
126    choices: Vec<Choice>,
127}
128
129#[derive(Debug, Deserialize)]
130struct Choice {
131    message: Message,
132    finish_reason: Option<String>,
133}
134
135// -- Streaming response types ------------------------------------------------
136
137#[derive(Debug, Deserialize)]
138struct StreamChunk {
139    choices: Vec<StreamChoice>,
140}
141
142#[derive(Debug, Deserialize)]
143struct StreamChoice {
144    delta: Delta,
145}
146
147#[derive(Debug, Deserialize)]
148struct Delta {
149    content: Option<String>,
150}
151
152// -- Errors ------------------------------------------------------------------
153
154#[derive(Debug, thiserror::Error)]
155pub enum LlmError {
156    #[error("OpenAI API error (HTTP {status}): {body}")]
157    Api { status: u16, body: String },
158
159    #[error(transparent)]
160    Transport(#[from] reqwest::Error),
161
162    #[error(transparent)]
163    Other(#[from] anyhow::Error),
164}
165
166impl ChatGptClient {
167    /// Create a new client. The API key is baked into the underlying
168    /// `reqwest::Client` as a default header so it doesn't need to be
169    /// cloned per-request.
170    pub fn new(api_key: &str, model: &str) -> Result<Self> {
171        let mut headers = header::HeaderMap::new();
172        let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
173            .context("invalid API key characters")?;
174        auth.set_sensitive(true);
175        headers.insert(header::AUTHORIZATION, auth);
176
177        let client = reqwest::Client::builder()
178            .default_headers(headers)
179            .build()
180            .context("failed to build HTTP client")?;
181
182        // GPT-5 models default to "medium" reasoning effort, which generates
183        // hidden reasoning tokens and adds significant latency. Use "minimal"
184        // for fast tool-calling agentic flows.
185        let reasoning_effort = if model.starts_with("gpt-5") {
186            Some("minimal".to_owned())
187        } else {
188            None
189        };
190
191        Ok(Self {
192            client,
193            model: model.to_owned(),
194            reasoning_effort,
195        })
196    }
197
198    /// Send a blocking chat completion request, returning the full response.
199    pub async fn chat(&self, messages: Vec<Message>) -> Result<String, LlmError> {
200        let request = ChatRequest {
201            model: self.model.clone(),
202            messages,
203            temperature: None,
204            stream: false,
205            tools: None,
206            reasoning_effort: self.reasoning_effort.clone(),
207        };
208
209        let response = self.client.post(API_URL).json(&request).send().await?;
210        let status = response.status();
211        if !status.is_success() {
212            let body = response.text().await.unwrap_or_default();
213            return Err(LlmError::Api {
214                status: status.as_u16(),
215                body,
216            });
217        }
218
219        let parsed: ChatResponse = response.json().await?;
220        Ok(parsed
221            .choices
222            .into_iter()
223            .next()
224            .and_then(|c| c.message.content)
225            .unwrap_or_default())
226    }
227
228    /// Non-streaming chat completion with tool support.
229    ///
230    /// Returns the full assistant `Message` and the `finish_reason`.
231    /// The agent loop inspects these to decide whether to execute tools
232    /// or return the final answer.
233    pub async fn chat_with_tools(
234        &self,
235        messages: Vec<Message>,
236        tools: Option<&[ToolDefinition]>,
237    ) -> Result<(Message, Option<String>), LlmError> {
238        let request = ChatRequest {
239            model: self.model.clone(),
240            messages,
241            temperature: None,
242            stream: false,
243            tools: tools.map(|t| t.to_vec()),
244            reasoning_effort: self.reasoning_effort.clone(),
245        };
246
247        let response = self.client.post(API_URL).json(&request).send().await?;
248        let status = response.status();
249        if !status.is_success() {
250            let body = response.text().await.unwrap_or_default();
251            return Err(LlmError::Api {
252                status: status.as_u16(),
253                body,
254            });
255        }
256
257        let parsed: ChatResponse = response.json().await?;
258        let choice = parsed
259            .choices
260            .into_iter()
261            .next()
262            .ok_or_else(|| LlmError::Other(anyhow::anyhow!("no choices in response")))?;
263
264        Ok((choice.message, choice.finish_reason))
265    }
266
267    /// Stream a chat completion, yielding content tokens as they arrive.
268    ///
269    /// The returned stream is `'static` -- it clones the HTTP client and model
270    /// name so callers don't need to worry about lifetimes.
271    pub fn chat_stream(
272        &self,
273        messages: Vec<Message>,
274    ) -> impl Stream<Item = Result<String, LlmError>> + Send {
275        let client = self.client.clone();
276        let model = self.model.clone();
277        let reasoning_effort = self.reasoning_effort.clone();
278
279        async_stream::try_stream! {
280            let request = ChatRequest {
281                model,
282                messages,
283                temperature: None,
284                stream: true,
285                tools: None,
286                reasoning_effort,
287            };
288
289            let mut response = client.post(API_URL).json(&request).send().await?;
290            if !response.status().is_success() {
291                let status = response.status().as_u16();
292                // Read error body via chunk() to avoid .text() consuming by value.
293                let mut body = String::new();
294                while let Some(chunk) = response.chunk().await? {
295                    body.push_str(&String::from_utf8_lossy(&chunk));
296                }
297                Err(LlmError::Api { status, body })?;
298            }
299            let mut buffer = String::new();
300
301            while let Some(chunk) = response.chunk().await? {
302                buffer.push_str(&String::from_utf8_lossy(&chunk));
303
304                // Process complete SSE events (delimited by double newline).
305                while let Some(pos) = buffer.find("\n\n") {
306                    let event = buffer[..pos].to_owned();
307                    buffer = buffer[pos + 2..].to_owned();
308
309                    for line in event.lines() {
310                        let data = match line.strip_prefix("data: ") {
311                            Some(d) => d.trim(),
312                            None => continue,
313                        };
314
315                        if data == "[DONE]" {
316                            return;
317                        }
318
319                        if let Ok(parsed) = serde_json::from_str::<StreamChunk>(data) {
320                            for choice in parsed.choices {
321                                if let Some(content) = choice.delta.content {
322                                    yield content;
323                                }
324                            }
325                        }
326                    }
327                }
328            }
329        }
330    }
331}