Skip to main content

openclaw_node/providers/
anthropic.rs

1//! Anthropic Claude provider bindings.
2
3use napi::bindgen_prelude::*;
4use napi_derive::napi;
5use std::sync::Arc;
6
7use openclaw_core::secrets::ApiKey;
8use openclaw_providers::traits::ChunkType;
9use openclaw_providers::{AnthropicProvider as RustAnthropicProvider, Provider};
10
11use super::types::{
12    JsCompletionRequest, JsCompletionResponse, JsStreamChunk, convert_request, convert_response,
13};
14use crate::error::OpenClawError;
15
16/// Anthropic Claude API provider.
17///
18/// Supports Claude 3.5 Sonnet, Claude 3.5 Haiku, and other Claude models.
19#[napi]
20pub struct AnthropicProvider {
21    inner: Arc<RustAnthropicProvider>,
22}
23
24#[napi]
25impl AnthropicProvider {
26    /// Create a new Anthropic provider with API key.
27    ///
28    /// # Arguments
29    ///
30    /// * `api_key` - Your Anthropic API key (starts with "sk-ant-")
31    #[napi(constructor)]
32    #[must_use]
33    pub fn new(api_key: String) -> Self {
34        let key = ApiKey::new(api_key);
35        Self {
36            inner: Arc::new(RustAnthropicProvider::new(key)),
37        }
38    }
39
40    /// Create a provider with custom base URL.
41    ///
42    /// Useful for proxies or custom endpoints.
43    #[napi(factory)]
44    #[must_use]
45    pub fn with_base_url(api_key: String, base_url: String) -> Self {
46        let key = ApiKey::new(api_key);
47        Self {
48            inner: Arc::new(RustAnthropicProvider::with_base_url(key, base_url)),
49        }
50    }
51
52    /// Provider name ("anthropic").
53    #[napi(getter)]
54    #[must_use]
55    pub fn name(&self) -> String {
56        self.inner.name().to_string()
57    }
58
59    /// List available models.
60    ///
61    /// Returns an array of model IDs like "claude-3-5-sonnet-20241022".
62    #[napi]
63    pub async fn list_models(&self) -> Result<Vec<String>> {
64        self.inner
65            .list_models()
66            .await
67            .map_err(|e| OpenClawError::from_provider_error(e).into())
68    }
69
70    /// Create a completion (non-streaming).
71    ///
72    /// # Arguments
73    ///
74    /// * `request` - The completion request with model, messages, etc.
75    ///
76    /// # Returns
77    ///
78    /// The completion response with content, tool calls, and usage.
79    #[napi]
80    pub async fn complete(&self, request: JsCompletionRequest) -> Result<JsCompletionResponse> {
81        let rust_request = convert_request(request);
82        let response = self
83            .inner
84            .complete(rust_request)
85            .await
86            .map_err(OpenClawError::from_provider_error)?;
87        Ok(convert_response(response))
88    }
89
90    /// Create a streaming completion.
91    ///
92    /// The callback is called for each chunk received. Chunks have:
93    /// - `chunk_type`: "`content_block_delta`", "`message_stop`", etc.
94    /// - `delta`: Text content (for delta chunks)
95    /// - `stop_reason`: Why generation stopped (for final chunk)
96    ///
97    /// # Arguments
98    ///
99    /// * `request` - The completion request
100    /// * `callback` - Function called with (error, chunk) for each chunk
101    #[napi]
102    pub fn complete_stream(
103        &self,
104        request: JsCompletionRequest,
105        #[napi(ts_arg_type = "(err: Error | null, chunk: JsStreamChunk | null) => void")]
106        callback: JsFunction,
107    ) -> Result<()> {
108        use futures::StreamExt;
109        use napi::threadsafe_function::{
110            ErrorStrategy, ThreadsafeFunction, ThreadsafeFunctionCallMode,
111        };
112
113        // Create threadsafe callback
114        let tsfn: ThreadsafeFunction<JsStreamChunk, ErrorStrategy::CalleeHandled> =
115            callback.create_threadsafe_function(0, |ctx| Ok(vec![ctx.value]))?;
116
117        let inner = self.inner.clone();
118        let rust_request = convert_request(request);
119
120        // Spawn async streaming task
121        napi::tokio::spawn(async move {
122            match inner.complete_stream(rust_request).await {
123                Ok(mut stream) => {
124                    while let Some(chunk_result) = stream.next().await {
125                        match chunk_result {
126                            Ok(chunk) => {
127                                let js_chunk = convert_stream_chunk(
128                                    &chunk.chunk_type,
129                                    chunk.delta.as_deref(),
130                                    chunk.index,
131                                );
132                                let _ = tsfn
133                                    .call(Ok(js_chunk), ThreadsafeFunctionCallMode::NonBlocking);
134                            }
135                            Err(e) => {
136                                let err = OpenClawError::from_provider_error(e);
137                                let _ = tsfn.call(
138                                    Err(napi::Error::from_reason(
139                                        serde_json::to_string(&err).unwrap_or_default(),
140                                    )),
141                                    ThreadsafeFunctionCallMode::NonBlocking,
142                                );
143                                break;
144                            }
145                        }
146                    }
147                }
148                Err(e) => {
149                    let err = OpenClawError::from_provider_error(e);
150                    let _ = tsfn.call(
151                        Err(napi::Error::from_reason(
152                            serde_json::to_string(&err).unwrap_or_default(),
153                        )),
154                        ThreadsafeFunctionCallMode::NonBlocking,
155                    );
156                }
157            }
158        });
159
160        Ok(())
161    }
162}
163
164/// Convert `ChunkType` to `JsStreamChunk`.
165fn convert_stream_chunk(
166    chunk_type: &ChunkType,
167    delta: Option<&str>,
168    index: Option<usize>,
169) -> JsStreamChunk {
170    let (type_str, stop_reason) = match chunk_type {
171        ChunkType::MessageStart => ("message_start", None),
172        ChunkType::ContentBlockStart => ("content_block_start", None),
173        ChunkType::ContentBlockDelta => ("content_block_delta", None),
174        ChunkType::ContentBlockStop => ("content_block_stop", None),
175        ChunkType::MessageDelta => ("message_delta", None),
176        ChunkType::MessageStop => ("message_stop", None),
177    };
178
179    JsStreamChunk {
180        chunk_type: type_str.to_string(),
181        delta: delta.map(std::string::ToString::to_string),
182        index: index.map(|i| i as u32),
183        stop_reason,
184    }
185}