alith_interface/llms/api/openai/completion/
req.rs1use crate::requests::{completion::*, stop_sequence::StopSequences};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4
5#[derive(Clone, Serialize, Default, Debug, Deserialize)]
6pub struct OpenAICompletionRequest {
7 pub model: String,
10
11 pub messages: Vec<CompletionRequestMessage>, #[serde(skip_serializing_if = "Option::is_none")]
21 pub logit_bias: Option<HashMap<String, serde_json::Value>>, #[serde(skip_serializing_if = "Option::is_none")]
25 pub logprobs: Option<bool>,
26
27 #[serde(skip_serializing_if = "Option::is_none")]
29 pub top_logprobs: Option<u8>,
30
31 #[serde(skip_serializing_if = "Option::is_none")]
35 pub max_tokens: Option<u64>,
36
37 #[serde(skip_serializing_if = "Option::is_none")]
39 pub temperature: Option<f32>,
40
41 #[serde(skip_serializing_if = "Option::is_none")]
45 pub frequency_penalty: Option<f32>,
46
47 #[serde(skip_serializing_if = "Option::is_none")]
51 pub presence_penalty: Option<f32>,
52
53 #[serde(skip_serializing_if = "Option::is_none")]
55 pub stop: Option<Stop>,
56
57 #[serde(skip_serializing_if = "Option::is_none")]
59 pub top_p: Option<f32>,
60
61 #[serde(skip_serializing_if = "Option::is_none")]
63 pub tools: Option<Vec<OpenAIToolDefinition>>,
64
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub tool_choice: Option<String>,
68}
69
70#[derive(Clone, Serialize, Debug, Deserialize)]
71pub struct OpenAIToolDefinition {
72 pub r#type: String,
73 pub function: ToolDefinition,
74}
75
76impl OpenAICompletionRequest {
77 pub fn new(req: &CompletionRequest) -> crate::Result<Self, CompletionError> {
78 let mut messages = Vec::new();
79 match &req.prompt.get_built_prompt_messages() {
80 Ok(prompt_message) => {
81 for m in prompt_message {
82 messages.push(CompletionRequestMessage::new(m)?);
83 }
84 }
85 Err(e) => return Err(CompletionError::RequestBuilderError(e.to_string())),
86 }
87
88 Ok(OpenAICompletionRequest {
89 messages,
90 model: req.backend.model_id().to_owned(),
91 logit_bias: req.logit_bias.as_ref().and_then(|lb| lb.get_openai()),
92 frequency_penalty: req.config.frequency_penalty,
93 logprobs: None,
94 top_logprobs: None,
95 max_tokens: req.config.actual_request_tokens,
96 presence_penalty: Some(req.config.presence_penalty),
97 stop: Stop::new(&req.stop_sequences)?,
98 temperature: Some(req.config.temperature),
99 top_p: req.config.top_p,
100 tools: if !req.tools.is_empty() {
101 Some(
102 req.tools
103 .iter()
104 .map(|tool| OpenAIToolDefinition {
105 r#type: "function".to_string(),
106 function: tool.clone(),
107 })
108 .collect(),
109 )
110 } else {
111 None
112 },
113 tool_choice: if !req.tools.is_empty() {
114 Some("auto".to_string())
115 } else {
116 None
117 },
118 })
119 }
120}
121
122#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
123pub struct CompletionRequestMessage {
124 pub role: String,
125 pub content: String,
126}
127
128impl CompletionRequestMessage {
129 pub fn new(
130 message: &std::collections::HashMap<String, String>,
131 ) -> crate::Result<Self, CompletionError> {
132 let role = message
133 .get("role")
134 .ok_or_else(|| CompletionError::RequestBuilderError("Role not found".to_string()))?;
135 let content = message
136 .get("content")
137 .ok_or_else(|| CompletionError::RequestBuilderError("Content not found".to_string()))?;
138
139 match role.as_str() {
140 "system" | "user" | "assistant" => Ok(CompletionRequestMessage {
141 role: role.to_string(),
142 content: content.to_string(),
143 }),
144 _ => Err(CompletionError::RequestBuilderError(format!(
145 "Role {} not supported",
146 role
147 ))),
148 }
149 }
150}
151
152#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
153#[serde(untagged)]
154pub enum Stop {
155 String(String), StringArray(Vec<String>), }
158
159impl Stop {
160 pub fn new(stop_sequences: &StopSequences) -> crate::Result<Option<Self>, CompletionError> {
161 match stop_sequences.sequences.len() {
162 0 => Ok(None),
163 1..=4 => Ok(Some(Stop::StringArray(stop_sequences.to_vec()))),
164 _ => Err(CompletionError::RequestBuilderError(
165 "OpenAI Stop array cannot have more than 4 elements".to_string(),
166 )),
167 }
168 }
169}