use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedRequest {
pub input: EmbedInput,
}
impl EmbedRequest {
pub fn new(input: impl Into<String>) -> Self {
Self {
input: EmbedInput::Single(input.into()),
}
}
pub fn new_batch(inputs: Vec<String>) -> Self {
Self {
input: EmbedInput::Batch(inputs),
}
}
pub fn from_text(text: impl Into<String>) -> Self {
Self::new(text)
}
pub fn from_texts(texts: Vec<String>) -> Self {
Self::new_batch(texts)
}
}
impl EmbedRequest {
pub fn single_input(&self) -> Option<&str> {
match &self.input {
EmbedInput::Single(text) => Some(text),
EmbedInput::Batch(_) => None,
}
}
pub fn inputs(&self) -> Vec<&str> {
match &self.input {
EmbedInput::Single(text) => vec![text],
EmbedInput::Batch(texts) => texts.iter().map(|s| s.as_str()).collect(),
}
}
pub fn is_batch(&self) -> bool {
matches!(self.input, EmbedInput::Batch(_))
}
pub fn input_count(&self) -> usize {
match &self.input {
EmbedInput::Single(_) => 1,
EmbedInput::Batch(texts) => texts.len(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum EmbedInput {
Single(String),
Batch(Vec<String>),
}
impl From<String> for EmbedInput {
fn from(text: String) -> Self {
EmbedInput::Single(text)
}
}
impl From<&str> for EmbedInput {
fn from(text: &str) -> Self {
EmbedInput::Single(text.to_string())
}
}
impl From<Vec<String>> for EmbedInput {
fn from(texts: Vec<String>) -> Self {
EmbedInput::Batch(texts)
}
}
impl From<Vec<&str>> for EmbedInput {
fn from(texts: Vec<&str>) -> Self {
EmbedInput::Batch(texts.into_iter().map(|s| s.to_string()).collect())
}
}