Skip to main content

poe2_agent/
llm.rs

1//! OpenAI chat completion API -- blocking and streaming, with tool calling.
2
3use anyhow::{Context, Result};
4use futures_core::Stream;
5use reqwest::header;
6use serde::{Deserialize, Serialize};
7
8const API_URL: &str = "https://api.openai.com/v1/chat/completions";
9
10/// OpenAI chat completion client.
11#[derive(Clone)]
12pub struct ChatGptClient {
13    client: reqwest::Client,
14    model: String,
15}
16
17#[derive(Debug, Serialize, Deserialize, Clone)]
18pub struct Message {
19    pub role: String,
20    #[serde(skip_serializing_if = "Option::is_none")]
21    pub content: Option<String>,
22    #[serde(skip_serializing_if = "Option::is_none")]
23    pub tool_calls: Option<Vec<ToolCall>>,
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_call_id: Option<String>,
26}
27
28impl Message {
29    pub fn system(content: impl Into<String>) -> Self {
30        Self {
31            role: "system".to_owned(),
32            content: Some(content.into()),
33            tool_calls: None,
34            tool_call_id: None,
35        }
36    }
37
38    pub fn user(content: impl Into<String>) -> Self {
39        Self {
40            role: "user".to_owned(),
41            content: Some(content.into()),
42            tool_calls: None,
43            tool_call_id: None,
44        }
45    }
46
47    pub fn assistant(content: impl Into<String>) -> Self {
48        Self {
49            role: "assistant".to_owned(),
50            content: Some(content.into()),
51            tool_calls: None,
52            tool_call_id: None,
53        }
54    }
55
56    pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
57        Self {
58            role: "assistant".to_owned(),
59            content: None,
60            tool_calls: Some(tool_calls),
61            tool_call_id: None,
62        }
63    }
64
65    pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
66        Self {
67            role: "tool".to_owned(),
68            content: Some(content.into()),
69            tool_calls: None,
70            tool_call_id: Some(tool_call_id.into()),
71        }
72    }
73}
74
75// -- Tool-calling types ------------------------------------------------------
76
77#[derive(Debug, Serialize, Deserialize, Clone)]
78pub struct ToolCall {
79    pub id: String,
80    #[serde(rename = "type")]
81    pub call_type: String,
82    pub function: FunctionCall,
83}
84
85#[derive(Debug, Serialize, Deserialize, Clone)]
86pub struct FunctionCall {
87    pub name: String,
88    pub arguments: String,
89}
90
91#[derive(Debug, Serialize, Clone)]
92pub struct ToolDefinition {
93    #[serde(rename = "type")]
94    pub tool_type: String,
95    pub function: FunctionDefinition,
96}
97
98#[derive(Debug, Serialize, Clone)]
99pub struct FunctionDefinition {
100    pub name: String,
101    pub description: String,
102    pub parameters: serde_json::Value,
103}
104
105// -- Request types -----------------------------------------------------------
106
107#[derive(Debug, Serialize)]
108struct ChatRequest {
109    model: String,
110    messages: Vec<Message>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    temperature: Option<f32>,
113    #[serde(skip_serializing_if = "std::ops::Not::not")]
114    stream: bool,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    tools: Option<Vec<ToolDefinition>>,
117}
118
119// -- Non-streaming response types --------------------------------------------
120
121#[derive(Debug, Deserialize)]
122struct ChatResponse {
123    choices: Vec<Choice>,
124}
125
126#[derive(Debug, Deserialize)]
127struct Choice {
128    message: Message,
129    finish_reason: Option<String>,
130}
131
132// -- Streaming response types ------------------------------------------------
133
134#[derive(Debug, Deserialize)]
135struct StreamChunk {
136    choices: Vec<StreamChoice>,
137}
138
139#[derive(Debug, Deserialize)]
140struct StreamChoice {
141    delta: Delta,
142}
143
144#[derive(Debug, Deserialize)]
145struct Delta {
146    content: Option<String>,
147}
148
149// -- Errors ------------------------------------------------------------------
150
151#[derive(Debug, thiserror::Error)]
152pub enum LlmError {
153    #[error("OpenAI API error (HTTP {status}): {body}")]
154    Api { status: u16, body: String },
155
156    #[error(transparent)]
157    Transport(#[from] reqwest::Error),
158
159    #[error(transparent)]
160    Other(#[from] anyhow::Error),
161}
162
163impl ChatGptClient {
164    /// Create a new client. The API key is baked into the underlying
165    /// `reqwest::Client` as a default header so it doesn't need to be
166    /// cloned per-request.
167    pub fn new(api_key: &str, model: &str) -> Result<Self> {
168        let mut headers = header::HeaderMap::new();
169        let mut auth = header::HeaderValue::from_str(&format!("Bearer {api_key}"))
170            .context("invalid API key characters")?;
171        auth.set_sensitive(true);
172        headers.insert(header::AUTHORIZATION, auth);
173
174        let client = reqwest::Client::builder()
175            .default_headers(headers)
176            .build()
177            .context("failed to build HTTP client")?;
178
179        Ok(Self {
180            client,
181            model: model.to_owned(),
182        })
183    }
184
185    /// Send a blocking chat completion request, returning the full response.
186    pub async fn chat(&self, messages: Vec<Message>) -> Result<String, LlmError> {
187        let request = ChatRequest {
188            model: self.model.clone(),
189            messages,
190            temperature: None,
191            stream: false,
192            tools: None,
193        };
194
195        let response = self.client.post(API_URL).json(&request).send().await?;
196        let status = response.status();
197        if !status.is_success() {
198            let body = response.text().await.unwrap_or_default();
199            return Err(LlmError::Api {
200                status: status.as_u16(),
201                body,
202            });
203        }
204
205        let parsed: ChatResponse = response.json().await?;
206        Ok(parsed
207            .choices
208            .into_iter()
209            .next()
210            .and_then(|c| c.message.content)
211            .unwrap_or_default())
212    }
213
214    /// Non-streaming chat completion with tool support.
215    ///
216    /// Returns the full assistant `Message` and the `finish_reason`.
217    /// The agent loop inspects these to decide whether to execute tools
218    /// or return the final answer.
219    pub async fn chat_with_tools(
220        &self,
221        messages: Vec<Message>,
222        tools: Option<&[ToolDefinition]>,
223    ) -> Result<(Message, Option<String>), LlmError> {
224        let request = ChatRequest {
225            model: self.model.clone(),
226            messages,
227            temperature: None,
228            stream: false,
229            tools: tools.map(|t| t.to_vec()),
230        };
231
232        let response = self.client.post(API_URL).json(&request).send().await?;
233        let status = response.status();
234        if !status.is_success() {
235            let body = response.text().await.unwrap_or_default();
236            return Err(LlmError::Api {
237                status: status.as_u16(),
238                body,
239            });
240        }
241
242        let parsed: ChatResponse = response.json().await?;
243        let choice = parsed
244            .choices
245            .into_iter()
246            .next()
247            .ok_or_else(|| LlmError::Other(anyhow::anyhow!("no choices in response")))?;
248
249        Ok((choice.message, choice.finish_reason))
250    }
251
252    /// Stream a chat completion, yielding content tokens as they arrive.
253    ///
254    /// The returned stream is `'static` -- it clones the HTTP client and model
255    /// name so callers don't need to worry about lifetimes.
256    pub fn chat_stream(
257        &self,
258        messages: Vec<Message>,
259    ) -> impl Stream<Item = Result<String, LlmError>> + Send {
260        let client = self.client.clone();
261        let model = self.model.clone();
262
263        async_stream::try_stream! {
264            let request = ChatRequest {
265                model,
266                messages,
267                temperature: None,
268                stream: true,
269                tools: None,
270            };
271
272            let mut response = client.post(API_URL).json(&request).send().await?;
273            if !response.status().is_success() {
274                let status = response.status().as_u16();
275                // Read error body via chunk() to avoid .text() consuming by value.
276                let mut body = String::new();
277                while let Some(chunk) = response.chunk().await? {
278                    body.push_str(&String::from_utf8_lossy(&chunk));
279                }
280                Err(LlmError::Api { status, body })?;
281            }
282            let mut buffer = String::new();
283
284            while let Some(chunk) = response.chunk().await? {
285                buffer.push_str(&String::from_utf8_lossy(&chunk));
286
287                // Process complete SSE events (delimited by double newline).
288                while let Some(pos) = buffer.find("\n\n") {
289                    let event = buffer[..pos].to_owned();
290                    buffer = buffer[pos + 2..].to_owned();
291
292                    for line in event.lines() {
293                        let data = match line.strip_prefix("data: ") {
294                            Some(d) => d.trim(),
295                            None => continue,
296                        };
297
298                        if data == "[DONE]" {
299                            return;
300                        }
301
302                        if let Ok(parsed) = serde_json::from_str::<StreamChunk>(data) {
303                            for choice in parsed.choices {
304                                if let Some(content) = choice.delta.content {
305                                    yield content;
306                                }
307                            }
308                        }
309                    }
310                }
311            }
312        }
313    }
314}