use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::shared::{Headers, ProviderMetadata, ProviderOptions, RequestInfo, ResponseInfo};
#[async_trait]
pub trait EmbeddingModel: Send + Sync + std::fmt::Debug {
fn provider(&self) -> &str;
fn model_id(&self) -> &str;
fn specification_version(&self) -> &'static str {
"v4"
}
async fn max_embeddings_per_call(&self) -> Option<u32> {
None
}
async fn supports_parallel_calls(&self) -> bool {
true
}
async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult>;
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EmbedOptions {
pub values: Vec<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub headers: Option<Headers>,
#[serde(
default,
rename = "providerOptions",
skip_serializing_if = "Option::is_none"
)]
pub provider_options: Option<ProviderOptions>,
}
pub type Embedding = Vec<f32>;
#[derive(Debug, Clone)]
pub struct EmbedResult {
pub embeddings: Vec<Embedding>,
pub usage: Option<EmbeddingUsage>,
pub provider_metadata: Option<ProviderMetadata>,
pub request: Option<RequestInfo>,
pub response: Option<ResponseInfo>,
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct EmbeddingUsage {
pub tokens: Option<u64>,
}