ceres-client 0.4.0

HTTP clients for Ceres portal harvesters and embedding providers
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
//! Ollama embeddings client.
//!
//! Generates embeddings locally via [Ollama](https://ollama.com), enabling
//! fully offline operation with zero API costs.
//!
//! # Supported models
//!
//! | Model | Dimensions | Notes |
//! |-------|-----------|-------|
//! | `nomic-embed-text` | 768 | Default, matches Gemini dimension |
//! | `mxbai-embed-large` | 1024 | Higher quality |
//! | `snowflake-arctic-embed` | 1024 | Strong retrieval performance |
//! | `all-minilm` | 384 | Smallest/fastest |
//!
//! # Examples
//!
//! ```no_run
//! use ceres_client::OllamaClient;
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let client = OllamaClient::new()?;
//! let embedding = client.get_embeddings("Hello, world!").await?;
//! println!("Embedding dimension: {}", embedding.len()); // 768
//! # Ok(())
//! # }
//! ```

use ceres_core::HttpConfig;
use ceres_core::error::AppError;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;

/// Default Ollama API endpoint.
const DEFAULT_ENDPOINT: &str = "http://localhost:11434";

/// Default embedding model.
const DEFAULT_MODEL: &str = "nomic-embed-text";

/// Timeout for Ollama requests in seconds.
///
/// Overrides the default `HttpConfig` timeout because CPU inference
/// can be significantly slower than cloud API calls.
const OLLAMA_TIMEOUT_SECS: u64 = 120;

/// Returns the base model name, stripping any `:tag` suffix.
///
/// Ollama model identifiers commonly include tags (e.g., `nomic-embed-text:latest`,
/// `snowflake-arctic-embed:335m`). We match on the base name only.
fn normalize_model_name(model: &str) -> &str {
    model.split(':').next().unwrap_or(model)
}

/// Returns the embedding dimension for a known Ollama model.
///
/// Handles tagged model identifiers (e.g., `snowflake-arctic-embed:335m`).
/// For unknown models, returns 768 (the most common dimension).
pub fn model_dimension(model: &str) -> usize {
    match normalize_model_name(model) {
        "nomic-embed-text" => 768,
        "mxbai-embed-large" | "snowflake-arctic-embed" => 1024,
        "all-minilm" => 384,
        _ => {
            tracing::warn!(
                model,
                "Unknown Ollama model dimension, defaulting to 768. \
                 Set EMBEDDING_MODEL to a known model or verify dimension matches your database."
            );
            768
        }
    }
}

/// HTTP client for generating embeddings via a local Ollama instance.
///
/// Ollama runs embedding models locally with zero per-request cost,
/// making it ideal for bulk embedding, development, and self-hosted deployments.
#[derive(Clone, Debug)]
pub struct OllamaClient {
    client: Client,
    model: String,
    endpoint: String,
    dim: usize,
    timeout_secs: u64,
}

/// Request body for Ollama embed API.
#[derive(Serialize)]
struct EmbedRequest<'a> {
    model: &'a str,
    input: Vec<&'a str>,
}

/// Response from Ollama embed API.
#[derive(Deserialize)]
struct EmbedResponse {
    embeddings: Vec<Vec<f32>>,
}

/// Error response from Ollama API.
#[derive(Deserialize)]
struct OllamaErrorResponse {
    error: String,
}

impl OllamaClient {
    /// Creates a new Ollama client with default settings.
    ///
    /// Uses `nomic-embed-text` model at `http://localhost:11434`.
    pub fn new() -> Result<Self, AppError> {
        Self::with_config(DEFAULT_MODEL, None)
    }

    /// Creates a new Ollama client with a specific model.
    pub fn with_model(model: &str) -> Result<Self, AppError> {
        Self::with_config(model, None)
    }

    /// Creates a new Ollama client with full configuration.
    ///
    /// # Arguments
    ///
    /// * `model` - Ollama model name (e.g., `nomic-embed-text`)
    /// * `endpoint` - Custom Ollama API endpoint (default: `http://localhost:11434`)
    pub fn with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
        let endpoint = endpoint.unwrap_or(DEFAULT_ENDPOINT);

