Skip to main content

elizaos_plugin_copilot_proxy/
types.rs

1//! Type definitions for the Copilot Proxy plugin.
2
3use serde::{Deserialize, Serialize};
4
5/// OpenAI-compatible chat message role.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "lowercase")]
8pub enum ChatRole {
9    /// System message.
10    System,
11    /// User message.
12    User,
13    /// Assistant message.
14    Assistant,
15}
16
17/// OpenAI-compatible chat message.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ChatMessage {
20    /// The role of the message author.
21    pub role: ChatRole,
22    /// The content of the message.
23    pub content: Option<String>,
24}
25
26impl ChatMessage {
27    /// Create a new system message.
28    pub fn system(content: impl Into<String>) -> Self {
29        Self {
30            role: ChatRole::System,
31            content: Some(content.into()),
32        }
33    }
34
35    /// Create a new user message.
36    pub fn user(content: impl Into<String>) -> Self {
37        Self {
38            role: ChatRole::User,
39            content: Some(content.into()),
40        }
41    }
42
43    /// Create a new assistant message.
44    pub fn assistant(content: impl Into<String>) -> Self {
45        Self {
46            role: ChatRole::Assistant,
47            content: Some(content.into()),
48        }
49    }
50}
51
52/// OpenAI-compatible chat completion request.
53#[derive(Debug, Clone, Serialize)]
54pub struct ChatCompletionRequest {
55    /// The model to use.
56    pub model: String,
57    /// The messages to generate a completion for.
58    pub messages: Vec<ChatMessage>,
59    /// Maximum tokens to generate.
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub max_tokens: Option<u32>,
62    /// Sampling temperature.
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub temperature: Option<f32>,
65    /// Top-p sampling.
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub top_p: Option<f32>,
68    /// Frequency penalty.
69    #[serde(skip_serializing_if = "Option::is_none")]
70    pub frequency_penalty: Option<f32>,
71    /// Presence penalty.
72    #[serde(skip_serializing_if = "Option::is_none")]
73    pub presence_penalty: Option<f32>,
74    /// Stop sequences.
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub stop: Option<Vec<String>>,
77    /// Whether to stream the response.
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub stream: Option<bool>,
80}
81
82impl ChatCompletionRequest {
83    /// Create a new chat completion request.
84    pub fn new(model: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
85        Self {
86            model: model.into(),
87            messages,
88            max_tokens: None,
89            temperature: None,
90            top_p: None,
91            frequency_penalty: None,
92            presence_penalty: None,
93            stop: None,
94            stream: None,
95        }
96    }
97
98    /// Set the maximum tokens.
99    pub fn max_tokens(mut self, tokens: u32) -> Self {
100        self.max_tokens = Some(tokens);
101        self
102    }
103
104    /// Set the temperature.
105    pub fn temperature(mut self, temp: f32) -> Self {
106        self.temperature = Some(temp);
107        self
108    }
109
110    /// Set top-p sampling.
111    pub fn top_p(mut self, p: f32) -> Self {
112        self.top_p = Some(p);
113        self
114    }
115
116    /// Set frequency penalty.
117    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
118        self.frequency_penalty = Some(penalty);
119        self
120    }
121
122    /// Set presence penalty.
123    pub fn presence_penalty(mut self, penalty: f32) -> Self {
124        self.presence_penalty = Some(penalty);
125        self
126    }
127
128    /// Set stop sequences.
129    pub fn stop(mut self, sequences: Vec<String>) -> Self {
130        self.stop = Some(sequences);
131        self
132    }
133}
134
135/// OpenAI-compatible chat completion choice.
136#[derive(Debug, Clone, Deserialize)]
137pub struct ChatCompletionChoice {
138    /// The index of this choice.
139    pub index: u32,
140    /// The generated message.
141    pub message: ChatMessage,
142    /// The reason the generation stopped.
143    pub finish_reason: Option<String>,
144}
145
146/// Token usage statistics.
147#[derive(Debug, Clone, Deserialize)]
148pub struct TokenUsage {
149    /// Tokens used in the prompt.
150    pub prompt_tokens: u32,
151    /// Tokens used in the completion.
152    pub completion_tokens: u32,
153    /// Total tokens used.
154    pub total_tokens: u32,
155}
156
157/// OpenAI-compatible chat completion response.
158#[derive(Debug, Clone, Deserialize)]
159pub struct ChatCompletionResponse {
160    /// Unique identifier for this completion.
161    pub id: String,
162    /// The object type (always "chat.completion").
163    pub object: String,
164    /// Unix timestamp of creation.
165    pub created: u64,
166    /// The model used.
167    pub model: String,
168    /// The generated choices.
169    pub choices: Vec<ChatCompletionChoice>,
170    /// Token usage statistics.
171    pub usage: Option<TokenUsage>,
172}
173
174/// Parameters for text generation.
175#[derive(Debug, Clone)]
176pub struct TextGenerationParams {
177    /// The prompt to generate text for.
178    pub prompt: String,
179    /// Optional system message.
180    pub system: Option<String>,
181    /// Optional model override.
182    pub model: Option<String>,
183    /// Sampling temperature.
184    pub temperature: Option<f32>,
185    /// Maximum tokens to generate.
186    pub max_tokens: Option<u32>,
187    /// Frequency penalty.
188    pub frequency_penalty: Option<f32>,
189    /// Presence penalty.
190    pub presence_penalty: Option<f32>,
191    /// Stop sequences.
192    pub stop: Option<Vec<String>>,
193}
194
195impl TextGenerationParams {
196    /// Create new text generation parameters.
197    pub fn new(prompt: impl Into<String>) -> Self {
198        Self {
199            prompt: prompt.into(),
200            system: None,
201            model: None,
202            temperature: None,
203            max_tokens: None,
204            frequency_penalty: None,
205            presence_penalty: None,
206            stop: None,
207        }
208    }
209
210    /// Set the system message.
211    pub fn system(mut self, system: impl Into<String>) -> Self {
212        self.system = Some(system.into());
213        self
214    }
215
216    /// Set the model.
217    pub fn model(mut self, model: impl Into<String>) -> Self {
218        self.model = Some(model.into());
219        self
220    }
221
222    /// Set the temperature.
223    pub fn temperature(mut self, temp: f32) -> Self {
224        self.temperature = Some(temp);
225        self
226    }
227
228    /// Set the maximum tokens.
229    pub fn max_tokens(mut self, tokens: u32) -> Self {
230        self.max_tokens = Some(tokens);
231        self
232    }
233
234    /// Set the frequency penalty.
235    pub fn frequency_penalty(mut self, penalty: f32) -> Self {
236        self.frequency_penalty = Some(penalty);
237        self
238    }
239
240    /// Set the presence penalty.
241    pub fn presence_penalty(mut self, penalty: f32) -> Self {
242        self.presence_penalty = Some(penalty);
243        self
244    }
245
246    /// Set stop sequences.
247    pub fn stop(mut self, sequences: Vec<String>) -> Self {
248        self.stop = Some(sequences);
249        self
250    }
251}
252
253/// Result of text generation.
254#[derive(Debug, Clone)]
255pub struct TextGenerationResult {
256    /// The generated text.
257    pub text: String,
258    /// Token usage statistics.
259    pub usage: Option<TokenUsage>,
260}
261
262/// Model information from the API.
263#[derive(Debug, Clone, Deserialize)]
264pub struct ModelInfo {
265    /// The model ID.
266    pub id: String,
267    /// The object type.
268    pub object: String,
269    /// Unix timestamp of creation.
270    pub created: u64,
271    /// The owner of the model.
272    pub owned_by: String,
273}
274
275/// Response from listing models.
276#[derive(Debug, Clone, Deserialize)]
277pub struct ModelsResponse {
278    /// The object type.
279    pub object: String,
280    /// The list of models.
281    pub data: Vec<ModelInfo>,
282}