use crate::RequestOptions;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct EmbeddingRequest {
pub model: String,
pub inputs: Vec<String>,
pub dimensions: Option<u32>,
pub options: RequestOptions,
}
impl EmbeddingRequest {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
inputs: Vec::new(),
dimensions: None,
options: RequestOptions::new(),
}
}
#[must_use]
pub fn inputs<I, S>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.inputs = inputs.into_iter().map(Into::into).collect();
self
}
#[must_use]
pub fn input(mut self, input: impl Into<String>) -> Self {
self.inputs.push(input.into());
self
}
#[must_use]
pub fn dimensions(mut self, dimensions: u32) -> Self {
self.dimensions = Some(dimensions);
self
}
#[must_use]
pub fn with_option<T>(mut self, option: T) -> Self
where
T: Clone + Send + Sync + 'static,
{
self.options.insert(option);
self
}
pub fn option<T>(&self) -> Option<&T>
where
T: Send + Sync + 'static,
{
self.options.get::<T>()
}
}
#[cfg(test)]
mod request_tests {
use super::*;
#[test]
fn builder_sets_model_and_inputs() {
let request = EmbeddingRequest::new("text-embedding-3-small")
.input("hello")
.input("world");
assert_eq!(request.model, "text-embedding-3-small");
assert_eq!(
request.inputs,
vec!["hello".to_string(), "world".to_string()]
);
assert_eq!(request.dimensions, None);
}
#[test]
fn inputs_replaces_existing() {
let request = EmbeddingRequest::new("model").input("a").inputs(["b", "c"]);
assert_eq!(request.inputs, vec!["b".to_string(), "c".to_string()]);
}
#[test]
fn dimensions_setter_stores_value() {
let request = EmbeddingRequest::new("model").dimensions(512);
assert_eq!(request.dimensions, Some(512));
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ProviderHint {
tag: &'static str,
}
#[test]
fn typed_options_round_trip() {
let request = EmbeddingRequest::new("model").with_option(ProviderHint { tag: "retrieval" });
assert_eq!(
request.option::<ProviderHint>(),
Some(&ProviderHint { tag: "retrieval" })
);
}
}