        // Validate endpoint URL to ensure it uses an allowed scheme
        let parsed = reqwest::Url::parse(endpoint).map_err(|e| {
            AppError::ConfigError(format!("Invalid Ollama endpoint '{}': {}", endpoint, e))
        })?;
        match parsed.scheme() {
            "http" | "https" => {}
            scheme => {
                return Err(AppError::ConfigError(format!(
                    "Invalid Ollama endpoint scheme '{}'. Only http and https are allowed.",
                    scheme
                )));
            }
        }

        // Use HttpConfig as base, override timeout for local inference
        let http_config = HttpConfig::default();
        let timeout_secs = if http_config.timeout.as_secs() < OLLAMA_TIMEOUT_SECS {
            OLLAMA_TIMEOUT_SECS
        } else {
            http_config.timeout.as_secs()
        };

        let client = Client::builder()
            .timeout(Duration::from_secs(timeout_secs))
            .build()
            .map_err(|e| AppError::ClientError(e.to_string()))?;

        // Normalize endpoint by trimming trailing slash
        let endpoint = endpoint.trim_end_matches('/').to_string();
        let dim = model_dimension(model);

        Ok(Self {
            client,
            model: model.to_string(),
            endpoint,
            dim,
            timeout_secs,
        })
    }

    /// Returns the model being used.
    pub fn model(&self) -> &str {
        &self.model
    }

    /// Generates an embedding for a single text.
    pub async fn get_embeddings(&self, text: &str) -> Result<Vec<f32>, AppError> {
        let embeddings = self.get_embeddings_batch(&[text]).await?;
        embeddings.into_iter().next().ok_or(AppError::EmptyResponse)
    }

    /// Generates embeddings for multiple texts in a single API call.
    ///
    /// Ollama's `/api/embed` endpoint supports native batching.
    pub async fn get_embeddings_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, AppError> {
        if texts.is_empty() {
            return Ok(Vec::new());
        }

        let url = format!("{}/api/embed", self.endpoint);
        let request_body = EmbedRequest {
            model: &self.model,
            input: texts.to_vec(),
        };

        let response = self
            .client
            .post(&url)
            .header("Content-Type", "application/json")
            .json(&request_body)
            .send()
            .await
            .map_err(|e| self.map_connection_error(e))?;

        let status = response.status();

        if !status.is_success() {
            let status_code = status.as_u16();
            let error_text = response.text().await.unwrap_or_default();
            return Err(self.map_api_error(status_code, &error_text));
        }

        let embed_response: EmbedResponse = response.json().await.map_err(|e| {
            AppError::ClientError(format!("Failed to parse Ollama response: {}", e))
        })?;

        Ok(embed_response.embeddings)
    }

    /// Maps reqwest connection/transport errors to AppError.
    fn map_connection_error(&self, err: reqwest::Error) -> AppError {
        if err.is_timeout() {
            AppError::Timeout(self.timeout_secs)
        } else if err.is_connect() {
            AppError::NetworkError(format!(
                "Cannot connect to Ollama at {}. Is it running? Try: ollama serve",
                self.endpoint
            ))
        } else {
            AppError::ClientError(format!("Ollama request failed: {}", err))
        }
    }

    /// Maps Ollama HTTP error responses to AppError.
    fn map_api_error(&self, status_code: u16, error_text: &str) -> AppError {
        let message = serde_json::from_str::<OllamaErrorResponse>(error_text)
            .map(|e| e.error)
            .unwrap_or_else(|_| format!("HTTP {}: {}", status_code, error_text));

        let lower_message = message.to_lowercase();
        let lower_model = self.model.to_lowercase();

        // Treat as "model not found" only when status is 404 and the message
        // clearly refers to a missing model, to avoid misclassifying generic 404s.
        let is_model_not_found = status_code == 404
            && (lower_message.contains("not found")
                && (lower_message.contains("model") || lower_message.contains(&lower_model)));

        if is_model_not_found {
            return AppError::ClientError(format!(
                "Ollama model '{}' not found. Try: ollama pull {}",
                self.model, self.model
            ));
        }

        if status_code == 404 {
            return AppError::ClientError(format!(
                "Received 404 from Ollama at {}: {}. Check that the endpoint is correct.",
                self.endpoint, message
            ));
        }

        AppError::ClientError(format!("Ollama error: {}", message))
    }
}

// =============================================================================
// Trait Implementation: EmbeddingProvider
// =============================================================================

