1use crate::client::ModelClient;
16use crate::error::{OllamaError, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct EmbedRequest {
23 pub model: String,
24 pub input: EmbedInput,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub truncate: Option<bool>,
27 #[serde(skip_serializing_if = "Option::is_none")]
28 pub options: Option<HashMap<String, serde_json::Value>>,
29 #[serde(skip_serializing_if = "Option::is_none")]
30 pub keep_alive: Option<String>,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub dimensions: Option<u32>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(untagged)]
38pub enum EmbedInput {
39 Single(String),
40 Multiple(Vec<String>),
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct EmbedResponse {
46 pub model: String,
47 pub embeddings: Vec<Vec<f32>>,
48 #[serde(default)]
49 pub total_duration: u64,
50 #[serde(default)]
51 pub load_duration: u64,
52 #[serde(default)]
53 pub prompt_eval_count: u32,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Default)]
58pub struct EmbeddingsRequest {
59 pub model: String,
60 pub prompt: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 pub options: Option<HashMap<String, serde_json::Value>>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 pub keep_alive: Option<String>,
65 #[serde(skip_serializing_if = "Option::is_none")]
66 pub truncate: Option<bool>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
71pub struct EmbeddingsResponse {
72 pub embedding: Vec<f32>,
73}
74
75impl ModelClient {
76 pub async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse> {
78 let url = self
79 .base_url
80 .join("api/embed")
81 .map_err(OllamaError::UrlError)?;
82 let response = self
83 .client
84 .post(url)
85 .json(&request)
86 .send()
87 .await
88 .map_err(OllamaError::RequestError)?;
89
90 self.handle_response(response, Some(&request.model)).await
91 }
92
93 pub async fn embeddings(&self, request: EmbeddingsRequest) -> Result<EmbeddingsResponse> {
95 let url = self
96 .base_url
97 .join("api/embeddings")
98 .map_err(OllamaError::UrlError)?;
99 let response = self
100 .client
101 .post(url)
102 .json(&request)
103 .send()
104 .await
105 .map_err(OllamaError::RequestError)?;
106
107 self.handle_response(response, Some(&request.model)).await
108 }
109}
110
111impl Default for EmbedRequest {
112 fn default() -> Self {
113 Self {
114 model: String::new(),
115 input: EmbedInput::Single(String::new()),
116 truncate: None,
117 options: None,
118 keep_alive: None,
119 dimensions: None,
120 }
121 }
122}