use crate::core::models::openai::{EmbeddingRequest, EmbeddingResponse};
use crate::core::providers::ProviderError;
use crate::core::types::{
context::RequestContext, embedding::EmbeddingInput,
embedding::EmbeddingRequest as CoreEmbeddingRequest, model::ProviderCapability,
};
use crate::server::state::AppState;
use crate::utils::error::gateway_error::GatewayError;
use actix_web::{HttpRequest, HttpResponse, Result as ActixResult, web};
use tracing::info;
use super::context::handle_ai_request;
use super::provider_selection::select_provider_for_model;
pub async fn embeddings(
state: web::Data<AppState>,
req: HttpRequest,
request: web::Json<EmbeddingRequest>,
) -> ActixResult<HttpResponse> {
info!("Embedding request for model: {}", request.model);
handle_ai_request(
&req,
request.into_inner(),
"Embedding",
|request, context| handle_embedding_with_state(state.get_ref(), request, context),
)
.await
}
pub async fn handle_embedding_with_state(
state: &AppState,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse, GatewayError> {
let unified_router = &state.unified_router;
handle_embedding_internal(unified_router, request, context).await
}
async fn handle_embedding_internal(
unified_router: &crate::core::router::UnifiedRouter,
request: EmbeddingRequest,
context: RequestContext,
) -> Result<EmbeddingResponse, GatewayError> {
let input = match &request.input {
serde_json::Value::String(s) => EmbeddingInput::Text(s.clone()),
serde_json::Value::Array(arr) => {
let texts: Vec<String> = arr
.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect();
EmbeddingInput::Array(texts)
}
_ => {
return Err(GatewayError::validation(
"Invalid input: expected string or array of strings",
));
}
};
select_provider_for_model(
unified_router,
&request.model,
ProviderCapability::Embeddings,
)?;
let requested_model = request.model.clone();
let core_request = CoreEmbeddingRequest {
model: requested_model,
input,
user: request.user,
encoding_format: None,
dimensions: None,
task_type: None,
};
let requested_model = core_request.model.clone();
let context_for_execution = context.clone();
let execution = unified_router
.execute_with_retry(&requested_model, move |deployment_id| {
let core_request = core_request.clone();
let context = context_for_execution.clone();
async move {
let deployment =
unified_router
.get_deployment(&deployment_id)
.ok_or_else(|| {
ProviderError::other("router", "Selected deployment not found")
})?;
let provider = deployment.provider.clone();
let selected_model = deployment.model.clone();
drop(deployment);
let mut request_for_provider = core_request.clone();
request_for_provider.model = selected_model;
let response = provider
.create_embeddings(request_for_provider, context)
.await?;
let tokens = response
.usage
.as_ref()
.map(|usage| u64::from(usage.total_tokens))
.unwrap_or_default();
Ok((response, tokens))
}
})
.await
.map_err(|(e, _)| GatewayError::Provider(e))?;
let core_response = execution.0;
let response = EmbeddingResponse {
object: core_response.object,
data: core_response
.data
.into_iter()
.map(|d| crate::core::models::openai::EmbeddingObject {
object: d.object,
embedding: d.embedding.into_iter().map(|f| f as f64).collect(),
index: d.index,
})
.collect(),
model: core_response.model,
usage: crate::core::models::openai::EmbeddingUsage {
prompt_tokens: core_response
.usage
.as_ref()
.map(|u| u.prompt_tokens)
.unwrap_or(0),
total_tokens: core_response
.usage
.as_ref()
.map(|u| u.total_tokens)
.unwrap_or(0),
},
};
Ok(response)
}