agents_runtime/providers/
openai.rs

1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone)]
8pub struct OpenAiConfig {
9    pub api_key: String,
10    pub model: String,
11    pub api_url: Option<String>,
12}
13
14impl OpenAiConfig {
15    pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
16        Self {
17            api_key: api_key.into(),
18            model: model.into(),
19            api_url: None,
20        }
21    }
22
23    pub fn with_api_url(mut self, api_url: Option<String>) -> Self {
24        self.api_url = api_url;
25        self
26    }
27}
28
29pub struct OpenAiChatModel {
30    client: Client,
31    config: OpenAiConfig,
32}
33
34impl OpenAiChatModel {
35    pub fn new(config: OpenAiConfig) -> anyhow::Result<Self> {
36        Ok(Self {
37            client: Client::builder()
38                .user_agent("rust-deep-agents-sdk/0.1")
39                .build()?,
40            config,
41        })
42    }
43}
44
45#[derive(Serialize)]
46struct ChatRequest<'a> {
47    model: &'a str,
48    messages: &'a [OpenAiMessage],
49}
50
51#[derive(Serialize)]
52struct OpenAiMessage {
53    role: &'static str,
54    content: String,
55}
56
57#[derive(Deserialize)]
58struct ChatResponse {
59    choices: Vec<Choice>,
60}
61
62#[derive(Deserialize)]
63struct Choice {
64    message: ChoiceMessage,
65}
66
67#[derive(Deserialize)]
68struct ChoiceMessage {
69    content: String,
70}
71
72fn to_openai_messages(request: &LlmRequest) -> Vec<OpenAiMessage> {
73    let mut messages = Vec::with_capacity(request.messages.len() + 1);
74    messages.push(OpenAiMessage {
75        role: "system",
76        content: request.system_prompt.clone(),
77    });
78
79    // Filter and validate message sequence for OpenAI compatibility
80    let mut last_was_tool_call = false;
81
82    for msg in &request.messages {
83        let role = match msg.role {
84            MessageRole::User => "user",
85            MessageRole::Agent => "assistant",
86            MessageRole::Tool => {
87                // Only include tool messages if they follow a tool call
88                if !last_was_tool_call {
89                    tracing::warn!("Skipping tool message without preceding tool_calls");
90                    continue;
91                }
92                "tool"
93            }
94            MessageRole::System => "system",
95        };
96
97        let content = match &msg.content {
98            MessageContent::Text(text) => text.clone(),
99            MessageContent::Json(value) => value.to_string(),
100        };
101
102        // Check if this assistant message contains tool calls
103        last_was_tool_call =
104            matches!(msg.role, MessageRole::Agent) && content.contains("tool_calls");
105
106        messages.push(OpenAiMessage { role, content });
107    }
108    messages
109}
110
111#[async_trait]
112impl LanguageModel for OpenAiChatModel {
113    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
114        let messages = to_openai_messages(&request);
115        let body = ChatRequest {
116            model: &self.config.model,
117            messages: &messages,
118        };
119        let url = self
120            .config
121            .api_url
122            .as_deref()
123            .unwrap_or("https://api.openai.com/v1/chat/completions");
124
125        // Debug logging
126        tracing::debug!(
127            "OpenAI request: model={}, messages={}",
128            self.config.model,
129            messages.len()
130        );
131        for (i, msg) in messages.iter().enumerate() {
132            tracing::debug!(
133                "Message {}: role={}, content_len={}",
134                i,
135                msg.role,
136                msg.content.len()
137            );
138            if msg.content.len() < 500 {
139                tracing::debug!("Message {} content: {}", i, msg.content);
140            }
141        }
142
143        let response = self
144            .client
145            .post(url)
146            .bearer_auth(&self.config.api_key)
147            .json(&body)
148            .send()
149            .await?;
150
151        if !response.status().is_success() {
152            let status = response.status();
153            let error_text = response.text().await.unwrap_or_default();
154            tracing::error!("OpenAI API error: status={}, body={}", status, error_text);
155            return Err(anyhow::anyhow!(
156                "OpenAI API error: {} - {}",
157                status,
158                error_text
159            ));
160        }
161
162        let data: ChatResponse = response.json().await?;
163        let choice = data
164            .choices
165            .into_iter()
166            .next()
167            .ok_or_else(|| anyhow::anyhow!("OpenAI response missing choices"))?;
168
169        Ok(LlmResponse {
170            message: AgentMessage {
171                role: MessageRole::Agent,
172                content: MessageContent::Text(choice.message.content),
173                metadata: None,
174            },
175        })
176    }
177}