1use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
18pub struct CompleteRequest {
19 pub argument: CompletionArgument,
21}
22
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
25pub struct CompletionArgument {
26 pub messages: Vec<SamplingMessage>,
28
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub model_preferences: Option<ModelPreferences>,
32
33 #[serde(skip_serializing_if = "Option::is_none")]
35 pub system_prompt: Option<String>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub include_context: Option<String>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
43 pub temperature: Option<f64>,
44
45 #[serde(skip_serializing_if = "Option::is_none")]
47 pub max_tokens: Option<i32>,
48
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub stop_sequences: Option<Vec<String>>,
52
53 #[serde(flatten)]
55 pub metadata: HashMap<String, Value>,
56}
57
58impl CompletionArgument {
59 pub fn new(messages: Vec<SamplingMessage>) -> Self {
61 Self {
62 messages,
63 model_preferences: None,
64 system_prompt: None,
65 include_context: None,
66 temperature: None,
67 max_tokens: None,
68 stop_sequences: None,
69 metadata: HashMap::new(),
70 }
71 }
72
73 pub fn with_model_preferences(mut self, preferences: ModelPreferences) -> Self {
75 self.model_preferences = Some(preferences);
76 self
77 }
78
79 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81 self.system_prompt = Some(prompt.into());
82 self
83 }
84
85 pub fn with_temperature(mut self, temperature: f64) -> Self {
87 self.temperature = Some(temperature);
88 self
89 }
90
91 pub fn with_max_tokens(mut self, max_tokens: i32) -> Self {
93 self.max_tokens = Some(max_tokens);
94 self
95 }
96
97 pub fn with_stop_sequences(mut self, sequences: Vec<String>) -> Self {
99 self.stop_sequences = Some(sequences);
100 self
101 }
102
103 pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
105 self.metadata.insert(key.into(), value);
106 self
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct ModelPreferences {
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub models: Option<Vec<String>>,
116
117 #[serde(skip_serializing_if = "Option::is_none")]
119 pub cost_priority: Option<CostPriority>,
120
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub speed_priority: Option<SpeedPriority>,
124
125 #[serde(skip_serializing_if = "Option::is_none")]
127 pub intelligence_priority: Option<IntelligencePriority>,
128}
129
130impl ModelPreferences {
131 pub fn new() -> Self {
133 Self {
134 models: None,
135 cost_priority: None,
136 speed_priority: None,
137 intelligence_priority: None,
138 }
139 }
140
141 pub fn with_models(mut self, models: Vec<String>) -> Self {
143 self.models = Some(models);
144 self
145 }
146
147 pub fn with_cost_priority(mut self, priority: CostPriority) -> Self {
149 self.cost_priority = Some(priority);
150 self
151 }
152
153 pub fn with_speed_priority(mut self, priority: SpeedPriority) -> Self {
155 self.speed_priority = Some(priority);
156 self
157 }
158
159 pub fn with_intelligence_priority(mut self, priority: IntelligencePriority) -> Self {
161 self.intelligence_priority = Some(priority);
162 self
163 }
164}
165
166impl Default for ModelPreferences {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
174#[serde(rename_all = "lowercase")]
175pub enum CostPriority {
176 Low,
178 Medium,
180 High,
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186#[serde(rename_all = "lowercase")]
187pub enum SpeedPriority {
188 Low,
190 Medium,
192 High,
194}
195
196#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
198#[serde(rename_all = "lowercase")]
199pub enum IntelligencePriority {
200 Low,
202 Medium,
204 High,
206}
207
208#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
210pub struct SamplingMessage {
211 pub role: MessageRole,
213
214 pub content: SamplingContent,
216}
217
218impl SamplingMessage {
219 pub fn new(role: MessageRole, content: SamplingContent) -> Self {
221 Self { role, content }
222 }
223
224 pub fn system(content: impl Into<String>) -> Self {
226 Self::new(MessageRole::System, SamplingContent::text(content))
227 }
228
229 pub fn user(content: impl Into<String>) -> Self {
231 Self::new(MessageRole::User, SamplingContent::text(content))
232 }
233
234 pub fn assistant(content: impl Into<String>) -> Self {
236 Self::new(MessageRole::Assistant, SamplingContent::text(content))
237 }
238}
239
240#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
242#[serde(rename_all = "lowercase")]
243pub enum MessageRole {
244 System,
246 User,
248 Assistant,
250}
251
252#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
254#[serde(tag = "type")]
255pub enum SamplingContent {
256 #[serde(rename = "text")]
258 Text {
259 text: String,
261 },
262
263 #[serde(rename = "image")]
265 Image {
266 data: String,
268
269 #[serde(rename = "mimeType")]
271 mime_type: String,
272 },
273}
274
275impl SamplingContent {
276 pub fn text(text: impl Into<String>) -> Self {
278 Self::Text { text: text.into() }
279 }
280
281 pub fn image(data: impl Into<String>, mime_type: impl Into<String>) -> Self {
283 Self::Image {
284 data: data.into(),
285 mime_type: mime_type.into(),
286 }
287 }
288}
289
290#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
292pub struct CompleteResponse {
293 pub completion: CompletionResult,
295
296 #[serde(skip_serializing_if = "Option::is_none")]
298 pub model: Option<String>,
299
300 #[serde(skip_serializing_if = "Option::is_none")]
302 pub stop_reason: Option<StopReason>,
303}
304
305#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
307#[serde(tag = "type")]
308pub enum CompletionResult {
309 #[serde(rename = "text")]
311 Text {
312 text: String,
314 },
315}
316
317impl CompletionResult {
318 pub fn text(text: impl Into<String>) -> Self {
320 Self::Text { text: text.into() }
321 }
322}
323
324#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
326#[serde(rename_all = "snake_case")]
327pub enum StopReason {
328 EndTurn,
330 MaxTokens,
332 StopSequence,
334 ToolUse,
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use serde_json::json;
342
343 #[test]
344 fn test_completion_argument_creation() {
345 let messages = vec![
346 SamplingMessage::system("You are a helpful assistant"),
347 SamplingMessage::user("Hello, how are you?"),
348 ];
349
350 let arg = CompletionArgument::new(messages)
351 .with_temperature(0.7)
352 .with_max_tokens(1000)
353 .with_system_prompt("Be helpful")
354 .with_metadata("priority", json!("high"));
355
356 assert_eq!(arg.temperature, Some(0.7));
357 assert_eq!(arg.max_tokens, Some(1000));
358 assert_eq!(arg.system_prompt, Some("Be helpful".to_string()));
359 assert_eq!(arg.metadata.get("priority"), Some(&json!("high")));
360 }
361
362 #[test]
363 fn test_model_preferences() {
364 let prefs = ModelPreferences::new()
365 .with_models(vec!["gpt-4".to_string(), "claude-3".to_string()])
366 .with_cost_priority(CostPriority::Medium)
367 .with_speed_priority(SpeedPriority::High)
368 .with_intelligence_priority(IntelligencePriority::High);
369
370 assert_eq!(
371 prefs.models,
372 Some(vec!["gpt-4".to_string(), "claude-3".to_string()])
373 );
374 assert_eq!(prefs.cost_priority, Some(CostPriority::Medium));
375 assert_eq!(prefs.speed_priority, Some(SpeedPriority::High));
376 assert_eq!(
377 prefs.intelligence_priority,
378 Some(IntelligencePriority::High)
379 );
380 }
381
382 #[test]
383 fn test_sampling_message_creation() {
384 let system_msg = SamplingMessage::system("You are helpful");
385 let user_msg = SamplingMessage::user("Hello");
386 let assistant_msg = SamplingMessage::assistant("Hi there!");
387
388 assert_eq!(system_msg.role, MessageRole::System);
389 assert_eq!(user_msg.role, MessageRole::User);
390 assert_eq!(assistant_msg.role, MessageRole::Assistant);
391 }
392
393 #[test]
394 fn test_sampling_content_text() {
395 let content = SamplingContent::text("Hello world");
396 let json = serde_json::to_value(&content).unwrap();
397 assert_eq!(json["type"], "text");
398 assert_eq!(json["text"], "Hello world");
399 }
400
401 #[test]
402 fn test_sampling_content_image() {
403 let content = SamplingContent::image("base64data", "image/png");
404 let json = serde_json::to_value(&content).unwrap();
405 assert_eq!(json["type"], "image");
406 assert_eq!(json["data"], "base64data");
407 assert_eq!(json["mimeType"], "image/png");
408 }
409
410 #[test]
411 fn test_completion_result() {
412 let result = CompletionResult::text("Generated response");
413 let json = serde_json::to_value(&result).unwrap();
414 assert_eq!(json["type"], "text");
415 assert_eq!(json["text"], "Generated response");
416 }
417
418 #[test]
419 fn test_priority_serialization() {
420 let cost = CostPriority::Low;
421 let speed = SpeedPriority::Medium;
422 let intel = IntelligencePriority::High;
423
424 assert_eq!(serde_json::to_string(&cost).unwrap(), "\"low\"");
425 assert_eq!(serde_json::to_string(&speed).unwrap(), "\"medium\"");
426 assert_eq!(serde_json::to_string(&intel).unwrap(), "\"high\"");
427 }
428
429 #[test]
430 fn test_stop_reason_serialization() {
431 let reasons = [
432 StopReason::EndTurn,
433 StopReason::MaxTokens,
434 StopReason::StopSequence,
435 StopReason::ToolUse,
436 ];
437
438 let expected = [
439 "\"end_turn\"",
440 "\"max_tokens\"",
441 "\"stop_sequence\"",
442 "\"tool_use\"",
443 ];
444
445 for (reason, expected) in reasons.iter().zip(expected.iter()) {
446 assert_eq!(serde_json::to_string(reason).unwrap(), *expected);
447 }
448 }
449}