Skip to main content

fierros_core/
providers.rs

1use crate::http_client::{JsonHttpClient, JsonHttpRequest, ReqwestJsonHttpClient};
2use crate::{
3    CompletionRequest, CompletionResponse, Embedder, FierrosError, FierrosResult, Llm, MessageRole,
4    TokenUsage,
5};
6use async_trait::async_trait;
7use serde_json::{json, Value};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct OpenAiCompatibleLlmConfig {
11    pub base_url: String,
12    pub model: String,
13    pub api_key: Option<String>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct OpenAiCompatibleEmbedderConfig {
18    pub base_url: String,
19    pub model: String,
20    pub api_key: Option<String>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct OllamaCompatibleLlmConfig {
25    pub base_url: String,
26    pub model: String,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct OllamaCompatibleEmbedderConfig {
31    pub base_url: String,
32    pub model: String,
33}
34
35#[derive(Debug, Clone)]
36pub struct OpenAiCompatibleLlm<C = ReqwestJsonHttpClient> {
37    config: OpenAiCompatibleLlmConfig,
38    client: C,
39}
40
41#[derive(Debug, Clone)]
42pub struct OpenAiCompatibleEmbedder<C = ReqwestJsonHttpClient> {
43    config: OpenAiCompatibleEmbedderConfig,
44    client: C,
45}
46
47#[derive(Debug, Clone)]
48pub struct OllamaCompatibleLlm<C = ReqwestJsonHttpClient> {
49    config: OllamaCompatibleLlmConfig,
50    client: C,
51}
52
53#[derive(Debug, Clone)]
54pub struct OllamaCompatibleEmbedder<C = ReqwestJsonHttpClient> {
55    config: OllamaCompatibleEmbedderConfig,
56    client: C,
57}
58
59impl OpenAiCompatibleLlm<ReqwestJsonHttpClient> {
60    pub fn new(config: OpenAiCompatibleLlmConfig) -> Self {
61        Self::with_client(config, ReqwestJsonHttpClient::default())
62    }
63}
64
65impl<C> OpenAiCompatibleLlm<C> {
66    pub fn with_client(config: OpenAiCompatibleLlmConfig, client: C) -> Self {
67        Self { config, client }
68    }
69}
70
71impl OpenAiCompatibleEmbedder<ReqwestJsonHttpClient> {
72    pub fn new(config: OpenAiCompatibleEmbedderConfig) -> Self {
73        Self::with_client(config, ReqwestJsonHttpClient::default())
74    }
75}
76
77impl<C> OpenAiCompatibleEmbedder<C> {
78    pub fn with_client(config: OpenAiCompatibleEmbedderConfig, client: C) -> Self {
79        Self { config, client }
80    }
81}
82
83impl OllamaCompatibleLlm<ReqwestJsonHttpClient> {
84    pub fn new(config: OllamaCompatibleLlmConfig) -> Self {
85        Self::with_client(config, ReqwestJsonHttpClient::default())
86    }
87}
88
89impl<C> OllamaCompatibleLlm<C> {
90    pub fn with_client(config: OllamaCompatibleLlmConfig, client: C) -> Self {
91        Self { config, client }
92    }
93}
94
95impl OllamaCompatibleEmbedder<ReqwestJsonHttpClient> {
96    pub fn new(config: OllamaCompatibleEmbedderConfig) -> Self {
97        Self::with_client(config, ReqwestJsonHttpClient::default())
98    }
99}
100
101impl<C> OllamaCompatibleEmbedder<C> {
102    pub fn with_client(config: OllamaCompatibleEmbedderConfig, client: C) -> Self {
103        Self { config, client }
104    }
105}
106
107#[async_trait]
108impl<C> Llm for OpenAiCompatibleLlm<C>
109where
110    C: JsonHttpClient,
111{
112    async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
113        validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
114
115        let body = json!({
116            "model": self.config.model,
117            "messages": request
118                .messages
119                .into_iter()
120                .map(|message| {
121                    json!({
122                        "role": message_role_to_wire(&message.role),
123                        "content": message.content
124                    })
125                })
126                .collect::<Vec<_>>(),
127            "temperature": request.temperature,
128            "max_tokens": request.max_tokens,
129        });
130        let response = self
131            .client
132            .post_json(JsonHttpRequest {
133                url: provider_url(&self.config.base_url, "/v1/chat/completions"),
134                headers: bearer_auth_headers(self.config.api_key.as_deref()),
135                body,
136            })
137            .await?;
138
139        if let Some(error_message) = extract_provider_error(&response) {
140            return Err(FierrosError::Provider(error_message));
141        }
142
143        let content = response
144            .get("choices")
145            .and_then(Value::as_array)
146            .and_then(|choices| choices.first())
147            .and_then(|choice| choice.get("message"))
148            .and_then(|message| message.get("content"))
149            .and_then(Value::as_str)
150            .ok_or_else(|| FierrosError::Provider("missing 'choices[0].message.content'".into()))?
151            .to_string();
152
153        Ok(CompletionResponse {
154            content,
155            usage: parse_openai_usage(&response),
156        })
157    }
158}
159
160#[async_trait]
161impl<C> Embedder for OpenAiCompatibleEmbedder<C>
162where
163    C: JsonHttpClient,
164{
165    async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
166        validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
167        if inputs.is_empty() {
168            return Err(FierrosError::InvalidInput(
169                "embedding inputs must not be empty".into(),
170            ));
171        }
172
173        let response = self
174            .client
175            .post_json(JsonHttpRequest {
176                url: provider_url(&self.config.base_url, "/v1/embeddings"),
177                headers: bearer_auth_headers(self.config.api_key.as_deref()),
178                body: json!({
179                    "model": self.config.model,
180                    "input": inputs,
181                }),
182            })
183            .await?;
184
185        if let Some(error_message) = extract_provider_error(&response) {
186            return Err(FierrosError::Provider(error_message));
187        }
188
189        let data = response
190            .get("data")
191            .and_then(Value::as_array)
192            .ok_or_else(|| {
193                FierrosError::Provider("missing 'data' array in embeddings response".into())
194            })?;
195
196        let embeddings = data
197            .iter()
198            .map(|item| {
199                parse_embedding_array(item.get("embedding").ok_or_else(|| {
200                    FierrosError::Provider(
201                        "missing 'data[*].embedding' in embeddings response".into(),
202                    )
203                })?)
204            })
205            .collect::<FierrosResult<Vec<_>>>()?;
206
207        if embeddings.len() != inputs.len() {
208            return Err(FierrosError::Provider(format!(
209                "embedder returned {} embeddings for {} inputs",
210                embeddings.len(),
211                inputs.len()
212            )));
213        }
214
215        Ok(embeddings)
216    }
217}
218
219#[async_trait]
220impl<C> Llm for OllamaCompatibleLlm<C>
221where
222    C: JsonHttpClient,
223{
224    async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
225        validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
226
227        let response = self
228            .client
229            .post_json(JsonHttpRequest {
230                url: provider_url(&self.config.base_url, "/api/chat"),
231                headers: Vec::new(),
232                body: json!({
233                    "model": self.config.model,
234                    "stream": false,
235                    "messages": request.messages.into_iter().map(|message| {
236                        json!({
237                            "role": message_role_to_wire(&message.role),
238                            "content": message.content
239                        })
240                    }).collect::<Vec<_>>(),
241                    "options": {
242                        "temperature": request.temperature,
243                        "num_predict": request.max_tokens
244                    }
245                }),
246            })
247            .await?;
248
249        if let Some(error_message) = extract_provider_error(&response) {
250            return Err(FierrosError::Provider(error_message));
251        }
252
253        let content = response
254            .get("message")
255            .and_then(|message| message.get("content"))
256            .and_then(Value::as_str)
257            .ok_or_else(|| FierrosError::Provider("missing 'message.content'".into()))?
258            .to_string();
259
260        let usage = match (
261            response.get("prompt_eval_count").and_then(Value::as_u64),
262            response.get("eval_count").and_then(Value::as_u64),
263        ) {
264            (Some(input_tokens), Some(output_tokens)) => Some(TokenUsage {
265                input_tokens: input_tokens as u32,
266                output_tokens: output_tokens as u32,
267            }),
268            _ => None,
269        };
270
271        Ok(CompletionResponse { content, usage })
272    }
273}
274
275#[async_trait]
276impl<C> Embedder for OllamaCompatibleEmbedder<C>
277where
278    C: JsonHttpClient,
279{
280    async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
281        validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
282        if inputs.is_empty() {
283            return Err(FierrosError::InvalidInput(
284                "embedding inputs must not be empty".into(),
285            ));
286        }
287
288        let response = self
289            .client
290            .post_json(JsonHttpRequest {
291                url: provider_url(&self.config.base_url, "/api/embed"),
292                headers: Vec::new(),
293                body: json!({
294                    "model": self.config.model,
295                    "input": inputs,
296                }),
297            })
298            .await?;
299
300        if let Some(error_message) = extract_provider_error(&response) {
301            return Err(FierrosError::Provider(error_message));
302        }
303
304        if let Some(embeddings) = response.get("embeddings").and_then(Value::as_array) {
305            let parsed = embeddings
306                .iter()
307                .map(parse_embedding_array)
308                .collect::<FierrosResult<Vec<_>>>()?;
309
310            if parsed.len() != inputs.len() {
311                return Err(FierrosError::Provider(format!(
312                    "embedder returned {} embeddings for {} inputs",
313                    parsed.len(),
314                    inputs.len()
315                )));
316            }
317
318            return Ok(parsed);
319        }
320
321        if let Some(embedding) = response.get("embedding") {
322            if inputs.len() != 1 {
323                return Err(FierrosError::Provider(
324                    "single 'embedding' response shape is only valid for one input".into(),
325                ));
326            }
327            return Ok(vec![parse_embedding_array(embedding)?]);
328        }
329
330        Err(FierrosError::Provider(
331            "missing 'embeddings' or 'embedding' in Ollama response".into(),
332        ))
333    }
334}
335
336fn validate_model_and_base_url(model: &str, base_url: &str) -> FierrosResult<()> {
337    if model.trim().is_empty() {
338        return Err(FierrosError::Configuration(
339            "provider model must not be empty".into(),
340        ));
341    }
342    if base_url.trim().is_empty() {
343        return Err(FierrosError::Configuration(
344            "provider base URL must not be empty".into(),
345        ));
346    }
347    Ok(())
348}
349
350fn provider_url(base_url: &str, path: &str) -> String {
351    format!("{}{}", base_url.trim_end_matches('/'), path)
352}
353
354fn bearer_auth_headers(api_key: Option<&str>) -> Vec<(String, String)> {
355    match api_key.filter(|value| !value.trim().is_empty()) {
356        Some(value) => vec![("Authorization".into(), format!("Bearer {value}"))],
357        None => Vec::new(),
358    }
359}
360
361fn message_role_to_wire(role: &MessageRole) -> &'static str {
362    match role {
363        MessageRole::System => "system",
364        MessageRole::User => "user",
365        MessageRole::Assistant => "assistant",
366        MessageRole::Tool => "tool",
367    }
368}
369
370fn parse_openai_usage(response: &Value) -> Option<TokenUsage> {
371    let usage = response.get("usage")?;
372    let input_tokens = usage.get("prompt_tokens")?.as_u64()?;
373    let output_tokens = usage.get("completion_tokens")?.as_u64()?;
374    Some(TokenUsage {
375        input_tokens: input_tokens as u32,
376        output_tokens: output_tokens as u32,
377    })
378}
379
380fn parse_embedding_array(value: &Value) -> FierrosResult<Vec<f32>> {
381    let values = value
382        .as_array()
383        .ok_or_else(|| FierrosError::Provider("embedding field must be an array".into()))?;
384
385    values
386        .iter()
387        .map(|item| {
388            item.as_f64().map(|number| number as f32).ok_or_else(|| {
389                FierrosError::Provider("embedding vector must contain numeric values".into())
390            })
391        })
392        .collect()
393}
394
395fn extract_provider_error(response: &Value) -> Option<String> {
396    response
397        .get("error")
398        .and_then(|error| {
399            error
400                .get("message")
401                .and_then(Value::as_str)
402                .or_else(|| error.as_str())
403        })
404        .map(std::string::ToString::to_string)
405}
406
407#[cfg(test)]
408mod tests {
409    use super::{
410        OllamaCompatibleEmbedder, OllamaCompatibleEmbedderConfig, OllamaCompatibleLlm,
411        OllamaCompatibleLlmConfig, OpenAiCompatibleEmbedder, OpenAiCompatibleEmbedderConfig,
412        OpenAiCompatibleLlm, OpenAiCompatibleLlmConfig,
413    };
414    use crate::http_client::{JsonHttpClient, JsonHttpRequest};
415    use crate::{CompletionRequest, Embedder, FierrosError, FierrosResult, Llm};
416    use serde_json::{json, Value};
417    use std::collections::VecDeque;
418    use std::sync::{Arc, Mutex};
419
420    #[derive(Debug, Clone, PartialEq)]
421    struct CapturedRequest {
422        url: String,
423        headers: Vec<(String, String)>,
424        body: Value,
425    }
426
427    #[derive(Clone, Default)]
428    struct StubHttpClient {
429        captured: Arc<Mutex<Vec<CapturedRequest>>>,
430        responses: Arc<Mutex<VecDeque<FierrosResult<Value>>>>,
431    }
432
433    impl StubHttpClient {
434        fn with_responses(responses: Vec<FierrosResult<Value>>) -> Self {
435            Self {
436                captured: Arc::new(Mutex::new(Vec::new())),
437                responses: Arc::new(Mutex::new(responses.into())),
438            }
439        }
440
441        fn captured(&self) -> Vec<CapturedRequest> {
442            self.captured.lock().expect("captured lock").clone()
443        }
444    }
445
446    #[async_trait::async_trait]
447    impl JsonHttpClient for StubHttpClient {
448        async fn post_json(&self, request: JsonHttpRequest) -> FierrosResult<Value> {
449            self.captured
450                .lock()
451                .expect("captured lock")
452                .push(CapturedRequest {
453                    url: request.url,
454                    headers: request.headers,
455                    body: request.body,
456                });
457
458            self.responses
459                .lock()
460                .expect("responses lock")
461                .pop_front()
462                .unwrap_or_else(|| {
463                    Err(FierrosError::Provider(
464                        "stub client exhausted responses".into(),
465                    ))
466                })
467        }
468    }
469
470    #[tokio::test]
471    async fn openai_llm_maps_completion_response_and_usage() {
472        let client = StubHttpClient::with_responses(vec![Ok(json!({
473            "choices": [{ "message": { "content": "answer text" } }],
474            "usage": { "prompt_tokens": 11, "completion_tokens": 4 }
475        }))]);
476        let llm = OpenAiCompatibleLlm::with_client(
477            OpenAiCompatibleLlmConfig {
478                base_url: "https://api.example.com/".into(),
479                model: "gpt-x".into(),
480                api_key: Some("secret".into()),
481            },
482            client.clone(),
483        );
484
485        let response = llm
486            .complete(CompletionRequest::from_user("What is new?"))
487            .await
488            .unwrap();
489        assert_eq!(response.content, "answer text");
490        assert_eq!(response.usage.unwrap().input_tokens, 11);
491
492        let captured = client.captured();
493        assert_eq!(captured.len(), 1);
494        assert_eq!(
495            captured[0].url,
496            "https://api.example.com/v1/chat/completions"
497        );
498        assert_eq!(
499            captured[0].headers,
500            vec![("Authorization".into(), "Bearer secret".into())]
501        );
502        assert_eq!(captured[0].body["model"], "gpt-x");
503    }
504
505    #[tokio::test]
506    async fn openai_llm_surfaces_provider_errors() {
507        let llm = OpenAiCompatibleLlm::with_client(
508            OpenAiCompatibleLlmConfig {
509                base_url: "https://api.example.com".into(),
510                model: "gpt-x".into(),
511                api_key: None,
512            },
513            StubHttpClient::with_responses(vec![Ok(json!({
514                "error": { "message": "invalid_api_key" }
515            }))]),
516        );
517
518        let error = llm
519            .complete(CompletionRequest::from_user("question"))
520            .await
521            .unwrap_err();
522        assert!(format!("{error}").contains("invalid_api_key"));
523    }
524
525    #[tokio::test]
526    async fn openai_embedder_maps_embedding_vectors() {
527        let client = StubHttpClient::with_responses(vec![Ok(json!({
528            "data": [
529                { "embedding": [0.1, 0.2] },
530                { "embedding": [0.3, 0.4] }
531            ]
532        }))]);
533        let embedder = OpenAiCompatibleEmbedder::with_client(
534            OpenAiCompatibleEmbedderConfig {
535                base_url: "https://api.example.com".into(),
536                model: "text-embedding-3-small".into(),
537                api_key: Some("secret".into()),
538            },
539            client.clone(),
540        );
541
542        let vectors = embedder
543            .embed(&["a".to_string(), "b".to_string()])
544            .await
545            .unwrap();
546        assert_eq!(vectors.len(), 2);
547        assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
548        assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
549
550        let captured = client.captured();
551        assert_eq!(captured[0].url, "https://api.example.com/v1/embeddings");
552    }
553
554    #[tokio::test]
555    async fn openai_embedder_detects_embedding_count_mismatch() {
556        let embedder = OpenAiCompatibleEmbedder::with_client(
557            OpenAiCompatibleEmbedderConfig {
558                base_url: "https://api.example.com".into(),
559                model: "text-embedding-3-small".into(),
560                api_key: None,
561            },
562            StubHttpClient::with_responses(vec![Ok(json!({
563                "data": [{ "embedding": [0.1, 0.2] }]
564            }))]),
565        );
566
567        let error = embedder
568            .embed(&["a".to_string(), "b".to_string()])
569            .await
570            .unwrap_err();
571        assert!(format!("{error}").contains("returned 1 embeddings for 2 inputs"));
572    }
573
574    #[tokio::test]
575    async fn ollama_llm_maps_message_and_usage() {
576        let llm = OllamaCompatibleLlm::with_client(
577            OllamaCompatibleLlmConfig {
578                base_url: "http://localhost:11434".into(),
579                model: "qwen2.5-coder".into(),
580            },
581            StubHttpClient::with_responses(vec![Ok(json!({
582                "message": { "content": "local answer" },
583                "prompt_eval_count": 6,
584                "eval_count": 3
585            }))]),
586        );
587
588        let response = llm
589            .complete(CompletionRequest::from_user("question"))
590            .await
591            .unwrap();
592        assert_eq!(response.content, "local answer");
593        assert_eq!(response.usage.unwrap().output_tokens, 3);
594    }
595
596    #[tokio::test]
597    async fn ollama_embedder_supports_embeddings_array_response() {
598        let embedder = OllamaCompatibleEmbedder::with_client(
599            OllamaCompatibleEmbedderConfig {
600                base_url: "http://localhost:11434".into(),
601                model: "nomic-embed-text".into(),
602            },
603            StubHttpClient::with_responses(vec![Ok(json!({
604                "embeddings": [[0.1, 0.2], [0.3, 0.4]]
605            }))]),
606        );
607
608        let vectors = embedder
609            .embed(&["a".to_string(), "b".to_string()])
610            .await
611            .unwrap();
612        assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
613        assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
614    }
615
616    #[tokio::test]
617    async fn ollama_embedder_supports_single_embedding_shape_for_one_input() {
618        let embedder = OllamaCompatibleEmbedder::with_client(
619            OllamaCompatibleEmbedderConfig {
620                base_url: "http://localhost:11434".into(),
621                model: "nomic-embed-text".into(),
622            },
623            StubHttpClient::with_responses(vec![Ok(json!({
624                "embedding": [0.1, 0.2, 0.3]
625            }))]),
626        );
627
628        let vectors = embedder.embed(&["a".to_string()]).await.unwrap();
629        assert_eq!(vectors, vec![vec![0.1_f32, 0.2_f32, 0.3_f32]]);
630    }
631
632    #[tokio::test]
633    async fn ollama_embedder_rejects_empty_inputs() {
634        let embedder = OllamaCompatibleEmbedder::with_client(
635            OllamaCompatibleEmbedderConfig {
636                base_url: "http://localhost:11434".into(),
637                model: "nomic-embed-text".into(),
638            },
639            StubHttpClient::with_responses(vec![]),
640        );
641
642        let error = embedder.embed(&[]).await.unwrap_err();
643        assert!(format!("{error}").contains("inputs must not be empty"));
644    }
645
646    async fn complete_with_trait(llm: &dyn Llm) -> String {
647        llm.complete(CompletionRequest::from_user("question"))
648            .await
649            .expect("llm response")
650            .content
651    }
652
653    async fn embed_with_trait(embedder: &dyn Embedder) -> Vec<Vec<f32>> {
654        embedder
655            .embed(&["a".to_string()])
656            .await
657            .expect("embedder response")
658    }
659
660    #[tokio::test]
661    async fn llm_adapters_are_interchangeable_behind_trait_object() {
662        let openai = OpenAiCompatibleLlm::with_client(
663            OpenAiCompatibleLlmConfig {
664                base_url: "https://api.example.com".into(),
665                model: "gpt-x".into(),
666                api_key: None,
667            },
668            StubHttpClient::with_responses(vec![Ok(json!({
669                "choices": [{ "message": { "content": "openai response" } }]
670            }))]),
671        );
672        let ollama = OllamaCompatibleLlm::with_client(
673            OllamaCompatibleLlmConfig {
674                base_url: "http://localhost:11434".into(),
675                model: "qwen2.5".into(),
676            },
677            StubHttpClient::with_responses(vec![Ok(json!({
678                "message": { "content": "ollama response" }
679            }))]),
680        );
681
682        assert_eq!(complete_with_trait(&openai).await, "openai response");
683        assert_eq!(complete_with_trait(&ollama).await, "ollama response");
684    }
685
686    #[tokio::test]
687    async fn embedder_adapters_are_interchangeable_behind_trait_object() {
688        let openai = OpenAiCompatibleEmbedder::with_client(
689            OpenAiCompatibleEmbedderConfig {
690                base_url: "https://api.example.com".into(),
691                model: "text-embedding-3-small".into(),
692                api_key: None,
693            },
694            StubHttpClient::with_responses(vec![Ok(json!({
695                "data": [{ "embedding": [0.4, 0.8] }]
696            }))]),
697        );
698        let ollama = OllamaCompatibleEmbedder::with_client(
699            OllamaCompatibleEmbedderConfig {
700                base_url: "http://localhost:11434".into(),
701                model: "nomic-embed-text".into(),
702            },
703            StubHttpClient::with_responses(vec![Ok(json!({
704                "embeddings": [[0.4, 0.8]]
705            }))]),
706        );
707
708        assert_eq!(embed_with_trait(&openai).await[0], vec![0.4_f32, 0.8_f32]);
709        assert_eq!(embed_with_trait(&ollama).await[0], vec![0.4_f32, 0.8_f32]);
710    }
711}