Skip to main content

oai_sdk/
embed.rs

1// Copyright 2026 Cloudflavor GmbH
2
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6
7// http://www.apache.org/licenses/LICENSE-2.0
8
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use crate::client::ModelClient;
16use crate::error::{OllamaError, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20/// Request for embeddings.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EmbedRequest {
23    pub model: String,
24    pub input: EmbedInput,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub truncate: Option<bool>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    pub options: Option<HashMap<String, serde_json::Value>>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub keep_alive: Option<String>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub dimensions: Option<u32>,
33}
34
35/// Input for embeddings.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(untagged)]
38pub enum EmbedInput {
39    Single(String),
40    Multiple(Vec<String>),
41}
42
43/// Response for embeddings.
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct EmbedResponse {
46    pub model: String,
47    pub embeddings: Vec<Vec<f32>>,
48    #[serde(default)]
49    pub total_duration: u64,
50    #[serde(default)]
51    pub load_duration: u64,
52    #[serde(default)]
53    pub prompt_eval_count: u32,
54}
55
56/// Request for legacy embeddings.
57#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct EmbeddingsRequest {
59    pub model: String,
60    pub prompt: String,
61    #[serde(skip_serializing_if = "Option::is_none")]
62    pub options: Option<HashMap<String, serde_json::Value>>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub keep_alive: Option<String>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub truncate: Option<bool>,
67}
68
69/// Response for legacy embeddings.
70#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct EmbeddingsResponse {
72    pub embedding: Vec<f32>,
73}
74
75impl ModelClient {
76    /// Generate embeddings from text.
77    pub async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse> {
78        let url = self
79            .base_url
80            .join("api/embed")
81            .map_err(OllamaError::UrlError)?;
82        let response = self
83            .client
84            .post(url)
85            .json(&request)
86            .send()
87            .await
88            .map_err(OllamaError::RequestError)?;
89
90        self.handle_response(response, Some(&request.model)).await
91    }
92
93    /// Generate legacy embeddings from text.
94    pub async fn embeddings(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse> {
95        let url = self
96            .base_url
97            .join("api/embeddings")
98            .map_err(OllamaError::UrlError)?;
99        let response = self
100            .client
101            .post(url)
102            .json(&request)
103            .send()
104            .await
105            .map_err(OllamaError::RequestError)?;
106
107        self.handle_response(response, Some(&request.model)).await
108    }
109}
110
111impl Default for EmbedRequest {
112    fn default() -> Self {
113        Self {
114            model: String::new(),
115            input: EmbedInput::Single(String::new()),
116            truncate: None,
117            options: None,
118            keep_alive: None,
119            dimensions: None,
120        }
121    }
122}