infernum_core/
request.rs

1//! Request types for inference operations.
2
3use serde::{Deserialize, Serialize};
4
5use crate::sampling::SamplingParams;
6use crate::types::{Message, ModelId, RequestId};
7
8/// Input format for prompts.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10#[serde(untagged)]
11pub enum PromptInput {
12    /// Raw text prompt.
13    Text(String),
14    /// Chat messages (will be formatted according to model's chat template).
15    Messages(Vec<Message>),
16    /// Pre-tokenized input.
17    Tokens(Vec<u32>),
18}
19
20impl From<String> for PromptInput {
21    fn from(s: String) -> Self {
22        Self::Text(s)
23    }
24}
25
26impl From<&str> for PromptInput {
27    fn from(s: &str) -> Self {
28        Self::Text(s.to_string())
29    }
30}
31
32impl From<Vec<Message>> for PromptInput {
33    fn from(messages: Vec<Message>) -> Self {
34        Self::Messages(messages)
35    }
36}
37
38/// Request for text generation.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct GenerateRequest {
41    /// Unique request identifier.
42    #[serde(default)]
43    pub request_id: RequestId,
44
45    /// Model to use for generation.
46    #[serde(default)]
47    pub model: Option<ModelId>,
48
49    /// Input prompt.
50    pub prompt: PromptInput,
51
52    /// Sampling parameters.
53    #[serde(default)]
54    pub sampling: SamplingParams,
55
56    /// Whether to stream the response.
57    #[serde(default)]
58    pub stream: bool,
59
60    /// Whether to echo the prompt in the response.
61    #[serde(default)]
62    pub echo: bool,
63
64    /// Number of completions to generate.
65    #[serde(default = "default_n")]
66    pub n: u32,
67
68    /// Include log probabilities for top tokens.
69    #[serde(default)]
70    pub logprobs: Option<u32>,
71}
72
73fn default_n() -> u32 {
74    1
75}
76
77impl GenerateRequest {
78    /// Creates a new generation request with the given prompt.
79    #[must_use]
80    pub fn new(prompt: impl Into<PromptInput>) -> Self {
81        Self {
82            request_id: RequestId::new(),
83            model: None,
84            prompt: prompt.into(),
85            sampling: SamplingParams::default(),
86            stream: false,
87            echo: false,
88            n: 1,
89            logprobs: None,
90        }
91    }
92
93    /// Creates a chat completion request.
94    #[must_use]
95    pub fn chat(messages: Vec<Message>) -> Self {
96        Self::new(PromptInput::Messages(messages))
97    }
98
99    /// Sets the model to use.
100    #[must_use]
101    pub fn with_model(mut self, model: impl Into<ModelId>) -> Self {
102        self.model = Some(model.into());
103        self
104    }
105
106    /// Sets the sampling parameters.
107    #[must_use]
108    pub fn with_sampling(mut self, sampling: SamplingParams) -> Self {
109        self.sampling = sampling;
110        self
111    }
112
113    /// Enables streaming.
114    #[must_use]
115    pub fn with_stream(mut self) -> Self {
116        self.stream = true;
117        self
118    }
119
120    /// Sets the number of completions.
121    #[must_use]
122    pub fn with_n(mut self, n: u32) -> Self {
123        self.n = n;
124        self
125    }
126
127    /// Enables log probabilities.
128    #[must_use]
129    pub fn with_logprobs(mut self, top_logprobs: u32) -> Self {
130        self.logprobs = Some(top_logprobs);
131        self
132    }
133}
134
135/// Request for generating embeddings.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct EmbedRequest {
138    /// Unique request identifier.
139    #[serde(default)]
140    pub request_id: RequestId,
141
142    /// Model to use for embeddings.
143    #[serde(default)]
144    pub model: Option<ModelId>,
145
146    /// Input texts to embed.
147    pub input: EmbedInput,
148
149    /// Encoding format for the embeddings.
150    #[serde(default)]
151    pub encoding_format: EncodingFormat,
152
153    /// Dimensionality for the embeddings (if model supports it).
154    #[serde(default)]
155    pub dimensions: Option<u32>,
156}
157
158/// Input format for embeddings.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(untagged)]
161pub enum EmbedInput {
162    /// Single text.
163    Single(String),
164    /// Multiple texts.
165    Multiple(Vec<String>),
166}
167
168impl From<String> for EmbedInput {
169    fn from(s: String) -> Self {
170        Self::Single(s)
171    }
172}
173
174impl From<&str> for EmbedInput {
175    fn from(s: &str) -> Self {
176        Self::Single(s.to_string())
177    }
178}
179
180impl From<Vec<String>> for EmbedInput {
181    fn from(v: Vec<String>) -> Self {
182        Self::Multiple(v)
183    }
184}
185
186impl EmbedInput {
187    /// Returns the inputs as a slice of strings.
188    #[must_use]
189    pub fn as_texts(&self) -> Vec<&str> {
190        match self {
191            Self::Single(s) => vec![s.as_str()],
192            Self::Multiple(v) => v.iter().map(String::as_str).collect(),
193        }
194    }
195}
196
197/// Encoding format for embeddings.
198#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
199#[serde(rename_all = "snake_case")]
200pub enum EncodingFormat {
201    /// 32-bit floating point.
202    #[default]
203    Float,
204    /// Base64 encoded binary.
205    Base64,
206}
207
208impl EmbedRequest {
209    /// Creates a new embedding request.
210    #[must_use]
211    pub fn new(input: impl Into<EmbedInput>) -> Self {
212        Self {
213            request_id: RequestId::new(),
214            model: None,
215            input: input.into(),
216            encoding_format: EncodingFormat::Float,
217            dimensions: None,
218        }
219    }
220
221    /// Sets the model to use.
222    #[must_use]
223    pub fn with_model(mut self, model: impl Into<ModelId>) -> Self {
224        self.model = Some(model.into());
225        self
226    }
227
228    /// Sets the encoding format.
229    #[must_use]
230    pub fn with_encoding_format(mut self, format: EncodingFormat) -> Self {
231        self.encoding_format = format;
232        self
233    }
234
235    /// Sets the output dimensions.
236    #[must_use]
237    pub fn with_dimensions(mut self, dims: u32) -> Self {
238        self.dimensions = Some(dims);
239        self
240    }
241}