ferrox_openai_api/
models.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Serialize, Deserialize, Clone)]
4pub struct Message {
5    pub role: String,
6    #[serde(skip_serializing_if = "Option::is_none")]
7    pub content: Option<String>,
8    #[serde(skip_serializing_if = "Option::is_none")]
9    pub tool_calls: Option<Vec<ToolCall>>,
10    #[serde(skip_serializing_if = "Option::is_none")]
11    pub tool_call_id: Option<String>,
12}
13
14#[derive(Debug, Serialize)]
15pub struct CompletionRequest {
16    /// ID of the model to use
17    pub model: String,
18    /// A list of messages comprising the conversation so far
19    pub messages: Vec<Message>,
20    /// What sampling temperature to use, between 0 and 2
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub temperature: Option<f32>,
23    /// Tool choice - can be "none", "auto" or a specific tool
24    #[serde(skip_serializing_if = "Option::is_none")]
25    pub tool_choice: Option<String>,
26    /// An alternative to sampling with temperature
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub top_p: Option<f32>,
29    /// How many chat completion choices to generate for each input message
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub n: Option<i32>,
32    /// Whether to stream back partial progress
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub stream: Option<bool>,
35    /// Up to 4 sequences where the API will stop generating further tokens
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub stop: Option<Vec<String>>,
38    /// The maximum number of tokens to generate in the chat completion
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub max_tokens: Option<i32>,
41    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on
42    /// whether they appear in the text so far
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub presence_penalty: Option<f32>,
45    /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their
46    /// existing frequency in the text so far
47    #[serde(skip_serializing_if = "Option::is_none")]
48    pub frequency_penalty: Option<f32>,
49    /// Modify the likelihood of specified tokens appearing in the completion
50    #[serde(skip_serializing_if = "Option::is_none")]
51    pub logit_bias: Option<std::collections::HashMap<String, f32>>,
52    /// A unique identifier representing your end-user
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub user: Option<String>,
55    /// Available tools/functions that the model can use
56    #[serde(skip_serializing_if = "Option::is_none")]
57    pub tools: Option<Vec<Tool>>,
58    /// Enable parallel tool calls
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub parallel_tool_calls: Option<bool>,
61}
62
63impl Default for CompletionRequest {
64    fn default() -> Self {
65        Self {
66            model: Model::OpenAI(OpenAIModel::GPT35Turbo).as_str().to_string(),
67            messages: Vec::new(),
68            temperature: None,
69            tool_choice: None,
70            top_p: None,
71            n: None,
72            stream: None,
73            stop: None,
74            max_tokens: None,
75            presence_penalty: None,
76            frequency_penalty: None,
77            logit_bias: None,
78            user: None,
79            tools: None,
80            parallel_tool_calls: None,
81        }
82    }
83}
84
85#[derive(Debug, Deserialize, Serialize)]
86pub struct CompletionResponse {
87    pub id: String,
88    pub choices: Vec<Choice>,
89}
90
91#[derive(Debug, Deserialize, Serialize)]
92pub struct Choice {
93    pub message: Message,
94    pub finish_reason: String,
95    pub index: i32,
96}
97
98#[derive(Debug, Serialize, Deserialize, Clone)]
99pub struct ToolCall {
100    pub id: String,
101    #[serde(rename = "type")]
102    pub tool_type: String,
103    pub function: ToolDefinition,
104}
105
106#[derive(Debug, Serialize, Deserialize, Clone)]
107pub struct ToolDefinition {
108    pub name: String,
109    pub arguments: String,
110}
111
112#[derive(Debug, Clone)]
113pub enum Model {
114    OpenAI(OpenAIModel),
115    Anthropic(AnthropicModel),
116}
117
118#[derive(Debug, Clone)]
119pub enum OpenAIModel {
120    GPT4,
121    GPT4Turbo,
122    GPT4Mini,
123    GPT4RealTimePreview,
124    GPT40,
125    GPT35Turbo,
126}
127
128#[derive(Debug, Clone)]
129pub enum AnthropicModel {
130    Claude3Opus,
131    Claude3Sonnet,
132}
133
134impl Model {
135    pub fn as_str(&self) -> &'static str {
136        match self {
137            Model::OpenAI(model) => model.as_str(),
138            Model::Anthropic(model) => model.as_str(),
139        }
140    }
141}
142
143impl OpenAIModel {
144    pub fn as_str(&self) -> &'static str {
145        match self {
146            OpenAIModel::GPT4 => "gpt-4",
147            OpenAIModel::GPT4Turbo => "gpt-4-turbo",
148            OpenAIModel::GPT4Mini => "gpt-4-mini",
149            OpenAIModel::GPT4RealTimePreview => "gpt-4-realtime-preview",
150            OpenAIModel::GPT40 => "gpt-4o",
151            OpenAIModel::GPT35Turbo => "gpt-3.5-turbo",
152        }
153    }
154}
155
156impl AnthropicModel {
157    pub fn as_str(&self) -> &'static str {
158        match self {
159            AnthropicModel::Claude3Opus => "claude-3-opus",
160            AnthropicModel::Claude3Sonnet => "claude-3-sonnet",
161        }
162    }
163}
164
165#[derive(Debug, Serialize, Clone)]
166pub struct FunctionDefinition {
167    pub name: String,
168    pub description: String,
169    pub parameters: serde_json::Value, // Using Value for flexibility with JSON Schema
170}
171
172#[derive(Debug, Serialize, Clone)]
173pub struct Tool {
174    #[serde(rename = "type")]
175    pub tool_type: String, // Usually "function"
176    pub function: FunctionDefinition,
177}