use std::{collections::HashMap, fmt::Display, sync::Arc};
use super::*;
use either::Either;
use image::DynamicImage;
use indexmap::IndexMap;
use serde_json::{json, Value};
pub trait RequestLike {
fn messages_ref(&self) -> &[IndexMap<String, MessageContent>];
fn images_ref(&self) -> &[DynamicImage];
fn take_messages(&mut self) -> RequestMessage;
fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>>;
fn take_adapters(&mut self) -> Option<Vec<String>>;
fn return_logprobs(&self) -> bool;
fn enable_search(&self) -> Option<bool>;
fn take_constraint(&mut self) -> Constraint;
fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)>;
fn take_sampling_params(&mut self) -> SamplingParams;
fn take_web_search_options(&mut self) -> Option<WebSearchOptions>;
fn truncate_sequence(&self) -> bool {
false
}
fn resolve_pending_prefixes(&mut self, _category: &ModelCategory) {}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TextMessages {
messages: Vec<IndexMap<String, MessageContent>>,
enable_thinking: Option<bool>,
}
impl From<TextMessages> for Vec<IndexMap<String, MessageContent>> {
fn from(value: TextMessages) -> Self {
value.messages
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum TextMessageRole {
User,
Assistant,
System,
Tool,
Custom(String),
}
impl Display for TextMessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::User => write!(f, "user"),
Self::Assistant => write!(f, "assistant"),
Self::System => write!(f, "system"),
Self::Tool => write!(f, "tool"),
Self::Custom(c) => write!(f, "{c}"),
}
}
}
impl Default for TextMessages {
fn default() -> Self {
Self::new()
}
}
impl TextMessages {
pub fn new() -> Self {
Self {
messages: Vec::new(),
enable_thinking: None,
}
}
pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
]));
self
}
pub fn clear(mut self) -> Self {
self.messages.clear();
self
}
pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
self.enable_thinking = Some(enable_thinking);
self
}
}
impl RequestLike for TextMessages {
fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
&self.messages
}
fn images_ref(&self) -> &[DynamicImage] {
&[]
}
fn take_messages(&mut self) -> RequestMessage {
let mut other = Vec::new();
std::mem::swap(&mut other, &mut self.messages);
RequestMessage::Chat {
messages: other,
enable_thinking: self.enable_thinking,
reasoning_effort: None,
}
}
fn enable_search(&self) -> Option<bool> {
None
}
fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
None
}
fn take_adapters(&mut self) -> Option<Vec<String>> {
None
}
fn return_logprobs(&self) -> bool {
false
}
fn take_constraint(&mut self) -> Constraint {
Constraint::None
}
fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
None
}
fn take_sampling_params(&mut self) -> SamplingParams {
SamplingParams::deterministic()
}
fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
None
}
}
impl From<TextMessages> for MultimodalMessages {
fn from(text: TextMessages) -> Self {
Self {
messages: text.messages,
images: Vec::new(),
audios: Vec::new(),
videos: Vec::new(),
enable_thinking: text.enable_thinking,
pending_prefixes: Vec::new(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
struct PendingMediaPrefix {
message_index: usize,
image_indices: Vec<usize>,
audio_indices: Vec<usize>,
video_indices: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct MultimodalMessages {
messages: Vec<IndexMap<String, MessageContent>>,
images: Vec<DynamicImage>,
audios: Vec<AudioInput>,
videos: Vec<VideoInput>,
enable_thinking: Option<bool>,
pending_prefixes: Vec<PendingMediaPrefix>,
}
impl Default for MultimodalMessages {
fn default() -> Self {
Self::new()
}
}
impl MultimodalMessages {
pub fn new() -> Self {
Self {
images: Vec::new(),
messages: Vec::new(),
audios: Vec::new(),
videos: Vec::new(),
enable_thinking: None,
pending_prefixes: Vec::new(),
}
}
pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
]));
self
}
pub fn add_image_message(
self,
role: TextMessageRole,
text: impl ToString,
images: Vec<DynamicImage>,
) -> Self {
self.add_multimodal_message(role, text, images, vec![], vec![])
}
pub fn add_audio_message(
self,
role: TextMessageRole,
text: impl ToString,
audios: Vec<AudioInput>,
) -> Self {
self.add_multimodal_message(role, text, vec![], audios, vec![])
}
pub fn add_video_message(
self,
role: TextMessageRole,
text: impl ToString,
videos: Vec<VideoInput>,
) -> Self {
self.add_multimodal_message(role, text, vec![], vec![], videos)
}
pub fn add_multimodal_message(
mut self,
role: TextMessageRole,
text: impl ToString,
images: Vec<DynamicImage>,
audios: Vec<AudioInput>,
videos: Vec<VideoInput>,
) -> Self {
let n_added_images = images.len();
let image_indices: Vec<usize> =
(self.images.len()..self.images.len() + n_added_images).collect();
self.images.extend(images);
let n_added_audios = audios.len();
let audio_indices: Vec<usize> =
(self.audios.len()..self.audios.len() + n_added_audios).collect();
self.audios.extend(audios);
let n_added_videos = videos.len();
let video_indices: Vec<usize> =
(self.videos.len()..self.videos.len() + n_added_videos).collect();
self.videos.extend(videos);
if n_added_images > 0 || n_added_audios > 0 || n_added_videos > 0 {
let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
for _ in 0..n_added_images {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("image".to_string()),
)]));
}
for _ in 0..n_added_audios {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("audio".to_string()),
)]));
}
for _ in 0..n_added_videos {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("video".to_string()),
)]));
}
content_vec.push(IndexMap::from([
("type".to_string(), Value::String("text".to_string())),
("text".to_string(), Value::String(text.to_string())),
]));
let message_index = self.messages.len();
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Right(content_vec)),
]));
self.pending_prefixes.push(PendingMediaPrefix {
message_index,
image_indices,
audio_indices,
video_indices,
});
} else {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
]));
}
self
}
pub fn clear(mut self) -> Self {
self.messages.clear();
self.images.clear();
self.audios.clear();
self.videos.clear();
self.pending_prefixes.clear();
self
}
pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
self.enable_thinking = Some(enable_thinking);
self
}
}
impl RequestLike for MultimodalMessages {
fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
&self.messages
}
fn images_ref(&self) -> &[DynamicImage] {
&self.images
}
fn resolve_pending_prefixes(&mut self, category: &ModelCategory) {
resolve_pending(category, &mut self.messages, &mut self.pending_prefixes);
}
fn take_messages(&mut self) -> RequestMessage {
let mut other_messages = Vec::new();
std::mem::swap(&mut other_messages, &mut self.messages);
let mut other_images = Vec::new();
std::mem::swap(&mut other_images, &mut self.images);
let mut other_audios = Vec::new();
std::mem::swap(&mut other_audios, &mut self.audios);
let mut other_videos = Vec::new();
std::mem::swap(&mut other_videos, &mut self.videos);
RequestMessage::MultimodalChat {
images: other_images,
messages: other_messages,
audios: other_audios,
videos: other_videos,
enable_thinking: self.enable_thinking,
reasoning_effort: None,
}
}
fn enable_search(&self) -> Option<bool> {
None
}
fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
None
}
fn take_adapters(&mut self) -> Option<Vec<String>> {
None
}
fn return_logprobs(&self) -> bool {
false
}
fn take_constraint(&mut self) -> Constraint {
Constraint::None
}
fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
None
}
fn take_sampling_params(&mut self) -> SamplingParams {
SamplingParams::deterministic()
}
fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
None
}
}
#[derive(Clone)]
pub struct RequestBuilder {
messages: Vec<IndexMap<String, MessageContent>>,
images: Vec<DynamicImage>,
audios: Vec<AudioInput>,
videos: Vec<VideoInput>,
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
adapters: Vec<String>,
return_logprobs: bool,
constraint: Constraint,
tools: Vec<Tool>,
tool_choice: ToolChoice,
sampling_params: SamplingParams,
web_search_options: Option<WebSearchOptions>,
enable_thinking: Option<bool>,
truncate_sequence: bool,
pending_prefixes: Vec<PendingMediaPrefix>,
}
impl Default for RequestBuilder {
fn default() -> Self {
Self::new()
}
}
impl From<TextMessages> for RequestBuilder {
fn from(value: TextMessages) -> Self {
Self {
messages: value.messages,
images: Vec::new(),
audios: Vec::new(),
videos: Vec::new(),
logits_processors: Vec::new(),
adapters: Vec::new(),
return_logprobs: false,
constraint: Constraint::None,
tools: Vec::new(),
tool_choice: ToolChoice::Auto,
sampling_params: SamplingParams::deterministic(),
web_search_options: None,
enable_thinking: None,
truncate_sequence: false,
pending_prefixes: Vec::new(),
}
}
}
impl From<MultimodalMessages> for RequestBuilder {
fn from(value: MultimodalMessages) -> Self {
Self {
messages: value.messages,
images: value.images,
audios: value.audios,
videos: value.videos,
logits_processors: Vec::new(),
adapters: Vec::new(),
return_logprobs: false,
constraint: Constraint::None,
tools: Vec::new(),
tool_choice: ToolChoice::Auto,
sampling_params: SamplingParams::deterministic(),
web_search_options: None,
enable_thinking: None,
truncate_sequence: false,
pending_prefixes: value.pending_prefixes,
}
}
}
impl RequestBuilder {
pub fn new() -> Self {
Self {
messages: Vec::new(),
images: Vec::new(),
audios: Vec::new(),
videos: Vec::new(),
logits_processors: Vec::new(),
adapters: Vec::new(),
return_logprobs: false,
constraint: Constraint::None,
tools: Vec::new(),
tool_choice: ToolChoice::Auto,
sampling_params: SamplingParams::deterministic(),
web_search_options: None,
enable_thinking: None,
truncate_sequence: false,
pending_prefixes: Vec::new(),
}
}
pub fn with_web_search_options(mut self, web_search_options: WebSearchOptions) -> Self {
self.web_search_options = Some(web_search_options);
self
}
pub fn add_message(mut self, role: TextMessageRole, text: impl ToString) -> Self {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
]));
self
}
pub fn add_tool_message(mut self, tool_content: impl ToString, tool_id: impl ToString) -> Self {
self.messages.push(IndexMap::from([
(
"role".to_string(),
Either::Left(TextMessageRole::Tool.to_string()),
),
(
"content".to_string(),
Either::Left(tool_content.to_string()),
),
(
"tool_call_id".to_string(),
Either::Left(tool_id.to_string()),
),
]));
self
}
pub fn add_message_with_tool_call(
mut self,
role: TextMessageRole,
text: impl ToString,
tool_calls: Vec<ToolCallResponse>,
) -> Self {
let tool_messages = tool_calls
.iter()
.map(|t| {
IndexMap::from([
("id".to_string(), Value::String(t.id.clone())),
("type".to_string(), Value::String(t.tp.to_string())),
(
"function".to_string(),
json!({
"name": t.function.name,
"arguments": t.function.arguments,
}),
),
])
})
.collect();
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
("function".to_string(), Either::Right(tool_messages)),
]));
self
}
pub fn add_image_message(
self,
role: TextMessageRole,
text: impl ToString,
images: Vec<DynamicImage>,
) -> Self {
self.add_multimodal_message(role, text, images, vec![], vec![])
}
pub fn add_audio_message(
self,
role: TextMessageRole,
text: impl ToString,
audios: Vec<AudioInput>,
) -> Self {
self.add_multimodal_message(role, text, vec![], audios, vec![])
}
pub fn add_video_message(
self,
role: TextMessageRole,
text: impl ToString,
videos: Vec<VideoInput>,
) -> Self {
self.add_multimodal_message(role, text, vec![], vec![], videos)
}
pub fn add_multimodal_message(
mut self,
role: TextMessageRole,
text: impl ToString,
images: Vec<DynamicImage>,
audios: Vec<AudioInput>,
videos: Vec<VideoInput>,
) -> Self {
let n_added_images = images.len();
let image_indices: Vec<usize> =
(self.images.len()..self.images.len() + n_added_images).collect();
self.images.extend(images);
let n_added_audios = audios.len();
let audio_indices: Vec<usize> =
(self.audios.len()..self.audios.len() + n_added_audios).collect();
self.audios.extend(audios);
let n_added_videos = videos.len();
let video_indices: Vec<usize> =
(self.videos.len()..self.videos.len() + n_added_videos).collect();
self.videos.extend(videos);
if n_added_images > 0 || n_added_audios > 0 || n_added_videos > 0 {
let mut content_vec: Vec<IndexMap<String, Value>> = Vec::new();
for _ in 0..n_added_images {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("image".to_string()),
)]));
}
for _ in 0..n_added_audios {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("audio".to_string()),
)]));
}
for _ in 0..n_added_videos {
content_vec.push(IndexMap::from([(
"type".to_string(),
Value::String("video".to_string()),
)]));
}
content_vec.push(IndexMap::from([
("type".to_string(), Value::String("text".to_string())),
("text".to_string(), Value::String(text.to_string())),
]));
let message_index = self.messages.len();
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Right(content_vec)),
]));
self.pending_prefixes.push(PendingMediaPrefix {
message_index,
image_indices,
audio_indices,
video_indices,
});
} else {
self.messages.push(IndexMap::from([
("role".to_string(), Either::Left(role.to_string())),
("content".to_string(), Either::Left(text.to_string())),
]));
}
self
}
pub fn add_logits_processor(mut self, processor: Arc<dyn CustomLogitsProcessor>) -> Self {
self.logits_processors.push(processor);
self
}
pub fn set_adapters(mut self, adapters: Vec<String>) -> Self {
self.adapters = adapters;
self
}
pub fn set_tools(mut self, tools: Vec<Tool>) -> Self {
self.tools = tools;
self
}
pub fn set_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
self.tool_choice = tool_choice;
self
}
pub fn return_logprobs(mut self, return_logprobs: bool) -> Self {
self.return_logprobs = return_logprobs;
self
}
pub fn set_constraint(mut self, constraint: Constraint) -> Self {
self.constraint = constraint;
self
}
pub fn set_sampling(mut self, params: SamplingParams) -> Self {
self.sampling_params = params;
self
}
pub fn set_deterministic_sampler(mut self) -> Self {
self.sampling_params = SamplingParams::deterministic();
self
}
pub fn set_sampler_temperature(mut self, temperature: f64) -> Self {
self.sampling_params.temperature = Some(temperature);
self
}
pub fn set_sampler_topk(mut self, topk: usize) -> Self {
self.sampling_params.top_k = Some(topk);
self
}
pub fn set_sampler_topp(mut self, topp: f64) -> Self {
self.sampling_params.top_p = Some(topp);
self
}
pub fn set_sampler_minp(mut self, minp: f64) -> Self {
self.sampling_params.min_p = Some(minp);
self
}
pub fn set_sampler_topn_logprobs(mut self, top_n_logprobs: usize) -> Self {
self.sampling_params.top_n_logprobs = top_n_logprobs;
self
}
pub fn set_sampler_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
self.sampling_params.frequency_penalty = Some(frequency_penalty);
self
}
pub fn set_sampler_presence_penalty(mut self, presence_penalty: f32) -> Self {
self.sampling_params.presence_penalty = Some(presence_penalty);
self
}
pub fn set_sampler_stop_toks(mut self, stop_toks: StopTokens) -> Self {
self.sampling_params.stop_toks = Some(stop_toks);
self
}
pub fn set_sampler_max_len(mut self, max_len: usize) -> Self {
self.sampling_params.max_len = Some(max_len);
self
}
pub fn set_sampler_logits_bias(mut self, logits_bias: HashMap<u32, f32>) -> Self {
self.sampling_params.logits_bias = Some(logits_bias);
self
}
pub fn set_sampler_n_choices(mut self, n_choices: usize) -> Self {
self.sampling_params.n_choices = n_choices;
self
}
pub fn set_sampler_dry_params(mut self, dry_params: DrySamplingParams) -> Self {
self.sampling_params.dry_params = Some(dry_params);
self
}
pub fn enable_thinking(mut self, enable_thinking: bool) -> Self {
self.enable_thinking = Some(enable_thinking);
self
}
pub fn with_truncate_sequence(mut self, truncate_sequence: bool) -> Self {
self.truncate_sequence = truncate_sequence;
self
}
}
impl RequestLike for RequestBuilder {
fn messages_ref(&self) -> &[IndexMap<String, MessageContent>] {
&self.messages
}
fn images_ref(&self) -> &[DynamicImage] {
&self.images
}
fn resolve_pending_prefixes(&mut self, category: &ModelCategory) {
resolve_pending(category, &mut self.messages, &mut self.pending_prefixes);
}
fn take_messages(&mut self) -> RequestMessage {
if self.images.is_empty() && self.audios.is_empty() && self.videos.is_empty() {
let mut other = Vec::new();
std::mem::swap(&mut other, &mut self.messages);
RequestMessage::Chat {
messages: other,
enable_thinking: self.enable_thinking,
reasoning_effort: None,
}
} else {
let mut other_messages = Vec::new();
std::mem::swap(&mut other_messages, &mut self.messages);
let mut other_images = Vec::new();
std::mem::swap(&mut other_images, &mut self.images);
let mut other_audios = Vec::new();
std::mem::swap(&mut other_audios, &mut self.audios);
let mut other_videos = Vec::new();
std::mem::swap(&mut other_videos, &mut self.videos);
RequestMessage::MultimodalChat {
images: other_images,
messages: other_messages,
audios: other_audios,
videos: other_videos,
enable_thinking: self.enable_thinking,
reasoning_effort: None,
}
}
}
fn enable_search(&self) -> Option<bool> {
self.web_search_options.as_ref().map(|_| true)
}
fn take_logits_processors(&mut self) -> Option<Vec<Arc<dyn CustomLogitsProcessor>>> {
if self.logits_processors.is_empty() {
None
} else {
let mut other = Vec::new();
std::mem::swap(&mut other, &mut self.logits_processors);
Some(other)
}
}
fn take_adapters(&mut self) -> Option<Vec<String>> {
if self.adapters.is_empty() {
None
} else {
let mut other = Vec::new();
std::mem::swap(&mut other, &mut self.adapters);
Some(other)
}
}
fn return_logprobs(&self) -> bool {
self.return_logprobs
}
fn take_constraint(&mut self) -> Constraint {
let mut other = Constraint::None;
std::mem::swap(&mut other, &mut self.constraint);
other
}
fn take_tools(&mut self) -> Option<(Vec<Tool>, ToolChoice)> {
if self.tools.is_empty() {
None
} else {
let mut other_ts = Vec::new();
std::mem::swap(&mut other_ts, &mut self.tools);
let mut other_tc = ToolChoice::Auto;
std::mem::swap(&mut other_tc, &mut self.tool_choice);
Some((other_ts, other_tc))
}
}
fn take_sampling_params(&mut self) -> SamplingParams {
let mut other = SamplingParams::deterministic();
std::mem::swap(&mut other, &mut self.sampling_params);
other
}
fn take_web_search_options(&mut self) -> Option<WebSearchOptions> {
let mut other = None;
std::mem::swap(&mut other, &mut self.web_search_options);
other
}
fn truncate_sequence(&self) -> bool {
self.truncate_sequence
}
}
fn resolve_pending(
category: &ModelCategory,
messages: &mut [IndexMap<String, MessageContent>],
pending: &mut Vec<PendingMediaPrefix>,
) {
let prefixer = match category {
ModelCategory::Multimodal { prefixer } => prefixer,
_ => {
pending.clear();
return;
}
};
for entry in pending.drain(..) {
let Some(msg) = messages.get_mut(entry.message_index) else {
continue;
};
let Some(Either::Right(content_vec)) = msg.get_mut("content") else {
continue;
};
for part in content_vec.iter_mut() {
let is_text = part
.get("type")
.is_some_and(|v| v == &Value::String("text".to_string()));
if !is_text {
continue;
}
if let Some(Value::String(text)) = part.get_mut("text") {
if !entry.image_indices.is_empty() {
*text = prefixer.prefix_image(entry.image_indices.clone(), text);
}
if !entry.audio_indices.is_empty() {
*text = prefixer.prefix_audio(entry.audio_indices.clone(), text);
}
if !entry.video_indices.is_empty() {
*text = prefixer.prefix_video(entry.video_indices.clone(), text);
}
}
break;
}
}
}
#[derive(Clone, Debug)]
pub enum EmbeddingRequestInput {
Prompt(String),
Tokens(Vec<u32>),
}
impl EmbeddingRequestInput {
pub fn into_request_message(self) -> RequestMessage {
match self {
Self::Prompt(prompt) => RequestMessage::Embedding { prompt },
Self::Tokens(prompt) => RequestMessage::EmbeddingTokens { prompt },
}
}
}
#[derive(Clone, Debug)]
pub struct EmbeddingRequest {
pub inputs: Vec<EmbeddingRequestInput>,
pub truncate_sequence: bool,
}
impl EmbeddingRequest {
pub fn builder() -> EmbeddingRequestBuilder {
EmbeddingRequestBuilder::new()
}
}
#[derive(Clone, Debug, Default)]
pub struct EmbeddingRequestBuilder {
inputs: Vec<EmbeddingRequestInput>,
truncate_sequence: bool,
}
impl EmbeddingRequestBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_prompt(mut self, prompt: impl Into<String>) -> Self {
self.inputs
.push(EmbeddingRequestInput::Prompt(prompt.into()));
self
}
pub fn add_prompts<I, S>(mut self, prompts: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.inputs.extend(
prompts
.into_iter()
.map(|prompt| EmbeddingRequestInput::Prompt(prompt.into())),
);
self
}
pub fn add_tokens(mut self, tokens: impl Into<Vec<u32>>) -> Self {
self.inputs
.push(EmbeddingRequestInput::Tokens(tokens.into()));
self
}
pub fn add_tokens_batch<I>(mut self, batches: I) -> Self
where
I: IntoIterator<Item = Vec<u32>>,
{
self.inputs
.extend(batches.into_iter().map(EmbeddingRequestInput::Tokens));
self
}
pub fn with_truncate_sequence(mut self, truncate: bool) -> Self {
self.truncate_sequence = truncate;
self
}
pub fn build(self) -> anyhow::Result<EmbeddingRequest> {
if self.inputs.is_empty() {
anyhow::bail!("Embedding request must contain at least one input.");
}
Ok(EmbeddingRequest {
inputs: self.inputs,
truncate_sequence: self.truncate_sequence,
})
}
}