1use serde::{Deserialize, Serialize};
4
5use crate::sampling::SamplingParams;
6use crate::types::{Message, ModelId, RequestId};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(untagged)]
11pub enum PromptInput {
12 Text(String),
14 Messages(Vec<Message>),
16 Tokens(Vec<u32>),
18}
19
20impl From<String> for PromptInput {
21 fn from(s: String) -> Self {
22 Self::Text(s)
23 }
24}
25
26impl From<&str> for PromptInput {
27 fn from(s: &str) -> Self {
28 Self::Text(s.to_string())
29 }
30}
31
32impl From<Vec<Message>> for PromptInput {
33 fn from(messages: Vec<Message>) -> Self {
34 Self::Messages(messages)
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GenerateRequest {
41 #[serde(default)]
43 pub request_id: RequestId,
44
45 #[serde(default)]
47 pub model: Option<ModelId>,
48
49 pub prompt: PromptInput,
51
52 #[serde(default)]
54 pub sampling: SamplingParams,
55
56 #[serde(default)]
58 pub stream: bool,
59
60 #[serde(default)]
62 pub echo: bool,
63
64 #[serde(default = "default_n")]
66 pub n: u32,
67
68 #[serde(default)]
70 pub logprobs: Option<u32>,
71}
72
73fn default_n() -> u32 {
74 1
75}
76
77impl GenerateRequest {
78 #[must_use]
80 pub fn new(prompt: impl Into<PromptInput>) -> Self {
81 Self {
82 request_id: RequestId::new(),
83 model: None,
84 prompt: prompt.into(),
85 sampling: SamplingParams::default(),
86 stream: false,
87 echo: false,
88 n: 1,
89 logprobs: None,
90 }
91 }
92
93 #[must_use]
95 pub fn chat(messages: Vec<Message>) -> Self {
96 Self::new(PromptInput::Messages(messages))
97 }
98
99 #[must_use]
101 pub fn with_model(mut self, model: impl Into<ModelId>) -> Self {
102 self.model = Some(model.into());
103 self
104 }
105
106 #[must_use]
108 pub fn with_sampling(mut self, sampling: SamplingParams) -> Self {
109 self.sampling = sampling;
110 self
111 }
112
113 #[must_use]
115 pub fn with_stream(mut self) -> Self {
116 self.stream = true;
117 self
118 }
119
120 #[must_use]
122 pub fn with_n(mut self, n: u32) -> Self {
123 self.n = n;
124 self
125 }
126
127 #[must_use]
129 pub fn with_logprobs(mut self, top_logprobs: u32) -> Self {
130 self.logprobs = Some(top_logprobs);
131 self
132 }
133}
134
135#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct EmbedRequest {
138 #[serde(default)]
140 pub request_id: RequestId,
141
142 #[serde(default)]
144 pub model: Option<ModelId>,
145
146 pub input: EmbedInput,
148
149 #[serde(default)]
151 pub encoding_format: EncodingFormat,
152
153 #[serde(default)]
155 pub dimensions: Option<u32>,
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(untagged)]
161pub enum EmbedInput {
162 Single(String),
164 Multiple(Vec<String>),
166}
167
168impl From<String> for EmbedInput {
169 fn from(s: String) -> Self {
170 Self::Single(s)
171 }
172}
173
174impl From<&str> for EmbedInput {
175 fn from(s: &str) -> Self {
176 Self::Single(s.to_string())
177 }
178}
179
180impl From<Vec<String>> for EmbedInput {
181 fn from(v: Vec<String>) -> Self {
182 Self::Multiple(v)
183 }
184}
185
186impl EmbedInput {
187 #[must_use]
189 pub fn as_texts(&self) -> Vec<&str> {
190 match self {
191 Self::Single(s) => vec![s.as_str()],
192 Self::Multiple(v) => v.iter().map(String::as_str).collect(),
193 }
194 }
195}
196
197#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
199#[serde(rename_all = "snake_case")]
200pub enum EncodingFormat {
201 #[default]
203 Float,
204 Base64,
206}
207
208impl EmbedRequest {
209 #[must_use]
211 pub fn new(input: impl Into<EmbedInput>) -> Self {
212 Self {
213 request_id: RequestId::new(),
214 model: None,
215 input: input.into(),
216 encoding_format: EncodingFormat::Float,
217 dimensions: None,
218 }
219 }
220
221 #[must_use]
223 pub fn with_model(mut self, model: impl Into<ModelId>) -> Self {
224 self.model = Some(model.into());
225 self
226 }
227
228 #[must_use]
230 pub fn with_encoding_format(mut self, format: EncodingFormat) -> Self {
231 self.encoding_format = format;
232 self
233 }
234
235 #[must_use]
237 pub fn with_dimensions(mut self, dims: u32) -> Self {
238 self.dimensions = Some(dims);
239 self
240 }
241}