use crate::http_client::{JsonHttpClient, JsonHttpRequest, ReqwestJsonHttpClient};
use crate::{
CompletionRequest, CompletionResponse, Embedder, FierrosError, FierrosResult, Llm, MessageRole,
TokenUsage,
};
use async_trait::async_trait;
use serde_json::{json, Value};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OpenAiCompatibleLlmConfig {
pub base_url: String,
pub model: String,
pub api_key: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OpenAiCompatibleEmbedderConfig {
pub base_url: String,
pub model: String,
pub api_key: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OllamaCompatibleLlmConfig {
pub base_url: String,
pub model: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct OllamaCompatibleEmbedderConfig {
pub base_url: String,
pub model: String,
}
#[derive(Debug, Clone)]
pub struct OpenAiCompatibleLlm<C = ReqwestJsonHttpClient> {
config: OpenAiCompatibleLlmConfig,
client: C,
}
#[derive(Debug, Clone)]
pub struct OpenAiCompatibleEmbedder<C = ReqwestJsonHttpClient> {
config: OpenAiCompatibleEmbedderConfig,
client: C,
}
#[derive(Debug, Clone)]
pub struct OllamaCompatibleLlm<C = ReqwestJsonHttpClient> {
config: OllamaCompatibleLlmConfig,
client: C,
}
#[derive(Debug, Clone)]
pub struct OllamaCompatibleEmbedder<C = ReqwestJsonHttpClient> {
config: OllamaCompatibleEmbedderConfig,
client: C,
}
impl OpenAiCompatibleLlm<ReqwestJsonHttpClient> {
pub fn new(config: OpenAiCompatibleLlmConfig) -> Self {
Self::with_client(config, ReqwestJsonHttpClient::default())
}
}
impl<C> OpenAiCompatibleLlm<C> {
pub fn with_client(config: OpenAiCompatibleLlmConfig, client: C) -> Self {
Self { config, client }
}
}
impl OpenAiCompatibleEmbedder<ReqwestJsonHttpClient> {
pub fn new(config: OpenAiCompatibleEmbedderConfig) -> Self {
Self::with_client(config, ReqwestJsonHttpClient::default())
}
}
impl<C> OpenAiCompatibleEmbedder<C> {
pub fn with_client(config: OpenAiCompatibleEmbedderConfig, client: C) -> Self {
Self { config, client }
}
}
impl OllamaCompatibleLlm<ReqwestJsonHttpClient> {
pub fn new(config: OllamaCompatibleLlmConfig) -> Self {
Self::with_client(config, ReqwestJsonHttpClient::default())
}
}
impl<C> OllamaCompatibleLlm<C> {
pub fn with_client(config: OllamaCompatibleLlmConfig, client: C) -> Self {
Self { config, client }
}
}
impl OllamaCompatibleEmbedder<ReqwestJsonHttpClient> {
pub fn new(config: OllamaCompatibleEmbedderConfig) -> Self {
Self::with_client(config, ReqwestJsonHttpClient::default())
}
}
impl<C> OllamaCompatibleEmbedder<C> {
pub fn with_client(config: OllamaCompatibleEmbedderConfig, client: C) -> Self {
Self { config, client }
}
}
#[async_trait]
impl<C> Llm for OpenAiCompatibleLlm<C>
where
C: JsonHttpClient,
{
async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
let body = json!({
"model": self.config.model,
"messages": request
.messages
.into_iter()
.map(|message| {
json!({
"role": message_role_to_wire(&message.role),
"content": message.content
})
})
.collect::<Vec<_>>(),
"temperature": request.temperature,
"max_tokens": request.max_tokens,
});
let response = self
.client
.post_json(JsonHttpRequest {
url: provider_url(&self.config.base_url, "/v1/chat/completions"),
headers: bearer_auth_headers(self.config.api_key.as_deref()),
body,
})
.await?;
if let Some(error_message) = extract_provider_error(&response) {
return Err(FierrosError::Provider(error_message));
}
let content = response
.get("choices")
.and_then(Value::as_array)
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(Value::as_str)
.ok_or_else(|| FierrosError::Provider("missing 'choices[0].message.content'".into()))?
.to_string();
Ok(CompletionResponse {
content,
usage: parse_openai_usage(&response),
})
}
}
#[async_trait]
impl<C> Embedder for OpenAiCompatibleEmbedder<C>
where
C: JsonHttpClient,
{
async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
if inputs.is_empty() {
return Err(FierrosError::InvalidInput(
"embedding inputs must not be empty".into(),
));
}
let response = self
.client
.post_json(JsonHttpRequest {
url: provider_url(&self.config.base_url, "/v1/embeddings"),
headers: bearer_auth_headers(self.config.api_key.as_deref()),
body: json!({
"model": self.config.model,
"input": inputs,
}),
})
.await?;
if let Some(error_message) = extract_provider_error(&response) {
return Err(FierrosError::Provider(error_message));
}
let data = response
.get("data")
.and_then(Value::as_array)
.ok_or_else(|| {
FierrosError::Provider("missing 'data' array in embeddings response".into())
})?;
let embeddings = data
.iter()
.map(|item| {
parse_embedding_array(item.get("embedding").ok_or_else(|| {
FierrosError::Provider(
"missing 'data[*].embedding' in embeddings response".into(),
)
})?)
})
.collect::<FierrosResult<Vec<_>>>()?;
if embeddings.len() != inputs.len() {
return Err(FierrosError::Provider(format!(
"embedder returned {} embeddings for {} inputs",
embeddings.len(),
inputs.len()
)));
}
Ok(embeddings)
}
}
#[async_trait]
impl<C> Llm for OllamaCompatibleLlm<C>
where
C: JsonHttpClient,
{
async fn complete(&self, request: CompletionRequest) -> FierrosResult<CompletionResponse> {
validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
let response = self
.client
.post_json(JsonHttpRequest {
url: provider_url(&self.config.base_url, "/api/chat"),
headers: Vec::new(),
body: json!({
"model": self.config.model,
"stream": false,
"messages": request.messages.into_iter().map(|message| {
json!({
"role": message_role_to_wire(&message.role),
"content": message.content
})
}).collect::<Vec<_>>(),
"options": {
"temperature": request.temperature,
"num_predict": request.max_tokens
}
}),
})
.await?;
if let Some(error_message) = extract_provider_error(&response) {
return Err(FierrosError::Provider(error_message));
}
let content = response
.get("message")
.and_then(|message| message.get("content"))
.and_then(Value::as_str)
.ok_or_else(|| FierrosError::Provider("missing 'message.content'".into()))?
.to_string();
let usage = match (
response.get("prompt_eval_count").and_then(Value::as_u64),
response.get("eval_count").and_then(Value::as_u64),
) {
(Some(input_tokens), Some(output_tokens)) => Some(TokenUsage {
input_tokens: input_tokens as u32,
output_tokens: output_tokens as u32,
}),
_ => None,
};
Ok(CompletionResponse { content, usage })
}
}
#[async_trait]
impl<C> Embedder for OllamaCompatibleEmbedder<C>
where
C: JsonHttpClient,
{
async fn embed(&self, inputs: &[String]) -> FierrosResult<Vec<Vec<f32>>> {
validate_model_and_base_url(&self.config.model, &self.config.base_url)?;
if inputs.is_empty() {
return Err(FierrosError::InvalidInput(
"embedding inputs must not be empty".into(),
));
}
let response = self
.client
.post_json(JsonHttpRequest {
url: provider_url(&self.config.base_url, "/api/embed"),
headers: Vec::new(),
body: json!({
"model": self.config.model,
"input": inputs,
}),
})
.await?;
if let Some(error_message) = extract_provider_error(&response) {
return Err(FierrosError::Provider(error_message));
}
if let Some(embeddings) = response.get("embeddings").and_then(Value::as_array) {
let parsed = embeddings
.iter()
.map(parse_embedding_array)
.collect::<FierrosResult<Vec<_>>>()?;
if parsed.len() != inputs.len() {
return Err(FierrosError::Provider(format!(
"embedder returned {} embeddings for {} inputs",
parsed.len(),
inputs.len()
)));
}
return Ok(parsed);
}
if let Some(embedding) = response.get("embedding") {
if inputs.len() != 1 {
return Err(FierrosError::Provider(
"single 'embedding' response shape is only valid for one input".into(),
));
}
return Ok(vec![parse_embedding_array(embedding)?]);
}
Err(FierrosError::Provider(
"missing 'embeddings' or 'embedding' in Ollama response".into(),
))
}
}
fn validate_model_and_base_url(model: &str, base_url: &str) -> FierrosResult<()> {
if model.trim().is_empty() {
return Err(FierrosError::Configuration(
"provider model must not be empty".into(),
));
}
if base_url.trim().is_empty() {
return Err(FierrosError::Configuration(
"provider base URL must not be empty".into(),
));
}
Ok(())
}
fn provider_url(base_url: &str, path: &str) -> String {
format!("{}{}", base_url.trim_end_matches('/'), path)
}
fn bearer_auth_headers(api_key: Option<&str>) -> Vec<(String, String)> {
match api_key.filter(|value| !value.trim().is_empty()) {
Some(value) => vec![("Authorization".into(), format!("Bearer {value}"))],
None => Vec::new(),
}
}
fn message_role_to_wire(role: &MessageRole) -> &'static str {
match role {
MessageRole::System => "system",
MessageRole::User => "user",
MessageRole::Assistant => "assistant",
MessageRole::Tool => "tool",
}
}
fn parse_openai_usage(response: &Value) -> Option<TokenUsage> {
let usage = response.get("usage")?;
let input_tokens = usage.get("prompt_tokens")?.as_u64()?;
let output_tokens = usage.get("completion_tokens")?.as_u64()?;
Some(TokenUsage {
input_tokens: input_tokens as u32,
output_tokens: output_tokens as u32,
})
}
fn parse_embedding_array(value: &Value) -> FierrosResult<Vec<f32>> {
let values = value
.as_array()
.ok_or_else(|| FierrosError::Provider("embedding field must be an array".into()))?;
values
.iter()
.map(|item| {
item.as_f64().map(|number| number as f32).ok_or_else(|| {
FierrosError::Provider("embedding vector must contain numeric values".into())
})
})
.collect()
}
fn extract_provider_error(response: &Value) -> Option<String> {
response
.get("error")
.and_then(|error| {
error
.get("message")
.and_then(Value::as_str)
.or_else(|| error.as_str())
})
.map(std::string::ToString::to_string)
}
#[cfg(test)]
mod tests {
use super::{
OllamaCompatibleEmbedder, OllamaCompatibleEmbedderConfig, OllamaCompatibleLlm,
OllamaCompatibleLlmConfig, OpenAiCompatibleEmbedder, OpenAiCompatibleEmbedderConfig,
OpenAiCompatibleLlm, OpenAiCompatibleLlmConfig,
};
use crate::http_client::{JsonHttpClient, JsonHttpRequest};
use crate::{CompletionRequest, Embedder, FierrosError, FierrosResult, Llm};
use serde_json::{json, Value};
use std::collections::VecDeque;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, PartialEq)]
struct CapturedRequest {
url: String,
headers: Vec<(String, String)>,
body: Value,
}
#[derive(Clone, Default)]
struct StubHttpClient {
captured: Arc<Mutex<Vec<CapturedRequest>>>,
responses: Arc<Mutex<VecDeque<FierrosResult<Value>>>>,
}
impl StubHttpClient {
fn with_responses(responses: Vec<FierrosResult<Value>>) -> Self {
Self {
captured: Arc::new(Mutex::new(Vec::new())),
responses: Arc::new(Mutex::new(responses.into())),
}
}
fn captured(&self) -> Vec<CapturedRequest> {
self.captured.lock().expect("captured lock").clone()
}
}
#[async_trait::async_trait]
impl JsonHttpClient for StubHttpClient {
async fn post_json(&self, request: JsonHttpRequest) -> FierrosResult<Value> {
self.captured
.lock()
.expect("captured lock")
.push(CapturedRequest {
url: request.url,
headers: request.headers,
body: request.body,
});
self.responses
.lock()
.expect("responses lock")
.pop_front()
.unwrap_or_else(|| {
Err(FierrosError::Provider(
"stub client exhausted responses".into(),
))
})
}
}
#[tokio::test]
async fn openai_llm_maps_completion_response_and_usage() {
let client = StubHttpClient::with_responses(vec![Ok(json!({
"choices": [{ "message": { "content": "answer text" } }],
"usage": { "prompt_tokens": 11, "completion_tokens": 4 }
}))]);
let llm = OpenAiCompatibleLlm::with_client(
OpenAiCompatibleLlmConfig {
base_url: "https://api.example.com/".into(),
model: "gpt-x".into(),
api_key: Some("secret".into()),
},
client.clone(),
);
let response = llm
.complete(CompletionRequest::from_user("What is new?"))
.await
.unwrap();
assert_eq!(response.content, "answer text");
assert_eq!(response.usage.unwrap().input_tokens, 11);
let captured = client.captured();
assert_eq!(captured.len(), 1);
assert_eq!(
captured[0].url,
"https://api.example.com/v1/chat/completions"
);
assert_eq!(
captured[0].headers,
vec![("Authorization".into(), "Bearer secret".into())]
);
assert_eq!(captured[0].body["model"], "gpt-x");
}
#[tokio::test]
async fn openai_llm_surfaces_provider_errors() {
let llm = OpenAiCompatibleLlm::with_client(
OpenAiCompatibleLlmConfig {
base_url: "https://api.example.com".into(),
model: "gpt-x".into(),
api_key: None,
},
StubHttpClient::with_responses(vec![Ok(json!({
"error": { "message": "invalid_api_key" }
}))]),
);
let error = llm
.complete(CompletionRequest::from_user("question"))
.await
.unwrap_err();
assert!(format!("{error}").contains("invalid_api_key"));
}
#[tokio::test]
async fn openai_embedder_maps_embedding_vectors() {
let client = StubHttpClient::with_responses(vec![Ok(json!({
"data": [
{ "embedding": [0.1, 0.2] },
{ "embedding": [0.3, 0.4] }
]
}))]);
let embedder = OpenAiCompatibleEmbedder::with_client(
OpenAiCompatibleEmbedderConfig {
base_url: "https://api.example.com".into(),
model: "text-embedding-3-small".into(),
api_key: Some("secret".into()),
},
client.clone(),
);
let vectors = embedder
.embed(&["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(vectors.len(), 2);
assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
let captured = client.captured();
assert_eq!(captured[0].url, "https://api.example.com/v1/embeddings");
}
#[tokio::test]
async fn openai_embedder_detects_embedding_count_mismatch() {
let embedder = OpenAiCompatibleEmbedder::with_client(
OpenAiCompatibleEmbedderConfig {
base_url: "https://api.example.com".into(),
model: "text-embedding-3-small".into(),
api_key: None,
},
StubHttpClient::with_responses(vec![Ok(json!({
"data": [{ "embedding": [0.1, 0.2] }]
}))]),
);
let error = embedder
.embed(&["a".to_string(), "b".to_string()])
.await
.unwrap_err();
assert!(format!("{error}").contains("returned 1 embeddings for 2 inputs"));
}
#[tokio::test]
async fn ollama_llm_maps_message_and_usage() {
let llm = OllamaCompatibleLlm::with_client(
OllamaCompatibleLlmConfig {
base_url: "http://localhost:11434".into(),
model: "qwen2.5-coder".into(),
},
StubHttpClient::with_responses(vec![Ok(json!({
"message": { "content": "local answer" },
"prompt_eval_count": 6,
"eval_count": 3
}))]),
);
let response = llm
.complete(CompletionRequest::from_user("question"))
.await
.unwrap();
assert_eq!(response.content, "local answer");
assert_eq!(response.usage.unwrap().output_tokens, 3);
}
#[tokio::test]
async fn ollama_embedder_supports_embeddings_array_response() {
let embedder = OllamaCompatibleEmbedder::with_client(
OllamaCompatibleEmbedderConfig {
base_url: "http://localhost:11434".into(),
model: "nomic-embed-text".into(),
},
StubHttpClient::with_responses(vec![Ok(json!({
"embeddings": [[0.1, 0.2], [0.3, 0.4]]
}))]),
);
let vectors = embedder
.embed(&["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(vectors[0], vec![0.1_f32, 0.2_f32]);
assert_eq!(vectors[1], vec![0.3_f32, 0.4_f32]);
}
#[tokio::test]
async fn ollama_embedder_supports_single_embedding_shape_for_one_input() {
let embedder = OllamaCompatibleEmbedder::with_client(
OllamaCompatibleEmbedderConfig {
base_url: "http://localhost:11434".into(),
model: "nomic-embed-text".into(),
},
StubHttpClient::with_responses(vec![Ok(json!({
"embedding": [0.1, 0.2, 0.3]
}))]),
);
let vectors = embedder.embed(&["a".to_string()]).await.unwrap();
assert_eq!(vectors, vec![vec![0.1_f32, 0.2_f32, 0.3_f32]]);
}
#[tokio::test]
async fn ollama_embedder_rejects_empty_inputs() {
let embedder = OllamaCompatibleEmbedder::with_client(
OllamaCompatibleEmbedderConfig {
base_url: "http://localhost:11434".into(),
model: "nomic-embed-text".into(),
},
StubHttpClient::with_responses(vec![]),
);
let error = embedder.embed(&[]).await.unwrap_err();
assert!(format!("{error}").contains("inputs must not be empty"));
}
async fn complete_with_trait(llm: &dyn Llm) -> String {
llm.complete(CompletionRequest::from_user("question"))
.await
.expect("llm response")
.content
}
async fn embed_with_trait(embedder: &dyn Embedder) -> Vec<Vec<f32>> {
embedder
.embed(&["a".to_string()])
.await
.expect("embedder response")
}
#[tokio::test]
async fn llm_adapters_are_interchangeable_behind_trait_object() {
let openai = OpenAiCompatibleLlm::with_client(
OpenAiCompatibleLlmConfig {
base_url: "https://api.example.com".into(),
model: "gpt-x".into(),
api_key: None,
},
StubHttpClient::with_responses(vec![Ok(json!({
"choices": [{ "message": { "content": "openai response" } }]
}))]),
);
let ollama = OllamaCompatibleLlm::with_client(
OllamaCompatibleLlmConfig {
base_url: "http://localhost:11434".into(),
model: "qwen2.5".into(),
},
StubHttpClient::with_responses(vec![Ok(json!({
"message": { "content": "ollama response" }
}))]),
);
assert_eq!(complete_with_trait(&openai).await, "openai response");
assert_eq!(complete_with_trait(&ollama).await, "ollama response");
}
#[tokio::test]
async fn embedder_adapters_are_interchangeable_behind_trait_object() {
let openai = OpenAiCompatibleEmbedder::with_client(
OpenAiCompatibleEmbedderConfig {
base_url: "https://api.example.com".into(),
model: "text-embedding-3-small".into(),
api_key: None,
},
StubHttpClient::with_responses(vec![Ok(json!({
"data": [{ "embedding": [0.4, 0.8] }]
}))]),
);
let ollama = OllamaCompatibleEmbedder::with_client(
OllamaCompatibleEmbedderConfig {
base_url: "http://localhost:11434".into(),
model: "nomic-embed-text".into(),
},
StubHttpClient::with_responses(vec![Ok(json!({
"embeddings": [[0.4, 0.8]]
}))]),
);
assert_eq!(embed_with_trait(&openai).await[0], vec![0.4_f32, 0.8_f32]);
assert_eq!(embed_with_trait(&ollama).await[0], vec![0.4_f32, 0.8_f32]);
}
}