Skip to main content

openclaw_node/providers/
openai.rs

1//! `OpenAI` GPT provider bindings.
2
3use napi::bindgen_prelude::*;
4use napi_derive::napi;
5use std::sync::Arc;
6
7use openclaw_core::secrets::ApiKey;
8use openclaw_providers::traits::ChunkType;
9use openclaw_providers::{OpenAIProvider as RustOpenAIProvider, Provider};
10
11use super::types::{
12    JsCompletionRequest, JsCompletionResponse, JsStreamChunk, convert_request, convert_response,
13};
14use crate::error::OpenClawError;
15
16/// `OpenAI` GPT API provider.
17///
18/// Supports GPT-4o, GPT-4, GPT-3.5, and other `OpenAI` models.
19/// Also works with Azure `OpenAI` and compatible APIs.
20#[napi]
21pub struct OpenAIProvider {
22    inner: Arc<RustOpenAIProvider>,
23}
24
25#[napi]
26impl OpenAIProvider {
27    /// Create a new `OpenAI` provider with API key.
28    ///
29    /// # Arguments
30    ///
31    /// * `api_key` - Your `OpenAI` API key (starts with "sk-")
32    #[napi(constructor)]
33    #[must_use]
34    pub fn new(api_key: String) -> Self {
35        let key = ApiKey::new(api_key);
36        Self {
37            inner: Arc::new(RustOpenAIProvider::new(key)),
38        }
39    }
40
41    /// Create a provider with custom base URL.
42    ///
43    /// Useful for Azure `OpenAI`, `LocalAI`, or other compatible APIs.
44    #[napi(factory)]
45    #[must_use]
46    pub fn with_base_url(api_key: String, base_url: String) -> Self {
47        let key = ApiKey::new(api_key);
48        Self {
49            inner: Arc::new(RustOpenAIProvider::with_base_url(key, base_url)),
50        }
51    }
52
53    /// Create a provider with organization ID.
54    #[napi(factory)]
55    #[must_use]
56    pub fn with_org(api_key: String, org_id: String) -> Self {
57        let key = ApiKey::new(api_key);
58        let provider = RustOpenAIProvider::new(key).with_org_id(org_id);
59        Self {
60            inner: Arc::new(provider),
61        }
62    }
63
64    /// Provider name ("openai").
65    #[napi(getter)]
66    #[must_use]
67    pub fn name(&self) -> String {
68        self.inner.name().to_string()
69    }
70
71    /// List available models.
72    ///
73    /// Returns an array of model IDs like "gpt-4o", "gpt-4-turbo".
74    #[napi]
75    pub async fn list_models(&self) -> Result<Vec<String>> {
76        self.inner
77            .list_models()
78            .await
79            .map_err(|e| OpenClawError::from_provider_error(e).into())
80    }
81
82    /// Create a completion (non-streaming).
83    ///
84    /// # Arguments
85    ///
86    /// * `request` - The completion request with model, messages, etc.
87    ///
88    /// # Returns
89    ///
90    /// The completion response with content, tool calls, and usage.
91    #[napi]
92    pub async fn complete(&self, request: JsCompletionRequest) -> Result<JsCompletionResponse> {
93        let rust_request = convert_request(request);
94        let response = self
95            .inner
96            .complete(rust_request)
97            .await
98            .map_err(OpenClawError::from_provider_error)?;
99        Ok(convert_response(response))
100    }
101
102    /// Create a streaming completion.
103    ///
104    /// The callback is called for each chunk received. Chunks have:
105    /// - `chunk_type`: "`content_block_delta`", "`message_stop`", etc.
106    /// - `delta`: Text content (for delta chunks)
107    /// - `stop_reason`: Why generation stopped (for final chunk)
108    ///
109    /// # Arguments
110    ///
111    /// * `request` - The completion request
112    /// * `callback` - Function called with (error, chunk) for each chunk
113    #[napi]
114    pub fn complete_stream(
115        &self,
116        request: JsCompletionRequest,
117        #[napi(ts_arg_type = "(err: Error | null, chunk: JsStreamChunk | null) => void")]
118        callback: JsFunction,
119    ) -> Result<()> {
120        use futures::StreamExt;
121        use napi::threadsafe_function::{
122            ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode,
123        };
124
125        // Create threadsafe callback
126        let tsfn: ThreadsafeFunction<JsStreamChunk, ErrorStrategy::CalleeHandled> =
127            callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?;
128
129        let inner = self.inner.clone();
130        let rust_request = convert_request(request);
131
132        // Spawn async streaming task
133        napi::tokio::spawn(async move {
134            match inner.complete_stream(rust_request).await {
135                Ok(mut stream) => {
136                    while let Some(chunk_result) = stream.next().await {
137                        match chunk_result {
138                            Ok(chunk) => {
139                                let js_chunk = convert_stream_chunk(
140                                    &chunk.chunk_type,
141                                    chunk.delta.as_deref(),
142                                    chunk.index,
143                                );
144                                let _ = tsfn
145                                    .call(Ok(js_chunk), ThreadsafeFunctionCallMode::NonBlocking);
146                            }
147                            Err(e) => {
148                                let err = OpenClawError::from_provider_error(e);
149                                let _ = tsfn.call(
150                                    Err(napi::Error::from_reason(
151                                        serde_json::to_string(&err).unwrap_or_default(),
152                                    )),
153                                    ThreadsafeFunctionCallMode::NonBlocking,
154                                );
155                                break;
156                            }
157                        }
158                    }
159                }
160                Err(e) => {
161                    let err = OpenClawError::from_provider_error(e);
162                    let _ = tsfn.call(
163                        Err(napi::Error::from_reason(
164                            serde_json::to_string(&err).unwrap_or_default(),
165                        )),
166                        ThreadsafeFunctionCallMode::NonBlocking,
167                    );
168                }
169            }
170        });
171
172        Ok(())
173    }
174}
175
176/// Convert `ChunkType` to `JsStreamChunk`.
177fn convert_stream_chunk(
178    chunk_type: &ChunkType,
179    delta: Option<&str>,
180    index: Option<usize>,
181) -> JsStreamChunk {
182    let (type_str, stop_reason) = match chunk_type {
183        ChunkType::MessageStart => ("message_start", None),
184        ChunkType::ContentBlockStart => ("content_block_start", None),
185        ChunkType::ContentBlockDelta => ("content_block_delta", None),
186        ChunkType::ContentBlockStop => ("content_block_stop", None),
187        ChunkType::MessageDelta => ("message_delta", None),
188        ChunkType::MessageStop => ("message_stop", None),
189    };
190
191    JsStreamChunk {
192        chunk_type: type_str.to_string(),
193        delta: delta.map(std::string::ToString::to_string),
194        index: index.map(|i| i as u32),
195        stop_reason,
196    }
197}