infernum_server/
openai.rs

1//! OpenAI-compatible API types.
2//!
3//! These types mirror the OpenAI API specification for drop-in compatibility.
4
5use serde::{Deserialize, Serialize};
6
7// === Chat Completions ===
8
9/// Chat completion request (OpenAI-compatible).
10#[derive(Debug, Clone, Deserialize)]
11pub struct ChatCompletionRequest {
12    /// Model to use.
13    pub model: String,
14    /// Messages in the conversation.
15    pub messages: Vec<ChatMessage>,
16    /// Temperature for sampling (0.0 - 2.0).
17    #[serde(default)]
18    pub temperature: Option<f32>,
19    /// Top-p (nucleus) sampling.
20    #[serde(default)]
21    pub top_p: Option<f32>,
22    /// Number of completions to generate.
23    #[serde(default)]
24    pub n: Option<u32>,
25    /// Whether to stream the response.
26    #[serde(default)]
27    pub stream: Option<bool>,
28    /// Stop sequences.
29    #[serde(default)]
30    pub stop: Option<Vec<String>>,
31    /// Maximum tokens to generate.
32    #[serde(default)]
33    pub max_tokens: Option<u32>,
34    /// Presence penalty (-2.0 to 2.0).
35    #[serde(default)]
36    pub presence_penalty: Option<f32>,
37    /// Frequency penalty (-2.0 to 2.0).
38    #[serde(default)]
39    pub frequency_penalty: Option<f32>,
40    /// User identifier for abuse monitoring.
41    #[serde(default)]
42    pub user: Option<String>,
43}
44
45/// A chat message.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ChatMessage {
48    /// Role (system, user, assistant, tool).
49    pub role: String,
50    /// Message content.
51    pub content: String,
52    /// Optional name for the sender.
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub name: Option<String>,
55}
56
57/// Chat completion response.
58#[derive(Debug, Clone, Serialize)]
59pub struct ChatCompletionResponse {
60    /// Response ID.
61    pub id: String,
62    /// Object type ("chat.completion").
63    pub object: String,
64    /// Creation timestamp (Unix epoch).
65    pub created: i64,
66    /// Model used.
67    pub model: String,
68    /// Generated choices.
69    pub choices: Vec<ChatChoice>,
70    /// Token usage statistics.
71    pub usage: Usage,
72}
73
74/// A chat completion choice.
75#[derive(Debug, Clone, Serialize)]
76pub struct ChatChoice {
77    /// Choice index.
78    pub index: u32,
79    /// Generated message.
80    pub message: ChatMessage,
81    /// Finish reason (stop, length, tool_calls, content_filter).
82    pub finish_reason: String,
83}
84
85/// Streaming chat completion chunk.
86#[derive(Debug, Clone, Serialize)]
87pub struct ChatCompletionChunk {
88    /// Response ID.
89    pub id: String,
90    /// Object type ("chat.completion.chunk").
91    pub object: String,
92    /// Creation timestamp.
93    pub created: i64,
94    /// Model used.
95    pub model: String,
96    /// Streaming choices.
97    pub choices: Vec<ChatChunkChoice>,
98}
99
100/// A streaming chat choice.
101#[derive(Debug, Clone, Serialize)]
102pub struct ChatChunkChoice {
103    /// Choice index.
104    pub index: u32,
105    /// Incremental content.
106    pub delta: ChatDelta,
107    /// Finish reason (only present on final chunk).
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub finish_reason: Option<String>,
110}
111
112/// Incremental chat content.
113#[derive(Debug, Clone, Serialize, Default)]
114pub struct ChatDelta {
115    /// Role (only on first chunk).
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub role: Option<String>,
118    /// Content fragment.
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub content: Option<String>,
121}
122
123// === Text Completions ===
124
125/// Text completion request (OpenAI-compatible).
126#[derive(Debug, Clone, Deserialize)]
127pub struct CompletionRequest {
128    /// Model to use.
129    pub model: String,
130    /// The prompt to complete.
131    pub prompt: String,
132    /// Temperature for sampling.
133    #[serde(default)]
134    pub temperature: Option<f32>,
135    /// Top-p sampling.
136    #[serde(default)]
137    pub top_p: Option<f32>,
138    /// Number of completions.
139    #[serde(default)]
140    pub n: Option<u32>,
141    /// Whether to stream.
142    #[serde(default)]
143    pub stream: Option<bool>,
144    /// Stop sequences.
145    #[serde(default)]
146    pub stop: Option<Vec<String>>,
147    /// Maximum tokens.
148    #[serde(default)]
149    pub max_tokens: Option<u32>,
150    /// Include log probabilities.
151    #[serde(default)]
152    pub logprobs: Option<u32>,
153    /// Echo the prompt.
154    #[serde(default)]
155    pub echo: Option<bool>,
156    /// Suffix to append.
157    #[serde(default)]
158    pub suffix: Option<String>,
159}
160
161/// Text completion response.
162#[derive(Debug, Clone, Serialize)]
163pub struct CompletionResponse {
164    /// Response ID.
165    pub id: String,
166    /// Object type ("text_completion").
167    pub object: String,
168    /// Creation timestamp.
169    pub created: i64,
170    /// Model used.
171    pub model: String,
172    /// Generated choices.
173    pub choices: Vec<CompletionChoice>,
174    /// Token usage.
175    pub usage: Usage,
176}
177
178/// A text completion choice.
179#[derive(Debug, Clone, Serialize)]
180pub struct CompletionChoice {
181    /// Generated text.
182    pub text: String,
183    /// Choice index.
184    pub index: u32,
185    /// Finish reason.
186    pub finish_reason: String,
187    /// Log probabilities (if requested).
188    #[serde(skip_serializing_if = "Option::is_none")]
189    pub logprobs: Option<LogProbs>,
190}
191
192/// Log probability information.
193#[derive(Debug, Clone, Serialize)]
194pub struct LogProbs {
195    /// Token strings.
196    pub tokens: Vec<String>,
197    /// Token log probabilities.
198    pub token_logprobs: Vec<f32>,
199    /// Top log probabilities.
200    pub top_logprobs: Vec<std::collections::HashMap<String, f32>>,
201    /// Text offsets.
202    pub text_offset: Vec<u32>,
203}
204
205// === Embeddings ===
206
207/// Embedding request (OpenAI-compatible).
208#[derive(Debug, Clone, Deserialize)]
209pub struct EmbeddingRequest {
210    /// Model to use.
211    pub model: String,
212    /// Input text(s) to embed.
213    pub input: EmbeddingInput,
214    /// Encoding format (float or base64).
215    #[serde(default)]
216    pub encoding_format: Option<String>,
217    /// Dimensions to truncate to.
218    #[serde(default)]
219    pub dimensions: Option<u32>,
220}
221
222/// Embedding input - single string or array.
223#[derive(Debug, Clone, Deserialize)]
224#[serde(untagged)]
225pub enum EmbeddingInput {
226    /// Single text input.
227    Single(String),
228    /// Multiple text inputs.
229    Multiple(Vec<String>),
230}
231
232/// Embedding response.
233#[derive(Debug, Clone, Serialize)]
234pub struct EmbeddingResponse {
235    /// Object type ("list").
236    pub object: String,
237    /// Embedding data.
238    pub data: Vec<EmbeddingData>,
239    /// Model used.
240    pub model: String,
241    /// Usage statistics.
242    pub usage: EmbeddingUsage,
243}
244
245/// A single embedding result.
246#[derive(Debug, Clone, Serialize)]
247pub struct EmbeddingData {
248    /// Object type ("embedding").
249    pub object: String,
250    /// Index in the input array.
251    pub index: u32,
252    /// The embedding vector.
253    pub embedding: Vec<f32>,
254}
255
256/// Embedding usage statistics.
257#[derive(Debug, Clone, Serialize)]
258pub struct EmbeddingUsage {
259    /// Prompt tokens used.
260    pub prompt_tokens: u32,
261    /// Total tokens used.
262    pub total_tokens: u32,
263}
264
265// === Models ===
266
267/// Models list response.
268#[derive(Debug, Clone, Serialize)]
269pub struct ModelsResponse {
270    /// Object type ("list").
271    pub object: String,
272    /// Available models.
273    pub data: Vec<ModelObject>,
274}
275
276/// Model information.
277#[derive(Debug, Clone, Serialize)]
278pub struct ModelObject {
279    /// Model ID.
280    pub id: String,
281    /// Object type ("model").
282    pub object: String,
283    /// Creation timestamp.
284    pub created: i64,
285    /// Owner (e.g., "openai", "infernum").
286    pub owned_by: String,
287}
288
289// === Common ===
290
291/// Token usage statistics.
292#[derive(Debug, Clone, Default, Serialize, Deserialize)]
293pub struct Usage {
294    /// Tokens in the prompt.
295    pub prompt_tokens: u32,
296    /// Tokens generated.
297    pub completion_tokens: u32,
298    /// Total tokens.
299    pub total_tokens: u32,
300}
301
302impl Usage {
303    /// Creates new usage statistics.
304    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
305        Self {
306            prompt_tokens,
307            completion_tokens,
308            total_tokens: prompt_tokens + completion_tokens,
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_chat_request_deserialization() {
319        let json = r#"{
320            "model": "gpt-4",
321            "messages": [
322                {"role": "system", "content": "You are a helpful assistant."},
323                {"role": "user", "content": "Hello!"}
324            ],
325            "temperature": 0.7,
326            "max_tokens": 100
327        }"#;
328
329        let req: ChatCompletionRequest = serde_json::from_str(json).unwrap();
330        assert_eq!(req.model, "gpt-4");
331        assert_eq!(req.messages.len(), 2);
332        assert_eq!(req.temperature, Some(0.7));
333        assert_eq!(req.max_tokens, Some(100));
334    }
335
336    #[test]
337    fn test_chat_response_serialization() {
338        let response = ChatCompletionResponse {
339            id: "chatcmpl-123".to_string(),
340            object: "chat.completion".to_string(),
341            created: 1677652288,
342            model: "gpt-4".to_string(),
343            choices: vec![ChatChoice {
344                index: 0,
345                message: ChatMessage {
346                    role: "assistant".to_string(),
347                    content: "Hello!".to_string(),
348                    name: None,
349                },
350                finish_reason: "stop".to_string(),
351            }],
352            usage: Usage::new(10, 5),
353        };
354
355        let json = serde_json::to_string(&response).unwrap();
356        assert!(json.contains("chatcmpl-123"));
357        assert!(json.contains("Hello!"));
358    }
359
360    #[test]
361    fn test_embedding_input_variants() {
362        // Single input
363        let json = r#"{"model": "text-embedding-3-small", "input": "Hello"}"#;
364        let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
365        matches!(req.input, EmbeddingInput::Single(_));
366
367        // Multiple inputs
368        let json = r#"{"model": "text-embedding-3-small", "input": ["Hello", "World"]}"#;
369        let req: EmbeddingRequest = serde_json::from_str(json).unwrap();
370        matches!(req.input, EmbeddingInput::Multiple(_));
371    }
372
373    #[test]
374    fn test_usage() {
375        let usage = Usage::new(100, 50);
376        assert_eq!(usage.prompt_tokens, 100);
377        assert_eq!(usage.completion_tokens, 50);
378        assert_eq!(usage.total_tokens, 150);
379    }
380}