use crate::client::ModelClient;
use crate::error::{OllamaError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedRequest {
pub model: String,
pub input: EmbedInput,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum EmbedInput {
Single(String),
Multiple(Vec<String>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbedResponse {
pub model: String,
pub embeddings: Vec<Vec<f32>>,
#[serde(default)]
pub total_duration: u64,
#[serde(default)]
pub load_duration: u64,
#[serde(default)]
pub prompt_eval_count: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct EmbeddingsRequest {
pub model: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub keep_alive: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingsResponse {
pub embedding: Vec<f32>,
}
impl ModelClient {
pub async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse> {
let url = self
.base_url
.join("api/embed")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
pub async fn embeddings(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse> {
let url = self
.base_url
.join("api/embeddings")
.map_err(OllamaError::UrlError)?;
let response = self
.client
.post(url)
.json(&request)
.send()
.await
.map_err(OllamaError::RequestError)?;
self.handle_response(response, Some(&request.model)).await
}
}
impl Default for EmbedRequest {
fn default() -> Self {
Self {
model: String::new(),
input: EmbedInput::Single(String::new()),
truncate: None,
options: None,
keep_alive: None,
dimensions: None,
}
}
}