1use crate::http_client::{JsonHttpClient, JsonHttpRequest, ReqwestJsonHttpClient};
2use crate::{
3 CompletionRequest, CompletionResponse, Embedder, FierrosError, FierrosResult, Llm, MessageRole,
4 TokenUsage,
5};
6use async_trait::async_trait;
7use serde_json::{json, Value};
8
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub struct OpenAiCompatibleLlmConfig {
11 pub base_url: String,
12 pub model: String,
13 pub api_key: Option<String>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct OpenAiCompatibleEmbedderConfig {
18 pub base_url: String,
19 pub model: String,
20 pub api_key: Option<String>,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
24pub struct OllamaCompatibleLlmConfig {
25 pub base_url: String,
26 pub model: String,
27}
28
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub struct OllamaCompatibleEmbedderConfig {
31 pub base_url: String,
32 pub model: String,
33}
34
35#[derive(Debug, Clone)]
36pub struct OpenAiCompatibleLlm<C = ReqwestJsonHttpClient> {
37 config: OpenAiCompatibleLlmConfig,
38 client: C,
39}
40
41#[derive(Debug, Clone)]
42pub struct OpenAiCompatibleEmbedder<C = ReqwestJsonHttpClient> {
43 config: OpenAiCompatibleEmbedderConfig,
44 client: C,
45}
46
47#[derive(Debug, Clone)]
48pub struct OllamaCompatibleLlm<C = ReqwestJsonHttpClient> {
49 config: OllamaCompatibleLlmConfig,
50 client: C,
51}
52
53#[derive(Debug, Clone)]
54pub struct OllamaCompatibleEmbedder<C = ReqwestJsonHttpClient> {
55 config: OllamaCompatibleEmbedderConfig,
56 client: C,
57}
58
59impl OpenAiCompatibleLlm<ReqwestJsonHttpClient> {
60 pub fn new(config: OpenAiCompatibleLlmConfig) -> Self {
61 Self::with_client(config, ReqwestJsonHttpClient::default())
62 }
63}
64
65impl<C> OpenAiCompatibleLlm<C> {
66 pub fn with_client(config: OpenAiCompatibleLlmConfig, client: C) -> Self {
67 Self { config, client }
68 }
69}
70
71impl OpenAiCompatibleEmbedder<ReqwestJsonHttpClient> {
72 pub fn new(config: OpenAiCompatibleEmbedderConfig) -> Self {
73 Self::with_client(config, ReqwestJsonHttpClient::default())
74 }
75}
76
77impl<C> OpenAiCompatibleEmbedder<C> {
78 pub fn with_client(config: OpenAiCompatibleEmbedderConfig, client: C) -> Self {
79 Self { config, client }
80 }
81}
82
83impl OllamaCompatibleLlm<ReqwestJsonHttpClient> {
84 pub fn new(config: OllamaCompatibleLlmConfig) -> Self {
85 Self::with_client(config, ReqwestJsonHttpClient::default())
86 }
87}
88
89impl<C> OllamaCompatibleLlm<C> {
90 pub fn with_client(config: OllamaCompatibleLlmConfig, client: C) -> Self {
91 Self { config, client }
92 }
93}
94
95impl OllamaCompatibleEmbedder<ReqwestJsonHttpClient> {
96 pub fn new(config: OllamaCompatibleEmbedderConfig) -> Self {
97 Self::with_client(config, ReqwestJsonHttpClient::default())
98 }
99}
100
101impl<C> OllamaCompatibleEmbedder<C> {
102 pub fn with_client(config: OllamaCompatibleEmbedderConfig, client: C) -> Self {
103 Self { config, client }
104 }
105}
106
107#[async_trait]
108impl<C> Llm for OpenAiCompatibleLlm<C>
109where
110 C: JsonHttpClient,
111{
112 async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
113 validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
114
115 let body = json!({
116 "model": self.config.model,
117 "messages": request
118 .messages
119 .into_iter()
120 .map(|message| {
121 json!({
122 "role": message_role_to_wire(&message.role),
123 "content": message.content
124 })
125 })
126 .collect::<Vec<_>>(),
127 "temperature": request.temperature,
128 "max_tokens": request.max_tokens,
129 });
130 let response = self
131 .client
132 .post_json(JsonHttpRequest {
133 url: provider_url(&self.config.base_url, "/v1/chat/completions"),
134 headers: bearer_auth_headers(self.config.api_key.as_deref()),
135 body,
136 })
137 .await?;
138
139 if let Some(error_message) = extract_provider_error(&response) {
140 return Err(FierrosError::Provider(error_message));
141 }
142
143 let content = response
144 .get("choices")
145 .and_then(Value::as_array)
146 .and_then(|choices| choices.first())
147 .and_then(|choice| choice.get("message"))
148 .and_then(|message| message.get("content"))
149 .and_then(Value::as_str)
150 .ok_or_else(|| FierrosError::Provider("missing 'choices[0].message.content'".into()))?
151 .to_string();
152
153 Ok(CompletionResponse {
154 content,
155 usage: parse_openai_usage(&response),
156 })
157 }
158}
159
160#[async_trait]
161impl<C> Embedder for OpenAiCompatibleEmbedder<C>
162where
163 C: JsonHttpClient,
164{
165 async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
166 validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
167 if inputs.is_empty() {
168 return Err(FierrosError::InvalidInput(
169 "embedding inputs must not be empty".into(),
170 ));
171 }
172
173 let response = self
174 .client
175 .post_json(JsonHttpRequest {
176 url: provider_url(&self.config.base_url, "/v1/embeddings"),
177 headers: bearer_auth_headers(self.config.api_key.as_deref()),
178 body: json!({
179 "model": self.config.model,
180 "input": inputs,
181 }),
182 })
183 .await?;
184
185 if let Some(error_message) = extract_provider_error(&response) {
186 return Err(FierrosError::Provider(error_message));
187 }
188
189 let data = response
190 .get("data")
191 .and_then(Value::as_array)
192 .ok_or_else(|| {
193 FierrosError::Provider("missing 'data' array in embeddings response".into())
194 })?;
195
196 let embeddings = data
197 .iter()
198 .map(|item| {
199 parse_embedding_array(item.get("embedding").ok_or_else(|| {
200 FierrosError::Provider(
201 "missing 'data[*].embedding' in embeddings response".into(),
202 )
203 })?)
204 })
205 .collect::<FierrosResult<Vec<_>>>()?;
206
207 if embeddings.len() != inputs.len() {
208 return Err(FierrosError::Provider(format!(
209 "embedder returned {} embeddings for {} inputs",
210 embeddings.len(),
211 inputs.len()
212 )));
213 }
214
215 Ok(embeddings)
216 }
217}
218
219#[async_trait]
220impl<C> Llm for OllamaCompatibleLlm<C>
221where
222 C: JsonHttpClient,
223{
224 async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
225 validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
226
227 let response = self
228 .client
229 .post_json(JsonHttpRequest {
230 url: provider_url(&self.config.base_url, "/api/chat"),
231 headers: Vec::new(),
232 body: json!({
233 "model": self.config.model,
234 "stream": false,
235 "messages": request.messages.into_iter().map(|message| {
236 json!({
237 "role": message_role_to_wire(&message.role),
238 "content": message.content
239 })
240 }).collect::<Vec<_>>(),
241 "options": {
242 "temperature": request.temperature,
243 "num_predict": request.max_tokens
244 }
245 }),
246 })
247 .await?;
248
249 if let Some(error_message) = extract_provider_error(&response) {
250 return Err(FierrosError::Provider(error_message));
251 }
252
253 let content = response
254 .get("message")
255 .and_then(|message| message.get("content"))
256 .and_then(Value::as_str)
257 .ok_or_else(|| FierrosError::Provider("missing 'message.content'".into()))?
258 .to_string();
259
260 let usage = match (
261 response.get("prompt_eval_count").and_then(Value::as_u64),
262 response.get("eval_count").and_then(Value::as_u64),
263 ) {
264 (Some(input_tokens), Some(output_tokens)) => Some(TokenUsage {
265 input_tokens: input_tokens as u32,
266 output_tokens: output_tokens as u32,
267 }),
268 _ => None,
269 };
270
271 Ok(CompletionResponse { content, usage })
272 }
273}
274
275#[async_trait]
276impl<C> Embedder for OllamaCompatibleEmbedder<C>
277where
278 C: JsonHttpClient,
279{
280 async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
281 validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
282 if inputs.is_empty() {
283 return Err(FierrosError::InvalidInput(
284 "embedding inputs must not be empty".into(),
285 ));
286 }
287
288 let response = self
289 .client
290 .post_json(JsonHttpRequest {
291 url: provider_url(&self.config.base_url, "/api/embed"),
292 headers: Vec::new(),
293 body: json!({
294 "model": self.config.model,
295 "input": inputs,
296 }),
297 })
298 .await?;
299
300 if let Some(error_message) = extract_provider_error(&response) {
301 return Err(FierrosError::Provider(error_message));
302 }
303
304 if let Some(embeddings) = response.get("embeddings").and_then(Value::as_array) {
305 let parsed = embeddings
306 .iter()
307 .map(parse_embedding_array)
308 .collect::<FierrosResult<Vec<_>>>()?;
309
310 if parsed.len() != inputs.len() {
311 return Err(FierrosError::Provider(format!(
312 "embedder returned {} embeddings for {} inputs",
313 parsed.len(),
314 inputs.len()
315 )));
316 }
317
318 return Ok(parsed);
319 }
320
321 if let Some(embedding) = response.get("embedding") {
322 if inputs.len() != 1 {
323 return Err(FierrosError::Provider(
324 "single 'embedding' response shape is only valid for one input".into(),
325 ));
326 }
327 return Ok(vec![parse_embedding_array(embedding)?]);
328 }
329
330 Err(FierrosError::Provider(
331 "missing 'embeddings' or 'embedding' in Ollama response".into(),
332 ))
333 }
334}
335
336fn validate_model_and_base_url(model: &str, base_url: &str) -> FierrosResult<()> {
337 if model.trim().is_empty() {
338 return Err(FierrosError::Configuration(
339 "provider model must not be empty".into(),
340 ));
341 }
342 if base_url.trim().is_empty() {
343 return Err(FierrosError::Configuration(
344 "provider base URL must not be empty".into(),
345 ));
346 }
347 Ok(())
348}
349
350fn provider_url(base_url: &str, path: &str) -> String {
351 format!("{}{}", base_url.trim_end_matches('/'), path)
352}
353
354fn bearer_auth_headers(api_key: Option<&str>) -> Vec<(String, String)> {
355 match api_key.filter(|value| !value.trim().is_empty()) {
356 Some(value) => vec![("Authorization".into(), format!("Bearer {value}"))],
357 None => Vec::new(),
358 }
359}
360
361fn message_role_to_wire(role: &MessageRole) -> &'static str {
362 match role {
363 MessageRole::System => "system",
364 MessageRole::User => "user",
365 MessageRole::Assistant => "assistant",
366 MessageRole::Tool => "tool",
367 }
368}
369
370fn parse_openai_usage(response: &Value) -> Option<TokenUsage> {
371 let usage = response.get("usage")?;
372 let input_tokens = usage.get("prompt_tokens")?.as_u64()?;
373 let output_tokens = usage.get("completion_tokens")?.as_u64()?;
374 Some(TokenUsage {
375 input_tokens: input_tokens as u32,
376 output_tokens: output_tokens as u32,
377 })
378}
379
380fn parse_embedding_array(value: &Value) -> FierrosResult<Vec<f32>> {
381 let values = value
382 .as_array()
383 .ok_or_else(|| FierrosError::Provider("embedding field must be an array".into()))?;
384
385 values
386 .iter()
387 .map(|item| {
388 item.as_f64().map(|number| number as f32).ok_or_else(|| {
389 FierrosError::Provider("embedding vector must contain numeric values".into())
390 })
391 })
392 .collect()
393}
394
395fn extract_provider_error(response: &Value) -> Option<String> {
396 response
397 .get("error")
398 .and_then(|error| {
399 error
400 .get("message")
401 .and_then(Value::as_str)
402 .or_else(|| error.as_str())
403 })
404 .map(std::string::ToString::to_string)
405}
406
407#[cfg(test)]
408mod tests {
409 use super::{
410 OllamaCompatibleEmbedder, OllamaCompatibleEmbedderConfig, OllamaCompatibleLlm,
411 OllamaCompatibleLlmConfig, OpenAiCompatibleEmbedder, OpenAiCompatibleEmbedderConfig,
412 OpenAiCompatibleLlm, OpenAiCompatibleLlmConfig,
413 };
414 use crate::http_client::{JsonHttpClient, JsonHttpRequest};
415 use crate::{CompletionRequest, Embedder, FierrosError, FierrosResult, Llm};
416 use serde_json::{json, Value};
417 use std::collections::VecDeque;
418 use std::sync::{Arc, Mutex};
419
420 #[derive(Debug, Clone, PartialEq)]
421 struct CapturedRequest {
422 url: String,
423 headers: Vec<(String, String)>,
424 body: Value,
425 }
426
427 #[derive(Clone, Default)]
428 struct StubHttpClient {
429 captured: Arc<Mutex<Vec<CapturedRequest>>>,
430 responses: Arc<Mutex<VecDeque<FierrosResult<Value>>>>,
431 }
432
433 impl StubHttpClient {
434 fn with_responses(responses: Vec<FierrosResult<Value>>) -> Self {
435 Self {
436 captured: Arc::new(Mutex::new(Vec::new())),
437 responses: Arc::new(Mutex::new(responses.into())),
438 }
439 }
440
441 fn captured(&self) -> Vec<CapturedRequest> {
442 self.captured.lock().expect("captured lock").clone()
443 }
444 }
445
446 #[async_trait::async_trait]
447 impl JsonHttpClient for StubHttpClient {
448 async fn post_json(&self, request: JsonHttpRequest) -> FierrosResult<Value> {
449 self.captured
450 .lock()
451 .expect("captured lock")
452 .push(CapturedRequest {
453 url: request.url,
454 headers: request.headers,
455 body: request.body,
456 });
457
458 self.responses
459 .lock()
460 .expect("responses lock")
461 .pop_front()
462 .unwrap_or_else(|| {
463 Err(FierrosError::Provider(
464 "stub client exhausted responses".into(),
465 ))
466 })
467 }
468 }
469
470 #[tokio::test]
471 async fn openai_llm_maps_completion_response_and_usage() {
472 let client = StubHttpClient::with_responses(vec![Ok(json!({
473 "choices": [{ "message": { "content": "answer text" } }],
474 "usage": { "prompt_tokens": 11, "completion_tokens": 4 }
475 }))]);
476 let llm = OpenAiCompatibleLlm::with_client(
477 OpenAiCompatibleLlmConfig {
478 base_url: "https://api.example.com/".into(),
479 model: "gpt-x".into(),
480 api_key: Some("secret".into()),
481 },
482 client.clone(),
483 );
484
485 let response = llm
486 .complete(CompletionRequest::from_user("What is new?"))
487 .await
488 .unwrap();
489 assert_eq!(response.content, "answer text");
490 assert_eq!(response.usage.unwrap().input_tokens, 11);
491
492 let captured = client.captured();
493 assert_eq!(captured.len(), 1);
494 assert_eq!(
495 captured[0].url,
496 "https://api.example.com/v1/chat/completions"
497 );
498 assert_eq!(
499 captured[0].headers,
500 vec![("Authorization".into(), "Bearer secret".into())]
501 );
502 assert_eq!(captured[0].body["model"], "gpt-x");
503 }
504
505 #[tokio::test]
506 async fn openai_llm_surfaces_provider_errors() {
507 let llm = OpenAiCompatibleLlm::with_client(
508 OpenAiCompatibleLlmConfig {
509 base_url: "https://api.example.com".into(),
510 model: "gpt-x".into(),
511 api_key: None,
512 },
513 StubHttpClient::with_responses(vec![Ok(json!({
514 "error": { "message": "invalid_api_key" }
515 }))]),
516 );
517
518 let error = llm
519 .complete(CompletionRequest::from_user("question"))
520 .await
521 .unwrap_err();
522 assert!(format!("{error}").contains("invalid_api_key"));
523 }
524
525 #[tokio::test]
526 async fn openai_embedder_maps_embedding_vectors() {
527 let client = StubHttpClient::with_responses(vec![Ok(json!({
528 "data": [
529 { "embedding": [0.1, 0.2] },
530 { "embedding": [0.3, 0.4] }
531 ]
532 }))]);
533 let embedder = OpenAiCompatibleEmbedder::with_client(
534 OpenAiCompatibleEmbedderConfig {
535 base_url: "https://api.example.com".into(),
536 model: "text-embedding-3-small".into(),
537 api_key: Some("secret".into()),
538 },
539 client.clone(),
540 );
541
542 let vectors = embedder
543 .embed(&["a".to_string(), "b".to_string()])
544 .await
545 .unwrap();
546 assert_eq!(vectors.len(), 2);
547 assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
548 assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
549
550 let captured = client.captured();
551 assert_eq!(captured[0].url, "https://api.example.com/v1/embeddings");
552 }
553
554 #[tokio::test]
555 async fn openai_embedder_detects_embedding_count_mismatch() {
556 let embedder = OpenAiCompatibleEmbedder::with_client(
557 OpenAiCompatibleEmbedderConfig {
558 base_url: "https://api.example.com".into(),
559 model: "text-embedding-3-small".into(),
560 api_key: None,
561 },
562 StubHttpClient::with_responses(vec![Ok(json!({
563 "data": [{ "embedding": [0.1, 0.2] }]
564 }))]),
565 );
566
567 let error = embedder
568 .embed(&["a".to_string(), "b".to_string()])
569 .await
570 .unwrap_err();
571 assert!(format!("{error}").contains("returned 1 embeddings for 2 inputs"));
572 }
573
574 #[tokio::test]
575 async fn ollama_llm_maps_message_and_usage() {
576 let llm = OllamaCompatibleLlm::with_client(
577 OllamaCompatibleLlmConfig {
578 base_url: "http://localhost:11434".into(),
579 model: "qwen2.5-coder".into(),
580 },
581 StubHttpClient::with_responses(vec![Ok(json!({
582 "message": { "content": "local answer" },
583 "prompt_eval_count": 6,
584 "eval_count": 3
585 }))]),
586 );
587
588 let response = llm
589 .complete(CompletionRequest::from_user("question"))
590 .await
591 .unwrap();
592 assert_eq!(response.content, "local answer");
593 assert_eq!(response.usage.unwrap().output_tokens, 3);
594 }
595
596 #[tokio::test]
597 async fn ollama_embedder_supports_embeddings_array_response() {
598 let embedder = OllamaCompatibleEmbedder::with_client(
599 OllamaCompatibleEmbedderConfig {
600 base_url: "http://localhost:11434".into(),
601 model: "nomic-embed-text".into(),
602 },
603 StubHttpClient::with_responses(vec![Ok(json!({
604 "embeddings": [[0.1, 0.2], [0.3, 0.4]]
605 }))]),
606 );
607
608 let vectors = embedder
609 .embed(&["a".to_string(), "b".to_string()])
610 .await
611 .unwrap();
612 assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
613 assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
614 }
615
616 #[tokio::test]
617 async fn ollama_embedder_supports_single_embedding_shape_for_one_input() {
618 let embedder = OllamaCompatibleEmbedder::with_client(
619 OllamaCompatibleEmbedderConfig {
620 base_url: "http://localhost:11434".into(),
621 model: "nomic-embed-text".into(),
622 },
623 StubHttpClient::with_responses(vec![Ok(json!({
624 "embedding": [0.1, 0.2, 0.3]
625 }))]),
626 );
627
628 let vectors = embedder.embed(&["a".to_string()]).await.unwrap();
629 assert_eq!(vectors, vec![vec![0.1_f32, 0.2_f32, 0.3_f32]]);
630 }
631
632 #[tokio::test]
633 async fn ollama_embedder_rejects_empty_inputs() {
634 let embedder = OllamaCompatibleEmbedder::with_client(
635 OllamaCompatibleEmbedderConfig {
636 base_url: "http://localhost:11434".into(),
637 model: "nomic-embed-text".into(),
638 },
639 StubHttpClient::with_responses(vec![]),
640 );
641
642 let error = embedder.embed(&[]).await.unwrap_err();
643 assert!(format!("{error}").contains("inputs must not be empty"));
644 }
645
646 async fn complete_with_trait(llm: &dyn Llm) -> String {
647 llm.complete(CompletionRequest::from_user("question"))
648 .await
649 .expect("llm response")
650 .content
651 }
652
653 async fn embed_with_trait(embedder: &dyn Embedder) -> Vec<Vec<f32>> {
654 embedder
655 .embed(&["a".to_string()])
656 .await
657 .expect("embedder response")
658 }
659
660 #[tokio::test]
661 async fn llm_adapters_are_interchangeable_behind_trait_object() {
662 let openai = OpenAiCompatibleLlm::with_client(
663 OpenAiCompatibleLlmConfig {
664 base_url: "https://api.example.com".into(),
665 model: "gpt-x".into(),
666 api_key: None,
667 },
668 StubHttpClient::with_responses(vec![Ok(json!({
669 "choices": [{ "message": { "content": "openai response" } }]
670 }))]),
671 );
672 let ollama = OllamaCompatibleLlm::with_client(
673 OllamaCompatibleLlmConfig {
674 base_url: "http://localhost:11434".into(),
675 model: "qwen2.5".into(),
676 },
677 StubHttpClient::with_responses(vec![Ok(json!({
678 "message": { "content": "ollama response" }
679 }))]),
680 );
681
682 assert_eq!(complete_with_trait(&openai).await, "openai response");
683 assert_eq!(complete_with_trait(&ollama).await, "ollama response");
684 }
685
686 #[tokio::test]
687 async fn embedder_adapters_are_interchangeable_behind_trait_object() {
688 let openai = OpenAiCompatibleEmbedder::with_client(
689 OpenAiCompatibleEmbedderConfig {
690 base_url: "https://api.example.com".into(),
691 model: "text-embedding-3-small".into(),
692 api_key: None,
693 },
694 StubHttpClient::with_responses(vec![Ok(json!({
695 "data": [{ "embedding": [0.4, 0.8] }]
696 }))]),
697 );
698 let ollama = OllamaCompatibleEmbedder::with_client(
699 OllamaCompatibleEmbedderConfig {
700 base_url: "http://localhost:11434".into(),
701 model: "nomic-embed-text".into(),
702 },
703 StubHttpClient::with_responses(vec![Ok(json!({
704 "embeddings": [[0.4, 0.8]]
705 }))]),
706 );
707
708 assert_eq!(embed_with_trait(&openai).await[0], vec![0.4_f32, 0.8_f32]);
709 assert_eq!(embed_with_trait(&ollama).await[0], vec![0.4_f32, 0.8_f32]);
710 }
711}