impl ceres_core::traits::EmbeddingProvider for OllamaClient {
    fn name(&self) -> &'static str {
        "ollama"
    }

    fn dimension(&self) -> usize {
        self.dim
    }

    fn max_batch_size(&self) -> usize {
        512
    }

    async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
        self.get_embeddings(text).await
    }

    async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
        let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
        self.get_embeddings_batch(&text_refs).await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_model_dimension() {
        assert_eq!(model_dimension("nomic-embed-text"), 768);
        assert_eq!(model_dimension("mxbai-embed-large"), 1024);
        assert_eq!(model_dimension("snowflake-arctic-embed"), 1024);
        assert_eq!(model_dimension("all-minilm"), 384);
        assert_eq!(model_dimension("unknown-model"), 768); // default
    }

    #[test]
    fn test_model_dimension_with_tags() {
        assert_eq!(model_dimension("nomic-embed-text:latest"), 768);
        assert_eq!(model_dimension("snowflake-arctic-embed:335m"), 1024);
        assert_eq!(model_dimension("mxbai-embed-large:latest"), 1024);
        assert_eq!(model_dimension("all-minilm:l6-v2"), 384);
    }

    #[test]
    fn test_new_client() {
        let client = OllamaClient::new();
        assert!(client.is_ok());
        let client = client.unwrap();
        assert_eq!(client.model(), "nomic-embed-text");
        assert_eq!(client.dim, 768);
        assert_eq!(client.endpoint, "http://localhost:11434");
    }

    #[test]
    fn test_client_with_model() {
        let client = OllamaClient::with_model("mxbai-embed-large").unwrap();
        assert_eq!(client.model(), "mxbai-embed-large");
        assert_eq!(client.dim, 1024);
    }

    #[test]
    fn test_client_with_config() {
        let client =
            OllamaClient::with_config("nomic-embed-text", Some("http://myhost:11434")).unwrap();
        assert_eq!(client.endpoint, "http://myhost:11434");
        assert_eq!(client.model(), "nomic-embed-text");
    }

    #[test]
    fn test_endpoint_trailing_slash_normalized() {
        let client =
            OllamaClient::with_config("nomic-embed-text", Some("http://localhost:11434/")).unwrap();
        assert_eq!(client.endpoint, "http://localhost:11434");
    }

    #[test]
    fn test_invalid_endpoint_scheme() {
        let result = OllamaClient::with_config("nomic-embed-text", Some("ftp://localhost:11434"));
        assert!(result.is_err());
        let err = result.unwrap_err().to_string();
        assert!(err.contains("scheme"));
    }

    #[test]
    fn test_invalid_endpoint_url() {
        let result = OllamaClient::with_config("nomic-embed-text", Some("not a url"));
        assert!(result.is_err());
    }

    #[test]
    fn test_request_serialization() {
        let request = EmbedRequest {
            model: "nomic-embed-text",
            input: vec!["Hello world", "Test input"],
        };

        let json = serde_json::to_string(&request).unwrap();
        assert!(json.contains("nomic-embed-text"));
        assert!(json.contains("Hello world"));
        assert!(json.contains("Test input"));
    }

    #[test]
    fn test_trait_implementation() {
        use ceres_core::traits::EmbeddingProvider;

        let client = OllamaClient::new().unwrap();
        assert_eq!(client.name(), "ollama");
        assert_eq!(client.dimension(), 768);
        assert_eq!(client.max_batch_size(), 512);
    }

    #[test]
    fn test_map_api_error_model_not_found() {
        let client = OllamaClient::new().unwrap();

        let err = client.map_api_error(404, r#"{"error":"model \"nomic-embed-text\" not found"}"#);
        let msg = err.to_string();
        assert!(msg.contains("not found"));
        assert!(msg.contains("ollama pull"));
    }

    #[test]
    fn test_map_api_error_generic_404() {
        let client = OllamaClient::new().unwrap();

        // A generic 404 (wrong endpoint) should NOT suggest "ollama pull"
        let err = client.map_api_error(404, r#"{"error":"Not Found"}"#);
        let msg = err.to_string();
        assert!(!msg.contains("ollama pull"));
        assert!(msg.contains("endpoint"));
    }

    #[test]
    fn test_map_api_error_generic() {
        let client = OllamaClient::new().unwrap();

        let err = client.map_api_error(500, r#"{"error":"internal server error"}"#);
        let msg = err.to_string();
        assert!(msg.contains("internal server error"));
    }
}