use std::fmt;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use rand::Rng;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::{OnceCell, RwLock};
use tonic::transport::{Channel, Endpoint};
pub mod tei_proto {
tonic::include_proto!("tei.v1");
}
use tei_proto::{
embed_client::EmbedClient, info_client::InfoClient, rerank_client::RerankClient,
tokenize_client::TokenizeClient, DecodeRequest, EmbedRequest, EncodeRequest, InfoRequest,
RerankRequest, TruncationDirection,
};
#[derive(Error, Debug)]
pub enum TeiError {
#[error("Connection failed: {0}")]
Connection(String),
#[error("gRPC error: {0}")]
Grpc(#[from] tonic::Status),
#[error("Transport error: {0}")]
Transport(#[from] tonic::transport::Error),
#[error("Empty response from server")]
EmptyResponse,
#[error("All retry attempts failed: {0}")]
RetryExhausted(String),
#[error("Invalid configuration: {0}")]
Config(String),
#[error("Partial response: expected {expected} results, received {received}")]
PartialResponse {
expected: usize,
received: usize,
},
#[error("Empty embedding received for text at index {index}")]
EmptyEmbedding {
index: usize,
},
#[error("Dimension mismatch at index {index}: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
index: usize,
},
#[error("Client has been closed")]
ClientClosed,
}
pub type Result<T> = std::result::Result<T, TeiError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerInfo {
pub version: String,
pub sha: Option<String>,
pub docker_label: Option<String>,
pub model_id: String,
pub model_sha: Option<String>,
pub model_dtype: String,
pub model_type: ModelType,
pub max_concurrent_requests: u32,
pub max_input_length: u32,
pub max_batch_tokens: u32,
pub max_batch_requests: Option<u32>,
pub max_client_batch_size: u32,
pub tokenization_workers: u32,
}
impl ServerInfo {
#[must_use]
pub fn is_fp16(&self) -> bool {
let dtype = self.model_dtype.to_lowercase();
(dtype.contains("float16") || dtype.contains("fp16") || dtype.contains("half"))
&& !dtype.contains("bfloat")
}
#[must_use]
pub fn is_bf16(&self) -> bool {
self.model_dtype.to_lowercase().contains("bf16")
|| self.model_dtype.to_lowercase().contains("bfloat")
}
#[must_use]
pub fn description(&self) -> String {
format!(
"{} v{} ({}, {})",
self.model_id, self.version, self.model_dtype, self.model_type
)
}
#[must_use]
pub fn effective_batch_size(&self) -> u32 {
match self.max_batch_requests {
Some(max_requests) => self.max_client_batch_size.min(max_requests),
None => self.max_client_batch_size,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelType {
Embedding,
Classifier,
Reranker,
Unknown,
}
impl From<i32> for ModelType {
fn from(val: i32) -> Self {
match val {
0 => Self::Embedding,
1 => Self::Classifier,
2 => Self::Reranker,
_ => Self::Unknown,
}
}
}
impl fmt::Display for ModelType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Embedding => write!(f, "embedding"),
Self::Classifier => write!(f, "classifier"),
Self::Reranker => write!(f, "reranker"),
Self::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone)]
pub struct Token {
pub id: u32,
pub text: String,
pub special: bool,
pub start: Option<u32>,
pub stop: Option<u32>,
}
#[derive(Debug, Clone)]
pub struct RerankResult {
pub index: usize,
pub score: f32,
pub text: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SparseValue {
pub index: u32,
pub value: f32,
}
#[derive(Debug, Clone)]
pub struct SparseEmbedding {
pub values: Vec<SparseValue>,
}
impl SparseEmbedding {
#[inline]
pub fn len(&self) -> usize {
self.values.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn dot(&self, other: &SparseEmbedding) -> f32 {
let mut i = 0;
let mut j = 0;
let mut result = 0.0f32;
while i < self.values.len() && j < other.values.len() {
match self.values[i].index.cmp(&other.values[j].index) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result += self.values[i].value * other.values[j].value;
i += 1;
j += 1;
}
}
}
result
}
pub fn sort_by_index(&mut self) {
self.values.sort_unstable_by_key(|v| v.index);
}
}
#[derive(Debug, Clone, Default)]
pub struct EmbedMetadata {
pub compute_chars: u32,
pub compute_tokens: u32,
pub total_time_ns: u64,
pub tokenization_time_ns: u64,
pub queue_time_ns: u64,
pub inference_time_ns: u64,
}
#[derive(Debug, Clone)]
pub struct TeiClientConfig {
pub endpoint: String,
pub timeout_secs: u64,
pub max_attempts: u32,
pub retry_base_delay_ms: u64,
pub retry_max_delay_ms: u64,
pub batch_token_budget: usize,
pub keepalive_secs: u64,
}
impl Default for TeiClientConfig {
fn default() -> Self {
Self {
endpoint: "http://localhost:18080".to_string(),
timeout_secs: 120,
max_attempts: 3,
retry_base_delay_ms: 100,
retry_max_delay_ms: 5000,
batch_token_budget: 8192,
keepalive_secs: 30,
}
}
}
impl TeiClientConfig {
pub fn from_env() -> Self {
let host = std::env::var("BRRR_TEI_HOST").unwrap_or_else(|_| "localhost".to_string());
let port = std::env::var("BRRR_TEI_PORT")
.ok()
.and_then(|p| p.parse().ok())
.unwrap_or(18080u16);
Self {
endpoint: format!("http://{host}:{port}"),
..Default::default()
}
}
#[must_use]
pub const fn retry_count(&self) -> u32 {
self.max_attempts.saturating_sub(1)
}
}
fn calculate_backoff_with_jitter(base_delay_ms: u64, max_delay_ms: u64, attempt: u32) -> Duration {
let backoff_ms = base_delay_ms.saturating_mul(1 << attempt.min(6));
let capped_delay_ms = backoff_ms.min(max_delay_ms);
let max_jitter_ms = capped_delay_ms / 4;
let jitter_ms = if max_jitter_ms > 0 {
rand::thread_rng().gen_range(0..=max_jitter_ms)
} else {
0
};
Duration::from_millis(capped_delay_ms + jitter_ms)
}
pub struct TeiClient {
config: TeiClientConfig,
channel: RwLock<Channel>,
server_info: OnceCell<ServerInfo>,
closed: AtomicBool,
}
impl TeiClient {
pub async fn new(endpoint: &str) -> Result<Self> {
let config = TeiClientConfig {
endpoint: endpoint.to_string(),
..Default::default()
};
Self::with_config(config).await
}
pub async fn with_config(config: TeiClientConfig) -> Result<Self> {
let channel = Self::create_channel(&config).await?;
Ok(Self {
config,
channel: RwLock::new(channel),
server_info: OnceCell::const_new(),
closed: AtomicBool::new(false),
})
}
pub async fn from_env() -> Result<Self> {
Self::with_config(TeiClientConfig::from_env()).await
}
pub async fn close(&self) -> Result<()> {
if self.closed.swap(true, Ordering::SeqCst) {
return Ok(());
}
let mut channel = self.channel.write().await;
*channel = Channel::from_static("http://[::1]:1").connect_lazy();
Ok(())
}
#[inline]
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::SeqCst)
}
#[inline]
fn ensure_not_closed(&self) -> Result<()> {
if self.is_closed() {
return Err(TeiError::ClientClosed);
}
Ok(())
}
async fn create_channel(config: &TeiClientConfig) -> Result<Channel> {
let endpoint = Endpoint::from_shared(config.endpoint.clone())
.map_err(|e| TeiError::Config(e.to_string()))?
.timeout(Duration::from_secs(config.timeout_secs))
.tcp_keepalive(Some(Duration::from_secs(config.keepalive_secs)))
.http2_keep_alive_interval(Duration::from_secs(config.keepalive_secs))
.keep_alive_timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(10));
endpoint
.connect()
.await
.map_err(|e| TeiError::Connection(e.to_string()))
}
async fn get_channel(&self) -> Channel {
self.channel.read().await.clone()
}
async fn reconnect(&self) -> Result<Channel> {
let new_channel = Self::create_channel(&self.config).await?;
let mut channel = self.channel.write().await;
*channel = new_channel.clone();
Ok(new_channel)
}
fn is_channel_error(&self, error: &TeiError) -> bool {
match error {
TeiError::Connection(_) | TeiError::Transport(_) => true,
TeiError::Grpc(status) => matches!(
status.code(),
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted
),
_ => false,
}
}
pub async fn info(&self) -> Result<ServerInfo> {
self.ensure_not_closed()?;
let info = self
.with_retry_reconnect(|channel| async move {
let mut client = InfoClient::new(channel);
let response = client.info(InfoRequest {}).await?;
Ok(response.into_inner())
})
.await?;
let to_option = |s: String| if s.is_empty() { None } else { Some(s) };
Ok(ServerInfo {
version: info.version,
sha: info.sha.and_then(|s| to_option(s)),
docker_label: info.docker_label.and_then(|s| to_option(s)),
model_id: info.model_id,
model_sha: info.model_sha.and_then(|s| to_option(s)),
model_dtype: info.model_dtype,
model_type: ModelType::from(info.model_type),
max_concurrent_requests: info.max_concurrent_requests,
max_input_length: info.max_input_length,
max_batch_tokens: info.max_batch_tokens,
max_batch_requests: info.max_batch_requests,
max_client_batch_size: info.max_client_batch_size,
tokenization_workers: info.tokenization_workers,
})
}
pub async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
self.embed_with_options(texts, true, true, None).await
}
pub async fn embed_with_options(
&self,
texts: &[&str],
normalize: bool,
truncate: bool,
dimensions: Option<u32>,
) -> Result<Vec<Vec<f32>>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
self.embed_stream_with_retry(texts, normalize, truncate, dimensions)
.await
}
pub async fn embed_single(
&self,
text: &str,
normalize: bool,
truncate: bool,
dimensions: Option<u32>,
) -> Result<(Vec<f32>, EmbedMetadata)> {
self.ensure_not_closed()?;
let client = EmbedClient::new(self.get_channel().await);
let request = EmbedRequest {
inputs: text.to_string(),
truncate,
normalize,
truncation_direction: TruncationDirection::Right.into(),
prompt_name: None,
dimensions,
};
let response = self
.with_retry(|| async {
let mut client = client.clone();
client.embed(request.clone()).await
})
.await?;
let inner = response.into_inner();
let metadata = inner
.metadata
.map(|m| EmbedMetadata {
compute_chars: m.compute_chars,
compute_tokens: m.compute_tokens,
total_time_ns: m.total_time_ns,
tokenization_time_ns: m.tokenization_time_ns,
queue_time_ns: m.queue_time_ns,
inference_time_ns: m.inference_time_ns,
})
.unwrap_or_default();
Ok((inner.embeddings, metadata))
}
async fn embed_stream_with_retry(
&self,
texts: &[&str],
normalize: bool,
truncate: bool,
dimensions: Option<u32>,
) -> Result<Vec<Vec<f32>>> {
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
match self
.embed_stream_once(texts, normalize, truncate, dimensions)
.await
{
Ok(results) => return Ok(results),
Err(TeiError::Grpc(status)) => {
match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status.to_string());
}
_ => {
last_error = Some(status.to_string());
}
}
}
Err(e) => {
if self.is_channel_error(&e) {
should_reconnect = true;
}
last_error = Some(e.to_string());
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
async fn embed_stream_once(
&self,
texts: &[&str],
normalize: bool,
truncate: bool,
dimensions: Option<u32>,
) -> Result<Vec<Vec<f32>>> {
let mut client = EmbedClient::new(self.get_channel().await);
let expected_count = texts.len();
let requests: Vec<EmbedRequest> = texts
.iter()
.map(|text| EmbedRequest {
inputs: (*text).to_string(),
truncate,
normalize,
truncation_direction: TruncationDirection::Right.into(),
prompt_name: None,
dimensions,
})
.collect();
let request_stream = tokio_stream::iter(requests);
let response = client.embed_stream(request_stream).await?;
let mut results = Vec::with_capacity(expected_count);
let mut stream = response.into_inner();
let mut expected_dims: Option<usize> = None;
use tokio_stream::StreamExt;
let mut index = 0usize;
while let Some(resp) = stream.next().await {
let resp = resp?;
if resp.embeddings.is_empty() {
return Err(TeiError::EmptyEmbedding { index });
}
let actual_dims = resp.embeddings.len();
match expected_dims {
None => {
expected_dims = Some(actual_dims);
}
Some(expected) if actual_dims != expected => {
return Err(TeiError::DimensionMismatch {
expected,
actual: actual_dims,
index,
});
}
Some(_) => {
}
}
results.push(resp.embeddings);
index += 1;
}
if results.len() != expected_count {
return Err(TeiError::PartialResponse {
expected: expected_count,
received: results.len(),
});
}
Ok(results)
}
pub async fn embed_batch(
&self,
texts: &[&str],
batch_size: Option<usize>,
) -> Result<Vec<Vec<f32>>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
let server_info = self.get_cached_server_info().await?;
let max_batch = batch_size
.unwrap_or(server_info.max_client_batch_size as usize)
.min(server_info.max_client_batch_size as usize);
let token_counts = self.count_tokens_batch(texts).await?;
let token_budget = self.config.batch_token_budget;
let batches = Self::build_token_batches(texts, &token_counts, token_budget, max_batch);
let mut all_results = Vec::with_capacity(texts.len());
for batch in batches {
let batch_results = self.embed(&batch).await?;
all_results.extend(batch_results);
}
Ok(all_results)
}
fn build_token_batches<'a>(
texts: &[&'a str],
token_counts: &[usize],
token_budget: usize,
max_batch: usize,
) -> Vec<Vec<&'a str>> {
let mut batches = Vec::new();
let mut current_batch = Vec::new();
let mut current_tokens = 0usize;
for (text, &count) in texts.iter().zip(token_counts.iter()) {
if !current_batch.is_empty()
&& (current_tokens + count > token_budget || current_batch.len() >= max_batch)
{
batches.push(std::mem::take(&mut current_batch));
current_tokens = 0;
}
current_batch.push(*text);
current_tokens += count;
}
if !current_batch.is_empty() {
batches.push(current_batch);
}
batches
}
pub async fn embed_sparse(&self, text: &str) -> Result<SparseEmbedding> {
let results = self.embed_sparse_batch(&[text]).await?;
results
.into_iter()
.next()
.ok_or(TeiError::EmptyResponse)
}
pub async fn embed_sparse_single(
&self,
text: &str,
truncate: bool,
) -> Result<(SparseEmbedding, EmbedMetadata)> {
self.ensure_not_closed()?;
let request = tei_proto::EmbedSparseRequest {
inputs: text.to_string(),
truncate,
truncation_direction: TruncationDirection::Right.into(),
prompt_name: None,
};
let response = self
.with_retry_reconnect(|channel| {
let req = request.clone();
async move {
let mut client = EmbedClient::new(channel);
client.embed_sparse(req).await
}
})
.await?;
let inner = response.into_inner();
let values: Vec<SparseValue> = inner
.sparse_embeddings
.into_iter()
.map(|sv| SparseValue {
index: sv.index,
value: sv.value,
})
.collect();
let metadata = inner
.metadata
.map(|m| EmbedMetadata {
compute_chars: m.compute_chars,
compute_tokens: m.compute_tokens,
total_time_ns: m.total_time_ns,
tokenization_time_ns: m.tokenization_time_ns,
queue_time_ns: m.queue_time_ns,
inference_time_ns: m.inference_time_ns,
})
.unwrap_or_default();
Ok((SparseEmbedding { values }, metadata))
}
pub async fn embed_sparse_batch(&self, texts: &[&str]) -> Result<Vec<SparseEmbedding>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
self.embed_sparse_stream_with_retry(texts, true).await
}
pub async fn embed_sparse_with_options(
&self,
texts: &[&str],
truncate: bool,
) -> Result<Vec<SparseEmbedding>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
self.embed_sparse_stream_with_retry(texts, truncate).await
}
async fn embed_sparse_stream_with_retry(
&self,
texts: &[&str],
truncate: bool,
) -> Result<Vec<SparseEmbedding>> {
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
match self.embed_sparse_stream_once(texts, truncate).await {
Ok(results) => return Ok(results),
Err(TeiError::Grpc(status)) => {
match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status.to_string());
}
_ => {
last_error = Some(status.to_string());
}
}
}
Err(e) => {
if self.is_channel_error(&e) {
should_reconnect = true;
}
last_error = Some(e.to_string());
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
async fn embed_sparse_stream_once(
&self,
texts: &[&str],
truncate: bool,
) -> Result<Vec<SparseEmbedding>> {
let mut client = EmbedClient::new(self.get_channel().await);
let expected_count = texts.len();
let requests: Vec<tei_proto::EmbedSparseRequest> = texts
.iter()
.map(|text| tei_proto::EmbedSparseRequest {
inputs: (*text).to_string(),
truncate,
truncation_direction: TruncationDirection::Right.into(),
prompt_name: None,
})
.collect();
let request_stream = tokio_stream::iter(requests);
let response = client.embed_sparse_stream(request_stream).await?;
let mut results = Vec::with_capacity(expected_count);
let mut stream = response.into_inner();
use tokio_stream::StreamExt;
while let Some(resp) = stream.next().await {
let resp = resp?;
let values: Vec<SparseValue> = resp
.sparse_embeddings
.into_iter()
.map(|sv| SparseValue {
index: sv.index,
value: sv.value,
})
.collect();
results.push(SparseEmbedding { values });
}
if results.len() != expected_count {
return Err(TeiError::PartialResponse {
expected: expected_count,
received: results.len(),
});
}
Ok(results)
}
pub async fn tokenize(&self, text: &str, add_special_tokens: bool) -> Result<Vec<Token>> {
self.ensure_not_closed()?;
let client = TokenizeClient::new(self.get_channel().await);
let request = EncodeRequest {
inputs: text.to_string(),
add_special_tokens,
prompt_name: None,
};
let response = self
.with_retry(|| async {
let mut client = client.clone();
client.tokenize(request.clone()).await
})
.await?;
let tokens = response
.into_inner()
.tokens
.into_iter()
.map(|t| Token {
id: t.id,
text: t.text,
special: t.special,
start: t.start,
stop: t.stop,
})
.collect();
Ok(tokens)
}
pub async fn count_tokens(&self, text: &str) -> Result<usize> {
let tokens = self.tokenize(text, false).await?;
Ok(tokens.len())
}
pub async fn count_tokens_batch(&self, texts: &[&str]) -> Result<Vec<usize>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
self.count_tokens_batch_with_retry(texts).await
}
async fn count_tokens_batch_with_retry(&self, texts: &[&str]) -> Result<Vec<usize>> {
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
match self.count_tokens_batch_once(texts).await {
Ok(counts) => return Ok(counts),
Err(TeiError::Grpc(status)) => match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status.to_string());
}
_ => {
last_error = Some(status.to_string());
}
},
Err(e) => {
if self.is_channel_error(&e) {
should_reconnect = true;
}
last_error = Some(e.to_string());
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
async fn count_tokens_batch_once(&self, texts: &[&str]) -> Result<Vec<usize>> {
let mut client = TokenizeClient::new(self.get_channel().await);
let expected_count = texts.len();
let requests: Vec<EncodeRequest> = texts
.iter()
.map(|text| EncodeRequest {
inputs: (*text).to_string(),
add_special_tokens: false,
prompt_name: None,
})
.collect();
let request_stream = tokio_stream::iter(requests);
let response = client.tokenize_stream(request_stream).await?;
let mut counts = Vec::with_capacity(expected_count);
let mut stream = response.into_inner();
use tokio_stream::StreamExt;
while let Some(resp) = stream.next().await {
let resp = resp?;
counts.push(resp.tokens.len());
}
if counts.len() != expected_count {
return Err(TeiError::PartialResponse {
expected: expected_count,
received: counts.len(),
});
}
Ok(counts)
}
pub async fn decode(&self, token_ids: &[u32]) -> Result<String> {
self.decode_with_options(token_ids, false).await
}
pub async fn decode_with_options(
&self,
token_ids: &[u32],
skip_special_tokens: bool,
) -> Result<String> {
self.ensure_not_closed()?;
let client = TokenizeClient::new(self.get_channel().await);
let request = DecodeRequest {
ids: token_ids.to_vec(),
skip_special_tokens,
};
let response = self
.with_retry(|| async {
let mut client = client.clone();
client.decode(request.clone()).await
})
.await?;
Ok(response.into_inner().text)
}
pub async fn decode_batch(
&self,
token_id_batches: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>> {
self.ensure_not_closed()?;
if token_id_batches.is_empty() {
return Ok(Vec::new());
}
self.decode_batch_with_retry(token_id_batches, skip_special_tokens)
.await
}
async fn decode_batch_with_retry(
&self,
token_id_batches: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>> {
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
match self
.decode_batch_once(token_id_batches, skip_special_tokens)
.await
{
Ok(results) => return Ok(results),
Err(TeiError::Grpc(status)) => match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status.to_string());
}
_ => {
last_error = Some(status.to_string());
}
},
Err(e) => {
if self.is_channel_error(&e) {
should_reconnect = true;
}
last_error = Some(e.to_string());
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error.unwrap_or_else(|| "Unknown error".to_string()),
))
}
async fn decode_batch_once(
&self,
token_id_batches: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>> {
let mut client = TokenizeClient::new(self.get_channel().await);
let expected_count = token_id_batches.len();
let requests: Vec<DecodeRequest> = token_id_batches
.iter()
.map(|ids| DecodeRequest {
ids: ids.to_vec(),
skip_special_tokens,
})
.collect();
let request_stream = tokio_stream::iter(requests);
let response = client.decode_stream(request_stream).await?;
let mut results = Vec::with_capacity(expected_count);
let mut stream = response.into_inner();
use tokio_stream::StreamExt;
while let Some(resp) = stream.next().await {
let resp = resp?;
results.push(resp.text);
}
if results.len() != expected_count {
return Err(TeiError::PartialResponse {
expected: expected_count,
received: results.len(),
});
}
Ok(results)
}
pub async fn rerank(
&self,
query: &str,
texts: &[String],
truncate: bool,
return_text: bool,
) -> Result<Vec<RerankResult>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
if query.is_empty() {
return Err(TeiError::Config("query cannot be empty".to_string()));
}
let client = RerankClient::new(self.get_channel().await);
let request = RerankRequest {
query: query.to_string(),
texts: texts.to_vec(),
truncate,
raw_scores: false, return_text,
truncation_direction: TruncationDirection::Right.into(),
};
let response = self
.with_retry(|| async {
let mut client = client.clone();
client.rerank(request.clone()).await
})
.await?;
let inner = response.into_inner();
let results: Vec<RerankResult> = inner
.ranks
.into_iter()
.map(|rank| RerankResult {
index: rank.index as usize,
score: rank.score,
text: rank.text,
})
.collect();
Ok(results)
}
pub async fn rerank_with_options(
&self,
query: &str,
texts: &[String],
truncate: bool,
return_text: bool,
raw_scores: bool,
) -> Result<Vec<RerankResult>> {
self.ensure_not_closed()?;
if texts.is_empty() {
return Ok(Vec::new());
}
if query.is_empty() {
return Err(TeiError::Config("query cannot be empty".to_string()));
}
let client = RerankClient::new(self.get_channel().await);
let request = RerankRequest {
query: query.to_string(),
texts: texts.to_vec(),
truncate,
raw_scores,
return_text,
truncation_direction: TruncationDirection::Right.into(),
};
let response = self
.with_retry(|| async {
let mut client = client.clone();
client.rerank(request.clone()).await
})
.await?;
let inner = response.into_inner();
let results: Vec<RerankResult> = inner
.ranks
.into_iter()
.map(|rank| RerankResult {
index: rank.index as usize,
score: rank.score,
text: rank.text,
})
.collect();
Ok(results)
}
pub async fn is_available(&self) -> bool {
self.info().await.is_ok()
}
async fn get_cached_server_info(&self) -> Result<ServerInfo> {
self.server_info
.get_or_try_init(|| self.info())
.await
.cloned()
}
async fn with_retry<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, tonic::Status>>,
{
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
match operation().await {
Ok(result) => return Ok(result),
Err(status) => {
match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status);
}
_ => {
last_error = Some(status);
}
}
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
))
}
async fn with_retry_reconnect<F, Fut, T>(&self, operation: F) -> Result<T>
where
F: Fn(Channel) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, tonic::Status>>,
{
let mut attempt = 0;
let mut last_error = None;
let mut should_reconnect = false;
while attempt < self.config.max_attempts {
if should_reconnect {
let _ = self.reconnect().await;
should_reconnect = false;
}
let channel = self.get_channel().await;
match operation(channel).await {
Ok(result) => return Ok(result),
Err(status) => {
match status.code() {
tonic::Code::InvalidArgument
| tonic::Code::NotFound
| tonic::Code::AlreadyExists
| tonic::Code::PermissionDenied
| tonic::Code::Unauthenticated => {
return Err(TeiError::Grpc(status));
}
tonic::Code::Unavailable
| tonic::Code::Unknown
| tonic::Code::Internal
| tonic::Code::Aborted => {
should_reconnect = true;
last_error = Some(status);
}
_ => {
last_error = Some(status);
}
}
}
}
attempt += 1;
if attempt < self.config.max_attempts {
let backoff = calculate_backoff_with_jitter(
self.config.retry_base_delay_ms,
self.config.retry_max_delay_ms,
attempt,
);
tokio::time::sleep(backoff).await;
}
}
Err(TeiError::RetryExhausted(
last_error
.map(|e| e.to_string())
.unwrap_or_else(|| "Unknown error".to_string()),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_default() {
let config = TeiClientConfig::default();
assert_eq!(config.endpoint, "http://localhost:18080");
assert_eq!(config.timeout_secs, 120);
assert_eq!(config.max_attempts, 3);
}
#[test]
fn test_max_attempts_semantics() {
let config = TeiClientConfig {
max_attempts: 3,
..Default::default()
};
assert_eq!(config.max_attempts, 3);
assert_eq!(config.retry_count(), 2);
let no_retry_config = TeiClientConfig {
max_attempts: 1,
..Default::default()
};
assert_eq!(no_retry_config.retry_count(), 0);
let disabled_config = TeiClientConfig {
max_attempts: 0,
..Default::default()
};
assert_eq!(disabled_config.retry_count(), 0);
}
#[test]
fn test_model_type_from_i32() {
assert_eq!(ModelType::from(0), ModelType::Embedding);
assert_eq!(ModelType::from(1), ModelType::Classifier);
assert_eq!(ModelType::from(2), ModelType::Reranker);
assert_eq!(ModelType::from(99), ModelType::Unknown);
}
#[test]
fn test_model_type_display() {
assert_eq!(ModelType::Embedding.to_string(), "embedding");
assert_eq!(ModelType::Classifier.to_string(), "classifier");
assert_eq!(ModelType::Reranker.to_string(), "reranker");
assert_eq!(ModelType::Unknown.to_string(), "unknown");
}
#[test]
fn test_server_info_is_fp16() {
let info = ServerInfo {
version: "1.0.0".to_string(),
sha: None,
docker_label: None,
model_id: "test-model".to_string(),
model_sha: None,
model_dtype: "float16".to_string(),
model_type: ModelType::Embedding,
max_concurrent_requests: 100,
max_input_length: 512,
max_batch_tokens: 16384,
max_batch_requests: Some(32),
max_client_batch_size: 32,
tokenization_workers: 4,
};
assert!(info.is_fp16());
assert!(!info.is_bf16());
let info_f32 = ServerInfo {
model_dtype: "float32".to_string(),
..info.clone()
};
assert!(!info_f32.is_fp16());
let info_half = ServerInfo {
model_dtype: "half".to_string(),
..info.clone()
};
assert!(info_half.is_fp16());
}
#[test]
fn test_server_info_is_bf16() {
let info = ServerInfo {
version: "1.0.0".to_string(),
sha: None,
docker_label: None,
model_id: "test-model".to_string(),
model_sha: None,
model_dtype: "bfloat16".to_string(),
model_type: ModelType::Embedding,
max_concurrent_requests: 100,
max_input_length: 512,
max_batch_tokens: 16384,
max_batch_requests: None,
max_client_batch_size: 32,
tokenization_workers: 4,
};
assert!(info.is_bf16());
assert!(!info.is_fp16());
}
#[test]
fn test_server_info_description() {
let info = ServerInfo {
version: "1.2.3".to_string(),
sha: Some("abc123".to_string()),
docker_label: None,
model_id: "BAAI/bge-large-en-v1.5".to_string(),
model_sha: None,
model_dtype: "float16".to_string(),
model_type: ModelType::Embedding,
max_concurrent_requests: 100,
max_input_length: 512,
max_batch_tokens: 16384,
max_batch_requests: Some(32),
max_client_batch_size: 32,
tokenization_workers: 4,
};
let desc = info.description();
assert!(desc.contains("BAAI/bge-large-en-v1.5"));
assert!(desc.contains("1.2.3"));
assert!(desc.contains("float16"));
assert!(desc.contains("embedding"));
}
#[test]
fn test_server_info_effective_batch_size() {
let info = ServerInfo {
version: "1.0.0".to_string(),
sha: None,
docker_label: None,
model_id: "test".to_string(),
model_sha: None,
model_dtype: "float32".to_string(),
model_type: ModelType::Embedding,
max_concurrent_requests: 100,
max_input_length: 512,
max_batch_tokens: 16384,
max_batch_requests: Some(16),
max_client_batch_size: 32,
tokenization_workers: 4,
};
assert_eq!(info.effective_batch_size(), 16);
let info_no_limit = ServerInfo {
max_batch_requests: None,
..info.clone()
};
assert_eq!(info_no_limit.effective_batch_size(), 32); }
#[test]
fn test_server_info_serde() {
let info = ServerInfo {
version: "1.0.0".to_string(),
sha: Some("abc123".to_string()),
docker_label: None,
model_id: "test-model".to_string(),
model_sha: Some("def456".to_string()),
model_dtype: "float16".to_string(),
model_type: ModelType::Embedding,
max_concurrent_requests: 100,
max_input_length: 512,
max_batch_tokens: 16384,
max_batch_requests: Some(32),
max_client_batch_size: 32,
tokenization_workers: 4,
};
let json = serde_json::to_string(&info).expect("serialize");
let deserialized: ServerInfo = serde_json::from_str(&json).expect("deserialize");
assert_eq!(deserialized.version, info.version);
assert_eq!(deserialized.sha, info.sha);
assert_eq!(deserialized.model_id, info.model_id);
assert_eq!(deserialized.model_dtype, info.model_dtype);
assert_eq!(deserialized.model_type, info.model_type);
assert_eq!(deserialized.max_concurrent_requests, info.max_concurrent_requests);
}
#[test]
fn test_build_token_batches_simple() {
let texts = vec!["a", "b", "c", "d"];
let token_counts = vec![100, 100, 100, 100];
let batches = TeiClient::build_token_batches(&texts, &token_counts, 250, 100);
assert_eq!(batches.len(), 2);
assert_eq!(batches[0], vec!["a", "b"]);
assert_eq!(batches[1], vec!["c", "d"]);
}
#[test]
fn test_build_token_batches_max_size() {
let texts = vec!["a", "b", "c", "d", "e"];
let token_counts = vec![10, 10, 10, 10, 10];
let batches = TeiClient::build_token_batches(&texts, &token_counts, 10000, 2);
assert_eq!(batches.len(), 3);
assert_eq!(batches[0], vec!["a", "b"]);
assert_eq!(batches[1], vec!["c", "d"]);
assert_eq!(batches[2], vec!["e"]);
}
#[test]
fn test_build_token_batches_empty() {
let texts: Vec<&str> = vec![];
let token_counts: Vec<usize> = vec![];
let batches = TeiClient::build_token_batches(&texts, &token_counts, 1000, 10);
assert!(batches.is_empty());
}
#[test]
fn test_build_token_batches_single_large() {
let texts = vec!["large"];
let token_counts = vec![5000];
let batches = TeiClient::build_token_batches(&texts, &token_counts, 1000, 10);
assert_eq!(batches.len(), 1);
assert_eq!(batches[0], vec!["large"]);
}
#[test]
fn test_decode_request_construction() {
let token_ids: Vec<u32> = vec![101, 7592, 2088, 102];
let request = super::tei_proto::DecodeRequest {
ids: token_ids.clone(),
skip_special_tokens: false,
};
assert_eq!(request.ids, token_ids);
assert!(!request.skip_special_tokens);
let request_with_skip = super::tei_proto::DecodeRequest {
ids: token_ids.clone(),
skip_special_tokens: true,
};
assert!(request_with_skip.skip_special_tokens);
}
#[test]
fn test_decode_request_empty_ids() {
let request = super::tei_proto::DecodeRequest {
ids: Vec::new(),
skip_special_tokens: false,
};
assert!(request.ids.is_empty());
}
#[test]
fn test_rerank_request_construction() {
let query = "programming languages".to_string();
let texts = vec!["Rust is fast".to_string(), "Python is easy".to_string()];
let request = super::tei_proto::RerankRequest {
query: query.clone(),
texts: texts.clone(),
truncate: true,
raw_scores: false,
return_text: true,
truncation_direction: TruncationDirection::Right.into(),
};
assert_eq!(request.query, query);
assert_eq!(request.texts, texts);
assert!(request.truncate);
assert!(!request.raw_scores);
assert!(request.return_text);
}
#[test]
fn test_rerank_request_with_raw_scores() {
let request = super::tei_proto::RerankRequest {
query: "test query".to_string(),
texts: vec!["text1".to_string()],
truncate: false,
raw_scores: true,
return_text: false,
truncation_direction: TruncationDirection::Left.into(),
};
assert!(!request.truncate);
assert!(request.raw_scores);
assert!(!request.return_text);
}
#[test]
fn test_rerank_result_struct() {
let result = RerankResult {
index: 5,
score: 0.95,
text: Some("sample text".to_string()),
};
assert_eq!(result.index, 5);
assert!((result.score - 0.95).abs() < f32::EPSILON);
assert_eq!(result.text, Some("sample text".to_string()));
let result_no_text = RerankResult {
index: 0,
score: 0.5,
text: None,
};
assert!(result_no_text.text.is_none());
}
#[test]
fn test_rerank_result_clone() {
let original = RerankResult {
index: 10,
score: 0.8,
text: Some("cloneable".to_string()),
};
let cloned = original.clone();
assert_eq!(cloned.index, original.index);
assert_eq!(cloned.score, original.score);
assert_eq!(cloned.text, original.text);
}
#[test]
fn test_embed_metadata_default() {
let meta = EmbedMetadata::default();
assert_eq!(meta.compute_chars, 0);
assert_eq!(meta.compute_tokens, 0);
assert_eq!(meta.total_time_ns, 0);
assert_eq!(meta.tokenization_time_ns, 0);
assert_eq!(meta.queue_time_ns, 0);
assert_eq!(meta.inference_time_ns, 0);
}
#[test]
fn test_embed_metadata_construction() {
let meta = EmbedMetadata {
compute_chars: 100,
compute_tokens: 25,
total_time_ns: 50_000_000,
tokenization_time_ns: 5_000_000,
queue_time_ns: 1_000_000,
inference_time_ns: 44_000_000,
};
assert_eq!(meta.compute_chars, 100);
assert_eq!(meta.compute_tokens, 25);
assert_eq!(meta.total_time_ns, 50_000_000);
assert_eq!(meta.tokenization_time_ns, 5_000_000);
assert_eq!(meta.queue_time_ns, 1_000_000);
assert_eq!(meta.inference_time_ns, 44_000_000);
}
#[test]
fn test_embed_metadata_clone() {
let original = EmbedMetadata {
compute_chars: 42,
compute_tokens: 10,
total_time_ns: 100_000,
tokenization_time_ns: 10_000,
queue_time_ns: 5_000,
inference_time_ns: 85_000,
};
let cloned = original.clone();
assert_eq!(cloned.compute_chars, original.compute_chars);
assert_eq!(cloned.compute_tokens, original.compute_tokens);
assert_eq!(cloned.total_time_ns, original.total_time_ns);
assert_eq!(cloned.tokenization_time_ns, original.tokenization_time_ns);
assert_eq!(cloned.queue_time_ns, original.queue_time_ns);
assert_eq!(cloned.inference_time_ns, original.inference_time_ns);
}
#[test]
fn test_partial_response_error_construction() {
let err = TeiError::PartialResponse {
expected: 10,
received: 7,
};
let msg = err.to_string();
assert!(msg.contains("10"));
assert!(msg.contains("7"));
assert!(msg.contains("expected"));
assert!(msg.contains("received"));
}
#[test]
fn test_partial_response_error_debug() {
let err = TeiError::PartialResponse {
expected: 100,
received: 0,
};
let debug = format!("{:?}", err);
assert!(debug.contains("PartialResponse"));
assert!(debug.contains("100"));
assert!(debug.contains("0"));
}
#[test]
fn test_empty_embedding_error_construction() {
let err = TeiError::EmptyEmbedding { index: 5 };
let msg = err.to_string();
assert!(msg.contains("5"));
assert!(msg.to_lowercase().contains("empty"));
assert!(msg.contains("index"));
}
#[test]
fn test_empty_embedding_error_debug() {
let err = TeiError::EmptyEmbedding { index: 42 };
let debug = format!("{:?}", err);
assert!(debug.contains("EmptyEmbedding"));
assert!(debug.contains("42"));
}
#[test]
fn test_dimension_mismatch_error_construction() {
let err = TeiError::DimensionMismatch {
expected: 768,
actual: 512,
index: 3,
};
let msg = err.to_string();
assert!(msg.contains("768"), "Should contain expected dimension");
assert!(msg.contains("512"), "Should contain actual dimension");
assert!(msg.contains("3"), "Should contain index");
assert!(msg.to_lowercase().contains("mismatch"), "Should indicate mismatch");
}
#[test]
fn test_dimension_mismatch_error_debug() {
let err = TeiError::DimensionMismatch {
expected: 1024,
actual: 256,
index: 7,
};
let debug = format!("{:?}", err);
assert!(debug.contains("DimensionMismatch"));
assert!(debug.contains("1024"));
assert!(debug.contains("256"));
assert!(debug.contains("7"));
}
#[test]
fn test_dimension_mismatch_error_clone() {
let original = TeiError::DimensionMismatch {
expected: 384,
actual: 768,
index: 0,
};
let debug = format!("{:?}", original);
assert!(debug.contains("DimensionMismatch"));
}
#[test]
fn test_error_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<TeiError>();
assert_sync::<TeiError>();
}
#[test]
fn test_backoff_with_jitter_produces_random_values() {
let base_delay = 100u64;
let max_delay = 5000u64;
let attempt = 2u32;
let mut results: Vec<Duration> = Vec::new();
for _ in 0..10 {
results.push(calculate_backoff_with_jitter(base_delay, max_delay, attempt));
}
let first = results[0];
let has_variation = results.iter().any(|d| *d != first);
let expected_base = base_delay * (1 << attempt); let max_with_jitter = expected_base + (expected_base / 4);
for duration in &results {
let ms = duration.as_millis() as u64;
assert!(ms >= expected_base, "Backoff should be at least {expected_base}ms, got {ms}ms");
assert!(ms <= max_with_jitter, "Backoff should be at most {max_with_jitter}ms, got {ms}ms");
}
if !has_variation {
eprintln!("Warning: No variation found in 10 jitter samples - extremely unlikely but possible");
}
}
#[test]
fn test_backoff_exponential_growth() {
let base_delay = 100u64;
let max_delay = 10000u64;
let backoff_1 = calculate_backoff_with_jitter(base_delay, max_delay, 1);
let backoff_2 = calculate_backoff_with_jitter(base_delay, max_delay, 2);
let backoff_3 = calculate_backoff_with_jitter(base_delay, max_delay, 3);
assert!(backoff_1.as_millis() >= 200, "Attempt 1 should be >= 200ms");
assert!(backoff_1.as_millis() <= 250, "Attempt 1 should be <= 250ms");
assert!(backoff_2.as_millis() >= 400, "Attempt 2 should be >= 400ms");
assert!(backoff_2.as_millis() <= 500, "Attempt 2 should be <= 500ms");
assert!(backoff_3.as_millis() >= 800, "Attempt 3 should be >= 800ms");
assert!(backoff_3.as_millis() <= 1000, "Attempt 3 should be <= 1000ms");
}
#[test]
fn test_backoff_respects_max_delay() {
let base_delay = 1000u64;
let max_delay = 2000u64;
let backoff = calculate_backoff_with_jitter(base_delay, max_delay, 10);
assert!(backoff.as_millis() <= 2500, "Backoff should be capped at max_delay + jitter");
assert!(backoff.as_millis() >= 2000, "Backoff should be at least max_delay");
}
#[test]
fn test_backoff_handles_zero_base_delay() {
let backoff = calculate_backoff_with_jitter(0, 5000, 5);
assert_eq!(backoff.as_millis(), 0);
}
#[test]
fn test_sparse_value_construction() {
let sv = SparseValue {
index: 1234,
value: 0.567,
};
assert_eq!(sv.index, 1234);
assert!((sv.value - 0.567).abs() < f32::EPSILON);
}
#[test]
fn test_sparse_value_equality() {
let sv1 = SparseValue {
index: 100,
value: 0.5,
};
let sv2 = SparseValue {
index: 100,
value: 0.5,
};
let sv3 = SparseValue {
index: 101,
value: 0.5,
};
let sv4 = SparseValue {
index: 100,
value: 0.6,
};
assert_eq!(sv1, sv2);
assert_ne!(sv1, sv3);
assert_ne!(sv1, sv4);
}
#[test]
fn test_sparse_value_clone() {
let original = SparseValue {
index: 42,
value: 0.99,
};
let cloned = original.clone();
assert_eq!(original, cloned);
}
#[test]
fn test_sparse_embedding_empty() {
let emb = SparseEmbedding { values: vec![] };
assert!(emb.is_empty());
assert_eq!(emb.len(), 0);
}
#[test]
fn test_sparse_embedding_with_values() {
let emb = SparseEmbedding {
values: vec![
SparseValue {
index: 1,
value: 0.5,
},
SparseValue {
index: 5,
value: 0.3,
},
SparseValue {
index: 10,
value: 0.8,
},
],
};
assert!(!emb.is_empty());
assert_eq!(emb.len(), 3);
}
#[test]
fn test_sparse_embedding_dot_product_identical() {
let emb = SparseEmbedding {
values: vec![
SparseValue {
index: 1,
value: 2.0,
},
SparseValue {
index: 3,
value: 3.0,
},
SparseValue {
index: 5,
value: 4.0,
},
],
};
let dot = emb.dot(&emb);
assert!((dot - 29.0).abs() < f32::EPSILON);
}
#[test]
fn test_sparse_embedding_dot_product_orthogonal() {
let emb1 = SparseEmbedding {
values: vec![
SparseValue {
index: 1,
value: 1.0,
},
SparseValue {
index: 3,
value: 1.0,
},
],
};
let emb2 = SparseEmbedding {
values: vec![
SparseValue {
index: 2,
value: 1.0,
},
SparseValue {
index: 4,
value: 1.0,
},
],
};
let dot = emb1.dot(&emb2);
assert!((dot - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_sparse_embedding_dot_product_partial_overlap() {
let emb1 = SparseEmbedding {
values: vec![
SparseValue {
index: 1,
value: 2.0,
},
SparseValue {
index: 3,
value: 3.0,
},
SparseValue {
index: 5,
value: 4.0,
},
],
};
let emb2 = SparseEmbedding {
values: vec![
SparseValue {
index: 2,
value: 1.0,
},
SparseValue {
index: 3,
value: 2.0,
},
SparseValue {
index: 5,
value: 1.5,
},
],
};
let dot = emb1.dot(&emb2);
assert!((dot - 12.0).abs() < f32::EPSILON);
}
#[test]
fn test_sparse_embedding_dot_product_empty() {
let emb1 = SparseEmbedding { values: vec![] };
let emb2 = SparseEmbedding {
values: vec![SparseValue {
index: 1,
value: 1.0,
}],
};
assert!((emb1.dot(&emb2) - 0.0).abs() < f32::EPSILON);
assert!((emb2.dot(&emb1) - 0.0).abs() < f32::EPSILON);
}
#[test]
fn test_sparse_embedding_sort_by_index() {
let mut emb = SparseEmbedding {
values: vec![
SparseValue {
index: 10,
value: 0.1,
},
SparseValue {
index: 2,
value: 0.2,
},
SparseValue {
index: 5,
value: 0.5,
},
],
};
emb.sort_by_index();
assert_eq!(emb.values[0].index, 2);
assert_eq!(emb.values[1].index, 5);
assert_eq!(emb.values[2].index, 10);
}
#[test]
fn test_sparse_embedding_clone() {
let original = SparseEmbedding {
values: vec![
SparseValue {
index: 1,
value: 0.5,
},
SparseValue {
index: 2,
value: 0.7,
},
],
};
let cloned = original.clone();
assert_eq!(original.len(), cloned.len());
for (orig, cln) in original.values.iter().zip(cloned.values.iter()) {
assert_eq!(orig, cln);
}
}
#[test]
fn test_embed_sparse_request_construction() {
let request = super::tei_proto::EmbedSparseRequest {
inputs: "test input".to_string(),
truncate: true,
truncation_direction: TruncationDirection::Right.into(),
prompt_name: None,
};
assert_eq!(request.inputs, "test input");
assert!(request.truncate);
assert!(request.prompt_name.is_none());
}
#[test]
fn test_embed_sparse_request_with_prompt_name() {
let request = super::tei_proto::EmbedSparseRequest {
inputs: "document text".to_string(),
truncate: false,
truncation_direction: TruncationDirection::Left.into(),
prompt_name: Some("passage".to_string()),
};
assert_eq!(request.inputs, "document text");
assert!(!request.truncate);
assert_eq!(request.prompt_name, Some("passage".to_string()));
}
#[test]
fn test_sparse_embedding_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<SparseEmbedding>();
assert_sync::<SparseEmbedding>();
assert_send::<SparseValue>();
assert_sync::<SparseValue>();
}
#[test]
fn test_sparse_embedding_debug() {
let emb = SparseEmbedding {
values: vec![SparseValue {
index: 42,
value: 0.123,
}],
};
let debug = format!("{:?}", emb);
assert!(debug.contains("SparseEmbedding"));
assert!(debug.contains("42"));
}
#[test]
fn test_client_closed_error_construction() {
let err = TeiError::ClientClosed;
let msg = err.to_string();
assert!(msg.to_lowercase().contains("closed"));
}
#[test]
fn test_client_closed_error_debug() {
let err = TeiError::ClientClosed;
let debug = format!("{:?}", err);
assert!(debug.contains("ClientClosed"));
}
#[tokio::test]
async fn test_client_is_closed_initially_false() {
let config = TeiClientConfig {
endpoint: "http://localhost:1".to_string(),
..Default::default()
};
let channel = Channel::from_static("http://[::1]:1").connect_lazy();
let client = TeiClient {
config,
channel: RwLock::new(channel),
server_info: OnceCell::const_new(),
closed: AtomicBool::new(false),
};
assert!(!client.is_closed());
}
#[tokio::test]
async fn test_client_close_sets_closed_flag() {
let config = TeiClientConfig {
endpoint: "http://localhost:1".to_string(),
..Default::default()
};
let channel = Channel::from_static("http://[::1]:1").connect_lazy();
let client = TeiClient {
config,
channel: RwLock::new(channel),
server_info: OnceCell::const_new(),
closed: AtomicBool::new(false),
};
assert!(!client.is_closed());
let result = client.close().await;
assert!(result.is_ok());
assert!(client.is_closed());
}
#[tokio::test]
async fn test_client_close_is_idempotent() {
let config = TeiClientConfig {
endpoint: "http://localhost:1".to_string(),
..Default::default()
};
let channel = Channel::from_static("http://[::1]:1").connect_lazy();
let client = TeiClient {
config,
channel: RwLock::new(channel),
server_info: OnceCell::const_new(),
closed: AtomicBool::new(false),
};
let result1 = client.close().await;
let result2 = client.close().await;
let result3 = client.close().await;
assert!(result1.is_ok());
assert!(result2.is_ok());
assert!(result3.is_ok());
assert!(client.is_closed());
}
#[tokio::test]
async fn test_ensure_not_closed_returns_error_when_closed() {
let config = TeiClientConfig {
endpoint: "http://localhost:1".to_string(),
..Default::default()
};
let channel = Channel::from_static("http://[::1]:1").connect_lazy();
let client = TeiClient {
config,
channel: RwLock::new(channel),
server_info: OnceCell::const_new(),
closed: AtomicBool::new(false),
};
assert!(client.ensure_not_closed().is_ok());
client.close().await.unwrap();
let result = client.ensure_not_closed();
assert!(matches!(result, Err(TeiError::ClientClosed)));
}
#[test]
fn test_tei_client_is_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<TeiClient>();
assert_sync::<TeiClient>();
}
}