Skip to main content

elizaos_plugin_openai/
lib.rs

1//! OpenAI model provider plugin for elizaOS.
2//!
3//! This crate provides:
4//! - A typed OpenAI HTTP client (`OpenAIClient`)
5//! - Convenience wrappers (`OpenAIPlugin`)
6//! - A helper to construct an elizaOS plugin definition
7#![warn(missing_docs)]
8
9/// Audio helpers and endpoints.
10pub mod audio;
11/// OpenAI API client implementation.
12pub mod client;
13/// Error types and result aliases.
14pub mod error;
15/// Tokenization helpers.
16pub mod tokenization;
17/// Typed request/response models.
18pub mod types;
19
20pub use audio::{detect_audio_mime_type, get_filename_for_data, AudioMimeType};
21pub use client::OpenAIClient;
22pub use error::{OpenAIError, Result};
23pub use tokenization::{count_tokens, detokenize, tokenize, truncate_to_token_limit};
24pub use types::{
25    ChatCompletionChoice, ChatCompletionResponse, ChatMessage, EmbeddingData, EmbeddingParams,
26    EmbeddingResponse, ImageData, ImageDescriptionParams, ImageDescriptionResult,
27    ImageGenerationParams, ImageGenerationResponse, ImageGenerationResult, ImageQuality, ImageSize,
28    ImageStyle, ModelInfo, ModelsResponse, OpenAIConfig, ResearchAnnotation, ResearchParams,
29    ResearchResult, ResponsesApiError, ResponsesApiResponse, TTSOutputFormat, TTSVoice,
30    TextGenerationParams, TextToSpeechParams, TranscriptionParams, TranscriptionResponse,
31    TranscriptionResponseFormat,
32};
33
34use anyhow::Result as AnyhowResult;
35use std::sync::Arc;
36
37/// High-level OpenAI plugin wrapper around an [`OpenAIClient`].
38pub struct OpenAIPlugin {
39    client: OpenAIClient,
40}
41
42impl OpenAIPlugin {
43    /// Create a new [`OpenAIPlugin`] from an [`OpenAIConfig`].
44    pub fn new(config: OpenAIConfig) -> Result<Self> {
45        let client = OpenAIClient::new(config)?;
46        Ok(Self { client })
47    }
48
49    /// Generate text from a user prompt using the default parameters.
50    pub async fn generate_text(&self, prompt: &str) -> Result<String> {
51        let params = TextGenerationParams::new(prompt);
52        self.client.generate_text(&params).await
53    }
54
55    /// Generate text from a user prompt with a system message.
56    pub async fn generate_text_with_system(&self, prompt: &str, system: &str) -> Result<String> {
57        let params = TextGenerationParams::new(prompt).system(system);
58        self.client.generate_text(&params).await
59    }
60
61    /// Generate text from explicitly provided generation parameters.
62    pub async fn generate_text_with_params(&self, params: &TextGenerationParams) -> Result<String> {
63        self.client.generate_text(params).await
64    }
65
66    /// Create an embedding vector for the provided text.
67    pub async fn create_embedding(&self, text: &str) -> Result<Vec<f32>> {
68        let params = EmbeddingParams::new(text);
69        self.client.create_embedding(&params).await
70    }
71
72    /// Get a reference to the underlying [`OpenAIClient`].
73    pub fn client(&self) -> &OpenAIClient {
74        &self.client
75    }
76}
77
78/// Construct an [`OpenAIPlugin`] from environment variables.
79///
80/// Required:
81/// - `OPENAI_API_KEY`
82///
83/// Optional:
84/// - `OPENAI_BASE_URL`
85/// - `OPENAI_SMALL_MODEL`
86/// - `OPENAI_LARGE_MODEL`
87pub fn get_openai_plugin() -> AnyhowResult<OpenAIPlugin> {
88    let api_key = std::env::var("OPENAI_API_KEY")
89        .map_err(|_| anyhow::anyhow!("OPENAI_API_KEY environment variable is required"))?;
90
91    let mut config = OpenAIConfig::new(&api_key);
92
93    if let Ok(base_url) = std::env::var("OPENAI_BASE_URL") {
94        config = config.base_url(&base_url);
95    }
96
97    if let Ok(model) = std::env::var("OPENAI_SMALL_MODEL") {
98        config = config.small_model(&model);
99    }
100
101    if let Ok(model) = std::env::var("OPENAI_LARGE_MODEL") {
102        config = config.large_model(&model);
103    }
104
105    if let Ok(model) = std::env::var("OPENAI_RESEARCH_MODEL") {
106        config = config.research_model(&model);
107    }
108
109    if let Ok(timeout) = std::env::var("OPENAI_RESEARCH_TIMEOUT") {
110        if let Ok(timeout_secs) = timeout.parse::<u64>() {
111            config = config.research_timeout_secs(timeout_secs);
112        }
113    }
114
115    OpenAIPlugin::new(config).map_err(|e| anyhow::anyhow!("Failed to create OpenAI plugin: {}", e))
116}
117
118/// Create an elizaOS [`elizaos::types::Plugin`] wired to OpenAI model handlers.
119pub fn create_openai_elizaos_plugin() -> AnyhowResult<elizaos::types::Plugin> {
120    use elizaos::types::{Plugin, PluginDefinition};
121    use std::collections::HashMap;
122
123    let openai = Arc::new(get_openai_plugin()?);
124
125    let mut model_handlers: HashMap<String, elizaos::types::ModelHandlerFn> = HashMap::new();
126
127    let openai_large = openai.clone();
128    model_handlers.insert(
129        "TEXT_LARGE".to_string(),
130        Box::new(move |params: serde_json::Value| {
131            let openai = openai_large.clone();
132            Box::pin(async move {
133                let prompt = params.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
134                let system = params.get("system").and_then(|v| v.as_str());
135                let temperature = params
136                    .get("temperature")
137                    .and_then(|v| v.as_f64())
138                    .unwrap_or(0.7) as f32;
139
140                let mut text_params = TextGenerationParams::new(prompt).temperature(temperature);
141
142                if let Some(sys) = system {
143                    text_params = text_params.system(sys);
144                }
145
146                openai
147                    .generate_text_with_params(&text_params)
148                    .await
149                    .map_err(|e| anyhow::anyhow!("OpenAI error: {}", e))
150            })
151        }),
152    );
153
154    let openai_small = openai.clone();
155    model_handlers.insert(
156        "TEXT_SMALL".to_string(),
157        Box::new(move |params: serde_json::Value| {
158            let openai = openai_small.clone();
159            Box::pin(async move {
160                let prompt = params.get("prompt").and_then(|v| v.as_str()).unwrap_or("");
161                let system = params.get("system").and_then(|v| v.as_str());
162                let temperature = params
163                    .get("temperature")
164                    .and_then(|v| v.as_f64())
165                    .unwrap_or(0.7) as f32;
166
167                let mut text_params = TextGenerationParams::new(prompt).temperature(temperature);
168
169                if let Some(sys) = system {
170                    text_params = text_params.system(sys);
171                }
172
173                openai
174                    .generate_text_with_params(&text_params)
175                    .await
176                    .map_err(|e| anyhow::anyhow!("OpenAI error: {}", e))
177            })
178        }),
179    );
180
181    let openai_research = openai.clone();
182    model_handlers.insert(
183        "RESEARCH".to_string(),
184        Box::new(move |params: serde_json::Value| {
185            let openai = openai_research.clone();
186            Box::pin(async move {
187                let input = params.get("input").and_then(|v| v.as_str()).unwrap_or("");
188                let instructions = params.get("instructions").and_then(|v| v.as_str());
189                let background = params.get("background").and_then(|v| v.as_bool());
190                let tools = params.get("tools").and_then(|v| v.as_array()).cloned();
191                let max_tool_calls = params.get("maxToolCalls").and_then(|v| v.as_i64()).map(|v| v as i32);
192                let model = params.get("model").and_then(|v| v.as_str());
193
194                let mut research_params = ResearchParams::new(input);
195
196                if let Some(inst) = instructions {
197                    research_params = research_params.instructions(inst);
198                }
199                if let Some(bg) = background {
200                    research_params = research_params.background(bg);
201                }
202                if let Some(t) = tools {
203                    research_params = research_params.tools(t);
204                }
205                if let Some(max) = max_tool_calls {
206                    research_params = research_params.max_tool_calls(max);
207                }
208                if let Some(m) = model {
209                    research_params = research_params.model(m);
210                }
211
212                let result = openai
213                    .client()
214                    .deep_research(&research_params)
215                    .await
216                    .map_err(|e| anyhow::anyhow!("OpenAI error: {}", e))?;
217
218                // Convert result to JSON string
219                let result_json = serde_json::json!({
220                    "id": result.id,
221                    "text": result.text,
222                    "annotations": result.annotations.iter().map(|a| serde_json::json!({
223                        "url": a.url,
224                        "title": a.title,
225                        "startIndex": a.start_index,
226                        "endIndex": a.end_index,
227                    })).collect::<Vec<_>>(),
228                    "outputItems": result.output_items,
229                    "status": result.status,
230                });
231
232                Ok(serde_json::to_string(&result_json).unwrap_or_default())
233            })
234        }),
235    );
236
237    Ok(Plugin {
238        definition: PluginDefinition {
239            name: "openai".to_string(),
240            description: "OpenAI model provider for elizaOS".to_string(),
241            ..Default::default()
242        },
243        model_handlers,
244        ..Default::default()
245    })
246}