openclaw_node/providers/
openai.rs1use napi::bindgen_prelude::*;
4use napi_derive::napi;
5use std::sync::Arc;
6
7use openclaw_core::secrets::ApiKey;
8use openclaw_providers::traits::ChunkType;
9use openclaw_providers::{OpenAIProvider as RustOpenAIProvider, Provider};
10
11use super::types::{
12 JsCompletionRequest, JsCompletionResponse, JsStreamChunk, convert_request, convert_response,
13};
14use crate::error::OpenClawError;
15
16#[napi]
21pub struct OpenAIProvider {
22 inner: Arc<RustOpenAIProvider>,
23}
24
25#[napi]
26impl OpenAIProvider {
27 #[napi(constructor)]
33 #[must_use]
34 pub fn new(api_key: String) -> Self {
35 let key = ApiKey::new(api_key);
36 Self {
37 inner: Arc::new(RustOpenAIProvider::new(key)),
38 }
39 }
40
41 #[napi(factory)]
45 #[must_use]
46 pub fn with_base_url(api_key: String, base_url: String) -> Self {
47 let key = ApiKey::new(api_key);
48 Self {
49 inner: Arc::new(RustOpenAIProvider::with_base_url(key, base_url)),
50 }
51 }
52
53 #[napi(factory)]
55 #[must_use]
56 pub fn with_org(api_key: String, org_id: String) -> Self {
57 let key = ApiKey::new(api_key);
58 let provider = RustOpenAIProvider::new(key).with_org_id(org_id);
59 Self {
60 inner: Arc::new(provider),
61 }
62 }
63
64 #[napi(getter)]
66 #[must_use]
67 pub fn name(&self) -> String {
68 self.inner.name().to_string()
69 }
70
71 #[napi]
75 pub async fn list_models(&self) -> Result<Vec<String>> {
76 self.inner
77 .list_models()
78 .await
79 .map_err(|e| OpenClawError::from_provider_error(e).into())
80 }
81
82 #[napi]
92 pub async fn complete(&self, request: JsCompletionRequest) -> Result<JsCompletionResponse> {
93 let rust_request = convert_request(request);
94 let response = self
95 .inner
96 .complete(rust_request)
97 .await
98 .map_err(OpenClawError::from_provider_error)?;
99 Ok(convert_response(response))
100 }
101
102 #[napi]
114 pub fn complete_stream(
115 &self,
116 request: JsCompletionRequest,
117 #[napi(ts_arg_type = "(err: Error | null, chunk: JsStreamChunk | null) => void")]
118 callback: JsFunction,
119 ) -> Result<()> {
120 use futures::StreamExt;
121 use napi::threadsafe_function::{
122 ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode,
123 };
124
125 let tsfn: ThreadsafeFunction<JsStreamChunk, ErrorStrategy::CalleeHandled> =
127 callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?;
128
129 let inner = self.inner.clone();
130 let rust_request = convert_request(request);
131
132 napi::tokio::spawn(async move {
134 match inner.complete_stream(rust_request).await {
135 Ok(mut stream) => {
136 while let Some(chunk_result) = stream.next().await {
137 match chunk_result {
138 Ok(chunk) => {
139 let js_chunk = convert_stream_chunk(
140 &chunk.chunk_type,
141 chunk.delta.as_deref(),
142 chunk.index,
143 );
144 let _ = tsfn
145 .call(Ok(js_chunk), ThreadsafeFunctionCallMode::NonBlocking);
146 }
147 Err(e) => {
148 let err = OpenClawError::from_provider_error(e);
149 let _ = tsfn.call(
150 Err(napi::Error::from_reason(
151 serde_json::to_string(&err).unwrap_or_default(),
152 )),
153 ThreadsafeFunctionCallMode::NonBlocking,
154 );
155 break;
156 }
157 }
158 }
159 }
160 Err(e) => {
161 let err = OpenClawError::from_provider_error(e);
162 let _ = tsfn.call(
163 Err(napi::Error::from_reason(
164 serde_json::to_string(&err).unwrap_or_default(),
165 )),
166 ThreadsafeFunctionCallMode::NonBlocking,
167 );
168 }
169 }
170 });
171
172 Ok(())
173 }
174}
175
176fn convert_stream_chunk(
178 chunk_type: &ChunkType,
179 delta: Option<&str>,
180 index: Option<usize>,
181) -> JsStreamChunk {
182 let (type_str, stop_reason) = match chunk_type {
183 ChunkType::MessageStart => ("message_start", None),
184 ChunkType::ContentBlockStart => ("content_block_start", None),
185 ChunkType::ContentBlockDelta => ("content_block_delta", None),
186 ChunkType::ContentBlockStop => ("content_block_stop", None),
187 ChunkType::MessageDelta => ("message_delta", None),
188 ChunkType::MessageStop => ("message_stop", None),
189 };
190
191 JsStreamChunk {
192 chunk_type: type_str.to_string(),
193 delta: delta.map(std::string::ToString::to_string),
194 index: index.map(|i| i as u32),
195 stop_reason,
196 }
197}