Skip to main content

oai_sdk/
openai.rs

1// Copyright 2026 Cloudflavor GmbH
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! OpenAI-compatible API endpoints for Ollama
16//!
17//! This module provides OpenAI-compatible endpoints that work with
18//! standard OpenAI client libraries.
19
20use crate::chat::Tool;
21use crate::client::ModelClient;
22use crate::error::{OllamaError, Result};
23use serde::{Deserialize, Serialize};
24use std::collections::HashMap;
25
26// Chat Completions Types
27
28/// A chat message
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ChatMessage {
31    pub role: String,
32    pub content: serde_json::Value,
33}
34
35impl ChatMessage {
36    pub fn new(role: impl Into<String>, content: impl Into<serde_json::Value>) -> Self {
37        Self {
38            role: role.into(),
39            content: content.into(),
40        }
41    }
42
43    pub fn user(content: impl Into<serde_json::Value>) -> Self {
44        Self::new("user", content)
45    }
46
47    pub fn assistant(content: impl Into<serde_json::Value>) -> Self {
48        Self::new("assistant", content)
49    }
50
51    pub fn system(content: impl Into<serde_json::Value>) -> Self {
52        Self::new("system", content)
53    }
54}
55
56/// Request for chat completions
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct ChatCompletionsRequest {
59    pub model: String,
60    pub messages: Vec<ChatMessage>,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub frequency_penalty: Option<f32>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub presence_penalty: Option<f32>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub response_format: Option<serde_json::Value>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub seed: Option<i32>,
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub stop: Option<Vec<String>>,
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub stream: Option<bool>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub stream_options: Option<StreamOptions>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub temperature: Option<f32>,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    pub top_p: Option<f32>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub max_tokens: Option<u32>,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub tools: Option<Vec<Tool>>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    pub reasoning_effort: Option<String>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub reasoning: Option<serde_json::Value>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub tool_choice: Option<serde_json::Value>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub logit_bias: Option<HashMap<String, f32>>,
91    #[serde(skip_serializing_if = "Option::is_none")]
92    pub user: Option<String>,
93    #[serde(skip_serializing_if = "Option::is_none")]
94    pub n: Option<u32>,
95}
96
97/// Stream options for chat completions
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct StreamOptions {
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub include_usage: Option<bool>,
102}
103
104/// Choice in chat completion response
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct Choice {
107    pub index: u32,
108    pub message: ChatMessage,
109    pub finish_reason: String,
110}
111
112/// Usage information
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct Usage {
115    pub prompt_tokens: u32,
116    pub completion_tokens: u32,
117    pub total_tokens: u32,
118}
119
120/// Response for chat completions
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ChatCompletionsResponse {
123    pub id: String,
124    pub choices: Vec<Choice>,
125    pub created: u64,
126    pub model: String,
127    pub usage: Usage,
128}
129
130// Embeddings Types
131
132/// Input for embeddings
133#[derive(Debug, Clone, Serialize, Deserialize)]
134#[serde(untagged)]
135pub enum OpenAIEmbeddingsInput {
136    Single(String),
137    Multiple(Vec<String>),
138    Tokens(Vec<u32>),
139    TokenArrays(Vec<Vec<u32>>),
140}
141
142impl Default for OpenAIEmbeddingsInput {
143    fn default() -> Self {
144        Self::Single(String::new())
145    }
146}
147
148/// Request for embeddings
149#[derive(Debug, Clone, Serialize, Deserialize, Default)]
150pub struct OpenAIEmbeddingsRequest {
151    pub model: String,
152    pub input: OpenAIEmbeddingsInput,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    pub encoding_format: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    pub dimensions: Option<u32>,
157    #[serde(skip_serializing_if = "Option::is_none")]
158    pub user: Option<String>,
159}
160
161/// Embedding vector
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct OpenAIEmbedding {
164    pub embedding: Vec<f32>,
165    pub index: u32,
166    pub object: String,
167}
168
169/// Response for embeddings
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct OpenAIEmbeddingsResponse {
172    pub data: Vec<OpenAIEmbedding>,
173    pub model: String,
174    pub usage: Usage,
175}
176
177// Responses Types
178
179/// Request for responses endpoint
180#[derive(Debug, Clone, Serialize, Deserialize, Default)]
181pub struct ResponsesRequest {
182    pub model: String,
183    #[serde(skip_serializing_if = "Option::is_none")]
184    pub input: Option<String>,
185    #[serde(skip_serializing_if = "Option::is_none")]
186    pub instructions: Option<String>,
187    #[serde(skip_serializing_if = "Option::is_none")]
188    pub tools: Option<Vec<serde_json::Value>>,
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub stream: Option<bool>,
191    #[serde(skip_serializing_if = "Option::is_none")]
192    pub temperature: Option<f32>,
193    #[serde(skip_serializing_if = "Option::is_none")]
194    pub top_p: Option<f32>,
195    #[serde(skip_serializing_if = "Option::is_none")]
196    pub max_output_tokens: Option<u32>,
197    #[serde(skip_serializing_if = "Option::is_none")]
198    pub previous_response_id: Option<String>,
199    #[serde(skip_serializing_if = "Option::is_none")]
200    pub conversation: Option<Vec<serde_json::Value>>,
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub truncation: Option<String>,
203}
204
205/// Response for responses endpoint
206#[derive(Debug, Clone, Serialize, Deserialize)]
207pub struct ResponsesResponse {
208    pub output: String,
209    pub done: bool,
210    pub model: String,
211    pub done_reason: String,
212    pub tool_calls: Vec<serde_json::Value>,
213    pub prompt_evals: u32,
214    pub eval_count: u32,
215    pub total_duration: u64,
216    pub load_duration: u64,
217    pub prompt_eval_duration: u64,
218    pub eval_duration: u64,
219    pub output_eval_count: u32,
220    pub output_eval_duration: u64,
221}
222
223impl ModelClient {
224    /// Create chat completions using OpenAI-compatible API
225    ///
226    /// This endpoint is compatible with OpenAI client libraries.
227    /// Use base URL `http://localhost:11434/v1/` with any API key.
228    ///
229    /// # Example
230    ///
231    /// ```no_run
232    /// use oai_sdk::{ModelClient, ChatCompletionsRequest, ChatMessage};
233    ///
234    /// #[tokio::main]
235    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
236    ///     let client = ModelClient::builder()
237    ///         .base_url("http://localhost:11434")
238    ///         .build()?;
239    ///
240    ///     let request = ChatCompletionsRequest {
241    ///         model: "llama3.1:8b".to_string(),
242    ///         messages: vec![ChatMessage::user("Why is the sky blue?")],
243    ///         stream: Some(false),
244    ///         ..Default::default()
245    ///     };
246    ///
247    ///     let response = client.chat_completions(request).await?;
248    ///     println!("{}", response.choices[0].message.content);
249    ///
250    ///     Ok(())
251    /// }
252    /// ```
253    pub async fn chat_completions(
254        &self,
255        request: ChatCompletionsRequest,
256    ) -> Result<ChatCompletionsResponse> {
257        let url = self
258            .base_url
259            .join("v1/chat/completions")
260            .map_err(OllamaError::UrlError)?;
261        let response = self
262            .client
263            .post(url)
264            .json(&request)
265            .send()
266            .await
267            .map_err(OllamaError::RequestError)?;
268
269        self.handle_response(response, Some(&request.model)).await
270    }
271
272    /// Generate embeddings using OpenAI-compatible API
273    ///
274    /// This endpoint is compatible with OpenAI client libraries.
275    /// Use base URL `http://localhost:11434/v1/` with any API key.
276    ///
277    /// # Example
278    ///
279    /// ```no_run
280    /// use oai_sdk::{ModelClient, OpenAIEmbeddingsRequest, OpenAIEmbeddingsInput};
281    ///
282    /// #[tokio::main]
283    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
284    ///     let client = ModelClient::builder()
285    ///         .base_url("http://localhost:11434")
286    ///         .build()?;
287    ///
288    ///     let request = OpenAIEmbeddingsRequest {
289    ///         model: "llama3.1:8b".to_string(),
290    ///         input: OpenAIEmbeddingsInput::Single("Why is the sky blue?".to_string()),
291    ///         encoding_format: Some("float".to_string()),
292    ///         ..Default::default()
293    ///     };
294    ///
295    ///     let response = client.openai_embeddings(request).await?;
296    ///     println!("Embeddings: {:?}", response.data[0].embedding);
297    ///
298    ///     Ok(())
299    /// }
300    /// ```
301    pub async fn openai_embeddings(
302        &self,
303        request: OpenAIEmbeddingsRequest,
304    ) -> Result<OpenAIEmbeddingsResponse> {
305        let url = self
306            .base_url
307            .join("v1/embeddings")
308            .map_err(OllamaError::UrlError)?;
309        let response = self
310            .client
311            .post(url)
312            .json(&request)
313            .send()
314            .await
315            .map_err(OllamaError::RequestError)?;
316
317        self.handle_response(response, Some(&request.model)).await
318    }
319
320    /// Generate responses using OpenAI-compatible API
321    ///
322    /// This endpoint is compatible with OpenAI client libraries.
323    /// Use base URL `http://localhost:11434/v1/` with any API key.
324    ///
325    /// # Example
326    ///
327    /// ```no_run
328    /// use oai_sdk::{ModelClient, ResponsesRequest};
329    ///
330    /// #[tokio::main]
331    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
332    ///     let client = ModelClient::builder()
333    ///         .base_url("http://localhost:11434")
334    ///         .build()?;
335    ///
336    ///     let request = ResponsesRequest {
337    ///         model: "llama3.1:8b".to_string(),
338    ///         input: Some("Why is the sky blue?".to_string()),
339    ///         stream: Some(false),
340    ///         ..Default::default()
341    ///     };
342    ///
343    ///     let response = client.responses(request).await?;
344    ///     println!("{}", response.output);
345    ///
346    ///     Ok(())
347    /// }
348    /// ```
349    pub async fn responses(&self, request: ResponsesRequest) -> Result<ResponsesResponse> {
350        let url = self
351            .base_url
352            .join("v1/responses")
353            .map_err(OllamaError::UrlError)?;
354        let response = self
355            .client
356            .post(url)
357            .json(&request)
358            .send()
359            .await
360            .map_err(OllamaError::RequestError)?;
361
362        self.handle_response(response, Some(&request.model)).await
363    }
364}